jerpint commited on
Commit
94ab357
1 Parent(s): b879aeb

actually classify images

Browse files
Files changed (1) hide show
  1. app.py +37 -3
app.py CHANGED
@@ -1,7 +1,41 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Bonjour " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import numpy as np
6
 
7
+ np.int = int
8
+
9
+
10
+ model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
11
+ **{'topN': 6, 'device':'cpu', 'num_classes': 200})
12
+
13
+ model.eval()
14
+
15
+ def classify_bird(img):
16
+
17
+ transform_test = transforms.Compose([
18
+ transforms.Resize((600, 600), Image.BILINEAR),
19
+ transforms.CenterCrop((448, 448)),
20
+ # transforms.RandomHorizontalFlip(), # only if train
21
+ transforms.ToTensor(),
22
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
23
+ ])
24
+
25
+ scaled_img = transform_test(img)
26
+ torch_images = scaled_img.unsqueeze(0)
27
+
28
+ with torch.no_grad():
29
+ top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
30
+
31
+ _, predict = torch.max(concat_logits, 1)
32
+ pred_id = predict.item()
33
+ bird_class = model.bird_classes[pred_id]
34
+ print(f"{bird_class=}")
35
+
36
+ return bird_class
37
+
38
+ image_component = gr.Image(type="pil", label="Bird Image")
39
+ demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs="text")
40
 
 
41
  demo.launch()