NPRC24 / SCBC /SCBC_Solution.py
Artyom
scbc
f8d6c27 verified
raw
history blame
4 kB
import os
import cv2
import json
import torch
import torchvision.transforms as transforms
from CPNet_model import LiteAWBISPNet
import torchvision
import numpy as np
from Utiles import white_balance,apply_color_space_transform, transform_xyz_to_srgb, apply_gamma,fix_orientation,binning,Four2One,One2Four
import time
from net.mwrcanet import Net
import torch.nn as nn
from PIL import Image
import torch.nn.functional as F
#######Set Raw path###########
Rpath = './Input'
image_files = []
####### Temp ###############################
infer_times = []
#######Color Matrix from Baseline#############
color_matrix = [1.06835938, -0.29882812, -0.14257812,
-0.43164062, 1.35546875, 0.05078125,
-0.1015625, 0.24414062, 0.5859375]
#######Data Transfer###########################
transforms_ = [ transforms.ToTensor(),
transforms.Resize([768,1024])]
transform = transforms.Compose(transforms_)
transforms_ = [ transforms.ToTensor()]
transformo = transforms.Compose(transforms_)
########Load the pretrained refinement model####
model = LiteAWBISPNet()
model.cuda()
model.load_state_dict(torch.load('./model_zoo/CC2.pth') )
######load pretrianed Denoised model##############
last_ckpt = './model_zoo/dn_mwrcanet_raw_c1.pth'
dn_net = Net()
dn_model = nn.DataParallel(dn_net).cuda()
tmp_ckpt = torch.load(last_ckpt)
pretrained_dict = tmp_ckpt['state_dict']
model_dict=dn_model.state_dict()
pretrained_dict_update = {k: v for k, v in pretrained_dict.items() if k in model_dict}
assert(len(pretrained_dict)==len(pretrained_dict_update))
assert(len(pretrained_dict_update)==len(model_dict))
model_dict.update(pretrained_dict_update)
dn_model.load_state_dict(model_dict)
############################Start Processing!#########
for filename in os.listdir(Rpath):
if os.path.splitext(filename)[-1].lower() == ".png":
image_files.append(filename)
with torch.no_grad():
for fp in image_files:
fp = os.path.join(Rpath, fp)
mn = os.path.splitext(fp)[-2]
mf = str(mn) + '.json'
raw_image = cv2.imread(fp, -1)
with open(mf, 'r') as file:
data = json.load(file)
############Bleack & Whilte##########################
time_BL_S = time.time()
raw_image = (raw_image.astype(np.float32) - 256.)
raw_image = raw_image / (4095. - 256.)
raw_image = np.clip(raw_image, 0.0, 1.0)
############# Binning ############################
raw_image = binning(raw_image,data)
############# Down sample ###########################
raw_image = cv2.resize(raw_image, [1024,768])
############ Raw Denoise ##########################
Temp_I = Four2One(raw_image)
Temp_I = transformo(Temp_I).unsqueeze(0).cuda()
Temp_I = dn_model(Temp_I)
Temp_I = np.asarray(Temp_I.squeeze(0).squeeze(0).cpu())
raw_image = One2Four(Temp_I)
#raw_image = cv2.resize(raw_image, [1024,768])
#############White Balance, Color M, Vignet #########
raw_image = white_balance(raw_image, data['as_shot_neutral'])
raw_image = apply_color_space_transform(raw_image, color_matrix)
raw_image = transform_xyz_to_srgb(raw_image)
raw_image = apply_gamma(raw_image)
#############Refinement#############################
Source = transform(raw_image).unsqueeze(0).float().cuda()
Out = model(Source)
#################Saving#############################
Out = Out.clip(0,1)
OA = np.asarray(Out.squeeze(0).cpu()).transpose(1,2,0).astype(np.float32)
OA = OA*255.
OA = OA.astype(np.uint8)
OA = fix_orientation(OA,data["orientation"])
time_Save_F = time.time()
OA = cv2.cvtColor(OA, cv2.COLOR_RGB2BGR)
OA = cv2.imwrite('./Output/' + str(os.path.basename(fp)),OA)
infer_times.append(time_Save_F-time_BL_S)
print(f"Average inference time: {np.mean(infer_times)} seconds")