elliesleightholm commited on
Commit
3e8a38e
1 Parent(s): 533a1c1

initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +58 -0
  3. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModel, AutoProcessor
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+
8
+ fashion_items = ['top', 'trousers', 'hat', 'jumper']
9
+
10
+ # Load model and processor
11
+ model_name = 'Marqo/marqo-fashionSigLIP'
12
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
13
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
14
+
15
+ # Preprocess and normalize text data
16
+ with torch.no_grad():
17
+ # Ensure truncation and padding are activated
18
+ processed_texts = processor(
19
+ text=fashion_items,
20
+ return_tensors="pt",
21
+ truncation=True, # Ensure text is truncated to fit model input size
22
+ padding=True # Pad shorter sequences so that all are the same length
23
+ )['input_ids']
24
+
25
+ text_features = model.get_text_features(processed_texts)
26
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
27
+
28
+ # Prediction function
29
+ def predict_from_url(url):
30
+ # Check if the URL is empty
31
+ if not url:
32
+ return {"Error": "Please input a URL"}
33
+
34
+ try:
35
+ image = Image.open(BytesIO(requests.get(url).content))
36
+ except Exception as e:
37
+ return {"Error": f"Failed to load image: {str(e)}"}
38
+
39
+ processed_image = processor(images=image, return_tensors="pt")['pixel_values']
40
+
41
+ with torch.no_grad():
42
+ image_features = model.get_image_features(processed_image)
43
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
44
+ text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
45
+
46
+ return {fashion_items[i]: float(text_probs[0, i]) for i in range(len(fashion_items))}
47
+
48
+ # Gradio interface
49
+ demo = gr.Interface(
50
+ fn=predict_from_url,
51
+ inputs=gr.Textbox(label="Enter Image URL"),
52
+ outputs=gr.Label(label="Classification Results"),
53
+ title="Fashion Item Classifier",
54
+ allow_flagging="never"
55
+ )
56
+
57
+ # Launch the interface
58
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ requests
4
+ Pillow
5
+ open_clip_torch
6
+ ftfy
7
+
8
+ # This is only needed for local deployment
9
+ gradio