import torch import argparse import torch.nn as nn from torch.utils.data import DataLoader from torchvision.utils import save_image as imwrite import os import time import re from torchvision import transforms from test_dataset_for_testing import dehaze_test_dataset from model_convnext2_hdr import fusion_net import glob import scipy.io import torch.optim as optim import cv2 import matplotlib.image from PIL import Image import random import math import numpy as np import sys import json os.environ["CUDA_VISIBLE_DEVICES"] = "0" #run python test_05_hdr.py ./data/ ./result/ ./daylight_isp_03/ 1 2 4 input_dir2 = '../data/' input_dir = '../1/mid/' result_dir = '../data/' checkpoint_dir = './result_low_light_hdr/' # get train IDs train_fns = glob.glob(input_dir + '*_1.png') train_ids = [os.path.basename(train_fn) for train_fn in train_fns] if not os.path.exists(result_dir): os.mkdir(result_dir) def json_read(fname, **kwargs): with open(fname) as j: data = json.load(j, **kwargs) return data def fraction_from_json(json_object): if 'Fraction' in json_object: return Fraction(*json_object['Fraction']) return json_object def fractions2floats(fractions): floats = [] for fraction in fractions: floats.append(float(fraction.numerator) / fraction.denominator) return floats def reprocessing(input): output = np.zeros(input.shape) input_1 = input output[:,:,0] = input_1[:,:,0] * 1.9021 - input_1[:,:,1] * 1.1651 + input_1[:,:,2] * 0.2630 output[:,:,1] = input_1[:,:,0] * (-0.3189) + input_1[:,:,1] * 1.5831 - input_1[:,:,2] * 0.2643 output[:,:,2] = input_1[:,:,0] * (-0.0662) - input_1[:,:,1] * 0.9350 + input_1[:,:,2] * 2.0013 result = np.clip(output, 0, 255).astype(np.uint8) return output def reprocessing1(input): output = np.zeros(input.shape) input_1 = input output[:,:,0] = input_1[:,:,0] * 1.521689 - input_1[:,:,1] * 0.673763 + input_1[:,:,2] * 0.152074 output[:,:,1] = input_1[:,:,0] * (-0.145724) + input_1[:,:,1] * 1.266507 - input_1[:,:,2] * 0.120783 output[:,:,2] = input_1[:,:,0] * (-0.0397583) - input_1[:,:,1] * 0.561249 + input_1[:,:,2] * 1.60100734 result = np.clip(output, 0, 255).astype(np.uint8) return output # --- Gpu device --- # device = torch.device("cuda:0") # --- Define the network --- # model_g = fusion_net() model_g = nn.DataParallel(model_g) MyEnsembleNet = model_g.to(device) MyEnsembleNet.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'checkpoint_gen.pth'))) # --- Test --- # with torch.no_grad(): MyEnsembleNet.eval() for ind in range(len(train_ids)): print(ind) train_id = train_ids[ind] in_path_in = input_dir + train_id[:-5] in_path_in_js = input_dir2 + train_id[:-5] metadata = json_read(in_path_in_js[:-1] + '.json', object_hook=fraction_from_json) noise_profile = float(metadata['noise_profile'][0]) pic_in1 = np.asarray(Image.open(in_path_in + '1.png'), np.float32) / 255. pic_in2 = np.asarray(Image.open(in_path_in + '2.png'), np.float32) / 255. pic_in3 = np.asarray(Image.open(in_path_in + '3.png'), np.float32) / 255. pic_in = np.concatenate([pic_in1, pic_in2, pic_in3],axis=2) #pic_in = cv2.resize(pic_in, None, fx = 0.5, fy = 0.5, interpolation=cv2.INTER_CUBIC ) [h,w,c] = pic_in.shape pad_h = 32 - h % 32 pad_w = 32 - w % 32 pic_in = np.expand_dims(np.pad(pic_in, ((0, pad_h), (0, pad_w),(0,0)), mode='reflect'),axis = 0) in_data = torch.from_numpy(pic_in).permute(0,3,1,2).to(device) out_data = MyEnsembleNet(in_data) out_datass = out_data.cpu().detach().numpy().transpose((0, 2, 3, 1)) output = np.clip(out_datass[0,:,:,:], 0, 1) if noise_profile < 0.02: output = reprocessing(output) else: output = reprocessing1(output) #cv2.imwrite(result_dir + train_id[:-6] + '.png', output[0:h,0:w,::-1] * 255) cv2.imwrite(result_dir + train_id[:-6] + '.jpg', output[0:h,0:w,::-1] * 255, [cv2.IMWRITE_JPEG_QUALITY, 100])