YaTharThShaRma999's picture
Update modules/vqa.py
44d71be verified
raw
history blame
No virus
1.25 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
class moondream:
def __init__(self, model_path="vikhyatk/moondream2", device="cuda:1"):
import torch
self.moondream_model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True, revision="2024-04-02", torch_dtype=torch.float16
).to(device)
self.moondream_tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", revision="2024-04-02")
def infer(self, prompt, image='vqa.jpg', caption=True):
if caption == True:
image = Image.open(image)
enc_image = self.moondream_model.encode_image(image)
out = self.moondream_model.answer_question(enc_image, "Describe this image.", self.moondream_tokenizer)
else:
image = Image.open(image)
enc_image = self.moondream_model.encode_image(image)
out = self.moondream_model.answer_question(enc_image, prompt, self.moondream_tokenizer)
output_dict = {}
output_dict["llm_output"] = str(out)
output_dict["real_output"] = {"display": None, "name": None, "metadata": None}
output_dict["type"] = "text"
return output_dict