kosmos-2-demo / app.py
macadeliccc's picture
init
a5153ba
raw
history blame
5.18 kB
import gradio as gr
import requests
from PIL import Image, ImageDraw, ImageFont
import random
from transformers import AutoProcessor, AutoModelForVision2Seq
# Load the model and processor
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
def draw_bounding_boxes(image: Image, entities):
draw = ImageDraw.Draw(image)
width, height = image.size
color_bank = [
"#0AC2FF", "#30D5C8", "#F3C300", "#47FF0A", "#C2FF0A"
]
try:
font_size = 20
font = ImageFont.truetype("assets/arial.ttf", font_size)
except IOError:
font_size = 20
font = ImageFont.load_default()
for entity in entities:
label, _, boxes = entity
for box in boxes:
box_coords = [
box[0] * width, box[1] * height,
box[2] * width, box[3] * height
]
outline_color = random.choice(color_bank)
text_fill_color = random.choice(color_bank)
draw.rectangle(box_coords, outline=outline_color, width=4)
text_position = (box_coords[0] + 5, box_coords[1] - font_size - 5)
draw.text(text_position, label, fill=text_fill_color, font=font)
return image
def highlight_entities(text, entities):
for entity in entities:
label = entity[0]
text = text.replace(label, f"*{label}*") # Highlighting by enclosing in asterisks
return text
def process_image(image, prompt_option, custom_prompt):
if not isinstance(image, Image.Image):
image = Image.open(image)
# Use the selected prompt option
if prompt_option == "Brief":
prompt = "<grounding>An image of"
elif prompt_option == "Detailed":
prompt = "<grounding> Describe this image in detail:"
else: # Custom
prompt = custom_prompt
inputs = processor(text=prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"],
use_cache=True,
max_new_tokens=128,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
processed_text, entities = processor.post_process_generation(generated_text)
# Draw bounding boxes on a copy of the image
processed_image = draw_bounding_boxes(image.copy(), entities)
highlighted_entities = highlight_entities(processed_text, entities)
return processed_image, processed_text, entities, highlighted_entities
def clear_interface():
return None, None, None, None
with gr.Blocks(gr.themes.Soft()) as demo:
gr.Markdown("# Kosmos-2 VQA Demo")
gr.Markdown("Run this space on your own hardware with this command: ```docker run -it```")
with gr.Row(equal_height=True):
image_input = gr.Image(type="pil", label="Upload Image")
processed_image_output = gr.Image(label="Processed Image")
with gr.Row(equal_height=True):
with gr.Column():
with gr.Accordion("Prompt Options"):
prompt_option = gr.Radio(choices=["Brief", "Detailed", "Custom"], label="Select Prompt Option", value="Brief")
custom_prompt_input = gr.Textbox(label="Custom Prompt", visible=False)
def show_custom_prompt_input(prompt_option):
return prompt_option == "Custom"
prompt_option.change(show_custom_prompt_input, inputs=[prompt_option], outputs=[custom_prompt_input])
with gr.Row(equal_height=True):
submit_button = gr.Button("Run Model")
clear_button = gr.Button("Clear", elem_id="clear_button")
with gr.Row(equal_height=True):
with gr.Column():
highlighted_entities = gr.Textbox(label="Processed Text")
with gr.Column():
with gr.Accordion("Entities"):
entities_output = gr.JSON(label="Entities", elem_id="entities_output")
# Define examples
examples = [
["assets/snowman.jpg", "Custom", "<grounding> Question: Where is<phrase> the fire</phrase><object><patch_index_0005><patch_index_0911></object> next to? Answer:"],
["assets/traffic.jpg", "Detailed", "<grounding> Describe this image in detail:"],
["assets/umbrellas.jpg", "Brief", "<grounding>An image of"],
]
gr.Examples(examples, inputs=[image_input, prompt_option, custom_prompt_input])
with gr.Row(equal_height=True):
with gr.Accordion("Additional Info"):
gr.Markdown("This demo uses the [Kosmos-2]")
submit_button.click(
fn=process_image,
inputs=[image_input, prompt_option, custom_prompt_input],
outputs=[processed_image_output, highlighted_entities, entities_output]
)
clear_button.click(
fn=clear_interface,
inputs=[],
outputs=[image_input, processed_image_output, highlighted_entities, entities_output]
)
demo.launch()