Spaces:
Runtime error
Runtime error
File size: 773 Bytes
c3a1897 eb902b3 c3a1897 eb902b3 c3a1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import cv2
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import torch
class SegmentAnything:
def __init__(self, device, arch="vit_h", pretrained_weights="pretrained_models/sam_vit_h_4b8939.pth"):
self.device = device
self.model = self.initialize_model(arch, pretrained_weights)
def initialize_model(self, arch, pretrained_weights):
sam = sam_model_registry[arch](checkpoint=pretrained_weights)
sam.to(device=self.device)
mask_generator = SamAutomaticMaskGenerator(sam)
return mask_generator
def generate_mask(self, img_src):
image = cv2.imread(img_src)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
anns = self.model.generate(image)
return anns |