|
import gradio as gr |
|
from Models import VisionModel |
|
import huggingface_hub |
|
from PIL import Image |
|
import torch.amp.autocast_mode |
|
from pathlib import Path |
|
import torch |
|
import torchvision.transforms.functional as TVF |
|
|
|
|
|
MODEL_REPO = "fancyfeast/joytag" |
|
THRESHOLD = 0.4 |
|
DESCRIPTION = """ |
|
Demo for the JoyTag model: https://huggingface.co/fancyfeast/joytag |
|
""" |
|
|
|
|
|
def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor: |
|
|
|
image_shape = image.size |
|
max_dim = max(image_shape) |
|
pad_left = (max_dim - image_shape[0]) // 2 |
|
pad_top = (max_dim - image_shape[1]) // 2 |
|
|
|
padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255)) |
|
padded_image.paste(image, (pad_left, pad_top)) |
|
|
|
|
|
if max_dim != target_size: |
|
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) |
|
|
|
|
|
image_tensor = TVF.pil_to_tensor(padded_image) / 255.0 |
|
|
|
|
|
image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) |
|
|
|
return image_tensor |
|
|
|
|
|
@torch.no_grad() |
|
def predict(image: Image.Image): |
|
image_tensor = prepare_image(image, model.image_size) |
|
batch = { |
|
'image': image_tensor.unsqueeze(0), |
|
} |
|
|
|
with torch.amp.autocast_mode.autocast('cpu', enabled=True): |
|
preds = model(batch) |
|
tag_preds = preds['tags'].sigmoid().cpu() |
|
|
|
scores = {top_tags[i]: tag_preds[0][i] for i in range(len(top_tags))} |
|
predicted_tags = [tag for tag, score in scores.items() if score > THRESHOLD] |
|
tag_string = ', '.join(predicted_tags) |
|
|
|
return tag_string, scores |
|
|
|
|
|
print("Downloading model...") |
|
path = huggingface_hub.snapshot_download(MODEL_REPO) |
|
print("Loading model...") |
|
model = VisionModel.load_model(path) |
|
model.eval() |
|
|
|
with open(Path(path) / 'top_tags.txt', 'r') as f: |
|
top_tags = [line.strip() for line in f.readlines() if line.strip()] |
|
|
|
print("Starting server...") |
|
|
|
gradio_app = gr.Interface( |
|
predict, |
|
inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), |
|
outputs=[ |
|
gr.Textbox(label="Tag String"), |
|
gr.Label(label="Tag Predictions", num_top_classes=100), |
|
], |
|
title="JoyTag", |
|
description=DESCRIPTION, |
|
allow_flagging="never", |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
gradio_app.launch() |
|
|