import pathlib import zipfile from typing import Any, Dict, List import cv2 import gradio as gr import numpy as np import torch from gradio_image_annotation import image_annotator from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from src.plot_utils import render_masks choice_mapping: Dict[str, List[str]] = { "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"], "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"], "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"], "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"], } def predict(model_choice, annotations: Dict[str, Any]): config_file, ckpt_path = choice_mapping[str(model_choice)] device = "cuda" if torch.cuda.is_available() else "cpu" sam2_model = build_sam2(config_file, ckpt_path, device=device) predictor = SAM2ImagePredictor(sam2_model) predictor.set_image(annotations["image"]) coordinates = [] for i in range(len(annotations["boxes"])): coordinate = [ int(annotations["boxes"][i]["xmin"]), int(annotations["boxes"][i]["ymin"]), int(annotations["boxes"][i]["xmax"]), int(annotations["boxes"][i]["ymax"]), ] coordinates.append(coordinate) masks, scores, _ = predictor.predict( point_coords=None, point_labels=None, box=np.array(coordinates), multimask_output=False, ) for count, mask in enumerate(masks): mask = mask.transpose(1, 2, 0) # type:ignore mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format cv2.imwrite(f"assets/mask_{count}.png", mask_image) mask_dir = pathlib.Path("assets/") with zipfile.ZipFile("assets/masks.zip", "w") as archive: for mask_file in mask_dir.glob("mask_*.png"): archive.write(mask_file, arcname=mask_file.relative_to(mask_dir)) return [ render_masks(annotations["image"], masks), gr.DownloadButton("Download Mask", value="assets/masks.zip", visible=True), ] with gr.Blocks(delete_cache=(30, 30)) as demo: gr.Markdown( """ # 1. Choose Model Checkpoint """ ) with gr.Row(): model = gr.Dropdown( choices=["tiny", "small", "base_plus", "large"], value="tiny", label="Model Checkpoint", info="Which model checkpoint to load?", ) gr.Markdown( """ # 2. Upload your Image and draw a bounding box """ ) annotator = image_annotator( value={"image": cv2.imread("assets/example.png")}, disable_edit_boxes=True, label="Draw a bounding box", ) btn = gr.Button("Get Segmentation Mask") download_btn = gr.DownloadButton( "Download Mask", value="assets/masks.zip", visible=False ) btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn]) demo.launch()