|
import pathlib |
|
import zipfile |
|
from typing import Any, Dict, List |
|
|
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from gradio_image_annotation import image_annotator |
|
from sam2.build_sam import build_sam2 |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
from src.plot_utils import render_masks |
|
|
|
choice_mapping: Dict[str, List[str]] = { |
|
"tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"], |
|
"small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"], |
|
"base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"], |
|
"large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"], |
|
} |
|
|
|
|
|
def predict(model_choice, annotations: Dict[str, Any]): |
|
config_file, ckpt_path = choice_mapping[str(model_choice)] |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
sam2_model = build_sam2(config_file, ckpt_path, device=device) |
|
predictor = SAM2ImagePredictor(sam2_model) |
|
predictor.set_image(annotations["image"]) |
|
coordinates = [] |
|
for i in range(len(annotations["boxes"])): |
|
coordinate = [ |
|
int(annotations["boxes"][i]["xmin"]), |
|
int(annotations["boxes"][i]["ymin"]), |
|
int(annotations["boxes"][i]["xmax"]), |
|
int(annotations["boxes"][i]["ymax"]), |
|
] |
|
coordinates.append(coordinate) |
|
|
|
masks, scores, _ = predictor.predict( |
|
point_coords=None, |
|
point_labels=None, |
|
box=np.array(coordinates), |
|
multimask_output=False, |
|
) |
|
for count, mask in enumerate(masks): |
|
mask = mask.transpose(1, 2, 0) |
|
mask_image = (mask * 255).astype(np.uint8) |
|
cv2.imwrite(f"assets/mask_{count}.png", mask_image) |
|
mask_dir = pathlib.Path("assets/") |
|
with zipfile.ZipFile("assets/masks.zip", "w") as archive: |
|
for mask_file in mask_dir.glob("mask_*.png"): |
|
archive.write(mask_file, arcname=mask_file.relative_to(mask_dir)) |
|
|
|
return [ |
|
render_masks(annotations["image"], masks), |
|
gr.DownloadButton("Download Mask(s)", value="assets/masks.zip", visible=True), |
|
] |
|
|
|
|
|
with gr.Blocks(delete_cache=(30, 30)) as demo: |
|
gr.Markdown( |
|
""" |
|
# 1. Choose Model Checkpoint |
|
""" |
|
) |
|
with gr.Row(): |
|
model = gr.Dropdown( |
|
choices=["tiny", "small", "base_plus", "large"], |
|
value="tiny", |
|
label="Model Checkpoint", |
|
info="Which model checkpoint to load?", |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
# 2. Upload your Image and draw bounding box(es) |
|
""" |
|
) |
|
|
|
annotator = image_annotator( |
|
value={"image": cv2.imread("assets/example.png")}, |
|
disable_edit_boxes=True, |
|
label="Draw a bounding box", |
|
) |
|
btn = gr.Button("Get Segmentation Mask(s)") |
|
download_btn = gr.DownloadButton( |
|
"Download Mask(s)", value="assets/masks.zip", visible=False |
|
) |
|
btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn]) |
|
|
|
demo.launch() |
|
|