Amitai Getzler
commited on
Commit
•
ec28284
1
Parent(s):
93c7837
:heavy_plus_sign: Add
Browse files- Dockerfile +20 -0
- app.py +162 -0
- requirements.txt +4 -1
Dockerfile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as a parent image
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Copy the current directory contents into the container at /app
|
8 |
+
COPY . /app
|
9 |
+
|
10 |
+
# Install any needed packages specified in requirements.txt
|
11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
# Make port 80 available to the world outside this container
|
14 |
+
EXPOSE 80
|
15 |
+
|
16 |
+
# Define environment variable
|
17 |
+
ENV NAME World
|
18 |
+
|
19 |
+
# Run app.py when the container launches
|
20 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
|
app.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import List, Dict, Any, Union
|
4 |
+
from base64 import b64decode
|
5 |
+
from io import BytesIO
|
6 |
+
import open_clip
|
7 |
+
import requests
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import uvicorn
|
12 |
+
|
13 |
+
app = FastAPI()
|
14 |
+
|
15 |
+
class EndpointHandler:
|
16 |
+
def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"):
|
17 |
+
self.model, self.preprocess_train, self.preprocess_val = (
|
18 |
+
open_clip.create_model_and_transforms(path)
|
19 |
+
)
|
20 |
+
|
21 |
+
if torch.cuda.is_available():
|
22 |
+
self.model = self.model.cuda()
|
23 |
+
|
24 |
+
self.tokenizer = open_clip.get_tokenizer(path)
|
25 |
+
|
26 |
+
def classify_image(self, candidate_labels, image):
|
27 |
+
def get_top_prediction(text_probs, labels):
|
28 |
+
max_index = text_probs[0].argmax().item()
|
29 |
+
return {
|
30 |
+
"label": labels[max_index],
|
31 |
+
"score": text_probs[0][max_index].item(),
|
32 |
+
}
|
33 |
+
|
34 |
+
top_prediction = None
|
35 |
+
for i in range(0, len(candidate_labels), 10):
|
36 |
+
batch_labels = candidate_labels[i : i + 10]
|
37 |
+
image_tensor = self.preprocess_val(image).unsqueeze(0)
|
38 |
+
text = self.tokenizer(batch_labels)
|
39 |
+
|
40 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
41 |
+
image_features = self.model.encode_image(image_tensor)
|
42 |
+
text_features = self.model.encode_text(text)
|
43 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
44 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
45 |
+
|
46 |
+
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
47 |
+
|
48 |
+
current_top = get_top_prediction(text_probs, batch_labels)
|
49 |
+
if top_prediction is None or current_top["score"] > top_prediction["score"]:
|
50 |
+
top_prediction = current_top
|
51 |
+
|
52 |
+
return {"label": top_prediction["label"]}
|
53 |
+
|
54 |
+
def combine_embeddings(self, text_embeddings, image_embeddings, text_weight=0.5, image_weight=0.5):
|
55 |
+
if text_embeddings is not None:
|
56 |
+
avg_text_embedding = np.mean(np.vstack(text_embeddings), axis=0)
|
57 |
+
else:
|
58 |
+
avg_text_embedding = np.zeros_like(image_embeddings[0])
|
59 |
+
|
60 |
+
if image_embeddings is not None:
|
61 |
+
avg_image_embeddings = np.mean(np.vstack(image_embeddings), axis=0)
|
62 |
+
else:
|
63 |
+
avg_image_embeddings = np.zeros_like(text_embeddings[0])
|
64 |
+
|
65 |
+
combined_embedding = np.average(
|
66 |
+
np.vstack((avg_text_embedding, avg_image_embeddings)),
|
67 |
+
axis=0,
|
68 |
+
weights=[text_weight, image_weight],
|
69 |
+
)
|
70 |
+
return combined_embedding
|
71 |
+
|
72 |
+
def average_text(self, doc):
|
73 |
+
text_chunks = [
|
74 |
+
" ".join(doc.split(" ")[i : i + 40])
|
75 |
+
for i in range(0, len(doc.split(" ")), 40)
|
76 |
+
]
|
77 |
+
text_embeddings = []
|
78 |
+
for chunk in text_chunks:
|
79 |
+
inputs = self.tokenizer(chunk)
|
80 |
+
text_features = self.model.encode_text(inputs)
|
81 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
82 |
+
text_embeddings.append(text_features.detach().squeeze().numpy())
|
83 |
+
combined = self.combine_embeddings(
|
84 |
+
text_embeddings, None, text_weight=1, image_weight=0
|
85 |
+
)
|
86 |
+
return combined
|
87 |
+
|
88 |
+
def embedd_image(self, doc) -> list:
|
89 |
+
if not isinstance(doc, str):
|
90 |
+
image = doc.get("image")
|
91 |
+
if "https://" in image:
|
92 |
+
image = image.split("|")
|
93 |
+
image = [
|
94 |
+
Image.open(BytesIO(response.content))
|
95 |
+
for response in [requests.get(image) for image in image]
|
96 |
+
][0]
|
97 |
+
image = self.preprocess_val(image).unsqueeze(0)
|
98 |
+
image_features = self.model.encode_image(image)
|
99 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
100 |
+
image_embedding = image_features.detach().squeeze().numpy()
|
101 |
+
if doc.get("description", "") == "":
|
102 |
+
return image_embedding.tolist()
|
103 |
+
else:
|
104 |
+
average_texts = self.average_text(doc.get("description"))
|
105 |
+
combined = self.combine_embeddings(
|
106 |
+
[average_texts],
|
107 |
+
[image_embedding],
|
108 |
+
text_weight=0.5,
|
109 |
+
image_weight=0.5,
|
110 |
+
)
|
111 |
+
return combined.tolist()
|
112 |
+
elif isinstance(doc, str):
|
113 |
+
return self.average_text(doc).tolist()
|
114 |
+
|
115 |
+
def process_batch(self, batch) -> object:
|
116 |
+
try:
|
117 |
+
batch = batch.get("batch")
|
118 |
+
if not isinstance(batch, list):
|
119 |
+
return "Invalid input: batch must be an array of strings.", 400
|
120 |
+
embeddings = [self.embedd_image(item) for item in batch]
|
121 |
+
return embeddings
|
122 |
+
except Exception as e:
|
123 |
+
return "An error occurred while processing the request.", 500
|
124 |
+
|
125 |
+
def base64_image_to_pil(self, base64_str) -> Image:
|
126 |
+
image_data = b64decode(base64_str)
|
127 |
+
image_buffer = BytesIO(image_data)
|
128 |
+
image = Image.open(image_buffer)
|
129 |
+
return image
|
130 |
+
|
131 |
+
handler = EndpointHandler()
|
132 |
+
|
133 |
+
class ClassifyRequest(BaseModel):
|
134 |
+
candidates: List[str]
|
135 |
+
image: str
|
136 |
+
|
137 |
+
class EmbeddRequest(BaseModel):
|
138 |
+
batch: List[Union[str, Dict[str, str]]]
|
139 |
+
|
140 |
+
@app.post("/classify")
|
141 |
+
def classify(request: ClassifyRequest):
|
142 |
+
try:
|
143 |
+
image = (
|
144 |
+
Image.open(BytesIO(requests.get(request.image).content))
|
145 |
+
if "https://" in request.image
|
146 |
+
else handler.base64_image_to_pil(request.image)
|
147 |
+
)
|
148 |
+
response = handler.classify_image(request.candidates, image)
|
149 |
+
return response
|
150 |
+
except Exception as e:
|
151 |
+
raise HTTPException(status_code=500, detail=str(e))
|
152 |
+
|
153 |
+
@app.post("/embedd")
|
154 |
+
def embedd(request: EmbeddRequest):
|
155 |
+
try:
|
156 |
+
embeddings = handler.process_batch(request.dict())
|
157 |
+
return {"embeddings": embeddings}
|
158 |
+
except Exception as e:
|
159 |
+
raise HTTPException(status_code=500, detail=str(e))
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
requirements.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
open_clip_torch
|
2 |
numpy
|
3 |
pillow
|
4 |
-
requests
|
|
|
|
|
|
|
|
1 |
open_clip_torch
|
2 |
numpy
|
3 |
pillow
|
4 |
+
requests
|
5 |
+
##
|
6 |
+
fastapi
|
7 |
+
uvicorn
|