|
import logging |
|
from typing import Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor, nn |
|
|
|
from ..melspec import MelSpectrogram |
|
from .hparams import HParams |
|
from .unet import UNet |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _normalize(x: Tensor) -> Tensor: |
|
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) |
|
|
|
|
|
class Denoiser(nn.Module): |
|
@property |
|
def stft_cfg(self) -> dict: |
|
hop_size = self.hp.hop_size |
|
return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4) |
|
|
|
@property |
|
def n_fft(self): |
|
return self.stft_cfg["n_fft"] |
|
|
|
@property |
|
def eps(self): |
|
return 1e-7 |
|
|
|
def __init__(self, hp: HParams): |
|
super().__init__() |
|
self.hp = hp |
|
self.net = UNet(input_dim=3, output_dim=3) |
|
self.mel_fn = MelSpectrogram(hp) |
|
|
|
self.dummy: Tensor |
|
self.register_buffer("dummy", torch.zeros(1), persistent=False) |
|
|
|
def to_mel(self, x: Tensor, drop_last=True): |
|
""" |
|
Args: |
|
x: (b t), wavs |
|
Returns: |
|
o: (b c t), mels |
|
""" |
|
if drop_last: |
|
return self.mel_fn(x)[..., :-1] |
|
return self.mel_fn(x) |
|
|
|
def _stft(self, x): |
|
""" |
|
Args: |
|
x: (b t) |
|
Returns: |
|
mag: (b f t) in [0, inf) |
|
cos: (b f t) in [-1, 1] |
|
sin: (b f t) in [-1, 1] |
|
""" |
|
dtype = x.dtype |
|
device = x.device |
|
|
|
if x.is_mps: |
|
x = x.cpu() |
|
|
|
window = torch.hann_window(self.stft_cfg["win_length"], device=x.device) |
|
s = torch.stft( |
|
x.float(), **self.stft_cfg, window=window, return_complex=True |
|
) |
|
|
|
s = s[..., :-1] |
|
|
|
mag = s.abs() |
|
|
|
phi = s.angle() |
|
cos = phi.cos() |
|
sin = phi.sin() |
|
|
|
mag = mag.to(dtype=dtype, device=device) |
|
cos = cos.to(dtype=dtype, device=device) |
|
sin = sin.to(dtype=dtype, device=device) |
|
|
|
return mag, cos, sin |
|
|
|
def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor): |
|
""" |
|
Args: |
|
mag: (b f t) in [0, inf) |
|
cos: (b f t) in [-1, 1] |
|
sin: (b f t) in [-1, 1] |
|
Returns: |
|
x: (b t) |
|
""" |
|
device = mag.device |
|
dtype = mag.dtype |
|
|
|
if mag.is_mps: |
|
mag = mag.cpu() |
|
cos = cos.cpu() |
|
sin = sin.cpu() |
|
|
|
real = mag * cos |
|
imag = mag * sin |
|
|
|
s = torch.complex(real, imag) |
|
|
|
if s.isnan().any(): |
|
logger.warning("NaN detected in ISTFT input.") |
|
|
|
s = s.to(torch.complex64) |
|
s = F.pad(s, (0, 1), "replicate") |
|
|
|
window = torch.hann_window(self.stft_cfg["win_length"], device=s.device) |
|
x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False) |
|
|
|
if x.isnan().any(): |
|
logger.warning("NaN detected in ISTFT output, set to zero.") |
|
x = torch.where(x.isnan(), torch.zeros_like(x), x) |
|
|
|
x = x.to(dtype=dtype, device=device) |
|
|
|
return x |
|
|
|
def _magphase(self, real, imag): |
|
mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt() |
|
cos = real / mag |
|
sin = imag / mag |
|
return mag, cos, sin |
|
|
|
def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor): |
|
""" |
|
Args: |
|
mag: (b f t) |
|
cos: (b f t) |
|
sin: (b f t) |
|
Returns: |
|
mag_mask: (b f t) in [0, 1], magnitude mask |
|
cos_res: (b f t) in [-1, 1], phase residual |
|
sin_res: (b f t) in [-1, 1], phase residual |
|
""" |
|
x = torch.stack([mag, cos, sin], dim=1) |
|
mag_mask, real, imag = self.net(x).unbind(1) |
|
mag_mask = mag_mask.sigmoid() |
|
real = real.tanh() |
|
imag = imag.tanh() |
|
_, cos_res, sin_res = self._magphase(real, imag) |
|
return mag_mask, sin_res, cos_res |
|
|
|
def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res): |
|
"""Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf""" |
|
sep_mag = F.relu(mag * mag_mask) |
|
sep_cos = cos * cos_res - sin * sin_res |
|
sep_sin = sin * cos_res + cos * sin_res |
|
return sep_mag, sep_cos, sep_sin |
|
|
|
def forward(self, x: Tensor, y: Union[Tensor, None] = None): |
|
""" |
|
Args: |
|
x: (b t), a mixed audio |
|
y: (b t), a fg audio |
|
""" |
|
assert x.dim() == 2, f"Expected (b t), got {x.size()}" |
|
x = x.to(self.dummy) |
|
x = _normalize(x) |
|
|
|
if y is not None: |
|
assert y.dim() == 2, f"Expected (b t), got {y.size()}" |
|
y = y.to(self.dummy) |
|
y = _normalize(y) |
|
|
|
mag, cos, sin = self._stft(x) |
|
mag_mask, sin_res, cos_res = self._predict(mag, cos, sin) |
|
sep_mag, sep_cos, sep_sin = self._separate( |
|
mag, cos, sin, mag_mask, cos_res, sin_res |
|
) |
|
|
|
o = self._istft(sep_mag, sep_cos, sep_sin) |
|
|
|
npad = x.shape[-1] - o.shape[-1] |
|
o = F.pad(o, (0, npad)) |
|
|
|
if y is not None: |
|
self.losses = dict(l1=F.l1_loss(o, y)) |
|
|
|
return o |
|
|