Amitai Getzler commited on
Commit
ec28284
1 Parent(s): 93c7837

:heavy_plus_sign: Add

Browse files
Files changed (3) hide show
  1. Dockerfile +20 -0
  2. app.py +162 -0
  3. requirements.txt +4 -1
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the current directory contents into the container at /app
8
+ COPY . /app
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Make port 80 available to the world outside this container
14
+ EXPOSE 80
15
+
16
+ # Define environment variable
17
+ ENV NAME World
18
+
19
+ # Run app.py when the container launches
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Dict, Any, Union
4
+ from base64 import b64decode
5
+ from io import BytesIO
6
+ import open_clip
7
+ import requests
8
+ import torch
9
+ import numpy as np
10
+ from PIL import Image
11
+ import uvicorn
12
+
13
+ app = FastAPI()
14
+
15
+ class EndpointHandler:
16
+ def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"):
17
+ self.model, self.preprocess_train, self.preprocess_val = (
18
+ open_clip.create_model_and_transforms(path)
19
+ )
20
+
21
+ if torch.cuda.is_available():
22
+ self.model = self.model.cuda()
23
+
24
+ self.tokenizer = open_clip.get_tokenizer(path)
25
+
26
+ def classify_image(self, candidate_labels, image):
27
+ def get_top_prediction(text_probs, labels):
28
+ max_index = text_probs[0].argmax().item()
29
+ return {
30
+ "label": labels[max_index],
31
+ "score": text_probs[0][max_index].item(),
32
+ }
33
+
34
+ top_prediction = None
35
+ for i in range(0, len(candidate_labels), 10):
36
+ batch_labels = candidate_labels[i : i + 10]
37
+ image_tensor = self.preprocess_val(image).unsqueeze(0)
38
+ text = self.tokenizer(batch_labels)
39
+
40
+ with torch.no_grad(), torch.cuda.amp.autocast():
41
+ image_features = self.model.encode_image(image_tensor)
42
+ text_features = self.model.encode_text(text)
43
+ image_features /= image_features.norm(dim=-1, keepdim=True)
44
+ text_features /= text_features.norm(dim=-1, keepdim=True)
45
+
46
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
47
+
48
+ current_top = get_top_prediction(text_probs, batch_labels)
49
+ if top_prediction is None or current_top["score"] > top_prediction["score"]:
50
+ top_prediction = current_top
51
+
52
+ return {"label": top_prediction["label"]}
53
+
54
+ def combine_embeddings(self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5):
55
+ if text_embeddings is not None:
56
+ avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0)
57
+ else:
58
+ avg_text_embedding = np.zeros_like(image_embeddings[0])
59
+
60
+ if image_embeddings is not None:
61
+ avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0)
62
+ else:
63
+ avg_image_embeddings = np.zeros_like(text_embeddings[0])
64
+
65
+ combined_embedding = np.average(
66
+ np.vstack((avg_text_embedding, avg_image_embeddings)),
67
+ axis=0,
68
+ weights=[text_weight, image_weight],
69
+ )
70
+ return combined_embedding
71
+
72
+ def average_text(self, doc):
73
+ text_chunks = [
74
+ " ".join(doc.split(" ")[i : i + 40])
75
+ for i in range(0, len(doc.split(" ")), 40)
76
+ ]
77
+ text_embeddings = []
78
+ for chunk in text_chunks:
79
+ inputs = self.tokenizer(chunk)
80
+ text_features = self.model.encode_text(inputs)
81
+ text_features /= text_features.norm(dim=-1, keepdim=True)
82
+ text_embeddings.append(text_features.detach().squeeze().numpy())
83
+ combined = self.combine_embeddings(
84
+ text_embeddings, None, text_weight=1, image_weight=0
85
+ )
86
+ return combined
87
+
88
+ def embedd_image(self, doc) -> list:
89
+ if not isinstance(doc, str):
90
+ image = doc.get("image")
91
+ if "https://" in image:
92
+ image = image.split("|")
93
+ image = [
94
+ Image.open(BytesIO(response.content))
95
+ for response in [requests.get(image) for image in image]
96
+ ][0]
97
+ image = self.preprocess_val(image).unsqueeze(0)
98
+ image_features = self.model.encode_image(image)
99
+ image_features /= image_features.norm(dim=-1, keepdim=True)
100
+ image_embedding = image_features.detach().squeeze().numpy()
101
+ if doc.get("description", "") == "":
102
+ return image_embedding.tolist()
103
+ else:
104
+ average_texts = self.average_text(doc.get("description"))
105
+ combined = self.combine_embeddings(
106
+ [average_texts],
107
+ [image_embedding],
108
+ text_weight=0.5,
109
+ image_weight=0.5,
110
+ )
111
+ return combined.tolist()
112
+ elif isinstance(doc, str):
113
+ return self.average_text(doc).tolist()
114
+
115
+ def process_batch(self, batch) -> object:
116
+ try:
117
+ batch = batch.get("batch")
118
+ if not isinstance(batch, list):
119
+ return "Invalid input: batch must be an array of strings.", 400
120
+ embeddings = [self.embedd_image(item) for item in batch]
121
+ return embeddings
122
+ except Exception as e:
123
+ return "An error occurred while processing the request.", 500
124
+
125
+ def base64_image_to_pil(self, base64_str) -> Image:
126
+ image_data = b64decode(base64_str)
127
+ image_buffer = BytesIO(image_data)
128
+ image = Image.open(image_buffer)
129
+ return image
130
+
131
+ handler = EndpointHandler()
132
+
133
+ class ClassifyRequest(BaseModel):
134
+ candidates: List[str]
135
+ image: str
136
+
137
+ class EmbeddRequest(BaseModel):
138
+ batch: List[Union[str, Dict[str, str]]]
139
+
140
+ @app.post("/classify")
141
+ def classify(request: ClassifyRequest):
142
+ try:
143
+ image = (
144
+ Image.open(BytesIO(requests.get(request.image).content))
145
+ if "https://" in request.image
146
+ else handler.base64_image_to_pil(request.image)
147
+ )
148
+ response = handler.classify_image(request.candidates, image)
149
+ return response
150
+ except Exception as e:
151
+ raise HTTPException(status_code=500, detail=str(e))
152
+
153
+ @app.post("/embedd")
154
+ def embedd(request: EmbeddRequest):
155
+ try:
156
+ embeddings = handler.process_batch(request.dict())
157
+ return {"embeddings": embeddings}
158
+ except Exception as e:
159
+ raise HTTPException(status_code=500, detail=str(e))
160
+
161
+ if __name__ == "__main__":
162
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
  open_clip_torch
2
  numpy
3
  pillow
4
- requests
 
 
 
 
1
  open_clip_torch
2
  numpy
3
  pillow
4
+ requests
5
+ ##
6
+ fastapi
7
+ uvicorn