Spaces:
Build error
Build error
Upload 12 files
Browse files- mono2binaural/src/__pycache__/models.cpython-38.pyc +0 -0
- mono2binaural/src/__pycache__/utils.cpython-38.pyc +0 -0
- mono2binaural/src/__pycache__/warping.cpython-38.pyc +0 -0
- mono2binaural/src/models.py +110 -0
- mono2binaural/src/utils.py +251 -0
- mono2binaural/src/warping.py +113 -0
- mono2binaural/useful_ckpts/m2b/binaural_network.net +0 -0
- mono2binaural/useful_ckpts/m2b/tx_positions.txt +0 -0
- mono2binaural/useful_ckpts/m2b/tx_positions2.txt +0 -0
- mono2binaural/useful_ckpts/m2b/tx_positions3.txt +0 -0
- mono2binaural/useful_ckpts/m2b/tx_positions4.txt +0 -0
- mono2binaural/useful_ckpts/m2b/tx_positions5.txt +0 -0
mono2binaural/src/__pycache__/models.cpython-38.pyc
ADDED
Binary file (5.12 kB). View file
|
|
mono2binaural/src/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.54 kB). View file
|
|
mono2binaural/src/__pycache__/warping.cpython-38.pyc
ADDED
Binary file (4.47 kB). View file
|
|
mono2binaural/src/models.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.linalg
|
3 |
+
from scipy.spatial.transform import Rotation as R
|
4 |
+
import torch as th
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from src.warping import GeometricTimeWarper, MonotoneTimeWarper
|
8 |
+
from src.utils import Net
|
9 |
+
|
10 |
+
|
11 |
+
class GeometricWarper(nn.Module):
|
12 |
+
def __init__(self, sampling_rate=48000):
|
13 |
+
super().__init__()
|
14 |
+
self.warper = GeometricTimeWarper(sampling_rate=sampling_rate)
|
15 |
+
|
16 |
+
def _transmitter_mouth(self, view):
|
17 |
+
# offset between tracking markers and real mouth position in the dataset
|
18 |
+
mouth_offset = np.array([0.09, 0, -0.20])
|
19 |
+
quat = view[:, 3:, :].transpose(2, 1).contiguous().detach().cpu().view(-1, 4).numpy()
|
20 |
+
# make sure zero-padded values are set to non-zero values (else scipy raises an exception)
|
21 |
+
norms = scipy.linalg.norm(quat, axis=1)
|
22 |
+
eps_val = (norms == 0).astype(np.float32)
|
23 |
+
quat = quat + eps_val[:, None]
|
24 |
+
transmitter_rot_mat = R.from_quat(quat)
|
25 |
+
transmitter_mouth = transmitter_rot_mat.apply(mouth_offset, inverse=True)
|
26 |
+
transmitter_mouth = th.Tensor(transmitter_mouth).view(view.shape[0], -1, 3).transpose(2, 1).contiguous()
|
27 |
+
if view.is_cuda:
|
28 |
+
transmitter_mouth = transmitter_mouth.cuda()
|
29 |
+
return transmitter_mouth
|
30 |
+
|
31 |
+
def _3d_displacements(self, view):
|
32 |
+
transmitter_mouth = self._transmitter_mouth(view)
|
33 |
+
# offset between tracking markers and ears in the dataset
|
34 |
+
left_ear_offset = th.Tensor([0, -0.08, -0.22]).cuda() if view.is_cuda else th.Tensor([0, -0.08, -0.22])
|
35 |
+
right_ear_offset = th.Tensor([0, 0.08, -0.22]).cuda() if view.is_cuda else th.Tensor([0, 0.08, -0.22])
|
36 |
+
# compute displacements between transmitter mouth and receiver left/right ear
|
37 |
+
displacement_left = view[:, 0:3, :] + transmitter_mouth - left_ear_offset[None, :, None]
|
38 |
+
displacement_right = view[:, 0:3, :] + transmitter_mouth - right_ear_offset[None, :, None]
|
39 |
+
displacement = th.stack([displacement_left, displacement_right], dim=1)
|
40 |
+
return displacement
|
41 |
+
|
42 |
+
def _warpfield(self, view, seq_length):
|
43 |
+
return self.warper.displacements2warpfield(self._3d_displacements(view), seq_length)
|
44 |
+
|
45 |
+
def forward(self, mono, view):
|
46 |
+
'''
|
47 |
+
:param mono: input signal as tensor of shape B x 1 x T
|
48 |
+
:param view: rx/tx position/orientation as tensor of shape B x 7 x K (K = T / 400)
|
49 |
+
:return: warped: warped left/right ear signal as tensor of shape B x 2 x T
|
50 |
+
'''
|
51 |
+
return self.warper(th.cat([mono, mono], dim=1), self._3d_displacements(view))
|
52 |
+
|
53 |
+
|
54 |
+
class Warpnet(nn.Module):
|
55 |
+
def __init__(self, layers=4, channels=64, view_dim=7):
|
56 |
+
super().__init__()
|
57 |
+
self.layers = [nn.Conv1d(view_dim if l == 0 else channels, channels, kernel_size=2) for l in range(layers)]
|
58 |
+
self.layers = nn.ModuleList(self.layers)
|
59 |
+
self.linear = nn.Conv1d(channels, 2, kernel_size=1)
|
60 |
+
self.neural_warper = MonotoneTimeWarper()
|
61 |
+
self.geometric_warper = GeometricWarper()
|
62 |
+
|
63 |
+
def neural_warpfield(self, view, seq_length):
|
64 |
+
warpfield = view
|
65 |
+
for layer in self.layers:
|
66 |
+
warpfield = F.pad(warpfield, pad=[1, 0])
|
67 |
+
warpfield = F.relu(layer(warpfield))
|
68 |
+
warpfield = self.linear(warpfield)
|
69 |
+
warpfield = F.interpolate(warpfield, size=seq_length)
|
70 |
+
return warpfield
|
71 |
+
|
72 |
+
def forward(self, mono, view):
|
73 |
+
'''
|
74 |
+
:param mono: input signal as tensor of shape B x 1 x T
|
75 |
+
:param view: rx/tx position/orientation as tensor of shape B x 7 x K (K = T / 400)
|
76 |
+
:return: warped: warped left/right ear signal as tensor of shape B x 2 x T
|
77 |
+
'''
|
78 |
+
geometric_warpfield = self.geometric_warper._warpfield(view, mono.shape[-1])
|
79 |
+
neural_warpfield = self.neural_warpfield(view, mono.shape[-1])
|
80 |
+
warpfield = geometric_warpfield + neural_warpfield
|
81 |
+
# ensure causality
|
82 |
+
warpfield = -F.relu(-warpfield) # the predicted warp
|
83 |
+
warped = self.neural_warper(th.cat([mono, mono], dim=1), warpfield)
|
84 |
+
return warped
|
85 |
+
|
86 |
+
class BinauralNetwork(Net):
|
87 |
+
def __init__(self,
|
88 |
+
view_dim=7,
|
89 |
+
warpnet_layers=4,
|
90 |
+
warpnet_channels=64,
|
91 |
+
model_name='binaural_network',
|
92 |
+
use_cuda=True):
|
93 |
+
super().__init__(model_name, use_cuda)
|
94 |
+
self.warper = Warpnet(warpnet_layers, warpnet_channels)
|
95 |
+
if self.use_cuda:
|
96 |
+
self.cuda()
|
97 |
+
|
98 |
+
def forward(self, mono, view):
|
99 |
+
'''
|
100 |
+
:param mono: the input signal as a B x 1 x T tensor
|
101 |
+
:param view: the receiver/transmitter position as a B x 7 x T tensor
|
102 |
+
:return: out: the binaural output produced by the network
|
103 |
+
intermediate: a two-channel audio signal obtained from the output of each intermediate layer
|
104 |
+
as a list of B x 2 x T tensors
|
105 |
+
'''
|
106 |
+
# print('mono ', mono.shape)
|
107 |
+
# print('view ', view.shape)
|
108 |
+
warped = self.warper(mono, view)
|
109 |
+
# print('warped ', warped.shape)
|
110 |
+
return warped
|
mono2binaural/src/utils.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
|
5 |
+
This source code is licensed under the license found in the
|
6 |
+
LICENSE file in the root directory of this source tree.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch as th
|
11 |
+
#import torchaudio as ta
|
12 |
+
|
13 |
+
|
14 |
+
class Net(th.nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, model_name="network", use_cuda=True):
|
17 |
+
super().__init__()
|
18 |
+
self.use_cuda = use_cuda
|
19 |
+
self.model_name = model_name
|
20 |
+
|
21 |
+
def save(self, model_dir, suffix=''):
|
22 |
+
'''
|
23 |
+
save the network to model_dir/model_name.suffix.net
|
24 |
+
:param model_dir: directory to save the model to
|
25 |
+
:param suffix: suffix to append after model name
|
26 |
+
'''
|
27 |
+
if self.use_cuda:
|
28 |
+
self.cpu()
|
29 |
+
|
30 |
+
if suffix == "":
|
31 |
+
fname = f"{model_dir}/{self.model_name}.net"
|
32 |
+
else:
|
33 |
+
fname = f"{model_dir}/{self.model_name}.{suffix}.net"
|
34 |
+
|
35 |
+
th.save(self.state_dict(), fname)
|
36 |
+
if self.use_cuda:
|
37 |
+
self.cuda()
|
38 |
+
|
39 |
+
def load_from_file(self, model_file):
|
40 |
+
'''
|
41 |
+
load network parameters from model_file
|
42 |
+
:param model_file: file containing the model parameters
|
43 |
+
'''
|
44 |
+
if self.use_cuda:
|
45 |
+
self.cpu()
|
46 |
+
|
47 |
+
states = th.load(model_file)
|
48 |
+
self.load_state_dict(states)
|
49 |
+
|
50 |
+
if self.use_cuda:
|
51 |
+
self.cuda()
|
52 |
+
print(f"Loaded: {model_file}")
|
53 |
+
|
54 |
+
def load(self, model_dir, suffix=''):
|
55 |
+
'''
|
56 |
+
load network parameters from model_dir/model_name.suffix.net
|
57 |
+
:param model_dir: directory to load the model from
|
58 |
+
:param suffix: suffix to append after model name
|
59 |
+
'''
|
60 |
+
if suffix == "":
|
61 |
+
fname = f"{model_dir}/{self.model_name}.net"
|
62 |
+
else:
|
63 |
+
fname = f"{model_dir}/{self.model_name}.{suffix}.net"
|
64 |
+
self.load_from_file(fname)
|
65 |
+
|
66 |
+
def num_trainable_parameters(self):
|
67 |
+
'''
|
68 |
+
:return: the number of trainable parameters in the model
|
69 |
+
'''
|
70 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
71 |
+
|
72 |
+
|
73 |
+
# class NewbobAdam(th.optim.Adam):
|
74 |
+
|
75 |
+
# def __init__(self,
|
76 |
+
# weights,
|
77 |
+
# net,
|
78 |
+
# artifacts_dir,
|
79 |
+
# initial_learning_rate=0.001,
|
80 |
+
# decay=0.5,
|
81 |
+
# max_decay=0.01
|
82 |
+
# ):
|
83 |
+
# '''
|
84 |
+
# Newbob learning rate scheduler
|
85 |
+
# :param weights: weights to optimize
|
86 |
+
# :param net: the network, must be an instance of type src.utils.Net
|
87 |
+
# :param artifacts_dir: (str) directory to save/restore models to/from
|
88 |
+
# :param initial_learning_rate: (float) initial learning rate
|
89 |
+
# :param decay: (float) value to decrease learning rate by when loss doesn't improve further
|
90 |
+
# :param max_decay: (float) maximum decay of learning rate
|
91 |
+
# '''
|
92 |
+
# super().__init__(weights, lr=initial_learning_rate)
|
93 |
+
# self.last_epoch_loss = np.inf
|
94 |
+
# self.total_decay = 1
|
95 |
+
# self.net = net
|
96 |
+
# self.decay = decay
|
97 |
+
# self.max_decay = max_decay
|
98 |
+
# self.artifacts_dir = artifacts_dir
|
99 |
+
# # store initial state as backup
|
100 |
+
# if decay < 1.0:
|
101 |
+
# net.save(artifacts_dir, suffix="newbob")
|
102 |
+
|
103 |
+
# def update_lr(self, loss):
|
104 |
+
# '''
|
105 |
+
# update the learning rate based on the current loss value and historic loss values
|
106 |
+
# :param loss: the loss after the current iteration
|
107 |
+
# '''
|
108 |
+
# if loss > self.last_epoch_loss and self.decay < 1.0 and self.total_decay > self.max_decay:
|
109 |
+
# self.total_decay = self.total_decay * self.decay
|
110 |
+
# print(f"NewbobAdam: Decay learning rate (loss degraded from {self.last_epoch_loss} to {loss})."
|
111 |
+
# f"Total decay: {self.total_decay}")
|
112 |
+
# # restore previous network state
|
113 |
+
# self.net.load(self.artifacts_dir, suffix="newbob")
|
114 |
+
# # decrease learning rate
|
115 |
+
# for param_group in self.param_groups:
|
116 |
+
# param_group['lr'] = param_group['lr'] * self.decay
|
117 |
+
# else:
|
118 |
+
# self.last_epoch_loss = loss
|
119 |
+
# # save last snapshot to restore it in case of lr decrease
|
120 |
+
# if self.decay < 1.0 and self.total_decay > self.max_decay:
|
121 |
+
# self.net.save(self.artifacts_dir, suffix="newbob")
|
122 |
+
|
123 |
+
|
124 |
+
# class FourierTransform:
|
125 |
+
# def __init__(self,
|
126 |
+
# fft_bins=2048,
|
127 |
+
# win_length_ms=40,
|
128 |
+
# frame_rate_hz=100,
|
129 |
+
# causal=False,
|
130 |
+
# preemphasis=0.0,
|
131 |
+
# sample_rate=48000,
|
132 |
+
# normalized=False):
|
133 |
+
# self.sample_rate = sample_rate
|
134 |
+
# self.frame_rate_hz = frame_rate_hz
|
135 |
+
# self.preemphasis = preemphasis
|
136 |
+
# self.fft_bins = fft_bins
|
137 |
+
# self.win_length = int(sample_rate * win_length_ms / 1000)
|
138 |
+
# self.hop_length = int(sample_rate / frame_rate_hz)
|
139 |
+
# self.causal = causal
|
140 |
+
# self.normalized = normalized
|
141 |
+
# if self.win_length > self.fft_bins:
|
142 |
+
# print('FourierTransform Warning: fft_bins should be larger than win_length')
|
143 |
+
|
144 |
+
# def _convert_format(self, data, expected_dims):
|
145 |
+
# if not type(data) == th.Tensor:
|
146 |
+
# data = th.Tensor(data)
|
147 |
+
# if len(data.shape) < expected_dims:
|
148 |
+
# data = data.unsqueeze(0)
|
149 |
+
# if not len(data.shape) == expected_dims:
|
150 |
+
# raise Exception(f"FourierTransform: data needs to be a Tensor with {expected_dims} dimensions but got shape {data.shape}")
|
151 |
+
# return data
|
152 |
+
|
153 |
+
# def _preemphasis(self, audio):
|
154 |
+
# if self.preemphasis > 0:
|
155 |
+
# return th.cat((audio[:, 0:1], audio[:, 1:] - self.preemphasis * audio[:, :-1]), dim=1)
|
156 |
+
# return audio
|
157 |
+
|
158 |
+
# def _revert_preemphasis(self, audio):
|
159 |
+
# if self.preemphasis > 0:
|
160 |
+
# for i in range(1, audio.shape[1]):
|
161 |
+
# audio[:, i] = audio[:, i] + self.preemphasis * audio[:, i-1]
|
162 |
+
# return audio
|
163 |
+
|
164 |
+
# def _magphase(self, complex_stft):
|
165 |
+
# mag, phase = ta.functional.magphase(complex_stft, 1.0)
|
166 |
+
# return mag, phase
|
167 |
+
|
168 |
+
# def stft(self, audio):
|
169 |
+
# '''
|
170 |
+
# wrapper around th.stft
|
171 |
+
# audio: wave signal as th.Tensor
|
172 |
+
# '''
|
173 |
+
# hann = th.hann_window(self.win_length)
|
174 |
+
# hann = hann.cuda() if audio.is_cuda else hann
|
175 |
+
# spec = th.stft(audio, n_fft=self.fft_bins, hop_length=self.hop_length, win_length=self.win_length,
|
176 |
+
# window=hann, center=not self.causal, normalized=self.normalized)
|
177 |
+
# return spec.contiguous()
|
178 |
+
|
179 |
+
# def complex_spectrogram(self, audio):
|
180 |
+
# '''
|
181 |
+
# audio: wave signal as th.Tensor
|
182 |
+
# return: th.Tensor of size channels x frequencies x time_steps (channels x y_axis x x_axis)
|
183 |
+
# '''
|
184 |
+
# self._convert_format(audio, expected_dims=2)
|
185 |
+
# audio = self._preemphasis(audio)
|
186 |
+
# return self.stft(audio)
|
187 |
+
|
188 |
+
# def magnitude_phase(self, audio):
|
189 |
+
# '''
|
190 |
+
# audio: wave signal as th.Tensor
|
191 |
+
# return: tuple containing two th.Tensor of size channels x frequencies x time_steps for magnitude and phase spectrum
|
192 |
+
# '''
|
193 |
+
# stft = self.complex_spectrogram(audio)
|
194 |
+
# return self._magphase(stft)
|
195 |
+
|
196 |
+
# def mag_spectrogram(self, audio):
|
197 |
+
# '''
|
198 |
+
# audio: wave signal as th.Tensor
|
199 |
+
# return: magnitude spectrum as th.Tensor of size channels x frequencies x time_steps for magnitude and phase spectrum
|
200 |
+
# '''
|
201 |
+
# return self.magnitude_phase(audio)[0]
|
202 |
+
|
203 |
+
# def power_spectrogram(self, audio):
|
204 |
+
# '''
|
205 |
+
# audio: wave signal as th.Tensor
|
206 |
+
# return: power spectrum as th.Tensor of size channels x frequencies x time_steps for magnitude and phase spectrum
|
207 |
+
# '''
|
208 |
+
# return th.pow(self.mag_spectrogram(audio), 2.0)
|
209 |
+
|
210 |
+
# def phase_spectrogram(self, audio):
|
211 |
+
# '''
|
212 |
+
# audio: wave signal as th.Tensor
|
213 |
+
# return: phase spectrum as th.Tensor of size channels x frequencies x time_steps for magnitude and phase spectrum
|
214 |
+
# '''
|
215 |
+
# return self.magnitude_phase(audio)[1]
|
216 |
+
|
217 |
+
# def mel_spectrogram(self, audio, n_mels):
|
218 |
+
# '''
|
219 |
+
# audio: wave signal as th.Tensor
|
220 |
+
# n_mels: number of bins used for mel scale warping
|
221 |
+
# return: mel spectrogram as th.Tensor of size channels x n_mels x time_steps for magnitude and phase spectrum
|
222 |
+
# '''
|
223 |
+
# spec = self.power_spectrogram(audio)
|
224 |
+
# mel_warping = ta.transforms.MelScale(n_mels, self.sample_rate)
|
225 |
+
# return mel_warping(spec)
|
226 |
+
|
227 |
+
# def complex_spec2wav(self, complex_spec, length):
|
228 |
+
# '''
|
229 |
+
# inverse stft
|
230 |
+
# complex_spec: complex spectrum as th.Tensor of size channels x frequencies x time_steps x 2 (real part/imaginary part)
|
231 |
+
# length: length of the audio to be reconstructed (in frames)
|
232 |
+
# '''
|
233 |
+
# complex_spec = self._convert_format(complex_spec, expected_dims=4)
|
234 |
+
# hann = th.hann_window(self.win_length)
|
235 |
+
# hann = hann.cuda() if complex_spec.is_cuda else hann
|
236 |
+
# wav = ta.functional.istft(complex_spec, n_fft=self.fft_bins, hop_length=self.hop_length, win_length=self.win_length, window=hann, length=length, center=not self.causal)
|
237 |
+
# wav = self._revert_preemphasis(wav)
|
238 |
+
# return wav
|
239 |
+
|
240 |
+
# def magphase2wav(self, mag_spec, phase_spec, length):
|
241 |
+
# '''
|
242 |
+
# reconstruction of wav signal from magnitude and phase spectrum
|
243 |
+
# mag_spec: magnitude spectrum as th.Tensor of size channels x frequencies x time_steps
|
244 |
+
# phase_spec: phase spectrum as th.Tensor of size channels x frequencies x time_steps
|
245 |
+
# length: length of the audio to be reconstructed (in frames)
|
246 |
+
# '''
|
247 |
+
# mag_spec = self._convert_format(mag_spec, expected_dims=3)
|
248 |
+
# phase_spec = self._convert_format(phase_spec, expected_dims=3)
|
249 |
+
# complex_spec = th.stack([mag_spec * th.cos(phase_spec), mag_spec * th.sin(phase_spec)], dim=-1)
|
250 |
+
# return self.complex_spec2wav(complex_spec, length)
|
251 |
+
|
mono2binaural/src/warping.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
All rights reserved.
|
4 |
+
|
5 |
+
This source code is licensed under the license found in the
|
6 |
+
LICENSE file in the root directory of this source tree.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch as th
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class TimeWarperFunction(th.autograd.Function):
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def forward(ctx, input, warpfield):
|
18 |
+
'''
|
19 |
+
:param ctx: autograd context
|
20 |
+
:param input: input signal (B x 2 x T)
|
21 |
+
:param warpfield: the corresponding warpfield (B x 2 x T)
|
22 |
+
:return: the warped signal (B x 2 x T)
|
23 |
+
'''
|
24 |
+
ctx.save_for_backward(input, warpfield)
|
25 |
+
# compute index list to lookup warped input values
|
26 |
+
idx_left = warpfield.floor().type(th.long)
|
27 |
+
idx_right = th.clamp(warpfield.ceil().type(th.long), max=input.shape[-1]-1)
|
28 |
+
# compute weight for linear interpolation
|
29 |
+
alpha = warpfield - warpfield.floor()
|
30 |
+
# linear interpolation
|
31 |
+
output = (1 - alpha) * th.gather(input, 2, idx_left) + alpha * th.gather(input, 2, idx_right)
|
32 |
+
return output
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def backward(ctx, grad_output):
|
36 |
+
input, warpfield = ctx.saved_tensors
|
37 |
+
# compute index list to lookup warped input values
|
38 |
+
idx_left = warpfield.floor().type(th.long)
|
39 |
+
idx_right = th.clamp(warpfield.ceil().type(th.long), max=input.shape[-1]-1)
|
40 |
+
# warpfield gradient
|
41 |
+
grad_warpfield = th.gather(input, 2, idx_right) - th.gather(input, 2, idx_left)
|
42 |
+
grad_warpfield = grad_output * grad_warpfield
|
43 |
+
# input gradient
|
44 |
+
grad_input = th.zeros(input.shape, device=input.device)
|
45 |
+
alpha = warpfield - warpfield.floor()
|
46 |
+
grad_input = grad_input.scatter_add(2, idx_left, grad_output * (1 - alpha)) + \
|
47 |
+
grad_input.scatter_add(2, idx_right, grad_output * alpha)
|
48 |
+
return grad_input, grad_warpfield
|
49 |
+
|
50 |
+
|
51 |
+
class TimeWarper(nn.Module):
|
52 |
+
|
53 |
+
def __init__(self):
|
54 |
+
super().__init__()
|
55 |
+
self.warper = TimeWarperFunction().apply
|
56 |
+
|
57 |
+
def _to_absolute_positions(self, warpfield, seq_length):
|
58 |
+
# translate warpfield from relative warp indices to absolute indices ([1...T] + warpfield)
|
59 |
+
temp_range = th.arange(seq_length, dtype=th.float)
|
60 |
+
temp_range = temp_range.cuda() if warpfield.is_cuda else temp_range
|
61 |
+
return th.clamp(warpfield + temp_range[None, None, :], min=0, max=seq_length-1)
|
62 |
+
|
63 |
+
def forward(self, input, warpfield):
|
64 |
+
'''
|
65 |
+
:param input: audio signal to be warped (B x 2 x T)
|
66 |
+
:param warpfield: the corresponding warpfield (B x 2 x T)
|
67 |
+
:return: the warped signal (B x 2 x T)
|
68 |
+
'''
|
69 |
+
warpfield = self._to_absolute_positions(warpfield, input.shape[-1])
|
70 |
+
warped = self.warper(input, warpfield)
|
71 |
+
return warped
|
72 |
+
|
73 |
+
|
74 |
+
class MonotoneTimeWarper(TimeWarper):
|
75 |
+
|
76 |
+
def forward(self, input, warpfield):
|
77 |
+
'''
|
78 |
+
:param input: audio signal to be warped (B x 2 x T)
|
79 |
+
:param warpfield: the corresponding warpfield (B x 2 x T)
|
80 |
+
:return: the warped signal (B x 2 x T), ensured to be monotonous
|
81 |
+
'''
|
82 |
+
warpfield = self._to_absolute_positions(warpfield, input.shape[-1])
|
83 |
+
# ensure monotonicity: each warp must be at least as big as previous_warp-1
|
84 |
+
warpfield = th.cummax(warpfield, dim=-1)[0]
|
85 |
+
# print('warpfield ',warpfield.shape)
|
86 |
+
# warp
|
87 |
+
warped = self.warper(input, warpfield)
|
88 |
+
return warped
|
89 |
+
|
90 |
+
|
91 |
+
class GeometricTimeWarper(TimeWarper):
|
92 |
+
|
93 |
+
def __init__(self, sampling_rate=48000):
|
94 |
+
super().__init__()
|
95 |
+
self.sampling_rate = sampling_rate
|
96 |
+
|
97 |
+
def displacements2warpfield(self, displacements, seq_length):
|
98 |
+
distance = th.sum(displacements**2, dim=2) ** 0.5
|
99 |
+
distance = F.interpolate(distance, size=seq_length)
|
100 |
+
warpfield = -distance / 343.0 * self.sampling_rate
|
101 |
+
return warpfield
|
102 |
+
|
103 |
+
def forward(self, input, displacements):
|
104 |
+
'''
|
105 |
+
:param input: audio signal to be warped (B x 2 x T)
|
106 |
+
:param displacements: sequence of 3D displacement vectors for geometric warping (B x 3 x T)
|
107 |
+
:return: the warped signal (B x 2 x T)
|
108 |
+
'''
|
109 |
+
warpfield = self.displacements2warpfield(displacements, input.shape[-1])
|
110 |
+
# print('Ge warpfield ', warpfield.shape)
|
111 |
+
# assert 1==2
|
112 |
+
warped = super().forward(input, warpfield)
|
113 |
+
return warped
|
mono2binaural/useful_ckpts/m2b/binaural_network.net
ADDED
Binary file (107 kB). View file
|
|
mono2binaural/useful_ckpts/m2b/tx_positions.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mono2binaural/useful_ckpts/m2b/tx_positions2.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mono2binaural/useful_ckpts/m2b/tx_positions3.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mono2binaural/useful_ckpts/m2b/tx_positions4.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mono2binaural/useful_ckpts/m2b/tx_positions5.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|