import os from typing import Any, Dict import cv2 import gradio as gr import numpy as np import torch from gradio_image_annotation import image_annotator from sam2 import load_model from sam2.sam2_image_predictor import SAM2ImagePredictor from src.plot_utils import export_mask from spaces import GPU os.environ["ZEROGPU_V2"] = "true" @GPU() def predict(model_choice, annotations: Dict[str, Any]): device = "cuda" if torch.cuda.is_available() else "cpu" sam2_model = load_model( variant=model_choice, ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt", device=device, ) predictor = SAM2ImagePredictor(sam2_model) # type:ignore predictor.set_image(annotations["image"]) coordinates = [] for i in range(len(annotations["boxes"])): coordinate = [ int(annotations["boxes"][i]["xmin"]), int(annotations["boxes"][i]["ymin"]), int(annotations["boxes"][i]["xmax"]), int(annotations["boxes"][i]["ymax"]), ] coordinates.append(coordinate) masks, scores, _ = predictor.predict( point_coords=None, point_labels=None, box=np.array(coordinates), multimask_output=False, ) if masks.shape[0] == 1: # handle single mask cases masks = np.expand_dims(masks, axis=0) return export_mask(masks) with gr.Blocks(delete_cache=(30, 30)) as demo: gr.Markdown( """ # 1. Choose Model Checkpoint """ ) with gr.Row(): model = gr.Dropdown( choices=["tiny", "small", "base_plus", "large"], value="tiny", label="Model Checkpoint", info="Which model checkpoint to load?", ) gr.Markdown( """ # 2. Upload your Image and draw bounding box(es) """ ) annotator = image_annotator( value={"image": cv2.imread("assets/example.png")}, disable_edit_boxes=True, label="Draw a bounding box", ) btn = gr.Button("Get Segmentation Mask(s)") btn.click( fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")] ) demo.launch()