import os from unittest.mock import patch import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM from transformers.dynamic_module_utils import get_imports import torch from PIL import Image, ImageDraw import random import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches import io # Define colormap colormap = ['red', 'green', 'blue', 'yellow', 'orange', 'purple', 'cyan'] # Workaround to fix import issues for Florence-2 model def workaround_fixed_get_imports(filename): if not str(filename).endswith("/modeling_florence2.py"): return get_imports(filename) imports = get_imports(filename) if "flash_attn" in imports: imports.remove("flash_attn") # Remove 'flash_attn' if it's causing issues return imports def initialize_model(): # Check if CUDA (GPU) is available and set the device accordingly device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Patch the get_imports function and load the model and processor with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports): try: model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval() processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True) print("Model and processor loaded successfully.") return model, processor, device except Exception as e: print(f"An error occurred while loading the model or processor: {e}") return None, None, device # Initialize the model and processor model, processor, device = initialize_model() # def run_example(task_prompt, image, text_input=None): # prompt = task_prompt if text_input is None else task_prompt + text_input # inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) # with torch.inference_mode(): # generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3) # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] # return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1])) def run_example(task_prompt, image, text_input=None): if text_input is None: prompt = task_prompt else: prompt = task_prompt + text_input inputs = processor(text=prompt, images=image, return_tensors="pt") generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height) ) return parsed_answer def fig_to_pil(fig): buf = io.BytesIO() fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') buf.seek(0) return Image.open(buf) def plot_bbox_img(image, data): fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(image) if 'bboxes' in data and 'labels' in data: bboxes, labels = data['bboxes'], data['labels'] else: return fig_to_pil(fig) for bbox, label in zip(bboxes, labels): x1, y1, x2, y2 = bbox rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='indigo', facecolor='none') ax.add_patch(rect) plt.text(x1, y1, label, color='white', fontsize=10, bbox=dict(facecolor='indigo', alpha=0.8)) ax.axis('off') return fig_to_pil(fig) def draw_polygons(image, prediction, fill_mask=False): fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(image) for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])): color = random.choice(colormap) for polygon in polygons: if isinstance(polygon[0], (int, float)): polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)] poly = patches.Polygon(polygon, edgecolor=color, facecolor=color if fill_mask else 'none', alpha=0.5 if fill_mask else 1, linewidth=2) ax.add_patch(poly) if polygon: plt.text(polygon[0][0], polygon[0][1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8)) ax.axis('off') return fig_to_pil(fig) def process_image(image, task, text): task_mapping = { "Caption": ("