Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
from PIL import Image, ImageDraw, ImageFont | |
import random | |
from transformers import AutoProcessor, AutoModelForVision2Seq | |
# Load the model and processor | |
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224") | |
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") | |
def draw_bounding_boxes(image: Image, entities): | |
draw = ImageDraw.Draw(image) | |
width, height = image.size | |
color_bank = [ | |
"#0AC2FF", "#30D5C8", "#F3C300", "#47FF0A", "#C2FF0A" | |
] | |
try: | |
font_size = 20 | |
font = ImageFont.truetype("assets/arial.ttf", font_size) | |
except IOError: | |
font_size = 20 | |
font = ImageFont.load_default() | |
for entity in entities: | |
label, _, boxes = entity | |
for box in boxes: | |
box_coords = [ | |
box[0] * width, box[1] * height, | |
box[2] * width, box[3] * height | |
] | |
outline_color = random.choice(color_bank) | |
text_fill_color = random.choice(color_bank) | |
draw.rectangle(box_coords, outline=outline_color, width=4) | |
text_position = (box_coords[0] + 5, box_coords[1] - font_size - 5) | |
draw.text(text_position, label, fill=text_fill_color, font=font) | |
return image | |
def highlight_entities(text, entities): | |
for entity in entities: | |
label = entity[0] | |
text = text.replace(label, f"*{label}*") # Highlighting by enclosing in asterisks | |
return text | |
def process_image(image, prompt_option, custom_prompt): | |
if not isinstance(image, Image.Image): | |
image = Image.open(image) | |
# Use the selected prompt option | |
if prompt_option == "Brief": | |
prompt = "<grounding>An image of" | |
elif prompt_option == "Detailed": | |
prompt = "<grounding> Describe this image in detail:" | |
else: # Custom | |
prompt = custom_prompt | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
generated_ids = model.generate( | |
pixel_values=inputs["pixel_values"], | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
image_embeds=None, | |
image_embeds_position_mask=inputs["image_embeds_position_mask"], | |
use_cache=True, | |
max_new_tokens=128, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
processed_text, entities = processor.post_process_generation(generated_text) | |
# Draw bounding boxes on a copy of the image | |
processed_image = draw_bounding_boxes(image.copy(), entities) | |
highlighted_entities = highlight_entities(processed_text, entities) | |
return processed_image, processed_text, entities, highlighted_entities | |
def clear_interface(): | |
return None, None, None, None | |
with gr.Blocks(gr.themes.Soft()) as demo: | |
gr.Markdown("# Kosmos-2 VQA Demo") | |
gr.Markdown("Run this space on your own hardware with this command: ```docker run -it```") | |
with gr.Row(equal_height=True): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
processed_image_output = gr.Image(label="Processed Image") | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
with gr.Accordion("Prompt Options"): | |
prompt_option = gr.Radio(choices=["Brief", "Detailed", "Custom"], label="Select Prompt Option", value="Brief") | |
custom_prompt_input = gr.Textbox(label="Custom Prompt", visible=False) | |
def show_custom_prompt_input(prompt_option): | |
return prompt_option == "Custom" | |
prompt_option.change(show_custom_prompt_input, inputs=[prompt_option], outputs=[custom_prompt_input]) | |
with gr.Row(equal_height=True): | |
submit_button = gr.Button("Run Model") | |
clear_button = gr.Button("Clear", elem_id="clear_button") | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
highlighted_entities = gr.Textbox(label="Processed Text") | |
with gr.Column(): | |
with gr.Accordion("Entities"): | |
entities_output = gr.JSON(label="Entities", elem_id="entities_output") | |
# Define examples | |
examples = [ | |
["assets/snowman.jpg", "Custom", "<grounding> Question: Where is<phrase> the fire</phrase><object><patch_index_0005><patch_index_0911></object> next to? Answer:"], | |
["assets/traffic.jpg", "Detailed", "<grounding> Describe this image in detail:"], | |
["assets/umbrellas.jpg", "Brief", "<grounding>An image of"], | |
] | |
gr.Examples(examples, inputs=[image_input, prompt_option, custom_prompt_input]) | |
with gr.Row(equal_height=True): | |
with gr.Accordion("Additional Info"): | |
gr.Markdown("This demo uses the [Kosmos-2]") | |
submit_button.click( | |
fn=process_image, | |
inputs=[image_input, prompt_option, custom_prompt_input], | |
outputs=[processed_image_output, highlighted_entities, entities_output] | |
) | |
clear_button.click( | |
fn=clear_interface, | |
inputs=[], | |
outputs=[image_input, processed_image_output, highlighted_entities, entities_output] | |
) | |
demo.launch() | |