|
import os |
|
import time |
|
import argparse |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
from utils_ours.util import setup_logger, print_args |
|
from torch.utils.data import DataLoader |
|
from dataloader.dataset import imageSet |
|
from models.archs.NAF_arch import NAF_Video |
|
from torch.nn.parallel import DistributedDataParallel |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from collections import OrderedDict |
|
import torch.nn as nn |
|
from models.utils import chunkV3 |
|
import pdb |
|
from ISP_pipeline import process_pngs_isp |
|
import os |
|
import json |
|
import cv2 |
|
from skimage import io |
|
|
|
ISO = [50,125,320,640,800] |
|
a = [0.00025822882,0.000580020745,0.00141667975,0.00278965863,0.00347614807] |
|
b = [2.32350645e-06,3.1125155625e-06,8.328992952e-06,3.3315971808e-05,5.205620595e-05] |
|
|
|
|
|
coeff_a = np.polyfit(ISO,a,1) |
|
coeff_b = np.polyfit(ISO,b,2) |
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser(description='imageTest') |
|
|
|
parser.add_argument('--frame', default=1, type=int) |
|
parser.add_argument('--test_dir', default = "/data/", type=str) |
|
parser.add_argument('--model_type', type=str, default='NAF_Video') |
|
parser.add_argument('--save_folder', default='/data/', type=str) |
|
parser.add_argument('--resume', default='', type=str) |
|
parser.add_argument('--testoption', default='image', type=str) |
|
parser.add_argument('--chunk', action='store_true') |
|
parser.add_argument('--debug', action='store_true') |
|
|
|
args = parser.parse_args() |
|
args.src_save_folder = '/data/' |
|
|
|
print(args.src_save_folder,'**********************') |
|
if not os.path.exists(args.src_save_folder): |
|
os.makedirs(args.src_save_folder) |
|
print(args.src_save_folder) |
|
|
|
low_iso_model = "denoise_model/low_iso.pth" |
|
mid_iso_model = "denoise_model/mid_iso.pth" |
|
high_mid_iso_model = "denoise_model/high_mid_iso.pth" |
|
high_iso_model = "denoise_model/high_iso.pth" |
|
|
|
network = NAF_Video(args).cuda() |
|
|
|
load_low_iso_net = torch.load(low_iso_model, map_location=torch.device('cuda')) |
|
load_low_iso_net_clean = OrderedDict() |
|
for k, v in load_low_iso_net.items(): |
|
if k.startswith('module.'): |
|
load_low_iso_net_clean[k[7:]] = v |
|
else: |
|
load_low_iso_net_clean[k] = v |
|
|
|
load_mid_iso_net = torch.load(mid_iso_model, map_location=torch.device('cpu')) |
|
load_mid_iso_net_clean = OrderedDict() |
|
for k, v in load_mid_iso_net.items(): |
|
if k.startswith('module.'): |
|
load_mid_iso_net_clean[k[7:]] = v |
|
else: |
|
load_mid_iso_net_clean[k] = v |
|
|
|
load_high_mid_iso_net = torch.load(high_mid_iso_model, map_location=torch.device('cpu')) |
|
load_high_mid_iso_net_clean = OrderedDict() |
|
for k, v in load_high_mid_iso_net.items(): |
|
if k.startswith('module.'): |
|
load_high_mid_iso_net_clean[k[7:]] = v |
|
else: |
|
load_high_mid_iso_net_clean[k] = v |
|
|
|
load_high_iso_net_clean = torch.load(high_iso_model, map_location=torch.device('cpu')) |
|
|
|
cudnn.benchmark = True |
|
|
|
test_dataset = imageSet(args) |
|
test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False) |
|
inference_time = [] |
|
with torch.no_grad(): |
|
|
|
for data in test_dataloader: |
|
|
|
noise = data['input'].cuda() |
|
json_path = data['json_path'][0] |
|
scene_name = os.path.splitext(os.path.basename(json_path))[0] |
|
|
|
|
|
json_cfa = process_pngs_isp.readjson(json_path) |
|
num_k = json_cfa['noise_profile'] |
|
iso = (num_k[0] - coeff_a[1])/coeff_a[0] |
|
|
|
if iso < 900: |
|
network.load_state_dict(load_low_iso_net_clean, strict=True) |
|
network.eval() |
|
elif iso < 1800: |
|
network.load_state_dict(load_mid_iso_net_clean, strict=True) |
|
network.eval() |
|
elif iso < 5600: |
|
network.load_state_dict(load_high_mid_iso_net_clean, strict=True) |
|
network.eval() |
|
else: |
|
network.load_state_dict(load_high_iso_net_clean, strict=True) |
|
network.eval() |
|
|
|
t0 = time.perf_counter() |
|
|
|
out = chunkV3(network, noise, args.testoption, patch_h=1024, patch_w=1024) |
|
out = torch.clamp(out, 0., 1.) |
|
|
|
|
|
name_rgb = os.path.join(args.src_save_folder, scene_name + '.jpg') |
|
|
|
if not os.path.exists(os.path.dirname(name_rgb)): |
|
os.makedirs(os.path.dirname(name_rgb)) |
|
|
|
out = out[0] |
|
del noise |
|
torch.cuda.empty_cache() |
|
|
|
img_pro = process_pngs_isp.isp_night_imaging(out, json_cfa, iso, |
|
do_demosaic = True, |
|
|
|
do_channel_gain_white_balance = True, |
|
do_xyz_transform = True, |
|
do_srgb_transform = True, |
|
|
|
do_gamma_correct = True, |
|
|
|
do_refinement = True, |
|
do_to_uint8 = True, |
|
|
|
do_resize_using_pil = True, |
|
do_fix_orientation = True |
|
) |
|
|
|
t1 = time.perf_counter() |
|
inference_time.append(t1-t0) |
|
img_pro = cv2.cvtColor(img_pro, cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(name_rgb, img_pro, [cv2.IMWRITE_PNG_COMPRESSION, 0]) |
|
|
|
print("Inference {} in {:.3f}s".format(scene_name, t1 - t0)) |
|
print(f"Average inference time: {np.mean(inference_time)} seconds") |
|
|
|
if __name__ == '__main__': |
|
main() |