import numpy as np import scipy.linalg from scipy.spatial.transform import Rotation as R import torch as th import torch.nn as nn import torch.nn.functional as F from src.warping import GeometricTimeWarper, MonotoneTimeWarper from src.utils import Net class GeometricWarper(nn.Module): def __init__(self, sampling_rate=48000): super().__init__() self.warper = GeometricTimeWarper(sampling_rate=sampling_rate) def _transmitter_mouth(self, view): # offset between tracking markers and real mouth position in the dataset mouth_offset = np.array([0.09, 0, -0.20]) quat = view[:, 3:, :].transpose(2, 1).contiguous().detach().cpu().view(-1, 4).numpy() # make sure zero-padded values are set to non-zero values (else scipy raises an exception) norms = scipy.linalg.norm(quat, axis=1) eps_val = (norms == 0).astype(np.float32) quat = quat + eps_val[:, None] transmitter_rot_mat = R.from_quat(quat) transmitter_mouth = transmitter_rot_mat.apply(mouth_offset, inverse=True) transmitter_mouth = th.Tensor(transmitter_mouth).view(view.shape[0], -1, 3).transpose(2, 1).contiguous() if view.is_cuda: transmitter_mouth = transmitter_mouth.cuda() return transmitter_mouth def _3d_displacements(self, view): transmitter_mouth = self._transmitter_mouth(view) # offset between tracking markers and ears in the dataset left_ear_offset = th.Tensor([0, -0.08, -0.22]).cuda() if view.is_cuda else th.Tensor([0, -0.08, -0.22]) right_ear_offset = th.Tensor([0, 0.08, -0.22]).cuda() if view.is_cuda else th.Tensor([0, 0.08, -0.22]) # compute displacements between transmitter mouth and receiver left/right ear displacement_left = view[:, 0:3, :] + transmitter_mouth - left_ear_offset[None, :, None] displacement_right = view[:, 0:3, :] + transmitter_mouth - right_ear_offset[None, :, None] displacement = th.stack([displacement_left, displacement_right], dim=1) return displacement def _warpfield(self, view, seq_length): return self.warper.displacements2warpfield(self._3d_displacements(view), seq_length) def forward(self, mono, view): ''' :param mono: input signal as tensor of shape B x 1 x T :param view: rx/tx position/orientation as tensor of shape B x 7 x K (K = T / 400) :return: warped: warped left/right ear signal as tensor of shape B x 2 x T ''' return self.warper(th.cat([mono, mono], dim=1), self._3d_displacements(view)) class Warpnet(nn.Module): def __init__(self, layers=4, channels=64, view_dim=7): super().__init__() self.layers = [nn.Conv1d(view_dim if l == 0 else channels, channels, kernel_size=2) for l in range(layers)] self.layers = nn.ModuleList(self.layers) self.linear = nn.Conv1d(channels, 2, kernel_size=1) self.neural_warper = MonotoneTimeWarper() self.geometric_warper = GeometricWarper() def neural_warpfield(self, view, seq_length): warpfield = view for layer in self.layers: warpfield = F.pad(warpfield, pad=[1, 0]) warpfield = F.relu(layer(warpfield)) warpfield = self.linear(warpfield) warpfield = F.interpolate(warpfield, size=seq_length) return warpfield def forward(self, mono, view): ''' :param mono: input signal as tensor of shape B x 1 x T :param view: rx/tx position/orientation as tensor of shape B x 7 x K (K = T / 400) :return: warped: warped left/right ear signal as tensor of shape B x 2 x T ''' geometric_warpfield = self.geometric_warper._warpfield(view, mono.shape[-1]) neural_warpfield = self.neural_warpfield(view, mono.shape[-1]) warpfield = geometric_warpfield + neural_warpfield # ensure causality warpfield = -F.relu(-warpfield) # the predicted warp warped = self.neural_warper(th.cat([mono, mono], dim=1), warpfield) return warped class BinauralNetwork(Net): def __init__(self, view_dim=7, warpnet_layers=4, warpnet_channels=64, model_name='binaural_network', use_cuda=True): super().__init__(model_name, use_cuda) self.warper = Warpnet(warpnet_layers, warpnet_channels) if self.use_cuda: self.cuda() def forward(self, mono, view): ''' :param mono: the input signal as a B x 1 x T tensor :param view: the receiver/transmitter position as a B x 7 x T tensor :return: out: the binaural output produced by the network intermediate: a two-channel audio signal obtained from the output of each intermediate layer as a list of B x 2 x T tensors ''' # print('mono ', mono.shape) # print('view ', view.shape) warped = self.warper(mono, view) # print('warped ', warped.shape) return warped