File size: 1,790 Bytes
b879aeb
94ab357
 
 
 
b879aeb
94ab357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9394d
 
94ab357
aa9394d
94ab357
 
aa9394d
ae37634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b879aeb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np

np.int = int


model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
                       **{'topN': 6, 'device':'cpu', 'num_classes': 200})

model.eval()

def classify_bird(img):

    transform_test = transforms.Compose([
        transforms.Resize((600, 600), Image.BILINEAR),
        transforms.CenterCrop((448, 448)),
        # transforms.RandomHorizontalFlip(),  # only if train
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    scaled_img = transform_test(img)
    torch_images = scaled_img.unsqueeze(0)

    with torch.no_grad():
        top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)

        probs = torch.softmax(concat_logits, 1)[0]
        prob_dict = {bird_cls: float(prob) for bird_cls, prob in zip(model.bird_classes, probs)}

    return prob_dict

image_component = gr.Image(type="pil", label="Bird Image")
label_component = gr.Label(label="Classification result", num_top_classes=3)
description = """
## About πŸ€—

Tutorial for deploying a gradio app on huggingface. This was done during a [livestream](https://youtube.com/live/bN9WTxzLBRE) on YouTube.

## Links πŸ”—

πŸ”— YouTube Livestream: https://youtube.com/live/bN9WTxzLBRE\n
πŸ”— Torchvision Model: https://pytorch.org/hub/nicolalandro_ntsnet-cub200_ntsnet/\n
πŸ”— Paper: http://artelab.dista.uninsubria.it/res/research/papers/2019/2019-IVCNZ-Nawaz-Birds.pdf\n

"""
title = "Bird Classifier 🐣"

demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs=label_component, description=description, title=title)

demo.launch()