import os from unittest.mock import patch import gradio as gr from transformers import AutoProcessor, AutoModelForCausalLM from transformers.dynamic_module_utils import get_imports import torch from PIL import Image, ImageDraw import random import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches import io # Define colormap colormap = ['red', 'green', 'blue', 'yellow', 'orange', 'purple', 'cyan'] # Workaround to fix import issues for Florence-2 model def workaround_fixed_get_imports(filename): if not str(filename).endswith("/modeling_florence2.py"): return get_imports(filename) imports = get_imports(filename) if "flash_attn" in imports: imports.remove("flash_attn") # Remove 'flash_attn' if it's causing issues return imports def initialize_model(): # Check if CUDA (GPU) is available and set the device accordingly device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Patch the get_imports function and load the model and processor with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports): try: model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval() processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True) print("Model and processor loaded successfully.") return model, processor, device except Exception as e: print(f"An error occurred while loading the model or processor: {e}") return None, None, device # Initialize the model and processor model, processor, device = initialize_model() # def run_example(task_prompt, image, text_input=None): # prompt = task_prompt if text_input is None else task_prompt + text_input # inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) # with torch.inference_mode(): # generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3) # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] # return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1])) def run_example(task_prompt, image, text_input=None): if text_input is None: prompt = task_prompt else: prompt = task_prompt + text_input inputs = processor(text=prompt, images=image, return_tensors="pt") generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height) ) return parsed_answer def fig_to_pil(fig): buf = io.BytesIO() fig.savefig(buf, format='png', dpi=300, bbox_inches='tight') buf.seek(0) return Image.open(buf) def plot_bbox_img(image, data): fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(image) if 'bboxes' in data and 'labels' in data: bboxes, labels = data['bboxes'], data['labels'] else: return fig_to_pil(fig) for bbox, label in zip(bboxes, labels): x1, y1, x2, y2 = bbox rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='indigo', facecolor='none') ax.add_patch(rect) plt.text(x1, y1, label, color='white', fontsize=10, bbox=dict(facecolor='indigo', alpha=0.8)) ax.axis('off') return fig_to_pil(fig) def draw_polygons(image, prediction, fill_mask=False): fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(image) for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])): color = random.choice(colormap) for polygon in polygons: if isinstance(polygon[0], (int, float)): polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)] poly = patches.Polygon(polygon, edgecolor=color, facecolor=color if fill_mask else 'none', alpha=0.5 if fill_mask else 1, linewidth=2) ax.add_patch(poly) if polygon: plt.text(polygon[0][0], polygon[0][1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8)) ax.axis('off') return fig_to_pil(fig) def process_image(image, task, text): task_mapping = { "Caption": ("", lambda result: (result[''], image)), "Detailed Caption": ("", lambda result: (result[''], image)), "More Detailed Caption": ("", lambda result: (result.get('', 'Failed to generate detailed caption'), image)), "Caption to Phrase Grounding": ("", lambda result: (str(result['']), plot_bbox_img(image, result['']))), "Object Detection": ("", lambda result: (str(result['']), plot_bbox_img(image, result['']))), "Referring Expression Segmentation": ("", lambda result: (str(result['']), draw_polygons(image, result[''], fill_mask=True))), "Region to Segmentation": ("", lambda result: (str(result['']), draw_polygons(image, result[''], fill_mask=True))), "OCR": ("", lambda result: (result[''], image)), } if task in task_mapping: prompt, process_func = task_mapping[task] print(f"Task: {task}, Prompt: {prompt}") # Debugging statement result = run_example(prompt, image, text) print(f"Result: {result}") # Debugging statement return process_func(result) else: return "", image image_path_1 = "Fiat-500-9-scaled.jpg" image_path_2 = "OCR_2.png" with gr.Blocks() as demo: gr.HTML("

Florence-2 Vision

") with gr.Tab(label="Image"): with gr.Row(): with gr.Column(): input_img = gr.Image(label="Input Picture", type="pil") task_dropdown = gr.Dropdown( choices=["Caption", "Detailed Caption", "More Detailed Caption", "Object Detection", "Caption to Phrase Grounding", "Referring Expression Segmentation", "Region to Segmentation", "OCR"], label="Task", value="Caption" ) text_input = gr.Textbox(label="Text Input (Optional)", visible=False) gr.Examples( examples=[ [image_path_1, "Detailed Caption", ""], [image_path_1, "Object Detection", ""], [image_path_1, "More Detailed Caption", ""], [image_path_1, "Caption to Phrase Grounding", "A white car parked on the street."], [image_path_1, "Region to Segmentation", ""], [image_path_2, "OCR", ""] ], inputs=[input_img, task_dropdown, text_input], cache_examples=False # Set this to False if caching is not needed ) submit_btn = gr.Button(value="Submit") with gr.Column(): output_text = gr.Textbox(label="Results") output_image = gr.Image(label="Image", type="pil") def update_text_input(task): return gr.update(visible=task in ["Region to Segmentation"]) task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input) submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image]) demo.launch()