Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision import transforms, datasets, models | |
import gradio as gr | |
transformer = models.ResNet18_Weights.IMAGENET1K_V1.transforms() | |
transformer | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class_names = ['Ahegao', 'Angry', 'Happy', 'Neutral', 'Sad', 'Surprise'] | |
classes_count = len(class_names) | |
model = models.resnet18(weights='DEFAULT').to(device) | |
model.fc = nn.Sequential( | |
nn.Linear(512, classes_count) | |
) | |
model.load_state_dict(torch.load('./model_params.pt', map_location=device), strict=False) | |
def predict(image): | |
transformed_image = transformer(image).unsqueeze(0).to(device) | |
model.eval() | |
with torch.inference_mode(): | |
pred = torch.softmax(model(transformed_image), dim=1) | |
pred_and_labels = {class_names[i]: pred[0][i].item() for i in range(len(pred[0]))} | |
return pred_and_labels | |
title = "Emotion Checker" | |
description = "Can classify 6 emotions: Ahegao, Angry, Happy, Neutral, Sad, Surprise" | |
examples = [ | |
'./example_1.jpg', | |
'./example_2.jpg', | |
'./example_3.jpg', | |
'./example_4.jpg', | |
'./example_5.jpg', | |
'./example_6.jpg', | |
] | |
app = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Label(num_top_classes=classes_count, label="Predictions")], | |
examples=examples, | |
title=title, | |
description=description | |
) | |
app.launch( | |
share=True, | |
height=800 | |
) |