Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
np.int = int | |
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, | |
**{'topN': 6, 'device':'cpu', 'num_classes': 200}) | |
model.eval() | |
def classify_bird(img): | |
transform_test = transforms.Compose([ | |
transforms.Resize((600, 600), Image.BILINEAR), | |
transforms.CenterCrop((448, 448)), | |
# transforms.RandomHorizontalFlip(), # only if train | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
scaled_img = transform_test(img) | |
torch_images = scaled_img.unsqueeze(0) | |
with torch.no_grad(): | |
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images) | |
_, predict = torch.max(concat_logits, 1) | |
pred_id = predict.item() | |
bird_class = model.bird_classes[pred_id] | |
print(f"{bird_class=}") | |
return bird_class | |
image_component = gr.Image(type="pil", label="Bird Image") | |
demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs="text") | |
demo.launch() | |