bird-classifier / app.py
jerpint's picture
actually classify images
94ab357
raw
history blame
1.21 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
np.int = int
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
**{'topN': 6, 'device':'cpu', 'num_classes': 200})
model.eval()
def classify_bird(img):
transform_test = transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
# transforms.RandomHorizontalFlip(), # only if train
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
scaled_img = transform_test(img)
torch_images = scaled_img.unsqueeze(0)
with torch.no_grad():
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
_, predict = torch.max(concat_logits, 1)
pred_id = predict.item()
bird_class = model.bird_classes[pred_id]
print(f"{bird_class=}")
return bird_class
image_component = gr.Image(type="pil", label="Bird Image")
demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs="text")
demo.launch()