File size: 1,495 Bytes
af5888a
ec0b3c1
af5888a
 
ec0b3c1
af5888a
ec0b3c1
af5888a
 
 
 
ec0b3c1
af5888a
44b03d2
 
 
 
 
ec0b3c1
 
 
af5888a
 
 
 
 
ec0b3c1
af5888a
 
16d828f
 
 
 
 
 
 
 
 
 
af5888a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Dict, Tuple

import torch
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

BOX_PROMPT_MODE = "box prompt"
MASK_GENERATION_MODE = "mask generation"
VIDEO_SEGMENTATION_MODE = "video segmentation"
MODE_NAMES = [BOX_PROMPT_MODE, MASK_GENERATION_MODE]

CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
CHECKPOINTS = {
    "tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
    "small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
    "base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"],
    "large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"],
}


def load_models(
    device: torch.device
) -> Tuple[Dict[str, SAM2ImagePredictor], Dict[str, SAM2AutomaticMaskGenerator]]:
    image_predictors = {}
    mask_generators = {}
    for key, (config, checkpoint) in CHECKPOINTS.items():
        model = build_sam2(config, checkpoint, device=device)
        image_predictors[key] = SAM2ImagePredictor(sam_model=model)
        mask_generators[key] = SAM2AutomaticMaskGenerator(
            model=model,
            points_per_side=32,
            points_per_batch=64,
            pred_iou_thresh=0.7,
            stability_score_thresh=0.92,
            stability_score_offset=0.7,
            crop_n_layers=1,
            box_nms_thresh=0.7,
        )
    return image_predictors, mask_generators