import gradio as gr import PIL.Image import transformers from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor import torch import os import string import functools import re import numpy as np import spaces # Model IDs MODEL_IDS = { "paligemma-3b-ft-widgetcap-waveui-448": "agentsea/paligemma-3b-ft-widgetcap-waveui-448", "paligemma-3b-ft-waveui-896": "agentsea/paligemma-3b-ft-waveui-896" } COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1'] # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load models and processors models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device) for name, model_id in MODEL_IDS.items()} processors = {name: PaliGemmaProcessor.from_pretrained(processor_id) for name, processor_id in MODEL_IDS.items()} ###### Transformers Inference @spaces.GPU def infer( image: PIL.Image.Image, text: str, max_new_tokens: int, model_choice: str ) -> str: model = models[model_choice] processor = processors[model_choice] inputs = processor(text=text, images=image, return_tensors="pt").to(device) with torch.inference_mode(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False ) result = processor.batch_decode(generated_ids, skip_special_tokens=True) return result[0][len(text):].lstrip("\n") def parse_segmentation(input_image, input_text, model_choice): out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice) objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True) labels = set(obj.get('name') for obj in objs if obj.get('name')) color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)} highlighted_text = [(obj['content'], obj.get('name')) for obj in objs] annotated_img = ( input_image, [ ( obj['mask'] if obj.get('mask') is not None else obj['xyxy'], obj['name'] or '', ) for obj in objs if 'mask' in obj or 'xyxy' in obj ], ) has_annotations = bool(annotated_img[1]) return annotated_img ######## Demo INTRO_TEXT = """## PaliGemma WaveUI\n\n Two fine-tuned models on the [WaveUI dataset](https://huggingface.co/datasets/agentsea/wave-ui) from different bases:\n\n - [paligemma-3b-ft-widgetcap-waveui-448](https://huggingface.co/agentsea/paligemma-3b-ft-widgetcap-waveui-448) - [paligemma-3b-ft-waveui-896](https://huggingface.co/agentsea/paligemma-3b-ft-waveui-896) Note:\n\n - the task they were fine-tuned on was detection, so it may not generalize to other tasks. Usage: write the task keyword "detect" before the element you want the model to detect. For example, "detect profile picture". """ with gr.Blocks(css="style.css") as demo: gr.Markdown(INTRO_TEXT) with gr.Tab("Detection"): model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys())) image = gr.Image(type="pil") seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')") seg_btn = gr.Button("Submit") annotated_image = gr.AnnotatedImage(label="Output") examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]] gr.Examples( examples=examples, inputs=[image, seg_input], ) seg_inputs = [ image, seg_input, model_choice ] seg_outputs = [ annotated_image ] seg_btn.click( fn=parse_segmentation, inputs=seg_inputs, outputs=seg_outputs, ) _SEGMENT_DETECT_RE = re.compile( r'(.*?)' + r'' * 4 + r'\s*' + '(?:%s)?' % (r'' * 16) + r'\s*([^;<>]+)? ?(?:; )?', ) def extract_objs(text, width, height, unique_labels=False): """Returns objs for a string with "" and "" tokens.""" objs = [] seen = set() while text: m = _SEGMENT_DETECT_RE.match(text) if not m: break print("m", m) gs = list(m.groups()) before = gs.pop(0) name = gs.pop() y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]] y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width)) mask = None content = m.group() if before: objs.append(dict(content=before)) content = content[len(before):] while unique_labels and name in seen: name = (name or '') + "'" seen.add(name) objs.append(dict( content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name)) text = text[len(before) + len(content):] if text: objs.append(dict(content=text)) return objs ######### if __name__ == "__main__": demo.queue(max_size=10).launch(debug=True)