colqwen-api / handler.py
dounan1's picture
Bypass pipeline and use model name explicitly
d79ac69
import torch
from PIL import Image
from io import BytesIO
import base64
from colpali_engine.models import ColQwen2, ColQwen2Processor
model_name = "vidore/colqwen2-v1.0"
class EndpointHandler:
def __init__(self, path=""):
self.model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto" # This will use CUDA if available, otherwise CPU
).eval()
self.processor = ColQwen2Processor.from_pretrained(model_name)
def __call__(self, data):
# Extract inputs from the request data
images_data = data.pop("images", [])
queries = data.pop("queries", [])
# Process images
images = []
for img_data in images_data:
img_bytes = base64.b64decode(img_data)
img = Image.open(BytesIO(img_bytes))
images.append(img)
# Process the inputs
batch_images = self.processor.process_images(images).to(self.model.device)
batch_queries = self.processor.process_queries(queries).to(self.model.device)
# Forward pass
with torch.no_grad():
image_embeddings = self.model(**batch_images)
query_embeddings = self.model(**batch_queries)
# Calculate scores
scores = self.processor.score_multi_vector(query_embeddings, image_embeddings)
return {"scores": scores.tolist()}