File size: 2,896 Bytes
3d74359
8260e47
 
bf29adc
 
 
410698b
6447433
6038d30
bf29adc
 
630e69b
8a9151b
630e69b
410698b
 
 
8a9151b
410698b
6038d30
 
 
 
 
 
 
 
 
 
 
 
8260e47
6038d30
 
 
 
 
 
 
6447433
bf29adc
6447433
 
 
 
 
 
410698b
6038d30
 
 
6447433
 
 
 
 
 
 
bf29adc
 
 
af62359
 
 
 
 
bf29adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95190fc
bf29adc
 
 
 
630e69b
bf29adc
 
 
95190fc
410698b
 
8260e47
bf29adc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Any, Dict

import cv2
import gradio as gr
import numpy as np
from gradio_image_annotation import image_annotator
from sam2 import load_model
from sam2.utils.visualization import show_masks
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor


# @spaces.GPU()
def predict(model_choice, annotations: Dict[str, Any]):
    sam2_model = load_model(
        variant=model_choice,
        ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
        device="cpu",
    )
    if annotations["boxes"]:
        predictor = SAM2ImagePredictor(sam2_model)  # type:ignore
        predictor.set_image(annotations["image"])
        coordinates = []
        for i in range(len(annotations["boxes"])):
            coordinate = [
                int(annotations["boxes"][i]["xmin"]),
                int(annotations["boxes"][i]["ymin"]),
                int(annotations["boxes"][i]["xmax"]),
                int(annotations["boxes"][i]["ymax"]),
            ]
            coordinates.append(coordinate)

        masks, scores, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=np.array(coordinates),
            multimask_output=False,
        )

        multi_box = len(scores) > 1

        return show_masks(
            image=annotations["image"],
            masks=masks,
            scores=scores if len(scores) == 1 else None,
            only_best=not multi_box,
        )

    else:
        mask_generator = SAM2AutomaticMaskGenerator(sam2_model)  # type: ignore
        masks = mask_generator.generate(annotations["image"])
        return show_masks(
            image=annotations["image"],
            masks=masks,  # type: ignore
            scores=None,
            only_best=False,
            autogenerated_mask=True
        )


with gr.Blocks(delete_cache=(30, 30)) as demo:
    gr.Markdown(
        """
        ## To read more about the Segment Anything Project please refer to the [Lightly AI blogpost](https://www.lightly.ai/post/segment-anything-model-and-friends)
        """
    )
    gr.Markdown(
        """
        # 1. Choose Model Checkpoint
        """
    )
    with gr.Row():
        model = gr.Dropdown(
            choices=["tiny", "small", "base_plus", "large"],
            value="tiny",
            label="Model Checkpoint",
            info="Which model checkpoint to load?",
        )

    gr.Markdown(
        """
        # 2. Upload your Image and draw bounding box(es)
        """
    )

    annotator = image_annotator(
        value={"image": cv2.imread("assets/example.png")},
        disable_edit_boxes=True,
        label="Draw a bounding box",
    )
    btn = gr.Button("Get Segmentation Mask(s)")
    btn.click(
        fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")]
    )

demo.launch()