Spaces:
Runtime error
Runtime error
add label component
Browse files
app.py
CHANGED
@@ -28,14 +28,13 @@ def classify_bird(img):
|
|
28 |
with torch.no_grad():
|
29 |
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
bird_class = model.bird_classes[pred_id]
|
34 |
-
print(f"{bird_class=}")
|
35 |
|
36 |
-
return
|
37 |
|
38 |
image_component = gr.Image(type="pil", label="Bird Image")
|
39 |
-
|
|
|
40 |
|
41 |
demo.launch()
|
|
|
28 |
with torch.no_grad():
|
29 |
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
|
30 |
|
31 |
+
probs = torch.softmax(concat_logits, 1)[0]
|
32 |
+
prob_dict = {bird_cls: float(prob) for bird_cls, prob in zip(model.bird_classes, probs)}
|
|
|
|
|
33 |
|
34 |
+
return prob_dict
|
35 |
|
36 |
image_component = gr.Image(type="pil", label="Bird Image")
|
37 |
+
label_component = gr.Label(label="Classification result", num_top_classes=3)
|
38 |
+
demo = gr.Interface(fn=classify_bird, inputs=image_component, outputs=label_component)
|
39 |
|
40 |
demo.launch()
|