Spaces:
Runtime error
Runtime error
File size: 1,790 Bytes
b879aeb 94ab357 b879aeb 94ab357 aa9394d 94ab357 aa9394d 94ab357 aa9394d ae37634 b879aeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
np.int = int
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
**{'topN': 6, 'device':'cpu', 'num_classes': 200})
model.eval()
def classify_bird(img):
transform_test = transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
# transforms.RandomHorizontalFlip(), # only if train
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
scaled_img = transform_test(img)
torch_images = scaled_img.unsqueeze(0)
with torch.no_grad():
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
probs = torch.softmax(concat_logits, 1)[0]
prob_dict = {bird_cls: float(prob) for bird_cls, prob in zip(model.bird_classes, probs)}
return prob_dict
image_component = gr.Image(type="pil", label="Bird Image")
label_component = gr.Label(label="Classification result", num_top_classes=3)
description = """
## About π€
Tutorial for deploying a gradio app on huggingface. This was done during a [livestream](https://youtube.com/live/bN9WTxzLBRE) on YouTube.
## Links π
π YouTube Livestream: https://youtube.com/live/bN9WTxzLBRE\n
π Torchvision Model: https://pytorch.org/hub/nicolalandro_ntsnet-cub200_ntsnet/\n
π Paper: http://artelab.dista.uninsubria.it/res/research/papers/2019/2019-IVCNZ-Nawaz-Birds.pdf\n
"""
title = "Bird Classifier π£"
demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs=label_component, description=description, title=title)
demo.launch()
|