SkalskiP's picture
add examples and improve SAM2AutomaticMaskGenerator results
16d828f
raw
history blame
5.49 kB
from typing import Optional
import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from gradio_image_prompter import ImagePrompter
from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
MASK_GENERATION_MODE, BOX_PROMPT_MODE
MARKDOWN = """
# Segment Anything Model 2 🔥
<div>
<a href="https://github.com/facebookresearch/segment-anything-2">
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;">
</a>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/what-is-segment-anything-2/">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
</a>
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y">
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
</a>
</div>
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable
visual segmentation in both images and videos. **Video segmentation will be available
soon.**
"""
EXAMPLES = [
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None],
]
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)
def process(
checkpoint_dropdown,
mode_dropdown,
image_input,
image_prompter_input
) -> Optional[Image.Image]:
if mode_dropdown == BOX_PROMPT_MODE:
image_input = image_prompter_input["image"]
prompt = image_prompter_input["points"]
if len(prompt) == 0:
return image_input
model = IMAGE_PREDICTORS[checkpoint_dropdown]
image = np.array(image_input.convert("RGB"))
box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt])
model.set_image(image)
masks, _, _ = model.predict(box=box, multimask_output=False)
# dirty fix; remove this later
if len(masks.shape) == 4:
masks = np.squeeze(masks)
detections = sv.Detections(
xyxy=sv.mask_to_xyxy(masks=masks),
mask=masks.astype(bool)
)
return MASK_ANNOTATOR.annotate(image_input, detections)
if mode_dropdown == MASK_GENERATION_MODE:
model = MASK_GENERATORS[checkpoint_dropdown]
image = np.array(image_input.convert("RGB"))
result = model.generate(image)
detections = sv.Detections.from_sam(result)
return MASK_ANNOTATOR.annotate(image_input, detections)
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
checkpoint_dropdown_component = gr.Dropdown(
choices=CHECKPOINT_NAMES,
value=CHECKPOINT_NAMES[0],
label="Checkpoint", info="Select a SAM2 checkpoint to use.",
interactive=True
)
mode_dropdown_component = gr.Dropdown(
choices=MODE_NAMES,
value=MODE_NAMES[0],
label="Mode",
info="Select a mode to use. `box prompt` if you want to generate masks for "
"selected objects, `mask generation` if you want to generate masks "
"for the whole image.",
interactive=True
)
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image', visible=False)
image_prompter_input_component = ImagePrompter(
type='pil', label='Image prompt')
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
image_output_component = gr.Image(type='pil', label='Image Output')
with gr.Row():
gr.Examples(
fn=process,
examples=EXAMPLES,
inputs=[
checkpoint_dropdown_component,
mode_dropdown_component,
image_input_component,
image_prompter_input_component,
],
outputs=[image_output_component],
run_on_click=True
)
def on_mode_dropdown_change(text):
return [
gr.Image(visible=text == MASK_GENERATION_MODE),
ImagePrompter(visible=text == BOX_PROMPT_MODE)
]
mode_dropdown_component.change(
on_mode_dropdown_change,
inputs=[mode_dropdown_component],
outputs=[
image_input_component,
image_prompter_input_component
]
)
submit_button_component.click(
fn=process,
inputs=[
checkpoint_dropdown_component,
mode_dropdown_component,
image_input_component,
image_prompter_input_component,
],
outputs=[image_output_component]
)
demo.launch(debug=False, show_error=True, max_threads=1)