IAT_enhancement / app.py
SakuraD's picture
init
cabb20f
raw
history blame
No virus
3.71 kB
import os
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Scale, Normalize, ConvertImageDtype
import numpy as np
import cv2
import gradio as gr
from huggingface_hub import hf_hub_download
from model import IAT
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
def dark_inference(img):
model = IAT()
checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'
state_dict = torch.load(checkpoint_file_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
print(f'Load model from {checkpoint_file_path}')
transform = Compose([
ToTensor(),
Scale(384),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ConvertImageDtype(torch.float)
])
input_img = transform(img)
print(f'Image shape: {input_img.shape}')
enhanced_img = model(input_img.unsqueeze(0))
return enhanced_img[0].permute(1, 2, 0).detach().numpy()
def exposure_inference(img):
model = IAT()
checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth'
state_dict = torch.load(checkpoint_file_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
print(f'Load model from {checkpoint_file_path}')
transform = Compose([
ToTensor(),
Scale(384),
ConvertImageDtype(torch.float)
])
input_img = transform(img)
print(f'Image shape: {input_img.shape}')
enhanced_img = model(input_img.unsqueeze(0))
return enhanced_img[0].permute(1, 2, 0).detach().numpy()
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# IAT
Gradio demo for <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>IAT</a>: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.
"""
)
with gr.Box():
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input Image', type='numpy')
with gr.Row():
dark_button = gr.Button('Low-light Enhancement')
with gr.Row():
exposure_button = gr.Button('Exposure Correction')
with gr.Column():
res_image = gr.Image(type='numpy', label='Resutls')
with gr.Row():
dark_example_images = gr.Dataset(
components=[input_image],
samples=[['dark_imgs/1.jpg'], ['dark_imgs/2.jpg'], ['dark_imgs/3.jpg']]
)
with gr.Row():
exposure_example_images = gr.Dataset(
components=[input_image],
samples=[['exposure_imgs/1.jpg'], ['exposure_imgs/2.jpg'], ['exposure_imgs/3.jpeg']]
)
gr.Markdown(
"""
<p style='text-align: center'><a href='https://arxiv.org/abs/2205.14871' target='_blank'>You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction</a> | <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>Github Repo</a></p>
"""
)
dark_button.click(fn=dark_inference, inputs=input_image, outputs=res_image)
exposure_button.click(fn=exposure_inference, inputs=input_image, outputs=res_image)
dark_example_images.click(fn=set_example_image, inputs=dark_example_images, outputs=dark_example_images.components)
exposure_example_images.click(fn=set_example_image, inputs=exposure_example_images, outputs=exposure_example_images.components)
demo.launch(enable_queue=True)