OmniParser / app.py
TotoB12's picture
Update app.py
c9933c7 verified
from typing import Optional
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
from PIL import Image
from ultralytics import YOLO
yolo_model = YOLO('weights/icon_detect/best.pt')
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence",
torch_dtype=torch.float32,
trust_remote_code=True
)
caption_model_processor = {'processor': processor, 'model': model}
print('Finished loading model.')
platform = 'pc'
draw_bbox_config = {
'text_scale': 0.8,
'text_thickness': 2,
'text_padding': 2,
'thickness': 2,
}
MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent 🔥
<div>
<a href="https://arxiv.org/pdf/2408.00203">
<img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
</a>
</div>
OmniParser is a screen parsing tool to convert general GUI screens to structured elements.
"""
@torch.inference_mode()
def process(
image_input,
box_threshold,
iou_threshold
) -> Optional[Image.Image]:
image_save_path = 'imgs/saved_image_demo.png'
image_input.save(image_save_path)
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_save_path,
display_img=False,
output_bb_format='xyxy',
goal_filtering=None,
easyocr_args={'paragraph': False, 'text_threshold': 0.9},
use_paddleocr=True
)
text, ocr_bbox = ocr_bbox_rslt
dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold
)
image = Image.open(io.BytesIO(base64.b64decode(dino_labeled_img)))
print('Finished processing.')
parsed_content_list_str = '\n'.join(parsed_content_list)
label_coordinates_str = label_coordinates # '\n'.join([str(coord) for coord in label_coordinates])
return image, parsed_content_list_str, label_coordinates_str
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
with gr.Column():
image_input_component = gr.Image(type='pil', label='Upload Image')
box_threshold_component = gr.Slider(
label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
iou_threshold_component = gr.Slider(
label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(type='pil', label='Image Output')
text_output_component = gr.Textbox(
label='Parsed Screen Elements', placeholder='Text Output')
coordinates_output_component = gr.Textbox(
label='Coordinates', placeholder='Coordinates Output')
submit_button_component.click(
fn=process,
inputs=[
image_input_component,
box_threshold_component,
iou_threshold_component
],
outputs=[
image_output_component,
text_output_component,
coordinates_output_component
]
)
demo.queue().launch(share=False)