DimaML commited on
Commit
f9e07cc
1 Parent(s): 5963eb5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms, datasets, models
4
+ import gradio as gr
5
+
6
+ transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms()
7
+ transformer
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ class_names = ['Ahegao', 'Angry', 'Happy', 'Neutral', 'Sad', 'Surprise']
12
+ classes_count = len(class_names)
13
+
14
+ model = models.resnet18(weights='DEFAULT').to(device)
15
+ model.fc = nn.Sequential(
16
+ nn.Linear(512, classes_count)
17
+ )
18
+ model.load_state_dict(torch.load('./model_params.pt', map_location=device), strict=False)
19
+
20
+ def predict(image):
21
+ transformed_image = transformer(image).unsqueeze(0).to(device)
22
+ model.eval()
23
+
24
+ with torch.inference_mode():
25
+ pred = torch.softmax(model(transformed_image), dim=1)
26
+
27
+ pred_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))}
28
+
29
+ return pred_and_labels
30
+
31
+ title = "Emotion Checker"
32
+ description = "Can classify 6 emotions: Ahegao, Angry, Happy, Neutral, Sad, Surprise"
33
+
34
+ examples = [
35
+ './example_1.jpg',
36
+ './example_2.jpg',
37
+ './example_3.jpg',
38
+ './example_4.jpg',
39
+ './example_5.jpg',
40
+ './example_6.jpg',
41
+ ]
42
+
43
+
44
+ app = gr.Interface(
45
+ fn=predict,
46
+ inputs=gr.Image(type="pil"),
47
+ outputs=[gr.Label(num_top_classes=classes_count, label="Predictions")],
48
+ examples=examples,
49
+ title=title,
50
+ description=description
51
+ )
52
+
53
+ app.launch(
54
+ share=True,
55
+ height=800
56
+ )