Spaces:
Sleeping
Sleeping
import os | |
import time | |
import spaces | |
import cv2 | |
import gradio as gr | |
import torch | |
from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
from realesrgan.utils import RealESRGANer | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
os.system("pip freeze") | |
upsampler_map = {} | |
half = True if torch.cuda.is_available() else False | |
if not os.path.exists('realesr-general-x4v3.pth'): | |
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .") | |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
model_path = 'realesr-general-x4v3.pth' | |
upsampler_map['RealESR-General-x4v3'] = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half) | |
if not os.path.exists('RealESRGAN_x2plus.pth'): | |
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P .") | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
model_path = 'RealESRGAN_x2plus.pth' | |
upsampler_map['RealESRGAN_x2plus'] = RealESRGANer(scale=2, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half) | |
os.makedirs('output', exist_ok=True) | |
def enhance( | |
img_path:str, | |
version:str='x4v3', | |
scale:int=2, | |
): | |
run_task_time = 0 | |
time_cost_str = '' | |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) | |
extension = os.path.splitext(os.path.basename(img_path))[1] | |
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) | |
if len(img.shape) == 3 and img.shape[2] == 4: | |
img_mode = 'RGBA' | |
elif len(img.shape) == 2: # for gray inputs | |
img_mode = None | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
else: | |
img_mode = None | |
h, w = img.shape[0:2] | |
if h < 300: | |
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) | |
if version == 'x4v3': | |
upsampler = upsampler_map['RealESR-General-x4v3'] | |
elif version == 'x4plus': | |
upsampler = upsampler_map['RealESRGAN_x4plus'] | |
elif version == 'x2plus': | |
upsampler = upsampler_map['RealESRGAN_x2plus'] | |
elif version == 'net_x4plus': | |
upsampler = upsampler_map['RealESRNet_x4plus'] | |
output, _ = upsampler.enhance(img, outscale=scale) | |
if img_mode == 'RGBA': # RGBA images should be saved in png format | |
extension = 'png' | |
else: | |
extension = 'jpg' | |
save_path = f'output/out.{extension}' | |
cv2.imwrite(save_path, output) | |
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) | |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str) | |
return output, save_path, time_cost_str | |
def get_time_cost(run_task_time, time_cost_str): | |
now_time = int(time.time()*1000) | |
if run_task_time == 0: | |
time_cost_str = 'start' | |
else: | |
if time_cost_str != '': | |
time_cost_str += f'-->' | |
time_cost_str += f'{now_time - run_task_time}' | |
run_task_time = now_time | |
return run_task_time, time_cost_str | |
def create_demo() -> gr.Blocks: | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
version = gr.Radio(['x4v3', 'x4plus', 'x2plus', 'net_x4plus'], label="Version", value='x4v3') | |
scale = gr.Number(label="Rescaling factor", value=2) | |
with gr.Column(): | |
g_btn = gr.Button("Enhance") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="filepath") | |
with gr.Column(): | |
restored_image = gr.Image(label="Restored Image", type="numpy", interactive=False) | |
download_path = gr.File(label="Download the output image", interactive=False) | |
restored_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False) | |
g_btn.click( | |
fn=enhance, | |
inputs=[input_image, version, scale], | |
outputs=[restored_image, download_path, restored_cost], | |
) | |
return demo | |