Image2Paragraph / models /segment_models /semgent_anything_model.py
Awiny's picture
update gradio ui
eb902b3
raw
history blame
773 Bytes
import cv2
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import torch
class SegmentAnything:
def __init__(self, device, arch="vit_h", pretrained_weights="pretrained_models/sam_vit_h_4b8939.pth"):
self.device = device
self.model = self.initialize_model(arch, pretrained_weights)
def initialize_model(self, arch, pretrained_weights):
sam = sam_model_registry[arch](checkpoint=pretrained_weights)
sam.to(device=self.device)
mask_generator = SamAutomaticMaskGenerator(sam)
return mask_generator
def generate_mask(self, img_src):
image = cv2.imread(img_src)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
anns = self.model.generate(image)
return anns