Ellie Sleightholm
commited on
Commit
β’
ced753a
1
Parent(s):
2a4d097
updating items and adding url classification option
Browse files
app.py
CHANGED
@@ -4,11 +4,12 @@ import torch
|
|
4 |
import requests
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
|
|
7 |
|
8 |
# Sidebar content
|
9 |
sidebar_markdown = """
|
10 |
|
11 |
-
|
12 |
|
13 |
## Documentation
|
14 |
π [Blog Post](https://www.marqo.ai/blog/search-model-for-fashion)
|
@@ -37,36 +38,57 @@ year = {2024}
|
|
37 |
|
38 |
# List of fashion items
|
39 |
items = [
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
]
|
71 |
|
72 |
|
@@ -90,27 +112,28 @@ def generate_description(item):
|
|
90 |
return f"An item of {item} worn on the hands"
|
91 |
else:
|
92 |
return f"A fashion item called {item}"
|
93 |
-
|
94 |
-
|
95 |
items_desc = [generate_description(item) for item in items]
|
96 |
text = tokenizer(items_desc)
|
97 |
|
98 |
-
# Encode text features
|
99 |
-
with torch.no_grad(), torch.
|
100 |
text_features = model.encode_text(text)
|
101 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
102 |
|
103 |
# Prediction function
|
104 |
-
def predict(
|
105 |
-
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
with torch.no_grad(), torch.
|
108 |
-
image_features = model.encode_image(
|
109 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
110 |
|
111 |
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
|
112 |
|
113 |
-
# Sort the confidences and get the top 10
|
114 |
sorted_confidences = sorted(
|
115 |
{items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(),
|
116 |
key=lambda x: x[1],
|
@@ -118,13 +141,16 @@ def predict(inp):
|
|
118 |
)
|
119 |
top_10_confidences = dict(sorted_confidences[:10])
|
120 |
|
121 |
-
return top_10_confidences
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# Gradio interface
|
124 |
title = "Fashion Item Classifier with Marqo-FashionSigLIP"
|
125 |
-
description = "Upload an image of a fashion item
|
126 |
|
127 |
-
# Example image paths with thumbnails
|
128 |
examples = [
|
129 |
["images/dress.jpg", "Dress"],
|
130 |
["images/sweatpants.jpg", "Sweatpants"],
|
@@ -152,14 +178,15 @@ with gr.Blocks(css="""
|
|
152 |
gr.Markdown(" ", elem_id="vertical-line") # Add an empty Markdown with a custom ID
|
153 |
with gr.Column(scale=2):
|
154 |
input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
|
155 |
-
|
|
|
|
|
|
|
156 |
gr.Markdown("Or click on one of the images below to classify it:")
|
157 |
gr.Examples(examples=examples, inputs=input_image)
|
158 |
-
# with gr.Column(scale=3):
|
159 |
output_label = gr.Label(num_top_classes=6)
|
160 |
-
predict_button.click(predict, inputs=input_image, outputs=output_label)
|
|
|
161 |
|
162 |
-
|
163 |
# Launch the interface
|
164 |
-
demo.launch(
|
165 |
-
|
|
|
4 |
import requests
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
+
from io import BytesIO
|
8 |
|
9 |
# Sidebar content
|
10 |
sidebar_markdown = """
|
11 |
|
12 |
+
Note, this demo can classify 300 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
|
13 |
|
14 |
## Documentation
|
15 |
π [Blog Post](https://www.marqo.ai/blog/search-model-for-fashion)
|
|
|
38 |
|
39 |
# List of fashion items
|
40 |
items = [
|
41 |
+
'abaya', 'anorak', 'apron', 'ascot tie',
|
42 |
+
'balaclava', 'ball gown', 'bandanna', 'baseball cap', 'bathing suit',
|
43 |
+
'beanie', 'bedclothes', 'bell-bottoms', 'belt', 'beret',
|
44 |
+
'Bermuda shorts', 'baby clothes', 'bib', 'bikini', 'blazer', 'bloomers', 'blouse', 'boa',
|
45 |
+
'bonnet', 'boot', 'bow', 'bow tie', 'boxer shorts', 'boxers', 'bra',
|
46 |
+
'bracelet', 'brassiere', 'breeches', 'briefs', 'buckle', 'button',
|
47 |
+
'caftan', 'camisole', 'camouflage', 'cap',
|
48 |
+
'cap and gown', 'cape', 'capris', 'cardigan', 'chemise', 'cloak', 'clogs',
|
49 |
+
'coat', 'collar', 'corset', 'costume', 'coveralls',
|
50 |
+
'cowboy boots', 'cowboy hat', 'cravat', 'crown', 'cuff', 'cuff links',
|
51 |
+
'culottes', 'dashiki', 'diaper', 'dinner jacket', 'dirndl',
|
52 |
+
'drawers', 'dress', 'dress shirt', 'duds', 'dungarees', 'earmuffs',
|
53 |
+
'earrings', 'elastic', 'evening gown', 'fashion', 'fedora',
|
54 |
+
'fez', 'flak jacket', 'flannel nightgown', 'flannel shirt', 'flip-flops',
|
55 |
+
'formal wear', 'frock', 'fur', 'fur coat', 'gabardine', 'gaiters',
|
56 |
+
'galoshes', 'garb', 'garters', 'getup', 'gilet',
|
57 |
+
'girdle', 'glasses', 'gloves', 'gown', 'halter top', 'handbag',
|
58 |
+
'handkerchief', 'hat', 'Hawaiian shirt', 'hazmat suit', 'headscarf',
|
59 |
+
'helmet', 'hem', 'high heels', 'hoodie', 'hook and eye', 'hose',
|
60 |
+
'hosiery', 'hospital gown', 'houndstooth', 'housecoat', 'jacket',
|
61 |
+
'jeans', 'jersey', 'jewelry', 'jodhpurs', 'jumper', 'jumpsuit', 'kerchief',
|
62 |
+
'khakis', 'kilt', 'kimono', 'kit', 'knickers', 'lab coat', 'lapel',
|
63 |
+
'leather jacket', 'leg warmers', 'leggings', 'leotard', 'life jacket',
|
64 |
+
'lingerie', 'loafers', 'loincloth', 'long johns', 'long underwear',
|
65 |
+
'miniskirt', 'mittens', 'moccasins', 'muffler', 'muumuu', 'neckerchief',
|
66 |
+
'necklace', 'nightgown', 'nightshirt', 'onesies', 'outerwear', 'outfit',
|
67 |
+
'overalls', 'overcoat', 'overshirt', 'pajamas', 'pants',
|
68 |
+
'pantsuit', 'pantyhose', 'parka', 'pea coat', 'peplum', 'petticoat',
|
69 |
+
'pinafore', 'pleat', 'pocket', 'pocketbook', 'polo shirt', 'poncho',
|
70 |
+
'poodle skirt', 'pullover', 'pumps', 'purse', 'raincoat',
|
71 |
+
'ring', 'robe', 'rugby shirt', 'sandals', 'sari', 'sarong', 'scarf',
|
72 |
+
'school uniform', 'scrubs', 'shawl', 'shirt',
|
73 |
+
'shoes', 'shorts', 'shoulder pads', 'shrug', 'singlet', 'skirt',
|
74 |
+
'slacks', 'slip', 'slippers', 'smock', 'snaps', 'sneakers', 'socks',
|
75 |
+
'sombrero', 'spacesuit', 'stockings', 'stole', 'suit',
|
76 |
+
'sun hat', 'sunbonnet', 'sundress', 'sunglasses', 'suspenders',
|
77 |
+
'sweater', 'sweatpants', 'sweatshirt', 'sweatsuit', 'swimsuit',
|
78 |
+
'T-shirt', 'tam', 'tank top', 'teddy', 'threads', 'tiara', 'tie',
|
79 |
+
'tie clip', 'tights', 'toga', 'tog', 'top', 'top coat', 'top hat', 'train',
|
80 |
+
'trench coat', 'trousers', 'trunks', 'tube top', 'tunic', 'turban',
|
81 |
+
'turtleneck', 'turtleneck shirt', 'tutu', 'tuxedo', 'tweed jacket',
|
82 |
+
'twin set', 'umbrella', 'underclothes', 'undershirt',
|
83 |
+
'underwear', 'uniform', 'veil', 'Velcro', 'vest', 'vestments', 'visor',
|
84 |
+
'waders', 'waistcoat', 'wear', 'wedding gown', 'Wellingtons', 'wetsuit',
|
85 |
+
'white tie', 'wig', 'windbreaker', 'woolens', 'wrap', 'yoke', 'zipper',
|
86 |
+
'zoris', 'jogger', 'palazzo', 'cargo', 'dresspants', 'chinos',
|
87 |
+
'crop top', 'romper', 'insulated jacket', 'fleece', 'rain jacket',
|
88 |
+
'running jacket', 'graphic top', 'pant', 'legging', 'skort', 'brief',
|
89 |
+
'sports bra', 'water shorts', 'cover up', 'goggle', 'glove', 'mitten',
|
90 |
+
'leg gaiter', 'neck gaiter', 'watch', 'bag', 'swim trunk',
|
91 |
+
'pocket watch', 'insoles', "climbing shoes",
|
92 |
]
|
93 |
|
94 |
|
|
|
112 |
return f"An item of {item} worn on the hands"
|
113 |
else:
|
114 |
return f"A fashion item called {item}"
|
|
|
|
|
115 |
items_desc = [generate_description(item) for item in items]
|
116 |
text = tokenizer(items_desc)
|
117 |
|
118 |
+
# Encode text features (unchanged)
|
119 |
+
with torch.no_grad(), torch.amp.autocast('cuda'):
|
120 |
text_features = model.encode_text(text)
|
121 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
122 |
|
123 |
# Prediction function
|
124 |
+
def predict(image, url):
|
125 |
+
if url:
|
126 |
+
response = requests.get(url)
|
127 |
+
image = Image.open(BytesIO(response.content))
|
128 |
+
|
129 |
+
processed_image = preprocess_val(image).unsqueeze(0)
|
130 |
|
131 |
+
with torch.no_grad(), torch.amp.autocast('cuda'):
|
132 |
+
image_features = model.encode_image(processed_image)
|
133 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
134 |
|
135 |
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
|
136 |
|
|
|
137 |
sorted_confidences = sorted(
|
138 |
{items[i]: float(text_probs[0, i]) for i in range(len(items))}.items(),
|
139 |
key=lambda x: x[1],
|
|
|
141 |
)
|
142 |
top_10_confidences = dict(sorted_confidences[:10])
|
143 |
|
144 |
+
return image, top_10_confidences
|
145 |
+
|
146 |
+
# Clear function
|
147 |
+
def clear_fields():
|
148 |
+
return None, ""
|
149 |
|
150 |
# Gradio interface
|
151 |
title = "Fashion Item Classifier with Marqo-FashionSigLIP"
|
152 |
+
description = "Upload an image or provide a URL of a fashion item to classify it using [Marqo-FashionSigLIP](https://huggingface.co/Marqo/marqo-fashionSigLIP)!"
|
153 |
|
|
|
154 |
examples = [
|
155 |
["images/dress.jpg", "Dress"],
|
156 |
["images/sweatpants.jpg", "Sweatpants"],
|
|
|
178 |
gr.Markdown(" ", elem_id="vertical-line") # Add an empty Markdown with a custom ID
|
179 |
with gr.Column(scale=2):
|
180 |
input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
|
181 |
+
input_url = gr.Textbox(label="Or provide an image URL")
|
182 |
+
with gr.Row():
|
183 |
+
predict_button = gr.Button("Classify")
|
184 |
+
clear_button = gr.Button("Clear")
|
185 |
gr.Markdown("Or click on one of the images below to classify it:")
|
186 |
gr.Examples(examples=examples, inputs=input_image)
|
|
|
187 |
output_label = gr.Label(num_top_classes=6)
|
188 |
+
predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label])
|
189 |
+
clear_button.click(clear_fields, outputs=[input_image, input_url])
|
190 |
|
|
|
191 |
# Launch the interface
|
192 |
+
demo.launch()
|
|