MiniCPM-v-2_6 / handler.py
fredaddy's picture
Update handler.py
a1c2e19 verified
raw
history blame
1.93 kB
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import AutoModel, AutoTokenizer
class EndpointHandler:
def __init__(self, path="/repository"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
self.model = AutoModel.from_pretrained(
path,
trust_remote_code=True,
attn_implementation='sdpa',
torch_dtype=torch.bfloat16 if self.device.type == "cuda" else torch.float32,
).to(self.device)
self.model.eval()
# Load the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True,
)
def __call__(self, data):
# Extract image and text from the input data
image_data = data.get("inputs", {}).get("image", "")
text_prompt = data.get("inputs", {}).get("text", "")
if not image_data or not text_prompt:
return {"error": "Both 'image' and 'text' must be provided in the input data."}
# Process the image data
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Failed to process image data: {e}"}
# Prepare the messages for the model
msgs = [{'role': 'user', 'content': [image, text_prompt]}]
# Generate output
with torch.no_grad():
res = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7,
top_p=0.95,
max_length=2000,
)
# The result is the generated text
output_text = res
return {"generated_text": output_text}