Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from model import MangaColorizer | |
from utils import pil_to_torch, torch_to_pil | |
def load_html_template(): | |
html_dir = os.path.join(os.path.dirname(__file__), "templates") | |
index_html_path = os.path.join(html_dir, "index.html") | |
if os.path.exists(index_html_path): | |
with open(index_html_path, "r") as html_file: | |
index_html = html_file.read() | |
return index_html | |
else: | |
print(f"Error: {index_html_path} not found.") | |
def load_model(): | |
model = MangaColorizer() | |
models_dir = os.path.join(os.path.dirname(__file__), '..', 'model') | |
model_file = os.path.join(models_dir, 'best_model_checkpoint.pth') | |
if os.path.exists(model_file): | |
with open(model_file, "rb") as f: | |
checkpoint = torch.load(f, map_location="cpu") | |
model.load_state_dict(checkpoint) | |
else: | |
print(f"Error: {model_file} not found.") | |
return model | |
model = load_model() | |
def colorize_image(image): | |
global model | |
img = Image.fromarray(image).convert("L") | |
output = model(pil_to_torch(img)).detach().cpu() | |
output_image = torch_to_pil(output) | |
return output_image | |
def main(): | |
index_html = load_html_template() | |
with gr.Blocks() as demo: | |
gr.HTML(index_html) | |
gr.Interface(colorize_image, inputs=["image"], outputs=["image"], allow_flagging="never") | |
gr.HTML(""" | |
<p style="text-align: center;font-size: large;"> | |
Checkout the <a href="https://github.com/zaidmehdi/manga-colorizer">Github Repo</a> | |
</p> | |
""") | |
demo.launch() | |
if __name__ == "__main__": | |
main() |