blip2-flan-t5-xxl / handler.py
merve's picture
merve HF staff
Update handler.py
369380f
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import base64
from io import BytesIO
from PIL import Image
import string
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_4bit=True)
def __call__(self, data):
"""
Args:
inputs:
Dict of image and text inputs.
"""
# process input
inputs = data.pop("inputs", data)
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(images=image, text=inputs["text"], return_tensors="pt").to("cuda", torch.float16)
generated_ids = self.model.generate(
**inputs,
temperature=1.0,
length_penalty=1.0,
repetition_penalty=1.5,
max_length=30,
min_length=1,
num_beams=5,
top_p=0.9,
)
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if result and result[-1] not in string.punctuation:
result += "."
return [{"generated_text": result}]