blip2_test /
florentgbelidji's picture
added files
history blame contribute delete
No virus
2.37 kB
import torch
from transformers import pipeline, AutoProcessor, Blip2ForConditionalGeneration
import os
"""import base64
from io import BytesIO
from PIL import Image"""
# check for GPU
device = 0 if torch.cuda.is_available() else -1
class EndpointHandler():
def __init__(self, path=""):
blip2_proc = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
#blip2 = Blip2ForConditionalGeneration.from_pretrained("sharded", device_map="auto", load_in_8bit=True)
blip2 = Blip2ForConditionalGeneration.from_pretrained(os.path.join(path, "sharded"), device_map="auto", load_in_8bit=True)
#translator = pipeline("translation",model="Helsinki-NLP/opus-mt-en-de",device=device)
def __call__(self, data):
# deserialize incomin request
"""b64_img = data.pop("b64", data)
lang = data.pop("lang", None)
decode = data.pop("decode", None)
#prepare image
im_bytes = base64.b64decode(b64_img) # im_bytes is a binary image
im_file = BytesIO(im_bytes) # convert image to file-like object
image ="RGB")
output = {}
inputs = self.blip2_proc(image, return_tensors="pt").to(device, torch.float16)
#nucleus vs beam sampling
if decode == None or decode == "beam":
generated_ids = self.blip2.generate(**inputs, max_new_tokens=20)
prediction = self.blip2_proc.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
#english vs german caption
if lang != None or lang == "de":
translation = self.translator(prediction)
output["beam"] = translation[0]
output["beam"] = prediction
if decode != None or decode == "nucleus":
generated_ids = self.blip2.generate(**inputs, max_new_tokens=20)
prediction = self.blip2_proc.batch_decode(generated_ids, skip_special_tokens=True,do_sample=True)[0].strip()
#english vs german caption
if lang != None or lang == "de":
translation = self.translator(prediction)
output["nucleus"] = translation[0]
output["nucleus"] = prediction
# postprocess the prediction
return output"""
return 73