Spaces:
Runtime error
Runtime error
# utils.py | |
import torch | |
import supervision as sv | |
from PIL import Image | |
import os | |
from typing import Union, Any, Tuple, Dict | |
from unittest.mock import patch | |
import torch | |
from PIL import Image | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from transformers.dynamic_module_utils import get_imports | |
FLORENCE_CHECKPOINT = "microsoft/Florence-2-base" | |
FLORENCE_OBJECT_DETECTION_TASK = '<OD>' | |
FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>' | |
FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>' | |
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>' | |
FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>' | |
def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]: | |
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72.""" | |
if not str(filename).endswith("/modeling_florence2.py"): | |
return get_imports(filename) | |
imports = get_imports(filename) | |
imports.remove("flash_attn") | |
return imports | |
def load_florence_model( | |
device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT | |
) -> Tuple[Any, Any]: | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) | |
return model, processor | |
def run_florence_inference( | |
model: Any, | |
processor: Any, | |
device: torch.device, | |
image: Image, | |
task: str, | |
text: str = "" | |
) -> Tuple[str, Dict]: | |
prompt = task + text | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode( | |
generated_ids, skip_special_tokens=False)[0] | |
response = processor.post_process_generation( | |
generated_text, task=task, image_size=image.size) | |
return generated_text, response | |
def detect_objects_in_image(image_input_path, texts, device): | |
# 加载图像 | |
image_input = Image.open(image_input_path) | |
# 初始化检测列表 | |
detections_list = [] | |
# 对每个文本进行检测 | |
for text in texts: | |
_, result = run_florence_inference( | |
model=FLORENCE_MODEL.to(device), | |
processor=FLORENCE_PROCESSOR, | |
device=device, | |
image=image_input, | |
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, | |
text=text | |
) | |
# 从结果中构建监督检测对象 | |
detections = sv.Detections.from_lmm( | |
lmm=sv.LMM.FLORENCE_2, | |
result=result, | |
resolution_wh=image_input.size | |
) | |
# 运行 SAM 推理 | |
detections = run_sam_inference(SAM_IMAGE_MODEL.to(device), image_input, detections) | |
# 将检测结果添加到列表中 | |
detections_list.append(detections) | |
# 合并所有检测结果 | |
detections = sv.Detections.merge(detections_list) | |
# 再次运行 SAM 推理 | |
detections = run_sam_inference(SAM_IMAGE_MODEL.to(device), image_input, detections) | |
return detections |