shyamgupta196 commited on
Commit
29ad364
1 Parent(s): 37de13f
Files changed (2) hide show
  1. app.py +23 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SegformerFeatureExtractor, SegformerForImageClassification
2
+ from PIL import Image
3
+ import requests
4
+
5
+ import gradio as gr
6
+
7
+
8
+
9
+ def seg(image):
10
+ feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b0")
11
+ model = SegformerForImageClassification.from_pretrained("nvidia/mit-b0")
12
+ print(model)
13
+
14
+ inputs = feature_extractor(images=image, return_tensors="pt")
15
+ outputs = model(**inputs)
16
+ logits = outputs.logits
17
+ # model predicts one of the 1000 ImageNet classes
18
+ predicted_class_idx = logits.argmax(-1).item()
19
+ return model.config.id2label[predicted_class_idx]
20
+
21
+
22
+ iface = gr.Interface(fn=seg, inputs=gr.inputs.Image(type='pil'), outputs='label')
23
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ timm
4
+ datasets