Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,066 Bytes
3b6fea8 32925c4 3b6fea8 32925c4 3b6fea8 32925c4 3b6fea8 ec4a1e7 3b6fea8 32925c4 3b6fea8 ec4a1e7 1af5cc8 ec4a1e7 6091f66 1af5cc8 6091f66 32925c4 6091f66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
#!/usr/bin/env python
from __future__ import annotations
import pathlib
import sys
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
current_dir = pathlib.Path(__file__).parent
submodule_dir = current_dir / "MangaLineExtraction_PyTorch"
sys.path.insert(0, submodule_dir.as_posix())
from model_torch import res_skip
DESCRIPTION = "# [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)"
def load_model(device: torch.device) -> nn.Module:
ckpt_path = hf_hub_download("public-data/MangaLineExtraction_PyTorch", "erika.pth")
state_dict = torch.load(ckpt_path)
model = res_skip()
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
MAX_SIZE = 1000
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)
@spaces.GPU
@torch.inference_mode()
def predict(image: np.ndarray) -> np.ndarray:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
if max(gray.shape) > MAX_SIZE:
scale = MAX_SIZE / max(gray.shape)
gray = cv2.resize(gray, None, fx=scale, fy=scale)
h, w = gray.shape
size = 16
new_w = (w + size - 1) // size * size
new_h = (h + size - 1) // size * size
patch = np.ones((1, 1, new_h, new_w), dtype=np.float32)
patch[0, 0, :h, :w] = gray
tensor = torch.from_numpy(patch).to(device)
out = model(tensor)
res = out.cpu().numpy()[0, 0, :h, :w]
res = np.clip(res, 0, 255).astype(np.uint8)
return res
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input", type="numpy")
run_button = gr.Button()
with gr.Column():
result = gr.Image(label="Result", elem_id="result")
run_button.click(
fn=predict,
inputs=input_image,
outputs=result,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|