import math import torch import numpy as np from typing import List from isegm.inference.clicker import Click from .base import BaseTransform class Crops(BaseTransform): def __init__(self, crop_size=(320, 480), min_overlap=0.2): super().__init__() self.crop_height, self.crop_width = crop_size self.min_overlap = min_overlap self.x_offsets = None self.y_offsets = None self._counts = None def transform(self, image_nd, clicks_lists: List[List[Click]]): assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 image_height, image_width = image_nd.shape[2:4] self._counts = None if image_height < self.crop_height or image_width < self.crop_width: return image_nd, clicks_lists self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) self._counts = np.zeros((image_height, image_width)) image_crops = [] for dy in self.y_offsets: for dx in self.x_offsets: self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] image_crops.append(image_crop) image_crops = torch.cat(image_crops, dim=0) self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) clicks_list = clicks_lists[0] clicks_lists = [] for dy in self.y_offsets: for dx in self.x_offsets: crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list] clicks_lists.append(crop_clicks) return image_crops, clicks_lists def inv_transform(self, prob_map): if self._counts is None: return prob_map new_prob_map = torch.zeros((1, 1, *self._counts.shape), dtype=prob_map.dtype, device=prob_map.device) crop_indx = 0 for dy in self.y_offsets: for dx in self.x_offsets: new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] crop_indx += 1 new_prob_map = torch.div(new_prob_map, self._counts) return new_prob_map def get_state(self): return self.x_offsets, self.y_offsets, self._counts def set_state(self, state): self.x_offsets, self.y_offsets, self._counts = state def reset(self): self.x_offsets = None self.y_offsets = None self._counts = None def get_offsets(length, crop_size, min_overlap_ratio=0.2): if length == crop_size: return [0] N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) N = math.ceil(N) overlap_ratio = (N - length / crop_size) / (N - 1) overlap_width = int(crop_size * overlap_ratio) offsets = [0] for i in range(1, N): new_offset = offsets[-1] + crop_size - overlap_width if new_offset + crop_size > length: new_offset = length - crop_size offsets.append(new_offset) return offsets