Spaces:
Runtime error
Runtime error
from typing import Optional | |
import gradio as gr | |
import numpy as np | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from gradio_image_prompter import ImagePrompter | |
from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \ | |
MASK_GENERATION_MODE, BOX_PROMPT_MODE | |
import spaces | |
MARKDOWN = """ | |
# Segment Anything Model 2 🔥 | |
<div> | |
<a href="https://github.com/facebookresearch/segment-anything-2"> | |
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;"> | |
</a> | |
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb"> | |
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;"> | |
</a> | |
<a href="https://blog.roboflow.com/what-is-segment-anything-2/"> | |
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;"> | |
</a> | |
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y"> | |
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;"> | |
</a> | |
</div> | |
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable | |
visual segmentation in both images and videos. **Video segmentation will be available | |
soon.** | |
""" | |
EXAMPLES = [ | |
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None], | |
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None], | |
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None], | |
] | |
DEVICE = torch.device('cuda') | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE) | |
def process( | |
checkpoint_dropdown, | |
mode_dropdown, | |
image_input, | |
image_prompter_input | |
) -> Optional[Image.Image]: | |
if mode_dropdown == BOX_PROMPT_MODE: | |
image_input = image_prompter_input["image"] | |
prompt = image_prompter_input["points"] | |
if len(prompt) == 0: | |
return image_input | |
model = IMAGE_PREDICTORS[checkpoint_dropdown] | |
image = np.array(image_input.convert("RGB")) | |
box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt]) | |
model.set_image(image) | |
masks, _, _ = model.predict(box=box, multimask_output=False) | |
# dirty fix; remove this later | |
if len(masks.shape) == 4: | |
masks = np.squeeze(masks) | |
detections = sv.Detections( | |
xyxy=sv.mask_to_xyxy(masks=masks), | |
mask=masks.astype(bool) | |
) | |
return MASK_ANNOTATOR.annotate(image_input, detections) | |
if mode_dropdown == MASK_GENERATION_MODE: | |
model = MASK_GENERATORS[checkpoint_dropdown] | |
image = np.array(image_input.convert("RGB")) | |
result = model.generate(image) | |
detections = sv.Detections.from_sam(result) | |
return MASK_ANNOTATOR.annotate(image_input, detections) | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
checkpoint_dropdown_component = gr.Dropdown( | |
choices=CHECKPOINT_NAMES, | |
value=CHECKPOINT_NAMES[0], | |
label="Checkpoint", info="Select a SAM2 checkpoint to use.", | |
interactive=True | |
) | |
mode_dropdown_component = gr.Dropdown( | |
choices=MODE_NAMES, | |
value=MODE_NAMES[0], | |
label="Mode", | |
info="Select a mode to use. `box prompt` if you want to generate masks for " | |
"selected objects, `mask generation` if you want to generate masks " | |
"for the whole image.", | |
interactive=True | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image_input_component = gr.Image( | |
type='pil', label='Upload image') | |
image_prompter_input_component = ImagePrompter( | |
type='pil', label='Image prompt', visible=False) | |
submit_button_component = gr.Button( | |
value='Submit', variant='primary') | |
with gr.Column(): | |
image_output_component = gr.Image(type='pil', label='Image Output') | |
with gr.Row(): | |
gr.Examples( | |
fn=process, | |
examples=EXAMPLES, | |
inputs=[ | |
checkpoint_dropdown_component, | |
mode_dropdown_component, | |
image_input_component, | |
image_prompter_input_component, | |
], | |
outputs=[image_output_component], | |
cache_examples=False, | |
run_on_click=True | |
) | |
def on_mode_dropdown_change(text): | |
return [ | |
gr.Image(visible=text == MASK_GENERATION_MODE), | |
ImagePrompter(visible=text == BOX_PROMPT_MODE) | |
] | |
mode_dropdown_component.change( | |
on_mode_dropdown_change, | |
inputs=[mode_dropdown_component], | |
outputs=[ | |
image_input_component, | |
image_prompter_input_component | |
] | |
) | |
submit_button_component.click( | |
fn=process, | |
inputs=[ | |
checkpoint_dropdown_component, | |
mode_dropdown_component, | |
image_input_component, | |
image_prompter_input_component, | |
], | |
outputs=[image_output_component] | |
) | |
demo.launch(debug=False, show_error=True) | |