Spaces:
Build error
Build error
""" | |
Copyright (c) Facebook, Inc. and its affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch as th | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class TimeWarperFunction(th.autograd.Function): | |
def forward(ctx, input, warpfield): | |
''' | |
:param ctx: autograd context | |
:param input: input signal (B x 2 x T) | |
:param warpfield: the corresponding warpfield (B x 2 x T) | |
:return: the warped signal (B x 2 x T) | |
''' | |
ctx.save_for_backward(input, warpfield) | |
# compute index list to lookup warped input values | |
idx_left = warpfield.floor().type(th.long) | |
idx_right = th.clamp(warpfield.ceil().type(th.long), max=input.shape[-1]-1) | |
# compute weight for linear interpolation | |
alpha = warpfield - warpfield.floor() | |
# linear interpolation | |
output = (1 - alpha) * th.gather(input, 2, idx_left) + alpha * th.gather(input, 2, idx_right) | |
return output | |
def backward(ctx, grad_output): | |
input, warpfield = ctx.saved_tensors | |
# compute index list to lookup warped input values | |
idx_left = warpfield.floor().type(th.long) | |
idx_right = th.clamp(warpfield.ceil().type(th.long), max=input.shape[-1]-1) | |
# warpfield gradient | |
grad_warpfield = th.gather(input, 2, idx_right) - th.gather(input, 2, idx_left) | |
grad_warpfield = grad_output * grad_warpfield | |
# input gradient | |
grad_input = th.zeros(input.shape, device=input.device) | |
alpha = warpfield - warpfield.floor() | |
grad_input = grad_input.scatter_add(2, idx_left, grad_output * (1 - alpha)) + \ | |
grad_input.scatter_add(2, idx_right, grad_output * alpha) | |
return grad_input, grad_warpfield | |
class TimeWarper(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.warper = TimeWarperFunction().apply | |
def _to_absolute_positions(self, warpfield, seq_length): | |
# translate warpfield from relative warp indices to absolute indices ([1...T] + warpfield) | |
temp_range = th.arange(seq_length, dtype=th.float) | |
temp_range = temp_range.cuda() if warpfield.is_cuda else temp_range | |
return th.clamp(warpfield + temp_range[None, None, :], min=0, max=seq_length-1) | |
def forward(self, input, warpfield): | |
''' | |
:param input: audio signal to be warped (B x 2 x T) | |
:param warpfield: the corresponding warpfield (B x 2 x T) | |
:return: the warped signal (B x 2 x T) | |
''' | |
warpfield = self._to_absolute_positions(warpfield, input.shape[-1]) | |
warped = self.warper(input, warpfield) | |
return warped | |
class MonotoneTimeWarper(TimeWarper): | |
def forward(self, input, warpfield): | |
''' | |
:param input: audio signal to be warped (B x 2 x T) | |
:param warpfield: the corresponding warpfield (B x 2 x T) | |
:return: the warped signal (B x 2 x T), ensured to be monotonous | |
''' | |
warpfield = self._to_absolute_positions(warpfield, input.shape[-1]) | |
# ensure monotonicity: each warp must be at least as big as previous_warp-1 | |
warpfield = th.cummax(warpfield, dim=-1)[0] | |
# print('warpfield ',warpfield.shape) | |
# warp | |
warped = self.warper(input, warpfield) | |
return warped | |
class GeometricTimeWarper(TimeWarper): | |
def __init__(self, sampling_rate=48000): | |
super().__init__() | |
self.sampling_rate = sampling_rate | |
def displacements2warpfield(self, displacements, seq_length): | |
distance = th.sum(displacements**2, dim=2) ** 0.5 | |
distance = F.interpolate(distance, size=seq_length) | |
warpfield = -distance / 343.0 * self.sampling_rate | |
return warpfield | |
def forward(self, input, displacements): | |
''' | |
:param input: audio signal to be warped (B x 2 x T) | |
:param displacements: sequence of 3D displacement vectors for geometric warping (B x 3 x T) | |
:return: the warped signal (B x 2 x T) | |
''' | |
warpfield = self.displacements2warpfield(displacements, input.shape[-1]) | |
# print('Ge warpfield ', warpfield.shape) | |
# assert 1==2 | |
warped = super().forward(input, warpfield) | |
return warped | |