from typing import Dict, List, Any from transformers import AutoProcessor, Blip2ForConditionalGeneration import base64 from io import BytesIO from PIL import Image import string import torch class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = AutoProcessor.from_pretrained(path) self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_4bit=True) def __call__(self, data): """ Args: inputs: Dict of image and text inputs. """ # process input inputs = data.pop("inputs", data) image = Image.open(BytesIO(base64.b64decode(inputs['image']))) inputs = self.processor(images=image, text=inputs["text"], return_tensors="pt").to("cuda", torch.float16) generated_ids = self.model.generate( **inputs, temperature=1.0, length_penalty=1.0, repetition_penalty=1.5, max_length=30, min_length=1, num_beams=5, top_p=0.9, ) result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() if result and result[-1] not in string.punctuation: result += "." return [{"generated_text": result}]