NPRC24 / MiAlgo /main_test.py
Artyom
MiAlgo
82567db verified
raw
history blame
20.8 kB
import os
import multiprocessing as mp
import numpy as np
import cv2 as cv
from tqdm import tqdm
from glob import glob
from utils import *
from grayness_index import GraynessIndex
import torch
import torch.nn.functional as F
from time import time
from network_raw_denoise import sc_net_1f
from network import MWRCANv4 as NET
from classes.fc4.ModelFC4 import ModelFC4
def load_img(img_path):
meta_all = {}
meta_all['img_path'] = img_path
# load meta
metadata = json_read(img_path.replace(".png", ".json"), object_hook=fraction_from_json)
meta_all['meta'] = metadata
# load image
img = cv.imread(img_path, cv.IMREAD_UNCHANGED)
meta_all['img'] = img
return meta_all
def pre_process(meta_all):
img = meta_all['img']
metadata = meta_all['meta']
cfa_pattern = metadata['cfa_pattern']
cfa_pattern_ = ""
for tt in cfa_pattern:
if tt == 0:
cfa_pattern_ += "r"
elif tt == 1:
cfa_pattern_ += "g"
elif tt == 2:
cfa_pattern_ += "b"
else:
raise
offsets = bayer_to_offsets(cfa_pattern_)
img = pack_raw_to_4ch(img, offsets)
if img.shape[0] != 768 and img.shape[1] != 1024:
img = cv.resize(img, (1024, 768), interpolation=cv.INTER_AREA) # RGB
bl_fix = np.clip((float(metadata["noise_profile"][0])-0.005) * 1000, 0, 10)
img = normalize(img, metadata['black_level'], metadata['white_level'], bl_fix).astype(np.float32)
noise_profile = float(metadata["noise_profile"][0])
noise_list = [0.00025822882, 0.000580020745, 0.00141667975, 0.00278965863, 0.00347614807]
if noise_profile < 0.005:
if noise_profile < noise_list[0]:
weight1 = noise_profile / noise_list[0]
final_lsc = lsc_npy[0] * weight1
linear_idx1, linear_idx2 = 0, 0
elif noise_profile > noise_list[-1]:
final_lsc = lsc_npy[-1]
linear_idx1, linear_idx2 = -1, -1
else:
for idx, nn in enumerate(noise_list):
if noise_profile < nn:
linear_idx1 = idx - 1
linear_idx2 = idx
break
weight1 = (noise_profile - noise_list[linear_idx1]) / (noise_list[linear_idx2] - noise_list[linear_idx1])
weight2 = 1-weight1
final_lsc = lsc_npy[linear_idx1] * weight1 + lsc_npy[linear_idx2] * weight2
ones = np.ones_like(final_lsc)
final_lsc = final_lsc * 0.6 + ones * 0.4
final_lsc[:, :512, :] = final_lsc[:, 1024:511:-1, :]
img = img * final_lsc
img = np.clip(img, 0.0, 1.0)
meta_all["img"] = img
rgb_gain = metadata['as_shot_neutral']
ra, ga, ba = rgb_gain
ra, ga, ba = 1/ra, 1/ga, 1/ba
meta_all['r_gains'] = [ra]
meta_all['g_gains'] = [ga]
meta_all['b_gains'] = [ba]
return meta_all
def raw_denoise(results):
checkpoint_path = "checkpoint/raw_denoise.pth"
device = torch.device("cuda")
model = get_net(sc_net_1f, checkpoint_path, device)
for meta_all in tqdm(results):
img = meta_all['img']
img = np.expand_dims(img, axis=0)
ori_inp = img.copy()
clip_min = max(np.mean(img)*3, 0.9)
img = np.clip(img, 0, clip_min)
img = torch.from_numpy(img.transpose(0, 3, 1, 2)).cuda()
with torch.no_grad():
output = model(img)
output = output.detach().cpu().numpy().transpose(0, 2, 3, 1)
img = ori_inp + output
img = np.clip(img, 0, 1)
img = np.squeeze(img)
meta_all['img'] = img
def predict_white_balance(results):
model = ModelFC4()
for model_index in [0, 1, 2]:
path_to_pretrained = os.path.join("./trained_models", "fc4_cwp", "fold_{}".format(model_index))
model.load(path_to_pretrained)
model.evaluation_mode()
for meta_all in tqdm(results):
img = meta_all['img'].copy()
img[:, :, 1] = (img[:, :, 1] + img[:, :, 3]) / 2
img = img[:, :, :-1]
img = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).cuda()
img_tmp = torch.pow(img, 1.0 / 2.2)
with torch.no_grad():
pred = model.predict(img_tmp, return_steps=False)
pred = pred.detach().cpu().squeeze(0).numpy()
# rgb gain
r, g, b = pred
r /= g
b /= g
g /= g
r = 1./ r
g = 1./ g
b = 1./ b
meta_all['r_gains'].append(r)
meta_all['g_gains'].append(g)
meta_all['b_gains'].append(b)
def convert_to_rgb(meta_all):
img = meta_all['img']
img[:, :, 1] = (img[:, :, 1] + img[:, :, 3]) / 2
img = img[:, :, :-1]
# WB
r_gains = sorted(meta_all['r_gains'])
b_gains = sorted(meta_all['b_gains'])
r_final = (r_gains[0] + r_gains[1] + r_gains[2]) / 3
g_final = 1
b_final = (b_gains[1] + b_gains[2] + b_gains[3]) / 3
img[:, :, 0] *= r_final
img[:, :, 1] *= g_final
img[:, :, 2] *= b_final
img = np.clip(img, 0, 1)
# CC
img = apply_color_space_transform(img, color_matrix, color_matrix)
# convert RGB
img = transform_xyz_to_srgb(img)
# shading fix
if float(meta_all['meta']["noise_profile"][0]) > 0.005:
lsc_m = lsc ** ((float(meta_all['meta']["noise_profile"][0])-0.005) * 100)
lsc_inv = 1 / lsc
lsc_inv = np.mean(lsc_inv, axis=-1, keepdims=True)
gray = cv.cvtColor(img.astype(np.float32), cv.COLOR_RGB2GRAY)
gray = gray[:, :, np.newaxis]
lsc_inv = lsc_inv * np.clip(gray*10, 0, 1) * np.clip((2 - (float(meta_all['meta']["noise_profile"][0])-0.005) * 100), 1, 2)
lsc_inv = np.clip(lsc_inv, 0.4, 1)
img = img * lsc_inv + gray * (1-lsc_inv)
img = img / lsc_m
# tonemaping
img = apply_tone_map(img)
# gamma
img = apply_gamma(img).astype(np.float32)
img = np.clip(img, 0, 1)
# contrast enhancement
mm = np.mean(img)
meta_all['mm'] = mm
if mm <= 0.1:
pass
elif float(meta_all['meta']["noise_profile"][0]) > 0.01:
yuv = cv.cvtColor(img, cv.COLOR_BGR2YUV)
y, u, v = cv.split(yuv)
y = autocontrast_using_pil(y)
yuv = np.stack([y, u, v], axis=-1)
rgb = cv.cvtColor(yuv, cv.COLOR_YUV2BGR)
rgb = np.clip(rgb, 0, 1)
img = img * 0.5 + rgb * 0.5
img = np.clip(img*255, 0, 255).round().astype(np.uint8)
if float(meta_all['meta']["noise_profile"][0]) > 0.02:
noise_params = 6
else:
noise_params = 3
img = cv.fastNlMeansDenoisingColored(img, None, noise_params, noise_params, 7, 21)
img = img.astype(np.float32) / 255.
img = usm_sharp(img)
else:
img = autocontrast_using_pil(img)
# gamma again
img = np.clip(img, 0, 1)
img_con = img ** (1/1.5)
gray = np.max(img_con, axis=-1, keepdims=True) # - 0.1
gray = np.clip(gray, 0.3, 1)
img = img_con * gray + img * (1-gray)
# AWB again
img = img[:, :, ::-1] # BGR
gi = GraynessIndex()
pred_illum = gi.apply(img)
r, g, b = pred_illum
pred_illum = pred_illum / g
r, g, b = pred_illum
if r < 1:
img = white_balance(img, pred_illum) # BGR
img = img[:, :, ::-1]
img = np.clip(img, 0, 1) # RGB
# fix orientation
img = fix_orientation(img, meta_all['meta']["orientation"])
meta_all['img'] = img # RGB
return meta_all
def nn_enhancement(results):
checkpoint_path1 = "checkpoint/nn_enhance.pth"
device = torch.device("cuda")
model = get_net(NET, checkpoint_path1, device)
for meta_all in tqdm(results):
# mm = meta_all['mm']
# if mm <= 0.1 or float(meta_all['meta']["noise_profile"][0]) > 0.01:
# meta_all['img'] = meta_all['img'] * 255
# continue
img = meta_all['img']
img = img.astype(np.float32)
img = torch.from_numpy(img.copy().transpose(2, 0, 1)).unsqueeze(0).to(device)
with torch.no_grad():
img = model(img)
# img = img
img = img.detach().cpu().squeeze(0).numpy().transpose(1, 2, 0)
img = np.clip(img, 0, 1)
img = img * 255.
img = img.round()
img = img.astype(np.uint8)
meta_all['img'] = img # RGB U8
def post_process(meta_all):
# color fix
img = meta_all['img'] # RGB U8
# increase saturation
increment=0.5
ori_img = img.copy() # RGB U8
hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS).astype(np.float32)
_, L, S = cv2.split(hls)
S = S / 255.
img = img.astype(np.float32)
temp = increment + S
mask_2 = temp > 1 # 大于1的位置
alpha_1 = S
alpha_2 = 1 - increment
alpha = alpha_1 * mask_2 + alpha_2 * (1 - mask_2)
L = L[:, :, np.newaxis]
alpha = alpha[:, :, np.newaxis]
alpha = 1/alpha -1
img = img + (img - L) * alpha
img = np.clip(img, 0, 255)
ori_img = ori_img.astype(np.float32)
mask = ori_img[:, :, 2] / 255.
# mask = np.max(ori_img, axis=-1) / 255.
mask = mask[:, :, np.newaxis]
mask = np.clip(mask - 0.1, 0, 1)
img = img * mask + ori_img * (1-mask)
img = np.clip(img, 0, 255).round().astype(np.uint8)
# decrease saturation
hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
hsv = hsv.astype(np.float32)
# 绿 青
mmax = 105
mmin = 40
alpha = 1 # 越大效果越猛, 误伤越多
beta = 4 # 越大效果越猛
gamma = 0.1 # 越小效果越猛
mid = mmin + ((mmax - mmin) / 2)
green_weight = np.abs(hsv[:, :, 0] - mid) / ((mmax - mmin)/2)
green_weight = np.clip(green_weight, 0, 1)
# green_weight = np.tanh(green_weight/alpha)
green_weight = green_weight**beta + gamma
green_weight = np.clip(green_weight, 0, 1)
green_weight = cv2.blur(green_weight, (11, 11))
hsv[:, :, 2] = hsv[:, :, 2] * green_weight
# 紫 洋红
mmax = 180
mmin = 130
alpha = 1 # 越大效果越猛, 误伤越多
beta = 8
# 越大效果越猛
gamma = -0.5 # 越小效果越猛
mid = mmin + ((mmax - mmin) / 2)
green_weight = np.abs(hsv[:, :, 0] - mid) / ((mmax - mmin)/2)
green_weight = np.clip(green_weight, 0, 1)
# green_weight = np.tanh(green_weight/alpha)
green_weight = (green_weight**beta + gamma) * 2
green_weight = np.clip(green_weight, 0, 1)
green_weight = cv2.blur(green_weight, (11, 11))
hsv[:, :, 2] = hsv[:, :, 2] * green_weight
hsv = np.clip(hsv, 0, 255)
hsv = hsv.round().astype(np.uint8)
img = cv2.cvtColor(hsv, cv2.COLOR_HLS2RGB) # RGB U8
img = np.clip(img, 0, 255)
img = np.clip(img, 0, 255).round().astype(np.uint8)
meta_all['img'] = img # RGB U8
return meta_all
def sky_enhancement(results):
model_path = "sky_seg.pt"
model = torch.load(model_path)
model.cuda()
model.eval()
for meta_all in tqdm(results):
if float(meta_all['meta']["noise_profile"][0]) >= 0.005:
continue
ori_img = meta_all['img'].copy().astype(np.float32) # RGB 0-255 U8
img = ori_img.copy()
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# 天空分割
scene_image = img.copy().astype(np.float32) # 0-255, bgr
img = img / 255.
lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
# 减少色温
lab[:,:,1] = lab[:,:,1] - (lab[:,:,2] + 127) * 0.03
lab[:,:,2] = lab[:,:,2] - (lab[:,:,2] + 127) * 0.1
# 将图像从LAB空间转换回BGR空间
img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
img = img * 255
img = np.clip(img, 0, 255).round().astype(np.float32)
img_mean = 0
img_std = 255.0
size = (512, 512)
img_h , img_w = img.shape[:2]
img = cv2.resize(img, size)
img = (img - img_mean) / img_std
img = np.transpose(img, [2, 0, 1])
img = np.expand_dims(img, axis=0)
img = torch.from_numpy(img).cuda()
with torch.no_grad():
mask = model(img)
mask = mask.detach().cpu()
mask = mask.permute((0,3,1,2))
mask = F.interpolate(mask,
size=[img_h , img_w],
mode='bilinear')
mask = mask[0].permute((1,2,0))
sky_mask = torch.argmax(mask, axis=2).numpy().astype(np.float32)
if sky_mask.max() < 0.1:
continue
#
img = ori_img.copy() # RGB
mask = img[:, :, 2] - np.max(img[:, :, :2], axis=-1)
mask[sky_mask==0]=0
a = np.sum(mask)
b = np.sum(sky_mask)
ratio_blue = a/b
# print(meta_all['img_path'], "blue ratio", ratio_blue)
# 非蓝天
if ratio_blue < 10:
img = ori_img.copy()
mask = np.mean(img[:, :, :2], axis=-1)
mask[sky_mask==0]=0
a = np.sum(mask)
b = np.sum(sky_mask)
ratio_light = a/b
# print(meta_all['img_path'], "light ratio", ratio_light)
# 暗天空,压暗
if ratio_light<50:
img = ori_img.copy()
img = img * 0.88
img = np.clip(img, 0, 255) # RGB
# 中等亮度,提亮
elif ratio_light < 200:
img = ori_img.copy()
img = img * 1.1
img = np.clip(img, 0, 255) # RGB
else:
pass
hsv = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
hsv[:, :, 1] = hsv[:, :, 1]* 0.4
hsv = np.clip(hsv, 0, 255).astype(np.uint8)
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.float32)
# 蓝天
else:
# LAB
img = ori_img.copy()
img = img / 255.
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
# 减少色温
lab[:,:,1] = lab[:,:,1] - (lab[:,:,2] + 127) * 0.03
lab[:,:,2] = lab[:,:,2] - (lab[:,:,2] + 127) * 0.1
# 将图像从LAB空间转换回BGR空间
img = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
img = img * 255
img = np.clip(img, 0, 255).round().astype(np.float32)
sky_image = img.copy().astype(np.float32) # 0-255, RGB
sky_image = cv2.cvtColor(sky_image, cv2.COLOR_RGB2BGR) # BGR 0-255 F32
sky_mask_ori = sky_mask.copy()
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 11))
sky_mask_ori = cv2.erode(sky_mask_ori, kernel)
sky_mask_ori = sky_mask_ori > 0.9
if np.sum(sky_mask_ori) > 0:
h, w = sky_mask.shape
sky_mask = cv2.resize(sky_mask, None, fx=0.1, fy=0.1, interpolation=cv2.INTER_NEAREST)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
sky_mask = cv2.dilate(sky_mask, kernel)
sky_mask_blur = cv2.blur(sky_mask, (21, 21))
sky_mask_blur[sky_mask>0.5] = sky_mask[sky_mask>0.5]
sky_mask = sky_mask_blur
sky_mask = cv2.resize(sky_mask, (w, h), interpolation=cv2.INTER_LINEAR)
sky_mask = np.clip(sky_mask, 0.1, 1)
sky_area_img = np.zeros_like(sky_image)
sky_area_img[sky_mask_ori] = sky_image[sky_mask_ori]
sky_area_img = cv2.cvtColor(sky_area_img, cv2.COLOR_BGR2GRAY)
sky_area_img_mean = np.sum(sky_area_img) / np.sum(sky_mask_ori)
if sky_area_img_mean > 20:
res = pyrblend(scene_image, sky_image, sky_mask)
res = np.clip(res, 0, 255) # 0-255, bgr
res = res.round().astype(np.uint8)
res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB) # RGB 0-255 U8
meta_all['img'] = res
def post_process2(meta_all):
# PIL
img = meta_all['img'].copy() # RGB U8
img = img.astype(np.float32) / 255.
yuv = cv.cvtColor(img, cv.COLOR_RGB2YUV)
y, u, v = cv.split(yuv)
y = autocontrast_using_pil(y)
yuv = np.stack([y, u, v], axis=-1)
rgb = cv.cvtColor(yuv, cv.COLOR_YUV2RGB)
rgb = np.clip(rgb, 0, 1)
img = rgb
img = np.clip(img*255, 0, 255)# .round().astype(np.uint8) # RGB
ori_img = meta_all['img'].copy().astype(np.float32)
mask = np.mean(ori_img, axis=-1) / 255.
mask = mask[:, :, np.newaxis]
mask = np.clip(mask - 0.1, 0, 1)
img = img * mask + ori_img * (1-mask)
img = np.clip(img, 0, 255)
img = img.round().astype(np.uint8)
meta_all['img'] = img
return meta_all
def save_jpg(meta_all):
img = meta_all['img'] # RGB U8
out_path = os.path.join(output_path, meta_all['img_path'].split("/")[-1].split(".")[0] + ".jpg")
cv.imwrite(out_path, img[:, :, ::-1], [cv.IMWRITE_JPEG_QUALITY, 100])
if __name__ == "__main__":
num_worker = 4
all_time = time()
input_path = "/data"
output_path = "/data"
# input_path = "/ssd/ntire24/nightrender/data/data"
# output_path = "/ssd/ntire24/nightrender/data/data"
os.makedirs(output_path, exist_ok=True)
# load img
s_time = time()
input_list = sorted(glob(os.path.join(input_path, "*.png")))# [:4]
if num_worker > 1:
with mp.Pool(num_worker) as pool:
results = list(tqdm(pool.imap(load_img, input_list), total=len(input_list)))
else:
results = []
for p in tqdm(input_list):
results.append(load_img(p))
load_time = time()-s_time
print("load_img time is: ", load_time)
# preprocess
s_time = time()
iso_list = [50, 125, 320, 640, 800]
lsc_npy = [np.load("./lsc_npy/{}.npy".format(iso)) for iso in iso_list]
if num_worker > 1:
with mp.Pool(num_worker) as pool:
results = list(tqdm(pool.imap(pre_process, results), total=len(results)))
else:
for r in tqdm(results):
r = pre_process(r)
del lsc_npy
print("pre_process time is: ", time()-s_time)
# raw denoise
s_time = time()
raw_denoise(results)
print("raw_denoise time is: ", time()-s_time)
# awb
s_time = time()
predict_white_balance(results)
print("predict_white_balance time is: ", time()-s_time)
# convert
s_time = time()
color_matrix = [1.06835938, -0.29882812, -0.14257812, -0.43164062, 1.35546875, 0.05078125, -0.1015625, 0.24414062, 0.5859375]
lsc = np.load("lsc.npy")
if num_worker > 1:
with mp.Pool(num_worker) as pool:
results = list(tqdm(pool.imap(convert_to_rgb, results), total=len(results)))
else:
for r in tqdm(results):
r = convert_to_rgb(r)
del lsc
print("convert_to_rgb time is: ", time()-s_time)
# NN_enhancement
s_time = time()
nn_enhancement(results)
print("nn_enhancement time is: ", time()-s_time)
# colorfix & sat enhance
s_time = time()
kernel = cv.getStructuringElement(cv.MORPH_RECT, (5, 5))
if num_worker > 1:
with mp.Pool(num_worker) as pool:
results = list(tqdm(pool.imap(post_process, results), total=len(results)))
else:
for r in tqdm(results):
r = post_process(r)
print("post_process time is: ", time()-s_time)
# sky_enhancement
s_time = time()
sky_enhancement(results)
print("sky_enhancement time is: ", time()-s_time)
# PIL autocontrast
s_time = time()
if num_worker > 1:
with mp.Pool(num_worker) as pool:
results = list(tqdm(pool.imap(post_process2, results), total=len(results)))
else:
for r in tqdm(results):
r = post_process2(r)
print("post_process2 time is: ", time()-s_time)
# save jpg
s_time = time()
if num_worker > 1:
with mp.Pool(num_worker) as pool:
_ = list(tqdm(pool.imap(save_jpg, results), total=len(results)))
else:
for r in tqdm(results):
save_jpg(r)
save_time = time()-s_time
print("save_jpg time is: ", save_time)
total_time = time()-all_time
total_time_without_load_save = total_time - load_time - save_time
print("per image inference time (without load and save) is: ", total_time_without_load_save / len(results), "s")