sam2-playground / modules /sam_inference.py
jhj0517
Add point prompt
41938cd
raw
history blame
6.86 kB
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from typing import Dict, List, Optional
import torch
import os
from datetime import datetime
import numpy as np
from modules.model_downloader import (
AVAILABLE_MODELS, DEFAULT_MODEL_TYPE, OUTPUT_DIR,
is_sam_exist,
download_sam_model_url
)
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR
from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE
from modules.mask_utils import (
save_psd_with_masks,
create_mask_combined_images,
create_mask_gallery
)
CONFIGS = {
"sam2_hiera_tiny": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_t.yaml"),
"sam2_hiera_small": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_s.yaml"),
"sam2_hiera_base_plus": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_b+.yaml"),
"sam2_hiera_large": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_l.yaml"),
}
class SamInference:
def __init__(self,
model_dir: str = MODELS_DIR,
output_dir: str = OUTPUT_DIR
):
self.model = None
self.available_models = list(AVAILABLE_MODELS.keys())
self.model_type = DEFAULT_MODEL_TYPE
self.model_dir = model_dir
self.output_dir = output_dir
self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0])
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.mask_generator = None
self.image_predictor = None
self.video_predictor = None
def load_model(self):
config = CONFIGS[self.model_type]
filename, url = AVAILABLE_MODELS[self.model_type]
model_path = os.path.join(self.model_dir, filename)
if not is_sam_exist(self.model_type):
print(f"\nNo SAM2 model found, downloading {self.model_type} model...")
download_sam_model_url(self.model_type)
print("\nApplying configs to model..")
try:
self.model = build_sam2(
config_file=config,
ckpt_path=model_path,
device=self.device
)
except Exception as e:
print(f"Error while Loading SAM2 model! {e}")
def generate_mask(self,
image: np.ndarray,
model_type: str,
**params):
if self.model is None or self.model_type != model_type:
self.model_type = model_type
self.load_model()
self.mask_generator = SAM2AutomaticMaskGenerator(
model=self.model,
**params
)
try:
generated_masks = self.mask_generator.generate(image)
except Exception as e:
raise f"Error while auto generating masks: {e}"
return generated_masks
def predict_image(self,
image: np.ndarray,
model_type: str,
box: Optional[np.ndarray] = None,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
**params):
if self.model is None or self.model_type != model_type:
self.model_type = model_type
self.load_model()
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
self.image_predictor.set_image(image)
try:
masks, scores, logits = self.image_predictor.predict(
box=box,
point_coords=point_coords,
point_labels=point_labels,
multimask_output=params["multimask_output"],
)
except Exception as e:
raise f"Error while predicting image with prompt: {e}"
return masks, scores, logits
def divide_layer(self,
image_input: np.ndarray,
image_prompt_input_data: Dict,
input_mode: str,
model_type: str,
*params):
timestamp = datetime.now().strftime("%m%d%H%M%S")
output_file_name = f"result-{timestamp}.psd"
output_path = os.path.join(self.output_dir, "psd", output_file_name)
hparams = {
'points_per_side': int(params[0]),
'points_per_batch': int(params[1]),
'pred_iou_thresh': float(params[2]),
'stability_score_thresh': float(params[3]),
'stability_score_offset': float(params[4]),
'crop_n_layers': int(params[5]),
'box_nms_thresh': float(params[6]),
'crop_n_points_downscale_factor': int(params[7]),
'min_mask_region_area': int(params[8]),
'use_m2m': bool(params[9]),
'multimask_output': bool(params[10])
}
if input_mode == AUTOMATIC_MODE:
image = image_input
generated_masks = self.generate_mask(
image=image,
model_type=model_type,
**hparams
)
elif input_mode == BOX_PROMPT_MODE:
image = image_prompt_input_data["image"]
image = np.array(image.convert("RGB"))
prompt = image_prompt_input_data["points"]
if len(prompt) == 0:
return [image], []
is_prompt_point = prompt[0][-1] == 4.0
if is_prompt_point:
point_labels = np.array([1 if is_left_click else 0 for x1, y1, is_left_click, x2, y2, _ in prompt])
prompt = np.array([[x1, y1] for x1, y1, is_left_click, x2, y2, _ in prompt])
else:
prompt = np.array([[x1, y1, x2, y2] for x1, y1, is_left_click, x2, y2, _ in prompt])
predicted_masks, scores, logits = self.predict_image(
image=image,
model_type=model_type,
box=prompt if not is_prompt_point else None,
point_coords=prompt if is_prompt_point else None,
point_labels=point_labels if is_prompt_point else None,
multimask_output=hparams["multimask_output"]
)
generated_masks = self.format_to_auto_result(predicted_masks)
save_psd_with_masks(image, generated_masks, output_path)
mask_combined_image = create_mask_combined_images(image, generated_masks)
gallery = create_mask_gallery(image, generated_masks)
return [mask_combined_image] + gallery, output_path
@staticmethod
def format_to_auto_result(
masks: np.ndarray
):
place_holder = 0
if len(masks.shape) <= 3:
masks = np.expand_dims(masks, axis=0)
result = [{"segmentation": mask[0], "area": place_holder} for mask in masks]
return result