Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from typing import Tuple, Dict, Any | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from unittest.mock import patch | |
from PIL import Image | |
from utils.imports import fixed_get_imports | |
CHECKPOINTS = [ | |
"microsoft/Florence-2-large-ft", | |
"microsoft/Florence-2-large", | |
"microsoft/Florence-2-base-ft", | |
"microsoft/Florence-2-base", | |
] | |
def load_models(device: torch.device) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): | |
models = {} | |
processors = {} | |
for checkpoint in CHECKPOINTS: | |
models[checkpoint] = AutoModelForCausalLM.from_pretrained( | |
checkpoint, trust_remote_code=True).to(device) | |
processors[checkpoint] = AutoProcessor.from_pretrained( | |
checkpoint, trust_remote_code=True) | |
return models, processors | |
def run_inference( | |
model: Any, | |
processor: Any, | |
device: torch.device, | |
image: Image, | |
task: str, | |
text: str = "" | |
) -> Tuple[str, Dict]: | |
prompt = task + text | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode( | |
generated_ids, skip_special_tokens=False)[0] | |
response = processor.post_process_generation( | |
generated_text, task=task, image_size=image.size) | |
return generated_text, response | |