import torch from torch import nn as nn import numpy as np import isegm.model.initializer as initializer def select_activation_function(activation): if isinstance(activation, str): if activation.lower() == 'relu': return nn.ReLU elif activation.lower() == 'softplus': return nn.Softplus else: raise ValueError(f"Unknown activation type {activation}") elif isinstance(activation, nn.Module): return activation else: raise ValueError(f"Unknown activation type {activation}") class BilinearConvTranspose2d(nn.ConvTranspose2d): def __init__(self, in_channels, out_channels, scale, groups=1): kernel_size = 2 * scale - scale % 2 self.scale = scale super().__init__( in_channels, out_channels, kernel_size=kernel_size, stride=scale, padding=1, groups=groups, bias=False) self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) class DistMaps(nn.Module): def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): super(DistMaps, self).__init__() self.spatial_scale = spatial_scale self.norm_radius = norm_radius self.cpu_mode = cpu_mode self.use_disks = use_disks if self.cpu_mode: from isegm.utils.cython import get_dist_maps self._get_dist_maps = get_dist_maps def get_coord_features(self, points, batchsize, rows, cols): if self.cpu_mode: coords = [] for i in range(batchsize): norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols, norm_delimeter)) coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() else: num_points = points.shape[1] // 2 points = points.view(-1, points.size(2)) points, points_order = torch.split(points, [2, 1], dim=1) invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) coord_rows, coord_cols = torch.meshgrid(row_array, col_array) coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) coords.add_(-add_xy) if not self.use_disks: coords.div_(self.norm_radius * self.spatial_scale) coords.mul_(coords) coords[:, 0] += coords[:, 1] coords = coords[:, :1] coords[invalid_points, :, :, :] = 1e6 coords = coords.view(-1, num_points, 1, rows, cols) coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w coords = coords.view(-1, 2, rows, cols) if self.use_disks: coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float() else: coords.sqrt_().mul_(2).tanh_() return coords def forward(self, x, coords): return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) class ScaleLayer(nn.Module): def __init__(self, init_value=1.0, lr_mult=1): super().__init__() self.lr_mult = lr_mult self.scale = nn.Parameter( torch.full((1,), init_value / lr_mult, dtype=torch.float32) ) def forward(self, x): scale = torch.abs(self.scale * self.lr_mult) return x * scale class BatchImageNormalize: def __init__(self, mean, std, dtype=torch.float): self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None] self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None] def __call__(self, tensor): tensor = tensor.clone() tensor.sub_(self.mean.to(tensor.device)).div_(self.std.to(tensor.device)) return tensor