File size: 2,896 Bytes
3d74359 8260e47 bf29adc 410698b 6447433 6038d30 bf29adc 630e69b 8a9151b 630e69b 410698b 8a9151b 410698b 6038d30 8260e47 6038d30 6447433 bf29adc 6447433 410698b 6038d30 6447433 bf29adc af62359 bf29adc 95190fc bf29adc 630e69b bf29adc 95190fc 410698b 8260e47 bf29adc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from typing import Any, Dict
import cv2
import gradio as gr
import numpy as np
from gradio_image_annotation import image_annotator
from sam2 import load_model
from sam2.utils.visualization import show_masks
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor
# @spaces.GPU()
def predict(model_choice, annotations: Dict[str, Any]):
sam2_model = load_model(
variant=model_choice,
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
device="cpu",
)
if annotations["boxes"]:
predictor = SAM2ImagePredictor(sam2_model) # type:ignore
predictor.set_image(annotations["image"])
coordinates = []
for i in range(len(annotations["boxes"])):
coordinate = [
int(annotations["boxes"][i]["xmin"]),
int(annotations["boxes"][i]["ymin"]),
int(annotations["boxes"][i]["xmax"]),
int(annotations["boxes"][i]["ymax"]),
]
coordinates.append(coordinate)
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=np.array(coordinates),
multimask_output=False,
)
multi_box = len(scores) > 1
return show_masks(
image=annotations["image"],
masks=masks,
scores=scores if len(scores) == 1 else None,
only_best=not multi_box,
)
else:
mask_generator = SAM2AutomaticMaskGenerator(sam2_model) # type: ignore
masks = mask_generator.generate(annotations["image"])
return show_masks(
image=annotations["image"],
masks=masks, # type: ignore
scores=None,
only_best=False,
autogenerated_mask=True
)
with gr.Blocks(delete_cache=(30, 30)) as demo:
gr.Markdown(
"""
## To read more about the Segment Anything Project please refer to the [Lightly AI blogpost](https://www.lightly.ai/post/segment-anything-model-and-friends)
"""
)
gr.Markdown(
"""
# 1. Choose Model Checkpoint
"""
)
with gr.Row():
model = gr.Dropdown(
choices=["tiny", "small", "base_plus", "large"],
value="tiny",
label="Model Checkpoint",
info="Which model checkpoint to load?",
)
gr.Markdown(
"""
# 2. Upload your Image and draw bounding box(es)
"""
)
annotator = image_annotator(
value={"image": cv2.imread("assets/example.png")},
disable_edit_boxes=True,
label="Draw a bounding box",
)
btn = gr.Button("Get Segmentation Mask(s)")
btn.click(
fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")]
)
demo.launch()
|