Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,062 Bytes
c29ac1a 1761ad9 eab0af2 1761ad9 fa8db89 1761ad9 d3c1d14 1761ad9 28e6c0e c29ac1a 1761ad9 c29ac1a 1761ad9 c29ac1a 1761ad9 c29ac1a 1761ad9 c29ac1a eab0af2 806a1a2 eab0af2 806a1a2 9313466 806a1a2 c29ac1a 1761ad9 c29ac1a 1761ad9 286ea0b 1761ad9 c29ac1a 1761ad9 c29ac1a d3c1d14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import os
import string
import functools
import re
import numpy as np
import spaces
# Model IDs
MODEL_IDS = {
"paligemma-3b-ft-widgetcap-waveui-448": "agentsea/paligemma-3b-ft-widgetcap-waveui-448",
"paligemma-3b-ft-waveui-896": "agentsea/paligemma-3b-ft-waveui-896"
}
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models and processors
models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
for name, model_id in MODEL_IDS.items()}
processors = {name: PaliGemmaProcessor.from_pretrained(processor_id)
for name, processor_id in MODEL_IDS.items()}
###### Transformers Inference
@spaces.GPU
def infer(
image: PIL.Image.Image,
text: str,
max_new_tokens: int,
model_choice: str
) -> str:
model = models[model_choice]
processor = processors[model_choice]
inputs = processor(text=text, images=image, return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n")
def parse_segmentation(input_image, input_text, model_choice):
out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice)
objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
labels = set(obj.get('name') for obj in objs if obj.get('name'))
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
annotated_img = (
input_image,
[
(
obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
obj['name'] or '',
)
for obj in objs
if 'mask' in obj or 'xyxy' in obj
],
)
has_annotations = bool(annotated_img[1])
return annotated_img
######## Demo
INTRO_TEXT = """## PaliGemma WaveUI\n\n
Two fine-tuned models on the [WaveUI dataset](https://huggingface.co/datasets/agentsea/wave-ui) from different bases:\n\n
- [paligemma-3b-ft-widgetcap-waveui-448](https://huggingface.co/agentsea/paligemma-3b-ft-widgetcap-waveui-448)
- [paligemma-3b-ft-waveui-896](https://huggingface.co/agentsea/paligemma-3b-ft-waveui-896)
Note:\n\n
- the task they were fine-tuned on was detection, so it may not generalize to other tasks.
Usage: write the task keyword "detect" before the element you want the model to detect. For example, "detect profile picture".
"""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(INTRO_TEXT)
with gr.Tab("Detection"):
model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys()))
image = gr.Image(type="pil")
seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
seg_btn = gr.Button("Submit")
annotated_image = gr.AnnotatedImage(label="Output")
examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
gr.Examples(
examples=examples,
inputs=[image, seg_input],
)
seg_inputs = [
image,
seg_input,
model_choice
]
seg_outputs = [
annotated_image
]
seg_btn.click(
fn=parse_segmentation,
inputs=seg_inputs,
outputs=seg_outputs,
)
_SEGMENT_DETECT_RE = re.compile(
r'(.*?)' +
r'<loc(\d{4})>' * 4 + r'\s*' +
'(?:%s)?' % (r'<seg(\d{3})>' * 16) +
r'\s*([^;<>]+)? ?(?:; )?',
)
def extract_objs(text, width, height, unique_labels=False):
"""Returns objs for a string with "<loc>" and "<seg>" tokens."""
objs = []
seen = set()
while text:
m = _SEGMENT_DETECT_RE.match(text)
if not m:
break
print("m", m)
gs = list(m.groups())
before = gs.pop(0)
name = gs.pop()
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
mask = None
content = m.group()
if before:
objs.append(dict(content=before))
content = content[len(before):]
while unique_labels and name in seen:
name = (name or '') + "'"
seen.add(name)
objs.append(dict(
content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
text = text[len(before) + len(content):]
if text:
objs.append(dict(content=text))
return objs
#########
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True) |