Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
import pandas as pd | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision import datasets, models, transforms | |
from torch_mtcnn import detect_faces | |
from torch_mtcnn import show_bboxes | |
def pipeline(img): | |
bounding_boxes, landmarks = detect_faces(img) | |
bb = [bounding_boxes[0,0], bounding_boxes[0,1], bounding_boxes[0,2], bounding_boxes[0,3]] | |
img_cropped = img.crop(bb) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_fair_7 = torchvision.models.resnet34(pretrained=True) | |
model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18) | |
model_fair_7.load_state_dict(torch.load('res34_fair_align_multi_7_20190809.pt', map_location=torch.device('cpu'))) | |
model_fair_7 = model_fair_7.to(device) | |
model_fair_7.eval() | |
trans = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
face_names = [] | |
gender_scores_fair = [] | |
age_scores_fair = [] | |
gender_preds_fair = [] | |
age_preds_fair = [] | |
image = trans(img_cropped) | |
image = image.view(1, 3, 224, 224) # reshape image to match model dimensions (1 batch size) | |
image = image.to(device) | |
# fair 7 class | |
outputs = model_fair_7(image) | |
outputs = outputs.cpu().detach().numpy() | |
outputs = np.squeeze(outputs) | |
gender_outputs = outputs[7:9] | |
age_outputs = outputs[9:18] | |
gender_score = np.exp(gender_outputs) / np.sum(np.exp(gender_outputs)) | |
age_score = np.exp(age_outputs) / np.sum(np.exp(age_outputs)) | |
gender_pred = np.argmax(gender_score) | |
age_pred = np.argmax(age_score) | |
gender_scores_fair.append(gender_score) | |
age_scores_fair.append(age_score) | |
gender_preds_fair.append(gender_pred) | |
age_preds_fair.append(age_pred) | |
result = pd.DataFrame([gender_preds_fair, | |
age_preds_fair]).T | |
result.columns = ['gender_preds_fair', | |
'age_preds_fair'] | |
# gender | |
result.loc[result['gender_preds_fair'] == 0, 'gender'] = 'Male' | |
result.loc[result['gender_preds_fair'] == 1, 'gender'] = 'Female' | |
# age | |
result.loc[result['age_preds_fair'] == 0, 'age'] = '0-2' | |
result.loc[result['age_preds_fair'] == 1, 'age'] = '3-9' | |
result.loc[result['age_preds_fair'] == 2, 'age'] = '10-19' | |
result.loc[result['age_preds_fair'] == 3, 'age'] = '20-29' | |
result.loc[result['age_preds_fair'] == 4, 'age'] = '30-39' | |
result.loc[result['age_preds_fair'] == 5, 'age'] = '40-49' | |
result.loc[result['age_preds_fair'] == 6, 'age'] = '50-59' | |
result.loc[result['age_preds_fair'] == 7, 'age'] = '60-69' | |
result.loc[result['age_preds_fair'] == 8, 'age'] = '70+' | |
return [result['gender'][0],result['age'][0]] | |
def predict(image): | |
predictions = pipeline(image) | |
return "A " + predictions[0] + " in the age range of " + predictions[1] | |
gr.Interface( | |
predict, | |
inputs=gr.inputs.Image(label="Upload a profile picture of a single person", type="pil"), | |
outputs=("text"), | |
title="Estimate age and gender from profile picture", | |
examples=["ex0.jpg","ex4.jpg", "ex1.jpg","ex2.jpg","ex3.jpg","ex5.jpg"] | |
).launch() |