File size: 3,065 Bytes
8260e47
 
 
 
 
bf29adc
 
e6eaebf
bf29adc
 
 
 
8260e47
 
630e69b
bf29adc
 
 
 
 
 
630e69b
 
bf29adc
e6eaebf
 
bf29adc
630e69b
8260e47
 
 
 
 
 
 
bf29adc
8260e47
 
bf29adc
 
 
8260e47
bf29adc
 
8260e47
 
 
 
 
 
 
 
bf29adc
 
8260e47
 
bf29adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630e69b
bf29adc
 
 
 
630e69b
bf29adc
 
 
 
8260e47
 
 
630e69b
bf29adc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pathlib
import zipfile
from typing import Any, Dict, List

import cv2
import gradio as gr
import numpy as np
import torch
from gradio_image_annotation import image_annotator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

from src.plot_utils import render_masks

choice_mapping: Dict[str, List[str]] = {
    "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
    "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
    "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
    "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
}


def predict(model_choice, annotations: Dict[str, Any]):
    config_file, ckpt_path = choice_mapping[str(model_choice)]
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sam2_model = build_sam2(config_file, ckpt_path, device=device)
    predictor = SAM2ImagePredictor(sam2_model)
    predictor.set_image(annotations["image"])
    coordinates = []
    for i in range(len(annotations["boxes"])):
        coordinate = [
            int(annotations["boxes"][i]["xmin"]),
            int(annotations["boxes"][i]["ymin"]),
            int(annotations["boxes"][i]["xmax"]),
            int(annotations["boxes"][i]["ymax"]),
        ]
        coordinates.append(coordinate)

    masks, scores, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=np.array(coordinates),
        multimask_output=False,
    )
    for count, mask in enumerate(masks):
        mask = mask.transpose(1, 2, 0)  # type:ignore
        mask_image = (mask * 255).astype(np.uint8)  # Convert to uint8 format
        cv2.imwrite(f"assets/mask_{count}.png", mask_image)
        mask_dir = pathlib.Path("assets/")
        with zipfile.ZipFile("assets/masks.zip", "w") as archive:
            for mask_file in mask_dir.glob("mask_*.png"):
                archive.write(mask_file, arcname=mask_file.relative_to(mask_dir))

    return [
        render_masks(annotations["image"], masks),
        gr.DownloadButton("Download Mask", value="assets/masks.zip", visible=True),
    ]


with gr.Blocks(delete_cache=(30, 30)) as demo:
    gr.Markdown(
        """
        # 1. Choose Model Checkpoint
        """
    )
    with gr.Row():
        model = gr.Dropdown(
            choices=["tiny", "small", "base_plus", "large"],
            value="tiny",
            label="Model Checkpoint",
            info="Which model checkpoint to load?",
        )

    gr.Markdown(
        """
        # 2. Upload your Image and draw a bounding box
        """
    )

    annotator = image_annotator(
        value={"image": cv2.imread("assets/example.png")},
        disable_edit_boxes=True,
        label="Draw a bounding box",
    )
    btn = gr.Button("Get Segmentation Mask")
    download_btn = gr.DownloadButton(
        "Download Mask", value="assets/masks.zip", visible=False
    )
    btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn])

demo.launch()