Spaces:
Runtime error
Runtime error
import torch | |
from torch import Tensor | |
from torch.nn.functional import affine_grid, grid_sample | |
def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor): | |
image_rgb = image[:, 0:3, :, :] | |
color_change_rgb = color_change[:, 0:3, :, :] | |
output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha) | |
return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1) | |
def apply_grid_change(grid_change, image: Tensor) -> Tensor: | |
n, c, h, w = image.shape | |
device = grid_change.device | |
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) | |
identity = torch.tensor( | |
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], | |
dtype=grid_change.dtype, | |
device=device).unsqueeze(0).repeat(n, 1, 1) | |
base_grid = affine_grid(identity, [n, c, h, w], align_corners=False) | |
grid = base_grid + grid_change | |
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False) | |
return resampled_image | |
class GridChangeApplier: | |
def __init__(self): | |
self.last_n = None | |
self.last_device = None | |
self.last_identity = None | |
def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor: | |
n, c, h, w = image.shape | |
device = grid_change.device | |
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2) | |
if n == self.last_n and device == self.last_device: | |
identity = self.last_identity | |
else: | |
identity = torch.tensor( | |
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], | |
dtype=grid_change.dtype, | |
device=device, | |
requires_grad=False) \ | |
.unsqueeze(0).repeat(n, 1, 1) | |
self.last_identity = identity | |
self.last_n = n | |
self.last_device = device | |
base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners) | |
grid = base_grid + grid_change | |
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners) | |
return resampled_image | |
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor: | |
return color_change * alpha + image * (1 - alpha) | |