import logging import torch import torch.nn.functional as F from torch import Tensor, nn from ..melspec import MelSpectrogram from .hparams import HParams from .unet import UNet logger = logging.getLogger(__name__) def _normalize(x: Tensor) -> Tensor: return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) class Denoiser(nn.Module): @property def stft_cfg(self) -> dict: hop_size = self.hp.hop_size return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4) @property def n_fft(self): return self.stft_cfg["n_fft"] @property def eps(self): return 1e-7 def __init__(self, hp: HParams): super().__init__() self.hp = hp self.net = UNet(input_dim=3, output_dim=3) self.mel_fn = MelSpectrogram(hp) self.dummy: Tensor self.register_buffer("dummy", torch.zeros(1), persistent=False) def to_mel(self, x: Tensor, drop_last=True): """ Args: x: (b t), wavs Returns: o: (b c t), mels """ if drop_last: return self.mel_fn(x)[..., :-1] # (b d t) return self.mel_fn(x) def _stft(self, x): """ Args: x: (b t) Returns: mag: (b f t) in [0, inf) cos: (b f t) in [-1, 1] sin: (b f t) in [-1, 1] """ dtype = x.dtype device = x.device if x.is_mps: x = x.cpu() window = torch.hann_window(self.stft_cfg["win_length"], device=x.device) s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1) s = s[..., :-1] # (b f t) mag = s.abs() # (b f t) phi = s.angle() # (b f t) cos = phi.cos() # (b f t) sin = phi.sin() # (b f t) mag = mag.to(dtype=dtype, device=device) cos = cos.to(dtype=dtype, device=device) sin = sin.to(dtype=dtype, device=device) return mag, cos, sin def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor): """ Args: mag: (b f t) in [0, inf) cos: (b f t) in [-1, 1] sin: (b f t) in [-1, 1] Returns: x: (b t) """ device = mag.device dtype = mag.dtype if mag.is_mps: mag = mag.cpu() cos = cos.cpu() sin = sin.cpu() real = mag * cos # (b f t) imag = mag * sin # (b f t) s = torch.complex(real, imag) # (b f t) if s.isnan().any(): logger.warning("NaN detected in ISTFT input.") s = F.pad(s, (0, 1), "replicate") # (b f t+1) window = torch.hann_window(self.stft_cfg["win_length"], device=s.device) x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False) if x.isnan().any(): logger.warning("NaN detected in ISTFT output, set to zero.") x = torch.where(x.isnan(), torch.zeros_like(x), x) x = x.to(dtype=dtype, device=device) return x def _magphase(self, real, imag): mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt() cos = real / mag sin = imag / mag return mag, cos, sin def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor): """ Args: mag: (b f t) cos: (b f t) sin: (b f t) Returns: mag_mask: (b f t) in [0, 1], magnitude mask cos_res: (b f t) in [-1, 1], phase residual sin_res: (b f t) in [-1, 1], phase residual """ x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t) mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t) mag_mask = mag_mask.sigmoid() # (b f t) real = real.tanh() # (b f t) imag = imag.tanh() # (b f t) _, cos_res, sin_res = self._magphase(real, imag) # (b f t) return mag_mask, sin_res, cos_res def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res): """Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf""" sep_mag = F.relu(mag * mag_mask) sep_cos = cos * cos_res - sin * sin_res sep_sin = sin * cos_res + cos * sin_res return sep_mag, sep_cos, sep_sin def forward(self, x: Tensor, y: Tensor | None = None): """ Args: x: (b t), a mixed audio y: (b t), a fg audio """ assert x.dim() == 2, f"Expected (b t), got {x.size()}" x = x.to(self.dummy) x = _normalize(x) if y is not None: assert y.dim() == 2, f"Expected (b t), got {y.size()}" y = y.to(self.dummy) y = _normalize(y) mag, cos, sin = self._stft(x) # (b 2f t) mag_mask, sin_res, cos_res = self._predict(mag, cos, sin) sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res) o = self._istft(sep_mag, sep_cos, sep_sin) npad = x.shape[-1] - o.shape[-1] o = F.pad(o, (0, npad)) if y is not None: self.losses = dict(l1=F.l1_loss(o, y)) return o