Spaces:
Runtime error
Runtime error
File size: 1,495 Bytes
af5888a ec0b3c1 af5888a ec0b3c1 af5888a ec0b3c1 af5888a ec0b3c1 af5888a 44b03d2 ec0b3c1 af5888a ec0b3c1 af5888a 16d828f af5888a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from typing import Dict, Tuple
import torch
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
BOX_PROMPT_MODE = "box prompt"
MASK_GENERATION_MODE = "mask generation"
VIDEO_SEGMENTATION_MODE = "video segmentation"
MODE_NAMES = [BOX_PROMPT_MODE, MASK_GENERATION_MODE]
CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
CHECKPOINTS = {
"tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
"small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
"base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"],
"large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"],
}
def load_models(
device: torch.device
) -> Tuple[Dict[str, SAM2ImagePredictor], Dict[str, SAM2AutomaticMaskGenerator]]:
image_predictors = {}
mask_generators = {}
for key, (config, checkpoint) in CHECKPOINTS.items():
model = build_sam2(config, checkpoint, device=device)
image_predictors[key] = SAM2ImagePredictor(sam_model=model)
mask_generators[key] = SAM2AutomaticMaskGenerator(
model=model,
points_per_side=32,
points_per_batch=64,
pred_iou_thresh=0.7,
stability_score_thresh=0.92,
stability_score_offset=0.7,
crop_n_layers=1,
box_nms_thresh=0.7,
)
return image_predictors, mask_generators
|