from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from sam2.build_sam import build_sam2, build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_video_predictor import SAM2VideoPredictor from typing import Dict, List, Optional import torch import os from datetime import datetime import numpy as np import gradio as gr from modules.model_downloader import ( AVAILABLE_MODELS, DEFAULT_MODEL_TYPE, OUTPUT_DIR, is_sam_exist, download_sam_model_url ) from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR from modules.constants import BOX_PROMPT_MODE, AUTOMATIC_MODE, COLOR_FILTER, PIXELIZE_FILTER from modules.mask_utils import ( save_psd_with_masks, create_mask_combined_images, create_mask_gallery, create_mask_pixelized_image, create_solid_color_mask_image ) from modules.logger_util import get_logger MODEL_CONFIGS = { "sam2_hiera_tiny": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_t.yaml"), "sam2_hiera_small": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_s.yaml"), "sam2_hiera_base_plus": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_b+.yaml"), "sam2_hiera_large": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_l.yaml"), } logger = get_logger() class SamInference: def __init__(self, model_dir: str = MODELS_DIR, output_dir: str = OUTPUT_DIR ): self.model = None self.available_models = list(AVAILABLE_MODELS.keys()) self.current_model_type = DEFAULT_MODEL_TYPE self.model_dir = model_dir self.output_dir = output_dir self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0]) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.mask_generator = None self.image_predictor = None self.video_predictor = None self.video_inference_state = None def load_model(self, model_type: Optional[str] = None, load_video_predictor: bool = False): if model_type is None: model_type = DEFAULT_MODEL_TYPE config = MODEL_CONFIGS[model_type] filename, url = AVAILABLE_MODELS[model_type] model_path = os.path.join(self.model_dir, filename) if not is_sam_exist(model_type): logger.info(f"No SAM2 model found, downloading {model_type} model...") download_sam_model_url(model_type) logger.info(f"Applying configs to {model_type} model..") if load_video_predictor: try: self.model = None self.video_predictor = build_sam2_video_predictor( config_file=config, ckpt_path=model_path, device=self.device ) except Exception as e: logger.exception("Error while loading SAM2 model for video predictor") raise f"Error while loading SAM2 model for video predictor!: {e}" try: self.model = build_sam2( config_file=config, ckpt_path=model_path, device=self.device ) except Exception as e: logger.exception("Error while loading SAM2 model") raise f"Error while loading SAM2 model!: {e}" def init_video_inference_state(self, model_type: str, vid_input: str): if self.video_predictor is None or model_type != self.current_model_type: self.current_model_type = model_type self.load_model(model_type=model_type, load_video_predictor=True) if self.video_inference_state is not None: self.video_predictor.reset_state(self.video_inference_state) self.video_inference_state = None self.video_inference_state = self.video_predictor.init_state(video_path=vid_input) def generate_mask(self, image: np.ndarray, model_type: str, **params): if self.model is None or self.current_model_type != model_type: self.current_model_type = model_type self.load_model(model_type=model_type) self.mask_generator = SAM2AutomaticMaskGenerator( model=self.model, **params ) try: generated_masks = self.mask_generator.generate(image) except Exception as e: logger.exception("Error while auto generating masks") raise f"Error while auto generating masks: str({e})" return generated_masks def predict_image(self, image: np.ndarray, model_type: str, box: Optional[np.ndarray] = None, point_coords: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, **params): if self.model is None or self.current_model_type != model_type: self.current_model_type = model_type self.load_model(model_type=model_type) self.image_predictor = SAM2ImagePredictor(sam_model=self.model) self.image_predictor.set_image(image) try: masks, scores, logits = self.image_predictor.predict( box=box, point_coords=point_coords, point_labels=point_labels, multimask_output=params["multimask_output"], ) except Exception as e: logger.exception("Error while predicting image with prompt") raise f"Error while predicting image with prompt: {str(e)}" return masks, scores, logits def predict_frame(self, frame_idx: int, obj_id: int, inference_state: Dict, points: Optional[np.ndarray] = None, labels: Optional[np.ndarray] = None, box: Optional[np.ndarray] = None): if self.video_predictor is None or self.video_inference_state is None: logger.exception("Error while predicting frame from video, load video predictor first") raise f"Error while predicting frame from video" try: out_frame_idx, out_obj_ids, out_mask_logits = self.video_predictor.add_new_points_or_box( inference_state=inference_state, frame_idx=frame_idx, obj_id=obj_id, points=points, labels=labels, box=box ) except Exception as e: logger.exception("Error while predicting frame with prompt") print(e) raise f"Error while predicting frame with prompt" return out_frame_idx, out_obj_ids, out_mask_logits def predict_video(self, video_input): pass def add_filter_to_preview(self, image_prompt_input_data: Dict, filter_mode: str, frame_idx: int, pixel_size: Optional[int] = None, color_hex: Optional[str] = None, ): if not image_prompt_input_data["points"]: error_message = ("Prompt data is empty! Please provide at least one point or box on the image.
" "If you've already added prompts, please press the eraser button " "and add your prompts again.") logger.error(error_message) raise gr.Error(error_message, duration=20) if self.video_predictor is None or self.video_inference_state is None: logger.exception("Error while adding filter to preview, load video predictor first") raise f"Error while adding filter to preview" image, prompt = image_prompt_input_data["image"], image_prompt_input_data["points"] image = np.array(image.convert("RGB")) point_labels, point_coords, box = self.handle_prompt_data(prompt) if filter_mode == COLOR_FILTER: idx, scores, logits = self.predict_frame( frame_idx=frame_idx, obj_id=0, inference_state=self.video_inference_state, points=point_coords, labels=point_labels, box=box ) masks = (logits[0] > 0.0).cpu().numpy() generated_masks = self.format_to_auto_result(masks) image = create_solid_color_mask_image(image, generated_masks, color_hex) elif filter_mode == PIXELIZE_FILTER: idx, scores, logits = self.predict_frame( frame_idx=frame_idx, obj_id=0, inference_state=self.video_inference_state, points=point_coords, labels=point_labels, box=box ) masks = (logits[0] > 0.0).cpu().numpy() generated_masks = self.format_to_auto_result(masks) image = create_mask_pixelized_image(image, generated_masks, pixel_size) return image def divide_layer(self, image_input: np.ndarray, image_prompt_input_data: Dict, input_mode: str, model_type: str, *params): timestamp = datetime.now().strftime("%m%d%H%M%S") output_file_name = f"result-{timestamp}.psd" output_path = os.path.join(self.output_dir, "psd", output_file_name) # Pre-processed gradio components hparams = { 'points_per_side': int(params[0]), 'points_per_batch': int(params[1]), 'pred_iou_thresh': float(params[2]), 'stability_score_thresh': float(params[3]), 'stability_score_offset': float(params[4]), 'crop_n_layers': int(params[5]), 'box_nms_thresh': float(params[6]), 'crop_n_points_downscale_factor': int(params[7]), 'min_mask_region_area': int(params[8]), 'use_m2m': bool(params[9]), 'multimask_output': bool(params[10]) } if input_mode == AUTOMATIC_MODE: image = image_input generated_masks = self.generate_mask( image=image, model_type=model_type, **hparams ) elif input_mode == BOX_PROMPT_MODE: image = image_prompt_input_data["image"] image = np.array(image.convert("RGB")) prompt = image_prompt_input_data["points"] if len(prompt) == 0: return [image], [] point_labels, point_coords, box = self.handle_prompt_data(prompt) predicted_masks, scores, logits = self.predict_image( image=image, model_type=model_type, box=box, point_coords=point_coords, point_labels=point_labels, multimask_output=hparams["multimask_output"] ) generated_masks = self.format_to_auto_result(predicted_masks) save_psd_with_masks(image, generated_masks, output_path) mask_combined_image = create_mask_combined_images(image, generated_masks) gallery = create_mask_gallery(image, generated_masks) gallery = [mask_combined_image] + gallery return gallery, output_path @staticmethod def format_to_auto_result( masks: np.ndarray ): place_holder = 0 if len(masks.shape) <= 3: masks = np.expand_dims(masks, axis=0) result = [{"segmentation": mask[0], "area": place_holder} for mask in masks] return result @staticmethod def handle_prompt_data( prompt_data: List ): """ Handle data from ImageInputPrompter. Args: prompt_data (Dict): A dictionary containing the 'prompt' key with a list of prompts. Returns: point_labels (List): list of points labels. point_coords (List): list of points coords. box (List): list of box datas. """ point_labels, point_coords, box = [], [], [] for x1, y1, left_click_indicator, x2, y2, point_indicator in prompt_data: is_point = point_indicator == 4.0 if is_point: point_labels.append(left_click_indicator) point_coords.append([x1, y1]) else: box.append([x1, y1, x2, y2]) point_labels = np.array(point_labels) if point_labels else None point_coords = np.array(point_coords) if point_coords else None box = np.array(box) if box else None return point_labels, point_coords, box