NPRC24 / IIR-Lab /final_test.py
Artyom
IIRLab
6721043 verified
raw
history blame
5.67 kB
import os
import time
import argparse
import torch
import torch.backends.cudnn as cudnn
from utils_ours.util import setup_logger, print_args
from torch.utils.data import DataLoader
from dataloader.dataset import imageSet
from models.archs.NAF_arch import NAF_Video
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict
import torch.nn as nn
from models.utils import chunkV3
import pdb
from ISP_pipeline import process_pngs_isp
import os
import json
import cv2
from skimage import io
ISO = [50,125,320,640,800]
a = [0.00025822882,0.000580020745,0.00141667975,0.00278965863,0.00347614807]
b = [2.32350645e-06,3.1125155625e-06,8.328992952e-06,3.3315971808e-05,5.205620595e-05]
#拟合
coeff_a = np.polyfit(ISO,a,1)
coeff_b = np.polyfit(ISO,b,2)
def main():
parser = argparse.ArgumentParser(description='imageTest')
parser.add_argument('--frame', default=1, type=int)
parser.add_argument('--test_dir', default = "/data/", type=str)
parser.add_argument('--model_type', type=str, default='NAF_Video')
parser.add_argument('--save_folder', default='/data/', type=str)
parser.add_argument('--resume', default='', type=str)
parser.add_argument('--testoption', default='image', type=str)
parser.add_argument('--chunk', action='store_true')
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
args.src_save_folder = '/data/'
print(args.src_save_folder,'**********************')
if not os.path.exists(args.src_save_folder):
os.makedirs(args.src_save_folder)
print(args.src_save_folder)
low_iso_model = "denoise_model/low_iso.pth"
mid_iso_model = "denoise_model/mid_iso.pth"
high_mid_iso_model = "denoise_model/high_mid_iso.pth"
high_iso_model = "denoise_model/high_iso.pth"
network = NAF_Video(args).cuda()
load_low_iso_net = torch.load(low_iso_model, map_location=torch.device('cuda'))
load_low_iso_net_clean = OrderedDict()
for k, v in load_low_iso_net.items():
if k.startswith('module.'):
load_low_iso_net_clean[k[7:]] = v
else:
load_low_iso_net_clean[k] = v
load_mid_iso_net = torch.load(mid_iso_model, map_location=torch.device('cpu'))
load_mid_iso_net_clean = OrderedDict()
for k, v in load_mid_iso_net.items():
if k.startswith('module.'):
load_mid_iso_net_clean[k[7:]] = v
else:
load_mid_iso_net_clean[k] = v
load_high_mid_iso_net = torch.load(high_mid_iso_model, map_location=torch.device('cpu'))
load_high_mid_iso_net_clean = OrderedDict()
for k, v in load_high_mid_iso_net.items():
if k.startswith('module.'):
load_high_mid_iso_net_clean[k[7:]] = v
else:
load_high_mid_iso_net_clean[k] = v
load_high_iso_net_clean = torch.load(high_iso_model, map_location=torch.device('cpu'))
cudnn.benchmark = True
test_dataset = imageSet(args)
test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False)
inference_time = []
with torch.no_grad():
for data in test_dataloader:
noise = data['input'].cuda()
json_path = data['json_path'][0]
scene_name = os.path.splitext(os.path.basename(json_path))[0]
# now let's process isp moudle
json_cfa = process_pngs_isp.readjson(json_path)
num_k = json_cfa['noise_profile']
iso = (num_k[0] - coeff_a[1])/coeff_a[0]
if iso < 900:
network.load_state_dict(load_low_iso_net_clean, strict=True)
network.eval()
elif iso < 1800:
network.load_state_dict(load_mid_iso_net_clean, strict=True)
network.eval()
elif iso < 5600:
network.load_state_dict(load_high_mid_iso_net_clean, strict=True)
network.eval()
else:
network.load_state_dict(load_high_iso_net_clean, strict=True)
network.eval()
t0 = time.perf_counter()
out = chunkV3(network, noise, args.testoption, patch_h=1024, patch_w=1024)
out = torch.clamp(out, 0., 1.)
# name_rgb = os.path.join(args.src_save_folder, scene_name + '_' + str(int(iso)) + '.jpg')
name_rgb = os.path.join(args.src_save_folder, scene_name + '.jpg')
if not os.path.exists(os.path.dirname(name_rgb)):
os.makedirs(os.path.dirname(name_rgb))
out = out[0]
del noise
torch.cuda.empty_cache()
img_pro = process_pngs_isp.isp_night_imaging(out, json_cfa, iso,
do_demosaic = True, # H/2 W/2
do_channel_gain_white_balance = True,
do_xyz_transform = True,
do_srgb_transform = True,
do_gamma_correct = True, # con
do_refinement = True, # 32 bit
do_to_uint8 = True,
do_resize_using_pil = True, # H/8, W/8
do_fix_orientation = True
)
t1 = time.perf_counter()
inference_time.append(t1-t0)
img_pro = cv2.cvtColor(img_pro, cv2.COLOR_RGB2BGR)
cv2.imwrite(name_rgb, img_pro, [cv2.IMWRITE_PNG_COMPRESSION, 0])
print("Inference {} in {:.3f}s".format(scene_name, t1 - t0))
print(f"Average inference time: {np.mean(inference_time)} seconds")
if __name__ == '__main__':
main()