|
from typing import Dict, Any |
|
|
|
import torch |
|
from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration, BitsAndBytesConfig |
|
from accelerate import init_empty_weights, infer_auto_device_map |
|
|
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import torch.nn.functional as F |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl") |
|
|
|
config = Blip2Config.from_pretrained("Salesforce/blip2-flan-t5-xxl") |
|
with init_empty_weights(): |
|
model = Blip2ForConditionalGeneration(config) |
|
device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"]) |
|
device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"] |
|
|
|
self.model = Blip2ForConditionalGeneration.from_pretrained( |
|
"Salesforce/blip2-flan-t5-xxl", device_map=device_map, |
|
torch_dtype=torch.float16, |
|
quantization_config=BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
inputs = data["inputs"] |
|
|
|
if inputs["mode"] == 'generate_text': |
|
|
|
input_text: str = inputs['input_text'] |
|
image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
|
max_new_tokens: int = inputs['max_new_tokens'] |
|
stop: str = inputs['stop'] |
|
temperature: float = inputs['temperature'] |
|
|
|
inputs = self.processor(images=image, text=input_text, return_tensors="pt").to( |
|
self.model.device, self.model.dtype |
|
) |
|
output = self.model.generate( |
|
**inputs, max_new_tokens=max_new_tokens, temperature=temperature |
|
)[0] |
|
output_text = self.processor.decode(output, skip_special_tokens=True).strip() |
|
if stop in output_text: |
|
output_text = output_text[: output_text.find(stop)] |
|
|
|
return {'output_text': output_text} |
|
|
|
elif inputs["mode"] == 'get_continuation_likelihood': |
|
|
|
prompt: str = inputs['prompt'] |
|
continuation = inputs['continuation'] |
|
image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
|
|
|
inputs = self.processor( |
|
images=image, text=(prompt + continuation), return_tensors="pt" |
|
).to(self.model.device, self.model.dtype) |
|
inputs["labels"] = inputs["input_ids"] |
|
input_ids = inputs["input_ids"][0] |
|
tokens = [self.processor.decode([t]) for t in input_ids] |
|
|
|
logits = self.model(**inputs).logits[0] |
|
logprobs = F.log_softmax(logits, dim=1) |
|
logprobs = [logprobs[i, inputs["input_ids"][0][i]] for i in range(len(tokens))] |
|
|
|
return { |
|
'prompt': prompt, |
|
'continuation': continuation, |
|
'tokens': tokens, |
|
'logprobs': logprobs |
|
} |
|
|