import torch import torch.nn as nn import numpy as np from typing import Any, Dict, List, Tuple from torch.nn import functional as F from functools import partial from isegm.model.ops import DistMaps, BatchImageNormalize, ScaleLayer from isegm.utils.serialization import serialize from .image_encoder import ImageEncoderViT from .image_encoder_lora import ImageEncoderViT_lora from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder from .transformer import TwoWayTransformer class SAMISWrapper(nn.Module): mask_threshold: float = 0.0 image_format: str = "RGB" @serialize def __init__( self, encoder_embed_dim=1280, encoder_depth=32, encoder_num_heads=16, encoder_global_attn_indexes=[7,15,23,31], enable_lora=True, enable_gra=True, pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_std: List[float] = [58.395, 57.12, 57.375], with_aux_output=False, with_prev_mask=False, norm_mean_std=([.485, .456, .406], [.229, .224, .225]), image_size=1024, ): super().__init__() self.with_aux_output = with_aux_output self.with_prev_mask = with_prev_mask self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1]) prompt_embed_dim = 256 image_size = image_size vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size if enable_lora: self.image_encoder = ImageEncoderViT_lora( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ) else: self.image_encoder = ImageEncoderViT( depth=encoder_depth, embed_dim=encoder_embed_dim, img_size=image_size, mlp_ratio=4, norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), num_heads=encoder_num_heads, patch_size=vit_patch_size, qkv_bias=True, use_rel_pos=True, global_attn_indexes=encoder_global_attn_indexes, window_size=14, out_chans=prompt_embed_dim, ) self.prompt_encoder = PromptEncoder( embed_dim=prompt_embed_dim, image_embedding_size=(image_embedding_size, image_embedding_size), input_image_size=(image_size, image_size), mask_in_chans=16, ) self.mask_decoder = MaskDecoder( num_multimask_outputs=3, transformer=TwoWayTransformer( depth=2, embedding_dim=prompt_embed_dim, mlp_dim=2048, num_heads=8, ), transformer_dim=prompt_embed_dim, enable_gra=enable_gra, iou_head_depth=3, iou_head_hidden_dim=256, ) self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) @property def device(self) -> Any: return self.pixel_mean.device def forward(self, image, points, gra=None, multimask_output=False, return_logits=True): image, prev_mask = self.prepare_input(image) point_coords, point_labels = self.get_model_input(points) batched_input = [] for bindx in range(image.shape[0]): batched_input.append( { "image": image[bindx], "point_coords": point_coords[bindx:bindx+1], "point_labels": point_labels[bindx:bindx+1], "mask_inputs": prev_mask[bindx:bindx+1], "gra": gra[bindx] if gra is not None else None, } ) input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) image_embeddings = self.image_encoder(input_images) output_masks = [] output_iou_predictions = [] output_low_res_masks = [] for image_record, curr_embedding in zip(batched_input, image_embeddings): if "point_coords" in image_record: points = (image_record["point_coords"], image_record["point_labels"]) else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=image_record.get("boxes", None), masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions = self.mask_decoder( image_embeddings=curr_embedding.unsqueeze(0), gra=image_record["gra"], image_pe=self.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) masks = self.postprocess_masks( low_res_masks, input_size=image_record["image"].shape[-2:], original_size=image_record["image"].shape[-2:], ) if not return_logits: masks = masks > self.mask_threshold output_masks.append(masks) output_iou_predictions.append(iou_predictions) output_low_res_masks.append(low_res_masks) return { "instances": torch.cat(output_masks, dim=0), "iou_predictions": torch.cat(output_iou_predictions, dim=0), "low_res_logits": torch.cat(output_low_res_masks, dim=0), } def postprocess_masks( self, masks: torch.Tensor, input_size: Tuple[int, ...], original_size: Tuple[int, ...], ) -> torch.Tensor: """ Remove padding and upscale masks to the original image size. Arguments: masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format. input_size (tuple(int, int)): The size of the image input to the model, in (H, W) format. Used to remove padding. original_size (tuple(int, int)): The original size of the image before resizing for input to the model, in (H, W) format. Returns: (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size. """ masks = F.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False, ) masks = masks[..., : input_size[0], : input_size[1]] masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) return masks def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - self.pixel_mean) / self.pixel_std # Pad h, w = x.shape[-2:] padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x def prepare_input(self, image): prev_mask = None if self.with_prev_mask: prev_mask = image[:, 3:, :, :] image = image[:, :3, :, :] image = self.normalization(image) return image, prev_mask def backbone_forward(self, image, coord_features=None): raise NotImplementedError def get_model_input(self, points_nd): device = points_nd.device points_nd = points_nd.cpu().numpy() points_coords = [] points_labels = [] for bindx in range(points_nd.shape[0]): points = points_nd[bindx] point_length = len(points) // 2 point_coords = [] point_labels = [] for i, point in enumerate(points): if point[0] == -1: point_labels.append(-1) else: if i < point_length: point_labels.append(1) else: point_labels.append(0) point_coords.append([point[1], point[0]]) points_coords.append(point_coords) points_labels.append(point_labels) coords_torch = torch.as_tensor(np.array(points_coords), dtype=torch.float, device=device) labels_torch = torch.as_tensor(np.array(points_labels), dtype=torch.int, device=device) return coords_torch, labels_torch def split_points_by_order(tpoints: torch.Tensor, groups): points = tpoints.cpu().numpy() num_groups = len(groups) bs = points.shape[0] num_points = points.shape[1] // 2 groups = [x if x > 0 else num_points for x in groups] group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) for x in groups] last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int_) for group_indx, group_size in enumerate(groups): last_point_indx_group[:, group_indx, 1] = group_size for bindx in range(bs): for pindx in range(2 * num_points): point = points[bindx, pindx, :] group_id = int(point[2]) if group_id < 0: continue is_negative = int(pindx >= num_points) if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click group_id = num_groups - 1 new_point_indx = last_point_indx_group[bindx, group_id, is_negative] last_point_indx_group[bindx, group_id, is_negative] += 1 group_points[group_id][bindx, new_point_indx, :] = point group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) for x in group_points] return group_points