import os |
import cv2 |
import json |
import torch |
import torchvision.transforms as transforms |
from CPNet_model import LiteAWBISPNet |
import torchvision |
import numpy as np |
from Utiles import white_balance,apply_color_space_transform, transform_xyz_to_srgb, apply_gamma,fix_orientation,binning,Four2One,One2Four |
import time |
from net.mwrcanet import Net |
import torch.nn as nn |
from PIL import Image |
import torch.nn.functional as F |
Rpath = './Input' |
image_files = [] |
infer_times = [] |
color_matrix = [1.06835938, -0.29882812, -0.14257812, |
-0.43164062, 1.35546875, 0.05078125, |
-0.1015625, 0.24414062, 0.5859375] |
transforms_ = [ transforms.ToTensor(), |
transforms.Resize([768,1024])] |
transform = transforms.Compose(transforms_) |
transforms_ = [ transforms.ToTensor()] |
transformo = transforms.Compose(transforms_) |
model = LiteAWBISPNet() |
model.cuda() |
model.load_state_dict(torch.load('./model_zoo/CC2.pth') ) |
last_ckpt = './model_zoo/dn_mwrcanet_raw_c1.pth' |
dn_net = Net() |
dn_model = nn.DataParallel(dn_net).cuda() |
tmp_ckpt = torch.load(last_ckpt) |
pretrained_dict = tmp_ckpt['state_dict'] |
model_dict=dn_model.state_dict() |
pretrained_dict_update = {k: v for k, v in pretrained_dict.items() if k in model_dict} |
assert(len(pretrained_dict)==len(pretrained_dict_update)) |
assert(len(pretrained_dict_update)==len(model_dict)) |
model_dict.update(pretrained_dict_update) |
dn_model.load_state_dict(model_dict) |
for filename in os.listdir(Rpath): |
if os.path.splitext(filename)[-1].lower() == ".png": |
image_files.append(filename) |
with torch.no_grad(): |
for fp in image_files: |
fp = os.path.join(Rpath, fp) |
mn = os.path.splitext(fp)[-2] |
mf = str(mn) + '.json' |
raw_image = cv2.imread(fp, -1) |
with open(mf, 'r') as file: |
data = json.load(file) |
time_BL_S = time.time() |
raw_image = (raw_image.astype(np.float32) - 256.) |
raw_image = raw_image / (4095. - 256.) |
raw_image = np.clip(raw_image, 0.0, 1.0) |
raw_image = binning(raw_image,data) |
raw_image = cv2.resize(raw_image, [1024,768]) |
Temp_I = Four2One(raw_image) |
Temp_I = transformo(Temp_I).unsqueeze(0).cuda() |
Temp_I = dn_model(Temp_I) |
Temp_I = np.asarray(Temp_I.squeeze(0).squeeze(0).cpu()) |
raw_image = One2Four(Temp_I) |
raw_image = white_balance(raw_image, data['as_shot_neutral']) |
raw_image = apply_color_space_transform(raw_image, color_matrix) |
raw_image = transform_xyz_to_srgb(raw_image) |
raw_image = apply_gamma(raw_image) |
Source = transform(raw_image).unsqueeze(0).float().cuda() |
Out = model(Source) |
Out = Out.clip(0,1) |
OA = np.asarray(Out.squeeze(0).cpu()).transpose(1,2,0).astype(np.float32) |
OA = OA*255. |
OA = OA.astype(np.uint8) |
OA = fix_orientation(OA,data["orientation"]) |
time_Save_F = time.time() |
OA = cv2.cvtColor(OA, cv2.COLOR_RGB2BGR) |
OA = cv2.imwrite('./Output/' + str(os.path.basename(fp)),OA) |
infer_times.append(time_Save_F-time_BL_S) |
print(f"Average inference time: {np.mean(infer_times)} seconds") |