Spaces:
Build error
Build error
vincentclaes
commited on
Commit
β’
e45afa6
1
Parent(s):
0df1067
have a working model
Browse files- README.md +1 -1
- app.py +59 -6
- poetry.lock +10 -68
- pyproject.toml +1 -0
- requirements.txt +1 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Emoji Predictor
|
3 |
-
emoji:
|
4 |
colorFrom: pink
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Emoji Predictor
|
3 |
+
emoji: π
|
4 |
colorFrom: pink
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -1,10 +1,16 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
|
|
3 |
from PIL import Image
|
|
|
|
|
|
|
4 |
from transformers import CLIPProcessor, CLIPModel
|
5 |
|
6 |
checkpoint = "vincentclaes/emoji-predictor"
|
7 |
-
|
|
|
8 |
emojis_as_images = [Image.open(f"emojis/{i}.png") for i in no_of_emojis]
|
9 |
K = 4
|
10 |
|
@@ -12,6 +18,29 @@ processor = CLIPProcessor.from_pretrained(checkpoint)
|
|
12 |
model = CLIPModel.from_pretrained(checkpoint)
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K=4):
|
16 |
inputs = processor(text=text, images=emojis, return_tensors="pt", padding=True, truncation=True)
|
17 |
outputs = model(**inputs)
|
@@ -23,11 +52,35 @@ def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K
|
|
23 |
predictions_suggestions_for_chunk = [torch.topk(prob, K).indices.tolist() for prob in probs][0]
|
24 |
predictions_suggestions_for_chunk
|
25 |
|
26 |
-
|
|
|
|
|
27 |
|
28 |
|
29 |
-
text = gr.inputs.Textbox()
|
30 |
title = "Predicting an Emoji"
|
31 |
-
description = "
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
from PIL import Image
|
6 |
+
from pathlib import Path
|
7 |
+
from more_itertools import chunked
|
8 |
+
|
9 |
from transformers import CLIPProcessor, CLIPModel
|
10 |
|
11 |
checkpoint = "vincentclaes/emoji-predictor"
|
12 |
+
x_, _, files = next(os.walk("./emojis"))
|
13 |
+
no_of_emojis = range(len(files))
|
14 |
emojis_as_images = [Image.open(f"emojis/{i}.png") for i in no_of_emojis]
|
15 |
K = 4
|
16 |
|
|
|
18 |
model = CLIPModel.from_pretrained(checkpoint)
|
19 |
|
20 |
|
21 |
+
def concat_images(*images):
|
22 |
+
"""Generate composite of all supplied images.
|
23 |
+
https://stackoverflow.com/a/71315656/1771155
|
24 |
+
"""
|
25 |
+
# Get the widest width.
|
26 |
+
width = max(image.width for image in images)
|
27 |
+
# Add up all the heights.
|
28 |
+
height = max(image.height for image in images)
|
29 |
+
# set the correct size of width and heigtht of composite.
|
30 |
+
composite = Image.new('RGB', (2*width, 2*height))
|
31 |
+
assert K == 4, "We expect 4 suggestions, other numbers won't work."
|
32 |
+
for i, image in enumerate(images):
|
33 |
+
if i == 0:
|
34 |
+
composite.paste(image, (0, 0))
|
35 |
+
elif i == 1:
|
36 |
+
composite.paste(image, (width, 0))
|
37 |
+
elif i == 2:
|
38 |
+
composite.paste(image, (0, height))
|
39 |
+
elif i == 3:
|
40 |
+
composite.paste(image, (width, height))
|
41 |
+
return composite
|
42 |
+
|
43 |
+
|
44 |
def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K=4):
|
45 |
inputs = processor(text=text, images=emojis, return_tensors="pt", padding=True, truncation=True)
|
46 |
outputs = model(**inputs)
|
|
|
52 |
predictions_suggestions_for_chunk = [torch.topk(prob, K).indices.tolist() for prob in probs][0]
|
53 |
predictions_suggestions_for_chunk
|
54 |
|
55 |
+
images = [Image.open(f"emojis/{i}.png") for i in predictions_suggestions_for_chunk]
|
56 |
+
images_concat = concat_images(*images)
|
57 |
+
return images_concat
|
58 |
|
59 |
|
60 |
+
text = gr.inputs.Textbox(placeholder="Enter a text and we will try to predict an emoji...")
|
61 |
title = "Predicting an Emoji"
|
62 |
+
description = """You provide a sentence and our few-shot fine tuned CLIP model will predict from the following emoji's:
|
63 |
+
\nβ€οΈ π π π π₯ π π β¨ π π π· πΊπΈ β π π π― π π πΈ π βΉοΈ π π π‘ π’ π€ π³ π π© π π π\n
|
64 |
+
"""
|
65 |
+
article = """
|
66 |
+
\n
|
67 |
+
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
68 |
+
\n
|
69 |
+
We fine tuned Open Ai's CLIP model on both text (tweets) and images of emoji's!\n
|
70 |
+
The current model is fine-tuned on 15 samples per emoji.
|
71 |
+
|
72 |
+
- model: https://huggingface.co/vincentclaes/emoji-predictor \n
|
73 |
+
- dataset: https://huggingface.co/datasets/vincentclaes/emoji-predictor \n
|
74 |
+
- code: https://github.com/vincentclaes/emoji-predictor \n
|
75 |
+
- profile: https://huggingface.co/vincentclaes \n
|
76 |
+
"""
|
77 |
+
examples = [
|
78 |
+
"I'm so happy for you!",
|
79 |
+
"I'm not feeling great today.",
|
80 |
+
"This makes me angry!",
|
81 |
+
"Can I follow you?",
|
82 |
+
"I'm so bored right now ...",
|
83 |
+
]
|
84 |
+
gr.Interface(fn=get_emoji, inputs=text, outputs=gr.Image(shape=(72,72)),
|
85 |
+
examples=examples, title=title, description=description,
|
86 |
+
article=article).launch()
|
poetry.lock
CHANGED
@@ -155,17 +155,6 @@ category = "main"
|
|
155 |
optional = false
|
156 |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
157 |
|
158 |
-
[[package]]
|
159 |
-
name = "commonmark"
|
160 |
-
version = "0.9.1"
|
161 |
-
description = "Python parser for the CommonMark Markdown spec"
|
162 |
-
category = "main"
|
163 |
-
optional = false
|
164 |
-
python-versions = "*"
|
165 |
-
|
166 |
-
[package.extras]
|
167 |
-
test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
|
168 |
-
|
169 |
[[package]]
|
170 |
name = "contourpy"
|
171 |
version = "1.0.5"
|
@@ -211,29 +200,6 @@ category = "main"
|
|
211 |
optional = false
|
212 |
python-versions = ">=3.6"
|
213 |
|
214 |
-
[[package]]
|
215 |
-
name = "docarray"
|
216 |
-
version = "0.16.5"
|
217 |
-
description = "The data structure for unstructured data"
|
218 |
-
category = "main"
|
219 |
-
optional = false
|
220 |
-
python-versions = "*"
|
221 |
-
|
222 |
-
[package.dependencies]
|
223 |
-
numpy = "*"
|
224 |
-
rich = ">=12.0.0"
|
225 |
-
|
226 |
-
[package.extras]
|
227 |
-
annlite = ["annlite (>=0.3.10)"]
|
228 |
-
benchmark = ["pandas", "seaborn"]
|
229 |
-
common = ["protobuf (>=3.13.0)", "lz4", "requests", "matplotlib", "pillow", "fastapi", "uvicorn", "jina-hubble-sdk (>=0.11.0)"]
|
230 |
-
elasticsearch = ["elasticsearch (>=8.2.0)"]
|
231 |
-
full = ["protobuf (>=3.13.0)", "lz4", "requests", "matplotlib", "pillow", "trimesh", "scipy", "jina-hubble-sdk (>=0.10.0)", "av", "fastapi", "uvicorn", "strawberry-graphql"]
|
232 |
-
qdrant = ["qdrant-client (>=0.7.3,<0.8.0)"]
|
233 |
-
redis = ["redis (>=4.3.0)"]
|
234 |
-
test = ["pytest", "pytest-timeout", "pytest-mock", "pytest-cov", "pytest-repeat", "pytest-reraise", "mock", "pytest-custom-exit-code", "black (==22.3.0)", "tensorflow (==2.7.0)", "paddlepaddle (==2.2.0)", "torch (==1.9.0)", "torchvision (==0.10.0)", "datasets", "onnx", "onnxruntime", "jupyterlab", "transformers (>=4.16.2)", "weaviate-client (>=3.3.0,<3.4.0)", "annlite (>=0.3.10)", "elasticsearch (>=8.2.0)", "redis (>=4.3.0)", "jina"]
|
235 |
-
weaviate = ["weaviate-client (>=3.3.0,<3.4.0)"]
|
236 |
-
|
237 |
[[package]]
|
238 |
name = "fastapi"
|
239 |
version = "0.85.0"
|
@@ -567,6 +533,14 @@ category = "main"
|
|
567 |
optional = false
|
568 |
python-versions = "*"
|
569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
[[package]]
|
571 |
name = "multidict"
|
572 |
version = "6.0.2"
|
@@ -694,17 +668,6 @@ category = "main"
|
|
694 |
optional = false
|
695 |
python-versions = "*"
|
696 |
|
697 |
-
[[package]]
|
698 |
-
name = "pygments"
|
699 |
-
version = "2.13.0"
|
700 |
-
description = "Pygments is a syntax highlighting package written in Python."
|
701 |
-
category = "main"
|
702 |
-
optional = false
|
703 |
-
python-versions = ">=3.6"
|
704 |
-
|
705 |
-
[package.extras]
|
706 |
-
plugins = ["importlib-metadata"]
|
707 |
-
|
708 |
[[package]]
|
709 |
name = "pynacl"
|
710 |
version = "1.5.0"
|
@@ -809,21 +772,6 @@ idna = {version = "*", optional = true, markers = "extra == \"idna2008\""}
|
|
809 |
[package.extras]
|
810 |
idna2008 = ["idna"]
|
811 |
|
812 |
-
[[package]]
|
813 |
-
name = "rich"
|
814 |
-
version = "12.5.1"
|
815 |
-
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
816 |
-
category = "main"
|
817 |
-
optional = false
|
818 |
-
python-versions = ">=3.6.3,<4.0.0"
|
819 |
-
|
820 |
-
[package.dependencies]
|
821 |
-
commonmark = ">=0.9.0,<0.10.0"
|
822 |
-
pygments = ">=2.6.0,<3.0.0"
|
823 |
-
|
824 |
-
[package.extras]
|
825 |
-
jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
|
826 |
-
|
827 |
[[package]]
|
828 |
name = "setuptools-scm"
|
829 |
version = "7.0.5"
|
@@ -1051,7 +999,7 @@ multidict = ">=4.0"
|
|
1051 |
[metadata]
|
1052 |
lock-version = "1.1"
|
1053 |
python-versions = "^3.9"
|
1054 |
-
content-hash = "
|
1055 |
|
1056 |
[metadata.files]
|
1057 |
aiohttp = [
|
@@ -1152,17 +1100,12 @@ click = [
|
|
1152 |
{file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
|
1153 |
]
|
1154 |
colorama = []
|
1155 |
-
commonmark = [
|
1156 |
-
{file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
|
1157 |
-
{file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
|
1158 |
-
]
|
1159 |
contourpy = []
|
1160 |
cryptography = []
|
1161 |
cycler = [
|
1162 |
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
|
1163 |
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
|
1164 |
]
|
1165 |
-
docarray = []
|
1166 |
fastapi = []
|
1167 |
ffmpy = []
|
1168 |
filelock = []
|
@@ -1225,6 +1168,7 @@ matplotlib = []
|
|
1225 |
mdit-py-plugins = []
|
1226 |
mdurl = []
|
1227 |
monotonic = []
|
|
|
1228 |
multidict = [
|
1229 |
{file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"},
|
1230 |
{file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"},
|
@@ -1302,7 +1246,6 @@ pycparser = [
|
|
1302 |
pycryptodome = []
|
1303 |
pydantic = []
|
1304 |
pydub = []
|
1305 |
-
pygments = []
|
1306 |
pynacl = []
|
1307 |
pyparsing = [
|
1308 |
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
|
@@ -1354,7 +1297,6 @@ pyyaml = [
|
|
1354 |
regex = []
|
1355 |
requests = []
|
1356 |
rfc3986 = []
|
1357 |
-
rich = []
|
1358 |
setuptools-scm = []
|
1359 |
six = [
|
1360 |
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
|
|
|
155 |
optional = false
|
156 |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
[[package]]
|
159 |
name = "contourpy"
|
160 |
version = "1.0.5"
|
|
|
200 |
optional = false
|
201 |
python-versions = ">=3.6"
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
[[package]]
|
204 |
name = "fastapi"
|
205 |
version = "0.85.0"
|
|
|
533 |
optional = false
|
534 |
python-versions = "*"
|
535 |
|
536 |
+
[[package]]
|
537 |
+
name = "more-itertools"
|
538 |
+
version = "8.14.0"
|
539 |
+
description = "More routines for operating on iterables, beyond itertools"
|
540 |
+
category = "main"
|
541 |
+
optional = false
|
542 |
+
python-versions = ">=3.5"
|
543 |
+
|
544 |
[[package]]
|
545 |
name = "multidict"
|
546 |
version = "6.0.2"
|
|
|
668 |
optional = false
|
669 |
python-versions = "*"
|
670 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
671 |
[[package]]
|
672 |
name = "pynacl"
|
673 |
version = "1.5.0"
|
|
|
772 |
[package.extras]
|
773 |
idna2008 = ["idna"]
|
774 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
775 |
[[package]]
|
776 |
name = "setuptools-scm"
|
777 |
version = "7.0.5"
|
|
|
999 |
[metadata]
|
1000 |
lock-version = "1.1"
|
1001 |
python-versions = "^3.9"
|
1002 |
+
content-hash = "5bc12d64b69b9c1f0f68ae6858e97ba26663256bae5a9172c0f5bb69402f6c62"
|
1003 |
|
1004 |
[metadata.files]
|
1005 |
aiohttp = [
|
|
|
1100 |
{file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
|
1101 |
]
|
1102 |
colorama = []
|
|
|
|
|
|
|
|
|
1103 |
contourpy = []
|
1104 |
cryptography = []
|
1105 |
cycler = [
|
1106 |
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
|
1107 |
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
|
1108 |
]
|
|
|
1109 |
fastapi = []
|
1110 |
ffmpy = []
|
1111 |
filelock = []
|
|
|
1168 |
mdit-py-plugins = []
|
1169 |
mdurl = []
|
1170 |
monotonic = []
|
1171 |
+
more-itertools = []
|
1172 |
multidict = [
|
1173 |
{file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"},
|
1174 |
{file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"},
|
|
|
1246 |
pycryptodome = []
|
1247 |
pydantic = []
|
1248 |
pydub = []
|
|
|
1249 |
pynacl = []
|
1250 |
pyparsing = [
|
1251 |
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
|
|
|
1297 |
regex = []
|
1298 |
requests = []
|
1299 |
rfc3986 = []
|
|
|
1300 |
setuptools-scm = []
|
1301 |
six = [
|
1302 |
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
|
pyproject.toml
CHANGED
@@ -9,6 +9,7 @@ python = "^3.9"
|
|
9 |
torch = "^1.12.1"
|
10 |
gradio = "^3.3.1"
|
11 |
transformers = "^4.22.1"
|
|
|
12 |
|
13 |
[tool.poetry.dev-dependencies]
|
14 |
|
|
|
9 |
torch = "^1.12.1"
|
10 |
gradio = "^3.3.1"
|
11 |
transformers = "^4.22.1"
|
12 |
+
more-itertools = "^8.14.0"
|
13 |
|
14 |
[tool.poetry.dev-dependencies]
|
15 |
|
requirements.txt
CHANGED
@@ -34,6 +34,7 @@ matplotlib==3.6.0
|
|
34 |
mdit-py-plugins==0.3.0
|
35 |
mdurl==0.1.2
|
36 |
monotonic==1.6
|
|
|
37 |
multidict==6.0.2
|
38 |
numpy==1.23.3
|
39 |
orjson==3.8.0
|
|
|
34 |
mdit-py-plugins==0.3.0
|
35 |
mdurl==0.1.2
|
36 |
monotonic==1.6
|
37 |
+
more-itertools==8.14.0
|
38 |
multidict==6.0.2
|
39 |
numpy==1.23.3
|
40 |
orjson==3.8.0
|