nph4rd's picture
Update app.py
28e6c0e verified
raw
history blame contribute delete
No virus
5.06 kB
import gradio as gr
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import os
import string
import functools
import re
import numpy as np
import spaces
# Model IDs
MODEL_IDS = {
"paligemma-3b-ft-widgetcap-waveui-448": "agentsea/paligemma-3b-ft-widgetcap-waveui-448",
"paligemma-3b-ft-waveui-896": "agentsea/paligemma-3b-ft-waveui-896"
}
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load models and processors
models = {name: PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
for name, model_id in MODEL_IDS.items()}
processors = {name: PaliGemmaProcessor.from_pretrained(processor_id)
for name, processor_id in MODEL_IDS.items()}
###### Transformers Inference
@spaces.GPU
def infer(
image: PIL.Image.Image,
text: str,
max_new_tokens: int,
model_choice: str
) -> str:
model = models[model_choice]
processor = processors[model_choice]
inputs = processor(text=text, images=image, return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n")
def parse_segmentation(input_image, input_text, model_choice):
out = infer(input_image, input_text, max_new_tokens=100, model_choice=model_choice)
objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
labels = set(obj.get('name') for obj in objs if obj.get('name'))
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
annotated_img = (
input_image,
[
(
obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
obj['name'] or '',
)
for obj in objs
if 'mask' in obj or 'xyxy' in obj
],
)
has_annotations = bool(annotated_img[1])
return annotated_img
######## Demo
INTRO_TEXT = """## PaliGemma WaveUI\n\n
Two fine-tuned models on the [WaveUI dataset](https://huggingface.co/datasets/agentsea/wave-ui) from different bases:\n\n
- [paligemma-3b-ft-widgetcap-waveui-448](https://huggingface.co/agentsea/paligemma-3b-ft-widgetcap-waveui-448)
- [paligemma-3b-ft-waveui-896](https://huggingface.co/agentsea/paligemma-3b-ft-waveui-896)
Note:\n\n
- the task they were fine-tuned on was detection, so it may not generalize to other tasks.
Usage: write the task keyword "detect" before the element you want the model to detect. For example, "detect profile picture".
"""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(INTRO_TEXT)
with gr.Tab("Detection"):
model_choice = gr.Dropdown(label="Select Model", choices=list(MODEL_IDS.keys()))
image = gr.Image(type="pil")
seg_input = gr.Text(label="Detect instruction (e.g. 'detect sign in button')")
seg_btn = gr.Button("Submit")
annotated_image = gr.AnnotatedImage(label="Output")
examples = [["./airbnb.jpg", "detect 'Amazing pools' button"]]
gr.Examples(
examples=examples,
inputs=[image, seg_input],
)
seg_inputs = [
image,
seg_input,
model_choice
]
seg_outputs = [
annotated_image
]
seg_btn.click(
fn=parse_segmentation,
inputs=seg_inputs,
outputs=seg_outputs,
)
_SEGMENT_DETECT_RE = re.compile(
r'(.*?)' +
r'<loc(\d{4})>' * 4 + r'\s*' +
'(?:%s)?' % (r'<seg(\d{3})>' * 16) +
r'\s*([^;<>]+)? ?(?:; )?',
)
def extract_objs(text, width, height, unique_labels=False):
"""Returns objs for a string with "<loc>" and "<seg>" tokens."""
objs = []
seen = set()
while text:
m = _SEGMENT_DETECT_RE.match(text)
if not m:
break
print("m", m)
gs = list(m.groups())
before = gs.pop(0)
name = gs.pop()
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
mask = None
content = m.group()
if before:
objs.append(dict(content=before))
content = content[len(before):]
while unique_labels and name in seen:
name = (name or '') + "'"
seen.add(name)
objs.append(dict(
content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
text = text[len(before) + len(content):]
if text:
objs.append(dict(content=text))
return objs
#########
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)