Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -14,14 +14,84 @@ from torch_mtcnn import detect_faces
|
|
14 |
from torch_mtcnn import show_bboxes
|
15 |
|
16 |
# pipeline = pipeline(task="image-classification", model="njgroene/fairface")
|
|
|
|
|
|
|
|
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def predict(image):
|
19 |
predictions = pipeline(image)
|
20 |
-
return
|
21 |
|
22 |
gr.Interface(
|
23 |
predict,
|
24 |
inputs=gr.inputs.Image(label="Upload a profile picture of a single person", type="filepath"),
|
25 |
-
outputs=
|
26 |
title="Estimate age and gender from profile picture",
|
27 |
).launch()
|
|
|
14 |
from torch_mtcnn import show_bboxes
|
15 |
|
16 |
# pipeline = pipeline(task="image-classification", model="njgroene/fairface")
|
17 |
+
def pipeline(image):
|
18 |
+
bounding_boxes, landmarks = detect_faces(img)
|
19 |
+
|
20 |
+
img_cropped = img.crop(bb)
|
21 |
|
22 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23 |
+
|
24 |
+
model_fair_7 = torchvision.models.resnet34(pretrained=True)
|
25 |
+
model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
|
26 |
+
model_fair_7.load_state_dict(torch.load('res34_fair_align_multi_7_20190809.pt', map_location=torch.device('cpu')))
|
27 |
+
model_fair_7 = model_fair_7.to(device)
|
28 |
+
model_fair_7.eval()
|
29 |
+
|
30 |
+
trans = transforms.Compose([
|
31 |
+
transforms.Resize((224, 224)),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
34 |
+
])
|
35 |
+
|
36 |
+
face_names = []
|
37 |
+
gender_scores_fair = []
|
38 |
+
age_scores_fair = []
|
39 |
+
gender_preds_fair = []
|
40 |
+
age_preds_fair = []
|
41 |
+
|
42 |
+
image = trans(img_cropped)
|
43 |
+
image = image.view(1, 3, 224, 224) # reshape image to match model dimensions (1 batch size)
|
44 |
+
image = image.to(device)
|
45 |
+
|
46 |
+
# fair 7 class
|
47 |
+
outputs = model_fair_7(image)
|
48 |
+
outputs = outputs.cpu().detach().numpy()
|
49 |
+
outputs = np.squeeze(outputs)
|
50 |
+
|
51 |
+
gender_outputs = outputs[7:9]
|
52 |
+
age_outputs = outputs[9:18]
|
53 |
+
|
54 |
+
gender_score = np.exp(gender_outputs) / np.sum(np.exp(gender_outputs))
|
55 |
+
age_score = np.exp(age_outputs) / np.sum(np.exp(age_outputs))
|
56 |
+
|
57 |
+
gender_pred = np.argmax(gender_score)
|
58 |
+
age_pred = np.argmax(age_score)
|
59 |
+
|
60 |
+
gender_scores_fair.append(gender_score)
|
61 |
+
age_scores_fair.append(age_score)
|
62 |
+
|
63 |
+
gender_preds_fair.append(gender_pred)
|
64 |
+
age_preds_fair.append(age_pred)
|
65 |
+
|
66 |
+
result = pd.DataFrame([gender_preds_fair,
|
67 |
+
age_preds_fair]).T
|
68 |
+
|
69 |
+
result.columns = ['gender_preds_fair',
|
70 |
+
'age_preds_fair']
|
71 |
+
# gender
|
72 |
+
result.loc[result['gender_preds_fair'] == 0, 'gender'] = 'Male'
|
73 |
+
result.loc[result['gender_preds_fair'] == 1, 'gender'] = 'Female'
|
74 |
+
|
75 |
+
# age
|
76 |
+
result.loc[result['age_preds_fair'] == 0, 'age'] = '0-2'
|
77 |
+
result.loc[result['age_preds_fair'] == 1, 'age'] = '3-9'
|
78 |
+
result.loc[result['age_preds_fair'] == 2, 'age'] = '10-19'
|
79 |
+
result.loc[result['age_preds_fair'] == 3, 'age'] = '20-29'
|
80 |
+
result.loc[result['age_preds_fair'] == 4, 'age'] = '30-39'
|
81 |
+
result.loc[result['age_preds_fair'] == 5, 'age'] = '40-49'
|
82 |
+
result.loc[result['age_preds_fair'] == 6, 'age'] = '50-59'
|
83 |
+
result.loc[result['age_preds_fair'] == 7, 'age'] = '60-69'
|
84 |
+
result.loc[result['age_preds_fair'] == 8, 'age'] = '70+'
|
85 |
+
|
86 |
+
return [result['gender'][0],result['age'][0]]
|
87 |
+
|
88 |
def predict(image):
|
89 |
predictions = pipeline(image)
|
90 |
+
return "A " + predictions[0] + " in the age range of " + predictions[1]
|
91 |
|
92 |
gr.Interface(
|
93 |
predict,
|
94 |
inputs=gr.inputs.Image(label="Upload a profile picture of a single person", type="filepath"),
|
95 |
+
outputs=("text"),
|
96 |
title="Estimate age and gender from profile picture",
|
97 |
).launch()
|