turbo_inversion / app_gfp.py
zhiweili
test gfp
0a74034
raw
history blame
5.54 kB
import os
import time
import spaces
import cv2
import gradio as gr
import torch
from gfpgan.utils import GFPGANer
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from realesrgan.utils import RealESRGANer
os.system("pip freeze")
# download weights
if not os.path.exists('realesr-general-x4v3.pth'):
os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
if not os.path.exists('GFPGANv1.2.pth'):
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
if not os.path.exists('GFPGANv1.3.pth'):
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
if not os.path.exists('GFPGANv1.4.pth'):
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
if not os.path.exists('RestoreFormer.pth'):
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
if not os.path.exists('CodeFormer.pth'):
os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
# background enhancer with RealESRGAN
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
upsampler = None
os.makedirs('output', exist_ok=True)
@spaces.GPU(duration=10)
def enhance(
img_path:str,
version:str='1.4',
scale:int=2,
upscale:int=2,
):
run_task_time = 0
time_cost_str = ''
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
extension = os.path.splitext(os.path.basename(img_path))[1]
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
elif len(img.shape) == 2: # for gray inputs
img_mode = None
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
else:
img_mode = None
h, w = img.shape[0:2]
if h < 300:
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
if version == 'v1.2':
face_enhancer = GFPGANer(model_path='GFPGANv1.2.pth', upscale=upscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'v1.3':
face_enhancer = GFPGANer(model_path='GFPGANv1.3.pth', upscale=upscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'v1.4':
face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=upscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'RestoreFormer':
face_enhancer = GFPGANer(model_path='RestoreFormer.pth', upscale=upscale, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'CodeFormer':
face_enhancer = GFPGANer(model_path='CodeFormer.pth', upscale=upscale, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
elif version == 'RealESR-General-x4v3':
face_enhancer = GFPGANer(model_path='realesr-general-x4v3.pth', upscale=upscale, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler)
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
if scale != 2:
interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
h, w = img.shape[0:2]
output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
else:
extension = 'jpg'
save_path = f'output/out.{extension}'
cv2.imwrite(save_path, output)
output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
return output, save_path, time_cost_str
def get_time_cost(run_task_time, time_cost_str):
now_time = int(time.time()*1000)
if run_task_time == 0:
time_cost_str = 'start'
else:
if time_cost_str != '':
time_cost_str += f'-->'
time_cost_str += f'{now_time - run_task_time}'
run_task_time = now_time
return run_task_time, time_cost_str
def create_demo() -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
version = gr.Radio(['v1.2', 'v1.3', 'v1.4'], type="value", value='v1.4', label='version')
scale = gr.Number(label="Rescaling factor", value=2)
with gr.Column():
upscale = gr.Number(label="Upscale factor", value=2)
g_btn = gr.Button("Enhance")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="filepath")
with gr.Column():
restored_image = gr.Image(label="Restored Image", type="numpy", interactive=False)
download_path = gr.File(label="Download the output image", interactive=False)
restored_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
g_btn.click(
fn=enhance,
inputs=[input_image, version, scale, upscale],
outputs=[restored_image, download_path, restored_cost],
)
return demo