SkalskiP's picture
initial commit
9c79daa
raw
history blame
1.58 kB
import torch
from typing import Tuple, Dict, Any
from transformers import AutoModelForCausalLM, AutoProcessor
from unittest.mock import patch
from PIL import Image
from utils.imports import fixed_get_imports
CHECKPOINTS = [
"microsoft/Florence-2-large-ft",
"microsoft/Florence-2-large",
"microsoft/Florence-2-base-ft",
"microsoft/Florence-2-base",
]
def load_models(device: torch.device) -> Tuple[Dict[str, Any], Dict[str, Any]]:
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
models = {}
processors = {}
for checkpoint in CHECKPOINTS:
models[checkpoint] = AutoModelForCausalLM.from_pretrained(
checkpoint, trust_remote_code=True).to(device)
processors[checkpoint] = AutoProcessor.from_pretrained(
checkpoint, trust_remote_code=True)
return models, processors
def run_inference(
model: Any,
processor: Any,
device: torch.device,
image: Image,
task: str,
text: str = ""
) -> Tuple[str, Dict]:
prompt = task + text
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3
)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=False)[0]
response = processor.post_process_generation(
generated_text, task=task, image_size=image.size)
return generated_text, response