File size: 5,672 Bytes
6721043 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import time
import argparse
import torch
import torch.backends.cudnn as cudnn
from utils_ours.util import setup_logger, print_args
from torch.utils.data import DataLoader
from dataloader.dataset import imageSet
from models.archs.NAF_arch import NAF_Video
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import torch.nn.functional as F
from collections import OrderedDict
import torch.nn as nn
from models.utils import chunkV3
import pdb
from ISP_pipeline import process_pngs_isp
import os
import json
import cv2
from skimage import io
ISO = [50,125,320,640,800]
a = [0.00025822882,0.000580020745,0.00141667975,0.00278965863,0.00347614807]
b = [2.32350645e-06,3.1125155625e-06,8.328992952e-06,3.3315971808e-05,5.205620595e-05]
#拟合
coeff_a = np.polyfit(ISO,a,1)
coeff_b = np.polyfit(ISO,b,2)
def main():
parser = argparse.ArgumentParser(description='imageTest')
parser.add_argument('--frame', default=1, type=int)
parser.add_argument('--test_dir', default = "/data/", type=str)
parser.add_argument('--model_type', type=str, default='NAF_Video')
parser.add_argument('--save_folder', default='/data/', type=str)
parser.add_argument('--resume', default='', type=str)
parser.add_argument('--testoption', default='image', type=str)
parser.add_argument('--chunk', action='store_true')
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
args.src_save_folder = '/data/'
print(args.src_save_folder,'**********************')
if not os.path.exists(args.src_save_folder):
os.makedirs(args.src_save_folder)
print(args.src_save_folder)
low_iso_model = "denoise_model/low_iso.pth"
mid_iso_model = "denoise_model/mid_iso.pth"
high_mid_iso_model = "denoise_model/high_mid_iso.pth"
high_iso_model = "denoise_model/high_iso.pth"
network = NAF_Video(args).cuda()
load_low_iso_net = torch.load(low_iso_model, map_location=torch.device('cuda'))
load_low_iso_net_clean = OrderedDict()
for k, v in load_low_iso_net.items():
if k.startswith('module.'):
load_low_iso_net_clean[k[7:]] = v
else:
load_low_iso_net_clean[k] = v
load_mid_iso_net = torch.load(mid_iso_model, map_location=torch.device('cpu'))
load_mid_iso_net_clean = OrderedDict()
for k, v in load_mid_iso_net.items():
if k.startswith('module.'):
load_mid_iso_net_clean[k[7:]] = v
else:
load_mid_iso_net_clean[k] = v
load_high_mid_iso_net = torch.load(high_mid_iso_model, map_location=torch.device('cpu'))
load_high_mid_iso_net_clean = OrderedDict()
for k, v in load_high_mid_iso_net.items():
if k.startswith('module.'):
load_high_mid_iso_net_clean[k[7:]] = v
else:
load_high_mid_iso_net_clean[k] = v
load_high_iso_net_clean = torch.load(high_iso_model, map_location=torch.device('cpu'))
cudnn.benchmark = True
test_dataset = imageSet(args)
test_dataloader = DataLoader(test_dataset, batch_size=1, num_workers=0, shuffle=False)
inference_time = []
with torch.no_grad():
for data in test_dataloader:
noise = data['input'].cuda()
json_path = data['json_path'][0]
scene_name = os.path.splitext(os.path.basename(json_path))[0]
# now let's process isp moudle
json_cfa = process_pngs_isp.readjson(json_path)
num_k = json_cfa['noise_profile']
iso = (num_k[0] - coeff_a[1])/coeff_a[0]
if iso < 900:
network.load_state_dict(load_low_iso_net_clean, strict=True)
network.eval()
elif iso < 1800:
network.load_state_dict(load_mid_iso_net_clean, strict=True)
network.eval()
elif iso < 5600:
network.load_state_dict(load_high_mid_iso_net_clean, strict=True)
network.eval()
else:
network.load_state_dict(load_high_iso_net_clean, strict=True)
network.eval()
t0 = time.perf_counter()
out = chunkV3(network, noise, args.testoption, patch_h=1024, patch_w=1024)
out = torch.clamp(out, 0., 1.)
# name_rgb = os.path.join(args.src_save_folder, scene_name + '_' + str(int(iso)) + '.jpg')
name_rgb = os.path.join(args.src_save_folder, scene_name + '.jpg')
if not os.path.exists(os.path.dirname(name_rgb)):
os.makedirs(os.path.dirname(name_rgb))
out = out[0]
del noise
torch.cuda.empty_cache()
img_pro = process_pngs_isp.isp_night_imaging(out, json_cfa, iso,
do_demosaic = True, # H/2 W/2
do_channel_gain_white_balance = True,
do_xyz_transform = True,
do_srgb_transform = True,
do_gamma_correct = True, # con
do_refinement = True, # 32 bit
do_to_uint8 = True,
do_resize_using_pil = True, # H/8, W/8
do_fix_orientation = True
)
t1 = time.perf_counter()
inference_time.append(t1-t0)
img_pro = cv2.cvtColor(img_pro, cv2.COLOR_RGB2BGR)
cv2.imwrite(name_rgb, img_pro, [cv2.IMWRITE_PNG_COMPRESSION, 0])
print("Inference {} in {:.3f}s".format(scene_name, t1 - t0))
print(f"Average inference time: {np.mean(inference_time)} seconds")
if __name__ == '__main__':
main() |