import gradio as gr import requests from PIL import Image, ImageDraw, ImageFont import random from transformers import AutoProcessor, AutoModelForVision2Seq # Load the model and processor model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224") processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") def draw_bounding_boxes(image: Image, entities): draw = ImageDraw.Draw(image) width, height = image.size color_bank = [ "#0AC2FF", "#30D5C8", "#F3C300", "#47FF0A", "#C2FF0A" ] try: font_size = 20 font = ImageFont.truetype("assets/arial.ttf", font_size) except IOError: font_size = 20 font = ImageFont.load_default() for entity in entities: label, _, boxes = entity for box in boxes: box_coords = [ box[0] * width, box[1] * height, box[2] * width, box[3] * height ] outline_color = random.choice(color_bank) text_fill_color = random.choice(color_bank) draw.rectangle(box_coords, outline=outline_color, width=4) text_position = (box_coords[0] + 5, box_coords[1] - font_size - 5) draw.text(text_position, label, fill=text_fill_color, font=font) return image def highlight_entities(text, entities): for entity in entities: label = entity[0] text = text.replace(label, f"*{label}*") # Highlighting by enclosing in asterisks return text def process_image(image, prompt_option, custom_prompt): if not isinstance(image, Image.Image): image = Image.open(image) # Use the selected prompt option if prompt_option == "Brief": prompt = "An image of" elif prompt_option == "Detailed": prompt = " Describe this image in detail:" else: # Custom prompt = custom_prompt inputs = processor(text=prompt, images=image, return_tensors="pt") generated_ids = model.generate( pixel_values=inputs["pixel_values"], input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], image_embeds=None, image_embeds_position_mask=inputs["image_embeds_position_mask"], use_cache=True, max_new_tokens=128, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] processed_text, entities = processor.post_process_generation(generated_text) # Draw bounding boxes on a copy of the image processed_image = draw_bounding_boxes(image.copy(), entities) highlighted_entities = highlight_entities(processed_text, entities) return processed_image, processed_text, entities, highlighted_entities def clear_interface(): return None, None, None, None with gr.Blocks(gr.themes.Soft()) as demo: gr.Markdown("# Kosmos-2 VQA Demo") gr.Markdown("Run this space on your own hardware with this command: ```docker run -it```") with gr.Row(equal_height=True): image_input = gr.Image(type="pil", label="Upload Image") processed_image_output = gr.Image(label="Processed Image") with gr.Row(equal_height=True): with gr.Column(): with gr.Accordion("Prompt Options"): prompt_option = gr.Radio(choices=["Brief", "Detailed", "Custom"], label="Select Prompt Option", value="Brief") custom_prompt_input = gr.Textbox(label="Custom Prompt", visible=False) def show_custom_prompt_input(prompt_option): return prompt_option == "Custom" prompt_option.change(show_custom_prompt_input, inputs=[prompt_option], outputs=[custom_prompt_input]) with gr.Row(equal_height=True): submit_button = gr.Button("Run Model") clear_button = gr.Button("Clear", elem_id="clear_button") with gr.Row(equal_height=True): with gr.Column(): highlighted_entities = gr.Textbox(label="Processed Text") with gr.Column(): with gr.Accordion("Entities"): entities_output = gr.JSON(label="Entities", elem_id="entities_output") # Define examples examples = [ ["assets/snowman.jpg", "Custom", " Question: Where is the fire next to? Answer:"], ["assets/traffic.jpg", "Detailed", " Describe this image in detail:"], ["assets/umbrellas.jpg", "Brief", "An image of"], ] gr.Examples(examples, inputs=[image_input, prompt_option, custom_prompt_input]) with gr.Row(equal_height=True): with gr.Accordion("Additional Info"): gr.Markdown("This demo uses the [Kosmos-2]") submit_button.click( fn=process_image, inputs=[image_input, prompt_option, custom_prompt_input], outputs=[processed_image_output, highlighted_entities, entities_output] ) clear_button.click( fn=clear_interface, inputs=[], outputs=[image_input, processed_image_output, highlighted_entities, entities_output] ) demo.launch()