Spaces:
Running
on
T4
Running
on
T4
File size: 6,139 Bytes
9bf54b1 561c629 8f3d49d 9bf54b1 561c629 9bf54b1 561c629 720b377 561c629 9bf54b1 720b377 561c629 720b377 561c629 9bf54b1 561c629 8f3d49d 561c629 9bf54b1 561c629 d26bbd5 8f3d49d 561c629 d26bbd5 561c629 8f3d49d 561c629 72f81a3 d26bbd5 9bf54b1 720b377 9bf54b1 561c629 8f3d49d 561c629 720b377 9bf54b1 561c629 6d00ac8 561c629 d26bbd5 561c629 9bf54b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
'''
Gradio demo (almost the same code as the one used in Huggingface space)
'''
import os, sys
import cv2
import time
import datetime, pytz
import gradio as gr
import torch
import numpy as np
from torchvision.utils import save_image
# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from test_code.inference import super_resolve_img
from test_code.test_utils import load_grl, load_rrdb, load_dat
def auto_download_if_needed(weight_path):
if os.path.exists(weight_path):
return
if not os.path.exists("pretrained"):
os.makedirs("pretrained")
if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")
if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth")
os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained")
def inference(img_path, model_name):
try:
weight_dtype = torch.float32
# Load the model
if model_name == "4xGRL":
weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_grl(weight_path, scale=4) # Directly use default way now
elif model_name == "4xRRDB":
weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_rrdb(weight_path, scale=4) # Directly use default way now
elif model_name == "2xRRDB":
weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_rrdb(weight_path, scale=2) # Directly use default way now
elif model_name == "4xDAT":
weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_dat(weight_path, scale=4) # Directly use default way now
else:
raise gr.Error("We don't support such Model")
generator = generator.to(dtype=weight_dtype)
print("We are processing ", img_path)
print("The time now is ", datetime.datetime.now(pytz.timezone('US/Eastern')))
# In default, we will automatically use crop to match 4x size
super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
store_name = str(time.time()) + ".png"
save_image(super_resolved_img, store_name)
outputs = cv2.imread(store_name)
outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
os.remove(store_name)
return outputs
except Exception as error:
raise gr.Error(f"global exception: {error}")
if __name__ == '__main__':
MARKDOWN = \
"""
## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
[GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720
### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight and [Here](https://imgsli.com/MjU0MjI0) for model comparisons.
### If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks! ###
"""
block = gr.Blocks().queue(max_size=10)
with block:
with gr.Row():
gr.Markdown(MARKDOWN)
with gr.Row(elem_classes=["container"]):
with gr.Column(scale=2):
input_image = gr.Image(type="filepath", label="Input")
model_name = gr.Dropdown(
[
"2xRRDB",
"4xRRDB",
"4xGRL",
"4xDAT",
],
type="value",
value="4xGRL",
label="model",
)
run_btn = gr.Button(value="Submit")
with gr.Column(scale=3):
output_image = gr.Image(type="numpy", label="Output image")
with gr.Row(elem_classes=["container"]):
gr.Examples(
[
["__assets__/lr_inputs/image-00277.png"],
["__assets__/lr_inputs/image-00542.png"],
["__assets__/lr_inputs/41.png"],
["__assets__/lr_inputs/f91.jpg"],
["__assets__/lr_inputs/image-00440.png"],
["__assets__/lr_inputs/image-00164.jpg"],
["__assets__/lr_inputs/img_eva.jpeg"],
["__assets__/lr_inputs/naruto.jpg"],
],
[input_image],
)
run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
block.launch()
|