|
import numpy as np |
|
from sam2.build_sam import build_sam2 |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
class Predictor: |
|
def __init__(self, model_cfg, checkpoint, device): |
|
self.device = device |
|
self.model = build_sam2(model_cfg, checkpoint, device=device) |
|
self.predictor = SAM2ImagePredictor(self.model) |
|
self.image_set = False |
|
|
|
def set_image(self, image): |
|
"""Set the image for SAM prediction.""" |
|
self.image = image |
|
self.predictor.set_image(image) |
|
self.image_set = True |
|
|
|
def predict(self, point_coords, point_labels, multimask_output=False): |
|
"""Run SAM prediction.""" |
|
if not self.image_set: |
|
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") |
|
return self.predictor.predict( |
|
point_coords=point_coords, |
|
point_labels=point_labels, |
|
multimask_output=multimask_output |
|
) |
|
|