|
import os |
|
import cv2 |
|
import json |
|
import torch |
|
import torchvision.transforms as transforms |
|
from CPNet_model import LiteAWBISPNet |
|
import torchvision |
|
import numpy as np |
|
from Utiles import white_balance,apply_color_space_transform, transform_xyz_to_srgb, apply_gamma,fix_orientation,binning,Four2One,One2Four |
|
import time |
|
from net.mwrcanet import Net |
|
import torch.nn as nn |
|
from PIL import Image |
|
import torch.nn.functional as F |
|
|
|
|
|
Rpath = './Input' |
|
image_files = [] |
|
|
|
|
|
|
|
|
|
infer_times = [] |
|
|
|
|
|
|
|
color_matrix = [1.06835938, -0.29882812, -0.14257812, |
|
-0.43164062, 1.35546875, 0.05078125, |
|
-0.1015625, 0.24414062, 0.5859375] |
|
|
|
|
|
|
|
transforms_ = [ transforms.ToTensor(), |
|
transforms.Resize([768,1024])] |
|
transform = transforms.Compose(transforms_) |
|
|
|
transforms_ = [ transforms.ToTensor()] |
|
transformo = transforms.Compose(transforms_) |
|
|
|
|
|
model = LiteAWBISPNet() |
|
model.cuda() |
|
model.load_state_dict(torch.load('./model_zoo/CC2.pth') ) |
|
|
|
|
|
last_ckpt = './model_zoo/dn_mwrcanet_raw_c1.pth' |
|
dn_net = Net() |
|
dn_model = nn.DataParallel(dn_net).cuda() |
|
tmp_ckpt = torch.load(last_ckpt) |
|
pretrained_dict = tmp_ckpt['state_dict'] |
|
model_dict=dn_model.state_dict() |
|
pretrained_dict_update = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
|
assert(len(pretrained_dict)==len(pretrained_dict_update)) |
|
assert(len(pretrained_dict_update)==len(model_dict)) |
|
model_dict.update(pretrained_dict_update) |
|
dn_model.load_state_dict(model_dict) |
|
|
|
|
|
|
|
for filename in os.listdir(Rpath): |
|
|
|
if os.path.splitext(filename)[-1].lower() == ".png": |
|
image_files.append(filename) |
|
|
|
with torch.no_grad(): |
|
for fp in image_files: |
|
|
|
fp = os.path.join(Rpath, fp) |
|
mn = os.path.splitext(fp)[-2] |
|
mf = str(mn) + '.json' |
|
|
|
raw_image = cv2.imread(fp, -1) |
|
with open(mf, 'r') as file: |
|
data = json.load(file) |
|
|
|
|
|
time_BL_S = time.time() |
|
|
|
raw_image = (raw_image.astype(np.float32) - 256.) |
|
raw_image = raw_image / (4095. - 256.) |
|
raw_image = np.clip(raw_image, 0.0, 1.0) |
|
|
|
|
|
|
|
|
|
|
|
raw_image = binning(raw_image,data) |
|
|
|
|
|
|
|
|
|
|
|
raw_image = cv2.resize(raw_image, [1024,768]) |
|
|
|
|
|
|
|
|
|
Temp_I = Four2One(raw_image) |
|
Temp_I = transformo(Temp_I).unsqueeze(0).cuda() |
|
Temp_I = dn_model(Temp_I) |
|
Temp_I = np.asarray(Temp_I.squeeze(0).squeeze(0).cpu()) |
|
raw_image = One2Four(Temp_I) |
|
|
|
|
|
|
|
|
|
raw_image = white_balance(raw_image, data['as_shot_neutral']) |
|
raw_image = apply_color_space_transform(raw_image, color_matrix) |
|
raw_image = transform_xyz_to_srgb(raw_image) |
|
raw_image = apply_gamma(raw_image) |
|
|
|
|
|
|
|
|
|
Source = transform(raw_image).unsqueeze(0).float().cuda() |
|
Out = model(Source) |
|
|
|
|
|
|
|
Out = Out.clip(0,1) |
|
OA = np.asarray(Out.squeeze(0).cpu()).transpose(1,2,0).astype(np.float32) |
|
OA = OA*255. |
|
OA = OA.astype(np.uint8) |
|
OA = fix_orientation(OA,data["orientation"]) |
|
time_Save_F = time.time() |
|
OA = cv2.cvtColor(OA, cv2.COLOR_RGB2BGR) |
|
OA = cv2.imwrite('./Output/' + str(os.path.basename(fp)),OA) |
|
|
|
infer_times.append(time_Save_F-time_BL_S) |
|
print(f"Average inference time: {np.mean(infer_times)} seconds") |
|
|