Talking_Head_Anime_3 / tha3 /nn /image_processing_util.py
Harry_FBK
Clone original THA3
60094bd
raw
history blame
No virus
2.28 kB
import torch
from torch import Tensor
from torch.nn.functional import affine_grid, grid_sample
def apply_rgb_change(alpha: Tensor, color_change: Tensor, image: Tensor):
image_rgb = image[:, 0:3, :, :]
color_change_rgb = color_change[:, 0:3, :, :]
output_rgb = color_change_rgb * alpha + image_rgb * (1 - alpha)
return torch.cat([output_rgb, image[:, 3:4, :, :]], dim=1)
def apply_grid_change(grid_change, image: Tensor) -> Tensor:
n, c, h, w = image.shape
device = grid_change.device
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
identity = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
dtype=grid_change.dtype,
device=device).unsqueeze(0).repeat(n, 1, 1)
base_grid = affine_grid(identity, [n, c, h, w], align_corners=False)
grid = base_grid + grid_change
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=False)
return resampled_image
class GridChangeApplier:
def __init__(self):
self.last_n = None
self.last_device = None
self.last_identity = None
def apply(self, grid_change: Tensor, image: Tensor, align_corners: bool = False) -> Tensor:
n, c, h, w = image.shape
device = grid_change.device
grid_change = torch.transpose(grid_change.view(n, 2, h * w), 1, 2).view(n, h, w, 2)
if n == self.last_n and device == self.last_device:
identity = self.last_identity
else:
identity = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
dtype=grid_change.dtype,
device=device,
requires_grad=False) \
.unsqueeze(0).repeat(n, 1, 1)
self.last_identity = identity
self.last_n = n
self.last_device = device
base_grid = affine_grid(identity, [n, c, h, w], align_corners=align_corners)
grid = base_grid + grid_change
resampled_image = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=align_corners)
return resampled_image
def apply_color_change(alpha, color_change, image: Tensor) -> Tensor:
return color_change * alpha + image * (1 - alpha)