new-test-model / handler.py
jeff-RQ's picture
Update handler.py
2a7e830
from typing import Any, Dict
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import io
from PIL import Image
import base64
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = Blip2Processor.from_pretrained(path)
self.model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
self.device = "cuda"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# process input
data = data.pop("inputs", data)
text = data.pop("text", data)
image_string = base64.b64decode(data["image"])
image = Image.open(io.BytesIO(image_string))
inputs = self.processor(images=image, text=text, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.model.generate(**inputs)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return [{"answer": generated_text}]