ElenaRyumina
commited on
Commit
β’
031ec86
1
Parent(s):
c601827
Summary
Browse files- .gitignore +171 -0
- LICENSE +21 -0
- app.py +191 -0
- images/fig1.jpg +0 -0
- images/fig2.jpg +0 -0
- images/fig3.jpg +0 -0
- images/fig4.jpg +0 -0
- images/fig5.jpg +0 -0
- images/fig6.jpg +0 -0
- images/fig7.jpg +0 -0
.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Compiled source #
|
2 |
+
###################
|
3 |
+
*.com
|
4 |
+
*.class
|
5 |
+
*.dll
|
6 |
+
*.exe
|
7 |
+
*.o
|
8 |
+
*.so
|
9 |
+
*.pyc
|
10 |
+
|
11 |
+
# Packages #
|
12 |
+
############
|
13 |
+
# it's better to unpack these files and commit the raw source
|
14 |
+
# git has its own built in compression methods
|
15 |
+
*.7z
|
16 |
+
*.dmg
|
17 |
+
*.gz
|
18 |
+
*.iso
|
19 |
+
*.rar
|
20 |
+
#*.tar
|
21 |
+
*.zip
|
22 |
+
|
23 |
+
# Logs and databases #
|
24 |
+
######################
|
25 |
+
*.log
|
26 |
+
*.sqlite
|
27 |
+
|
28 |
+
# OS generated files #
|
29 |
+
######################
|
30 |
+
.DS_Store
|
31 |
+
ehthumbs.db
|
32 |
+
Icon
|
33 |
+
Thumbs.db
|
34 |
+
.tmtags
|
35 |
+
.idea
|
36 |
+
.vscode
|
37 |
+
tags
|
38 |
+
vendor.tags
|
39 |
+
tmtagsHistory
|
40 |
+
*.sublime-project
|
41 |
+
*.sublime-workspace
|
42 |
+
.bundle
|
43 |
+
|
44 |
+
# Byte-compiled / optimized / DLL files
|
45 |
+
__pycache__/
|
46 |
+
*.py[cod]
|
47 |
+
*$py.class
|
48 |
+
|
49 |
+
# C extensions
|
50 |
+
*.so
|
51 |
+
|
52 |
+
# Distribution / packaging
|
53 |
+
.Python
|
54 |
+
build/
|
55 |
+
develop-eggs/
|
56 |
+
dist/
|
57 |
+
downloads/
|
58 |
+
eggs/
|
59 |
+
.eggs/
|
60 |
+
lib/
|
61 |
+
lib64/
|
62 |
+
parts/
|
63 |
+
sdist/
|
64 |
+
var/
|
65 |
+
wheels/
|
66 |
+
pip-wheel-metadata/
|
67 |
+
share/python-wheels/
|
68 |
+
*.egg-info/
|
69 |
+
.installed.cfg
|
70 |
+
*.egg
|
71 |
+
MANIFEST
|
72 |
+
node_modules/
|
73 |
+
|
74 |
+
# PyInstaller
|
75 |
+
# Usually these files are written by a python script from a template
|
76 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
77 |
+
*.manifest
|
78 |
+
*.spec
|
79 |
+
|
80 |
+
# Installer logs
|
81 |
+
pip-log.txt
|
82 |
+
pip-delete-this-directory.txt
|
83 |
+
|
84 |
+
# Unit test / coverage reports
|
85 |
+
htmlcov/
|
86 |
+
.tox/
|
87 |
+
.nox/
|
88 |
+
.coverage
|
89 |
+
.coverage.*
|
90 |
+
.cache
|
91 |
+
nosetests.xml
|
92 |
+
coverage.xml
|
93 |
+
*.cover
|
94 |
+
.hypothesis/
|
95 |
+
.pytest_cache/
|
96 |
+
|
97 |
+
# Translations
|
98 |
+
*.mo
|
99 |
+
*.pot
|
100 |
+
|
101 |
+
# Django stuff:
|
102 |
+
*.log
|
103 |
+
local_settings.py
|
104 |
+
db.sqlite3
|
105 |
+
db.sqlite3-journal
|
106 |
+
|
107 |
+
# Flask stuff:
|
108 |
+
instance/
|
109 |
+
.webassets-cache
|
110 |
+
|
111 |
+
# Scrapy stuff:
|
112 |
+
.scrapy
|
113 |
+
|
114 |
+
# Sphinx documentation
|
115 |
+
docs/_build/
|
116 |
+
|
117 |
+
# PyBuilder
|
118 |
+
target/
|
119 |
+
|
120 |
+
# Jupyter Notebook
|
121 |
+
.ipynb_checkpoints
|
122 |
+
|
123 |
+
# IPython
|
124 |
+
profile_default/
|
125 |
+
ipython_config.py
|
126 |
+
|
127 |
+
# pyenv
|
128 |
+
.python-version
|
129 |
+
|
130 |
+
# pipenv
|
131 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
132 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
133 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
134 |
+
# install all needed dependencies.
|
135 |
+
#Pipfile.lock
|
136 |
+
|
137 |
+
# celery beat schedule file
|
138 |
+
celerybeat-schedule
|
139 |
+
|
140 |
+
# SageMath parsed files
|
141 |
+
*.sage.py
|
142 |
+
|
143 |
+
# Environments
|
144 |
+
.env
|
145 |
+
.venv
|
146 |
+
env/
|
147 |
+
venv/
|
148 |
+
ENV/
|
149 |
+
env.bak/
|
150 |
+
venv.bak/
|
151 |
+
|
152 |
+
# Spyder project settings
|
153 |
+
.spyderproject
|
154 |
+
.spyproject
|
155 |
+
|
156 |
+
# Rope project settings
|
157 |
+
.ropeproject
|
158 |
+
|
159 |
+
# mkdocs documentation
|
160 |
+
/site
|
161 |
+
|
162 |
+
# mypy
|
163 |
+
.mypy_cache/
|
164 |
+
.dmypy.json
|
165 |
+
dmypy.json
|
166 |
+
|
167 |
+
# Pyre type checker
|
168 |
+
.pyre/
|
169 |
+
|
170 |
+
# Custom
|
171 |
+
*.pth
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Elena Ryumina
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
from torchvision import transforms
|
4 |
+
import mediapipe as mp
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
import requests
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
model_url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth"
|
12 |
+
model_path = "FER_static_ResNet50_AffectNet.pth"
|
13 |
+
|
14 |
+
response = requests.get(model_url, stream=True)
|
15 |
+
with open(model_path, 'wb') as file:
|
16 |
+
for chunk in response.iter_content(chunk_size=8192):
|
17 |
+
file.write(chunk)
|
18 |
+
|
19 |
+
pth_model = torch.jit.load(model_path).to('cuda')
|
20 |
+
pth_model.eval()
|
21 |
+
|
22 |
+
DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}
|
23 |
+
|
24 |
+
mp_face_mesh = mp.solutions.face_mesh
|
25 |
+
|
26 |
+
def pth_processing(fp):
|
27 |
+
class PreprocessInput(torch.nn.Module):
|
28 |
+
def init(self):
|
29 |
+
super(PreprocessInput, self).init()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = x.to(torch.float32)
|
33 |
+
x = torch.flip(x, dims=(0,))
|
34 |
+
x[0, :, :] -= 91.4953
|
35 |
+
x[1, :, :] -= 103.8827
|
36 |
+
x[2, :, :] -= 131.0912
|
37 |
+
return x
|
38 |
+
|
39 |
+
def get_img_torch(img):
|
40 |
+
|
41 |
+
ttransform = transforms.Compose([
|
42 |
+
transforms.PILToTensor(),
|
43 |
+
PreprocessInput()
|
44 |
+
])
|
45 |
+
img = img.resize((224, 224), Image.Resampling.NEAREST)
|
46 |
+
img = ttransform(img)
|
47 |
+
img = torch.unsqueeze(img, 0).to('cuda')
|
48 |
+
return img
|
49 |
+
return get_img_torch(fp)
|
50 |
+
|
51 |
+
def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
|
52 |
+
|
53 |
+
x_px = min(math.floor(normalized_x * image_width), image_width - 1)
|
54 |
+
y_px = min(math.floor(normalized_y * image_height), image_height - 1)
|
55 |
+
|
56 |
+
return x_px, y_px
|
57 |
+
|
58 |
+
def get_box(fl, w, h):
|
59 |
+
idx_to_coors = {}
|
60 |
+
for idx, landmark in enumerate(fl.landmark):
|
61 |
+
landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)
|
62 |
+
|
63 |
+
if landmark_px:
|
64 |
+
idx_to_coors[idx] = landmark_px
|
65 |
+
|
66 |
+
x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0])
|
67 |
+
y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1])
|
68 |
+
endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0])
|
69 |
+
endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1])
|
70 |
+
|
71 |
+
(startX, startY) = (max(0, x_min), max(0, y_min))
|
72 |
+
(endX, endY) = (min(w - 1, endX), min(h - 1, endY))
|
73 |
+
|
74 |
+
return startX, startY, endX, endY
|
75 |
+
|
76 |
+
def predict(inp):
|
77 |
+
|
78 |
+
inp = np.array(inp)
|
79 |
+
h, w = inp.shape[:2]
|
80 |
+
|
81 |
+
with mp_face_mesh.FaceMesh(
|
82 |
+
max_num_faces=1,
|
83 |
+
refine_landmarks=False,
|
84 |
+
min_detection_confidence=0.5,
|
85 |
+
min_tracking_confidence=0.5) as face_mesh:
|
86 |
+
results = face_mesh.process(inp)
|
87 |
+
if results.multi_face_landmarks:
|
88 |
+
for fl in results.multi_face_landmarks:
|
89 |
+
startX, startY, endX, endY = get_box(fl, w, h)
|
90 |
+
cur_face = inp[startY:endY, startX: endX]
|
91 |
+
cur_face_n = pth_processing(Image.fromarray(cur_face))
|
92 |
+
prediction = torch.nn.functional.softmax(pth_model(cur_face_n), dim=1).cpu().detach().numpy()[0]
|
93 |
+
confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
|
94 |
+
|
95 |
+
return cur_face, confidences
|
96 |
+
|
97 |
+
def clear():
|
98 |
+
return (
|
99 |
+
gr.Image(value=None, type="pil"),
|
100 |
+
gr.Image(value=None,scale=1, elem_classes="dl2"),
|
101 |
+
gr.Label(value=None,num_top_classes=3, scale=1, elem_classes="dl3")
|
102 |
+
)
|
103 |
+
|
104 |
+
style = """
|
105 |
+
div.dl1 div.upload-container {
|
106 |
+
height: 350px;
|
107 |
+
max-height: 350px;
|
108 |
+
}
|
109 |
+
|
110 |
+
div.dl2 {
|
111 |
+
max-height: 200px;
|
112 |
+
}
|
113 |
+
|
114 |
+
div.dl2 img {
|
115 |
+
max-height: 200px;
|
116 |
+
}
|
117 |
+
|
118 |
+
.submit {
|
119 |
+
display: inline-block;
|
120 |
+
padding: 10px 20px;
|
121 |
+
font-size: 16px;
|
122 |
+
font-weight: bold;
|
123 |
+
text-align: center;
|
124 |
+
text-decoration: none;
|
125 |
+
cursor: pointer;
|
126 |
+
border: var(--button-border-width) solid var(--button-primary-border-color);
|
127 |
+
background: var(--button-primary-background-fill);
|
128 |
+
color: var(--button-primary-text-color);
|
129 |
+
border-radius: 8px;
|
130 |
+
transition: all 0.3s ease;
|
131 |
+
}
|
132 |
+
|
133 |
+
.submit[disabled] {
|
134 |
+
cursor: not-allowed;
|
135 |
+
opacity: 0.6;
|
136 |
+
}
|
137 |
+
|
138 |
+
.submit:hover:not([disabled]) {
|
139 |
+
border-color: var(--button-primary-border-color-hover);
|
140 |
+
background: var(--button-primary-background-fill-hover);
|
141 |
+
color: var(--button-primary-text-color-hover);
|
142 |
+
}
|
143 |
+
|
144 |
+
.submit:active:not([disabled]) {
|
145 |
+
transform: scale(0.98);
|
146 |
+
}
|
147 |
+
"""
|
148 |
+
|
149 |
+
with gr.Blocks(css=style) as demo:
|
150 |
+
with gr.Row():
|
151 |
+
with gr.Column(scale=2, elem_classes="dl1"):
|
152 |
+
input_image = gr.Image(type="pil")
|
153 |
+
with gr.Row():
|
154 |
+
submit = gr.Button(
|
155 |
+
value="Submit", interactive=True, scale=1, elem_classes="submit"
|
156 |
+
)
|
157 |
+
clear_btn = gr.Button(
|
158 |
+
value="Clear", interactive=True, scale=1
|
159 |
+
)
|
160 |
+
with gr.Column(scale=1, elem_classes="dl4"):
|
161 |
+
output_image = gr.Image(scale=1, elem_classes="dl2")
|
162 |
+
output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3")
|
163 |
+
gr.Examples(
|
164 |
+
["images/fig7.jpg", "images/fig1.jpg", "images/fig2.jpg","images/fig3.jpg",
|
165 |
+
"images/fig4.jpg", "images/fig5.jpg", "images/fig6.jpg"],
|
166 |
+
[input_image],
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
submit.click(
|
171 |
+
fn=predict,
|
172 |
+
inputs=[input_image],
|
173 |
+
outputs=[
|
174 |
+
output_image,
|
175 |
+
output_label
|
176 |
+
],
|
177 |
+
queue=True,
|
178 |
+
)
|
179 |
+
clear_btn.click(
|
180 |
+
fn=clear,
|
181 |
+
inputs=[],
|
182 |
+
outputs=[
|
183 |
+
input_image,
|
184 |
+
output_image,
|
185 |
+
output_label,
|
186 |
+
],
|
187 |
+
queue=True,
|
188 |
+
)
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
demo.queue(api_open=False).launch(share=False)
|
images/fig1.jpg
ADDED
images/fig2.jpg
ADDED
images/fig3.jpg
ADDED
images/fig4.jpg
ADDED
images/fig5.jpg
ADDED
images/fig6.jpg
ADDED
images/fig7.jpg
ADDED