|
from typing import Dict, List, Any |
|
import torch |
|
from transformers import AutoProcessor, LlavaForConditionalGeneration |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="", vision_model="obsidian3b"): |
|
self.model = LlavaForConditionalGeneration.from_pretrained( |
|
"NousResearch/Obsidian-3B-V0.5", |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
).to("cuda" if torch.is_cuda_available() else "cpu") |
|
self.processor = AutoProcessor.from_pretrained("NousResearch/Obsidian-3B-V0.5") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
image (:obj: `Image`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs", "") |
|
image = data.pop("image", None) |
|
|
|
inputs = self.processor(inputs, image, return_tensors="pt") |
|
res = self.model.generate(**inputs, do_sample=False, max_new_tokens=4096) |
|
return self.processor.decode(res[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|