jerpint commited on
Commit
aa9394d
1 Parent(s): 64b2c77

add label component

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -28,14 +28,13 @@ def classify_bird(img):
28
  with torch.no_grad():
29
  top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
30
 
31
- _, predict = torch.max(concat_logits, 1)
32
- pred_id = predict.item()
33
- bird_class = model.bird_classes[pred_id]
34
- print(f"{bird_class=}")
35
 
36
- return bird_class
37
 
38
  image_component = gr.Image(type="pil", label="Bird Image")
39
- demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs="text")
 
40
 
41
  demo.launch()
 
28
  with torch.no_grad():
29
  top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
30
 
31
+ probs = torch.softmax(concat_logits, 1)[0]
32
+ prob_dict = {bird_cls: float(prob) for bird_cls, prob in zip(model.bird_classes, probs)}
 
 
33
 
34
+ return prob_dict
35
 
36
  image_component = gr.Image(type="pil", label="Bird Image")
37
+ label_component = gr.Label(label="Classification result", num_top_classes=3)
38
+ demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs=label_component)
39
 
40
  demo.launch()