Tirath5504 commited on
Commit
b069732
1 Parent(s): cae1118

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -0
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTForImageClassification, ViTImageProcessor
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ model = ViTForImageClassification.from_pretrained('vit-hateful-gesture-classification')
7
+ processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
8
+
9
+ class_names = ['cut_throat_gesture', 'finger_gun_to_the_head', 'middle_finger', 'slanted_eyes_gesture', 'swastika']
10
+
11
+ def predict(image):
12
+ inputs = processor(images=image, return_tensors="pt")
13
+
14
+ with torch.no_grad():
15
+ outputs = model(**inputs).logits
16
+
17
+ predicted_class_idx = outputs.argmax(-1).item()
18
+ predicted_class = class_names[predicted_class_idx]
19
+
20
+ return predicted_class
21
+
22
+ iface = gr.Interface(fn=predict,
23
+ inputs=gr.inputs.Image(type="pil"),
24
+ outputs=gr.outputs.Label(num_top_classes=1),
25
+ title="Hateful Gesture Detection",
26
+ description="Upload an image to classify hateful gestures or symbols")
27
+
28
+ if __name__ == "__main__":
29
+ iface.launch(share=True)