Spaces:
Running
Running
#!/usr/bin/env python | |
from __future__ import annotations | |
import functools | |
import json | |
import os | |
import pathlib | |
import tarfile | |
from typing import Callable | |
import gradio as gr | |
import huggingface_hub | |
import PIL.Image | |
import torch | |
import torchvision.transforms as T | |
DESCRIPTION = "# [RF5/danbooru-pretrained](https://github.com/RF5/danbooru-pretrained)" | |
MODEL_REPO = "public-data/danbooru-pretrained" | |
def load_sample_image_paths() -> list[pathlib.Path]: | |
image_dir = pathlib.Path("images") | |
if not image_dir.exists(): | |
dataset_repo = "hysts/sample-images-TADNE" | |
path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset") | |
with tarfile.open(path) as f: | |
f.extractall() | |
return sorted(image_dir.glob("*")) | |
def load_model(device: torch.device) -> torch.nn.Module: | |
path = huggingface_hub.hf_hub_download(MODEL_REPO, "resnet50-13306192.pth") | |
state_dict = torch.load(path) | |
model = torch.hub.load("RF5/danbooru-pretrained", "resnet50", pretrained=False) | |
model.load_state_dict(state_dict) | |
model.to(device) | |
model.eval() | |
return model | |
def load_labels() -> list[str]: | |
path = huggingface_hub.hf_hub_download(MODEL_REPO, "class_names_6000.json") | |
with open(path) as f: | |
labels = json.load(f) | |
return labels | |
def predict( | |
image: PIL.Image.Image, | |
score_threshold: float, | |
transform: Callable, | |
device: torch.device, | |
model: torch.nn.Module, | |
labels: list[str], | |
) -> dict[str, float]: | |
data = transform(image) | |
data = data.to(device).unsqueeze(0) | |
preds = model(data)[0] | |
preds = torch.sigmoid(preds) | |
preds = preds.cpu().numpy().astype(float) | |
res = dict() | |
for prob, label in zip(preds.tolist(), labels): | |
if prob < score_threshold: | |
continue | |
res[label] = prob | |
return res | |
image_paths = load_sample_image_paths() | |
examples = [[path.as_posix(), 0.4] for path in image_paths] | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model = load_model(device) | |
labels = load_labels() | |
transform = T.Compose( | |
[ | |
T.Resize(360), | |
T.ToTensor(), | |
T.Normalize(mean=[0.7137, 0.6628, 0.6519], std=[0.2970, 0.3017, 0.2979]), | |
] | |
) | |
fn = functools.partial(predict, transform=transform, device=device, model=model, labels=labels) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Input", type="pil") | |
threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.4) | |
run_button = gr.Button() | |
with gr.Column(): | |
result = gr.Label(label="Output") | |
inputs = [image, threshold] | |
gr.Examples( | |
examples=examples, | |
inputs=inputs, | |
outputs=result, | |
fn=fn, | |
cache_examples=os.getenv("CACHE_EXAMPLES") == "1", | |
) | |
run_button.click( | |
fn=fn, | |
inputs=inputs, | |
outputs=result, | |
api_name="predict", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=15).launch() | |