import torch from torch import Tensor from torch.nn.functional import affine_grid, grid_sample def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor): image_rgb = image[:, 0:3, :, :] color_change_rgb = color_change[:, 0:3, :, :] output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha) return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1) def apply_grid_change(grid_change, image: Tensor) -> Tensor: n, c, h, w = image.shape device = grid_change.device grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) identity = torch.tensor( [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=grid_change.dtype, device=device).unsqueeze(0).repeat(n, 1, 1) base_grid = affine_grid(identity, [n, c, h, w], align_corners=False) grid = base_grid + grid_change resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False) return resampled_image class GridChangeApplier: def __init__(self): self.last_n = None self.last_device = None self.last_identity = None def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor: n, c, h, w = image.shape device = grid_change.device grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) if n == self.last_n and device == self.last_device: identity = self.last_identity else: identity = torch.tensor( [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=grid_change.dtype, device=device, requires_grad=False) \ .unsqueeze(0).repeat(n, 1, 1) self.last_identity = identity self.last_n = n self.last_device = device base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners) grid = base_grid + grid_change resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners) return resampled_image def apply_color_change(alpha, color_change, image: Tensor) -> Tensor: return color_change * alpha + image * (1 - alpha)