Spaces:
Running
Running
from typing import Optional | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from torch.utils.data import DataLoader | |
def atan2(y, x): | |
r"""Element-wise arctangent function of y/x. | |
Returns a new tensor with signed angles in radians. | |
It is an alternative implementation of torch.atan2 | |
Args: | |
y (Tensor): First input tensor | |
x (Tensor): Second input tensor [shape=y.shape] | |
Returns: | |
Tensor: [shape=y.shape]. | |
""" | |
pi = 2 * torch.asin(torch.tensor(1.0)) | |
x += ((x == 0) & (y == 0)) * 1.0 | |
out = torch.atan(y / x) | |
out += ((y >= 0) & (x < 0)) * pi | |
out -= ((y < 0) & (x < 0)) * pi | |
out *= 1 - ((y > 0) & (x == 0)) * 1.0 | |
out += ((y > 0) & (x == 0)) * (pi / 2) | |
out *= 1 - ((y < 0) & (x == 0)) * 1.0 | |
out += ((y < 0) & (x == 0)) * (-pi / 2) | |
return out | |
# Define basic complex operations on torch.Tensor objects whose last dimension | |
# consists in the concatenation of the real and imaginary parts. | |
def _norm(x: torch.Tensor) -> torch.Tensor: | |
r"""Computes the norm value of a torch Tensor, assuming that it | |
comes as real and imaginary part in its last dimension. | |
Args: | |
x (Tensor): Input Tensor of shape [shape=(..., 2)] | |
Returns: | |
Tensor: shape as x excluding the last dimension. | |
""" | |
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2 | |
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Element-wise multiplication of two complex Tensors described | |
through their real and imaginary parts. | |
The result is added to the `out` tensor""" | |
# check `out` and allocate it if needed | |
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) | |
if out is None or out.shape != target_shape: | |
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) | |
if out is a: | |
real_a = a[..., 0] | |
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1]) | |
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0]) | |
else: | |
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]) | |
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]) | |
return out | |
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Element-wise multiplication of two complex Tensors described | |
through their real and imaginary parts | |
can work in place in case out is a only""" | |
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)]) | |
if out is None or out.shape != target_shape: | |
out = torch.zeros(target_shape, dtype=a.dtype, device=a.device) | |
if out is a: | |
real_a = a[..., 0] | |
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1] | |
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0] | |
else: | |
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1] | |
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0] | |
return out | |
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Element-wise multiplicative inverse of a Tensor with complex | |
entries described through their real and imaginary parts. | |
can work in place in case out is z""" | |
ez = _norm(z) | |
if out is None or out.shape != z.shape: | |
out = torch.zeros_like(z) | |
out[..., 0] = z[..., 0] / ez | |
out[..., 1] = -z[..., 1] / ez | |
return out | |
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Element-wise complex conjugate of a Tensor with complex entries | |
described through their real and imaginary parts. | |
can work in place in case out is z""" | |
if out is None or out.shape != z.shape: | |
out = torch.zeros_like(z) | |
out[..., 0] = z[..., 0] | |
out[..., 1] = -z[..., 1] | |
return out | |
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: | |
""" | |
Invert 1x1 or 2x2 matrices | |
Will generate errors if the matrices are singular: user must handle this | |
through his own regularization schemes. | |
Args: | |
M (Tensor): [shape=(..., nb_channels, nb_channels, 2)] | |
matrices to invert: must be square along dimensions -3 and -2 | |
Returns: | |
invM (Tensor): [shape=M.shape] | |
inverses of M | |
""" | |
nb_channels = M.shape[-2] | |
if out is None or out.shape != M.shape: | |
out = torch.empty_like(M) | |
if nb_channels == 1: | |
# scalar case | |
out = _inv(M, out) | |
elif nb_channels == 2: | |
# two channels case: analytical expression | |
# first compute the determinent | |
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :]) | |
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :]) | |
# invert it | |
invDet = _inv(det) | |
# then fill out the matrix with the inverse | |
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :]) | |
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :]) | |
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :]) | |
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :]) | |
else: | |
raise Exception("Only 2 channels are supported for the torch version.") | |
return out | |
# Now define the signal-processing low-level functions used by the Separator | |
def expectation_maximization( | |
y: torch.Tensor, | |
x: torch.Tensor, | |
iterations: int = 2, | |
eps: float = 1e-10, | |
batch_size: int = 200, | |
): | |
r"""Expectation maximization algorithm, for refining source separation | |
estimates. | |
This algorithm allows to make source separation results better by | |
enforcing multichannel consistency for the estimates. This usually means | |
a better perceptual quality in terms of spatial artifacts. | |
The implementation follows the details presented in [1]_, taking | |
inspiration from the original EM algorithm proposed in [2]_ and its | |
weighted refinement proposed in [3]_, [4]_. | |
It works by iteratively: | |
* Re-estimate source parameters (power spectral densities and spatial | |
covariance matrices) through :func:`get_local_gaussian_model`. | |
* Separate again the mixture with the new parameters by first computing | |
the new modelled mixture covariance matrices with :func:`get_mix_model`, | |
prepare the Wiener filters through :func:`wiener_gain` and apply them | |
with :func:`apply_filter``. | |
References | |
---------- | |
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and | |
N. Takahashi and Y. Mitsufuji, "Improving music source separation based | |
on deep neural networks through data augmentation and network | |
blending." 2017 IEEE International Conference on Acoustics, Speech | |
and Signal Processing (ICASSP). IEEE, 2017. | |
.. [2] N.Q. Duong and E. Vincent and R.Gribonval. "Under-determined | |
reverberant audio source separation using a full-rank spatial | |
covariance model." IEEE Transactions on Audio, Speech, and Language | |
Processing 18.7 (2010): 1830-1840. | |
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source | |
separation with deep neural networks." IEEE/ACM Transactions on Audio, | |
Speech, and Language Processing 24.9 (2016): 1652-1664. | |
.. [4] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music | |
separation with deep neural networks." 2016 24th European Signal | |
Processing Conference (EUSIPCO). IEEE, 2016. | |
.. [5] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for | |
source separation." IEEE Transactions on Signal Processing | |
62.16 (2014): 4298-4310. | |
Args: | |
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] | |
initial estimates for the sources | |
x (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2)] | |
complex STFT of the mixture signal | |
iterations (int): [scalar] | |
number of iterations for the EM algorithm. | |
eps (float or None): [scalar] | |
The epsilon value to use for regularization and filters. | |
Returns: | |
y (Tensor): [shape=(nb_frames, nb_bins, nb_channels, 2, nb_sources)] | |
estimated sources after iterations | |
v (Tensor): [shape=(nb_frames, nb_bins, nb_sources)] | |
estimated power spectral densities | |
R (Tensor): [shape=(nb_bins, nb_channels, nb_channels, 2, nb_sources)] | |
estimated spatial covariance matrices | |
Notes: | |
* You need an initial estimate for the sources to apply this | |
algorithm. This is precisely what the :func:`wiener` function does. | |
* This algorithm *is not* an implementation of the "exact" EM | |
proposed in [1]_. In particular, it does compute the posterior | |
covariance matrices the same (exact) way. Instead, it uses the | |
simplified approximate scheme initially proposed in [5]_ and further | |
refined in [3]_, [4]_, that boils down to just take the empirical | |
covariance of the recent source estimates, followed by a weighted | |
average for the update of the spatial covariance matrix. It has been | |
empirically demonstrated that this simplified algorithm is more | |
robust for music separation. | |
Warning: | |
It is *very* important to make sure `x.dtype` is `torch.float64` | |
if you want double precision, because this function will **not** | |
do such conversion for you from `torch.complex32`, in case you want the | |
smaller RAM usage on purpose. | |
It is usually always better in terms of quality to have double | |
precision, by e.g. calling :func:`expectation_maximization` | |
with ``x.to(torch.float64)``. | |
""" | |
# dimensions | |
(nb_frames, nb_bins, nb_channels) = x.shape[:-1] | |
nb_sources = y.shape[-1] | |
regularization = torch.cat( | |
( | |
torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], | |
torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device), | |
), | |
dim=2, | |
) | |
regularization = torch.sqrt(torch.as_tensor(eps)) * ( | |
regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)) | |
) | |
# allocate the spatial covariance matrices | |
R = [ | |
torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) | |
for j in range(nb_sources) | |
] | |
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device) | |
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device) | |
for it in range(iterations): | |
# constructing the mixture covariance matrix. Doing it with a loop | |
# to avoid storing anytime in RAM the whole 6D tensor | |
# update the PSD as the average spectrogram over channels | |
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2) | |
# update spatial covariance matrices (weighted update) | |
for j in range(nb_sources): | |
R[j] = torch.tensor(0.0, device=x.device) | |
weight = torch.tensor(eps, device=x.device) | |
pos: int = 0 | |
batch_size = batch_size if batch_size else nb_frames | |
while pos < nb_frames: | |
t = torch.arange(pos, min(nb_frames, pos + batch_size)) | |
pos = int(t[-1]) + 1 | |
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0) | |
weight = weight + torch.sum(v[t, ..., j], dim=0) | |
R[j] = R[j] / weight[..., None, None, None] | |
weight = torch.zeros_like(weight) | |
# cloning y if we track gradient, because we're going to update it | |
if y.requires_grad: | |
y = y.clone() | |
pos = 0 | |
while pos < nb_frames: | |
t = torch.arange(pos, min(nb_frames, pos + batch_size)) | |
pos = int(t[-1]) + 1 | |
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype) | |
# compute mix covariance matrix | |
Cxx = regularization | |
for j in range(nb_sources): | |
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone()) | |
# invert it | |
inv_Cxx = _invert(Cxx) | |
# separate the sources | |
for j in range(nb_sources): | |
# create a wiener gain for this source | |
gain = torch.zeros_like(inv_Cxx) | |
# computes multichannel Wiener gain as v_j R_j inv_Cxx | |
indices = torch.cartesian_prod( | |
torch.arange(nb_channels), | |
torch.arange(nb_channels), | |
torch.arange(nb_channels), | |
) | |
for index in indices: | |
gain[:, :, index[0], index[1], :] = _mul_add( | |
R[j][None, :, index[0], index[2], :].clone(), | |
inv_Cxx[:, :, index[2], index[1], :], | |
gain[:, :, index[0], index[1], :], | |
) | |
gain = gain * v[t, ..., None, None, None, j] | |
# apply it to the mixture | |
for i in range(nb_channels): | |
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j]) | |
return y, v, R | |
def wiener( | |
targets_spectrograms: torch.Tensor, | |
mix_stft: torch.Tensor, | |
iterations: int = 1, | |
softmask: bool = False, | |
residual: bool = False, | |
scale_factor: float = 10.0, | |
eps: float = 1e-10, | |
): | |
"""Wiener-based separation for multichannel audio. | |
The method uses the (possibly multichannel) spectrograms of the | |
sources to separate the (complex) Short Term Fourier Transform of the | |
mix. Separation is done in a sequential way by: | |
* Getting an initial estimate. This can be done in two ways: either by | |
directly using the spectrograms with the mixture phase, or | |
by using a softmasking strategy. This initial phase is controlled | |
by the `softmask` flag. | |
* If required, adding an additional residual target as the mix minus | |
all targets. | |
* Refinining these initial estimates through a call to | |
:func:`expectation_maximization` if the number of iterations is nonzero. | |
This implementation also allows to specify the epsilon value used for | |
regularization. It is based on [1]_, [2]_, [3]_, [4]_. | |
References | |
---------- | |
.. [1] S. Uhlich and M. Porcu and F. Giron and M. Enenkl and T. Kemp and | |
N. Takahashi and Y. Mitsufuji, "Improving music source separation based | |
on deep neural networks through data augmentation and network | |
blending." 2017 IEEE International Conference on Acoustics, Speech | |
and Signal Processing (ICASSP). IEEE, 2017. | |
.. [2] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel audio source | |
separation with deep neural networks." IEEE/ACM Transactions on Audio, | |
Speech, and Language Processing 24.9 (2016): 1652-1664. | |
.. [3] A. Nugraha and A. Liutkus and E. Vincent. "Multichannel music | |
separation with deep neural networks." 2016 24th European Signal | |
Processing Conference (EUSIPCO). IEEE, 2016. | |
.. [4] A. Liutkus and R. Badeau and G. Richard "Kernel additive models for | |
source separation." IEEE Transactions on Signal Processing | |
62.16 (2014): 4298-4310. | |
Args: | |
targets_spectrograms (Tensor): spectrograms of the sources | |
[shape=(nb_frames, nb_bins, nb_channels, nb_sources)]. | |
This is a nonnegative tensor that is | |
usually the output of the actual separation method of the user. The | |
spectrograms may be mono, but they need to be 4-dimensional in all | |
cases. | |
mix_stft (Tensor): [shape=(nb_frames, nb_bins, nb_channels, complex=2)] | |
STFT of the mixture signal. | |
iterations (int): [scalar] | |
number of iterations for the EM algorithm | |
softmask (bool): Describes how the initial estimates are obtained. | |
* if `False`, then the mixture phase will directly be used with the | |
spectrogram as initial estimates. | |
* if `True`, initial estimates are obtained by multiplying the | |
complex mix element-wise with the ratio of each target spectrogram | |
with the sum of them all. This strategy is better if the model are | |
not really good, and worse otherwise. | |
residual (bool): if `True`, an additional target is created, which is | |
equal to the mixture minus the other targets, before application of | |
expectation maximization | |
eps (float): Epsilon value to use for computing the separations. | |
This is used whenever division with a model energy is | |
performed, i.e. when softmasking and when iterating the EM. | |
It can be understood as the energy of the additional white noise | |
that is taken out when separating. | |
Returns: | |
Tensor: shape=(nb_frames, nb_bins, nb_channels, complex=2, nb_sources) | |
STFT of estimated sources | |
Notes: | |
* Be careful that you need *magnitude spectrogram estimates* for the | |
case `softmask==False`. | |
* `softmask=False` is recommended | |
* The epsilon value will have a huge impact on performance. If it's | |
large, only the parts of the signal with a significant energy will | |
be kept in the sources. This epsilon then directly controls the | |
energy of the reconstruction error. | |
Warning: | |
As in :func:`expectation_maximization`, we recommend converting the | |
mixture `x` to double precision `torch.float64` *before* calling | |
:func:`wiener`. | |
""" | |
if softmask: | |
# if we use softmask, we compute the ratio mask for all targets and | |
# multiply by the mix stft | |
y = ( | |
mix_stft[..., None] | |
* ( | |
targets_spectrograms | |
/ (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)) | |
)[..., None, :] | |
) | |
else: | |
# otherwise, we just multiply the targets spectrograms with mix phase | |
# we tacitly assume that we have magnitude estimates. | |
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None] | |
nb_sources = targets_spectrograms.shape[-1] | |
y = torch.zeros( | |
mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device | |
) | |
y[..., 0, :] = targets_spectrograms * torch.cos(angle) | |
y[..., 1, :] = targets_spectrograms * torch.sin(angle) | |
if residual: | |
# if required, adding an additional target as the mix minus | |
# available targets | |
y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1) | |
if iterations == 0: | |
return y | |
# we need to refine the estimates. Scales down the estimates for | |
# numerical stability | |
max_abs = torch.max( | |
torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), | |
torch.sqrt(_norm(mix_stft)).max() / scale_factor, | |
) | |
mix_stft = mix_stft / max_abs | |
y = y / max_abs | |
# call expectation maximization | |
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0] | |
# scale estimates up again | |
y = y * max_abs | |
return y | |
def _covariance(y_j): | |
""" | |
Compute the empirical covariance for a source. | |
Args: | |
y_j (Tensor): complex stft of the source. | |
[shape=(nb_frames, nb_bins, nb_channels, 2)]. | |
Returns: | |
Cj (Tensor): [shape=(nb_frames, nb_bins, nb_channels, nb_channels, 2)] | |
just y_j * conj(y_j.T): empirical covariance for each TF bin. | |
""" | |
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1] | |
Cj = torch.zeros( | |
(nb_frames, nb_bins, nb_channels, nb_channels, 2), | |
dtype=y_j.dtype, | |
device=y_j.device, | |
) | |
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels)) | |
for index in indices: | |
Cj[:, :, index[0], index[1], :] = _mul_add( | |
y_j[:, :, index[0], :], | |
_conj(y_j[:, :, index[1], :]), | |
Cj[:, :, index[0], index[1], :], | |
) | |
return Cj | |