import logging from typing import Union import matplotlib.pyplot as plt import pandas as pd import torch from torch import Tensor, nn from torch.distributions import Beta from ..common import Normalizer from ..denoiser.inference import load_denoiser from ..melspec import MelSpectrogram from .hparams import HParams from .lcfm import CFM, IRMAE, LCFM from .univnet import UnivNet logger = logging.getLogger(__name__) def _maybe(fn): def _fn(*args): if args[0] is None: return None return fn(*args) return _fn def _normalize_wav(x: Tensor): return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7) class Enhancer(nn.Module): def __init__(self, hp: HParams): super().__init__() self.hp = hp n_mels = self.hp.num_mels vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim latent_dim = self.hp.lcfm_latent_dim self.lcfm = LCFM( IRMAE( input_dim=n_mels, output_dim=vocoder_input_dim, latent_dim=latent_dim, ), CFM( cond_dim=n_mels, output_dim=self.hp.lcfm_latent_dim, solver_nfe=self.hp.cfm_solver_nfe, solver_method=self.hp.cfm_solver_method, time_mapping_divisor=self.hp.cfm_time_mapping_divisor, ), z_scale=self.hp.lcfm_z_scale, ) self.lcfm.set_mode_(self.hp.lcfm_training_mode) self.mel_fn = MelSpectrogram(hp) self.vocoder = UnivNet(self.hp, vocoder_input_dim) self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu") self.normalizer = Normalizer() self._eval_lambd = 0.0 self.dummy: Tensor self.register_buffer("dummy", torch.zeros(1)) if self.hp.enhancer_stage1_run_dir is not None: pretrained_path = ( self.hp.enhancer_stage1_run_dir / "ds/G/default/mp_rank_00_model_states.pt" ) self._load_pretrained(pretrained_path) # logger.info(f"{self.__class__.__name__} summary") # logger.info(f"{self.summarize()}") def _load_pretrained(self, path): # Clone is necessary as otherwise it holds a reference to the original model cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()} denoiser_state_dict = { k: v.clone() for k, v in self.denoiser.state_dict().items() } state_dict = torch.load(path, map_location="cpu")["module"] self.load_state_dict(state_dict, strict=False) self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser logger.info(f"Loaded pretrained model from {path}") def summarize(self): npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad) npa = lambda m: sum(p.numel() for p in m.parameters()) rows = [] for name, module in self.named_children(): rows.append(dict(name=name, trainable=npa_train(module), total=npa(module))) rows.append(dict(name="total", trainable=npa_train(self), total=npa(self))) df = pd.DataFrame(rows) return df.to_markdown(index=False) def to_mel(self, x: Tensor, drop_last=True): """ Args: x: (b t), wavs Returns: o: (b c t), mels """ if drop_last: return self.mel_fn(x)[..., :-1] # (b d t) return self.mel_fn(x) def _may_denoise(self, x: Tensor, y: Union[Tensor, None] = None): if self.hp.lcfm_training_mode == "cfm": return self.denoiser(x, y) return x def configurate_(self, nfe, solver, lambd, tau): """ Args: nfe: number of function evaluations solver: solver method lambd: denoiser strength [0, 1] tau: prior temperature [0, 1] """ self.lcfm.cfm.solver.configurate_(nfe, solver) self.lcfm.eval_tau_(tau) self._eval_lambd = lambd def forward( self, x: Tensor, y: Union[Tensor, None] = None, z: Union[Tensor, None] = None ): """ Args: x: (b t), mix wavs (fg + bg) y: (b t), fg clean wavs z: (b t), fg distorted wavs Returns: o: (b t), reconstructed wavs """ assert x.dim() == 2, f"Expected (b t), got {x.size()}" assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}" if self.hp.lcfm_training_mode == "cfm": self.normalizer.eval() x = _normalize_wav(x) y = _maybe(_normalize_wav)(y) z = _maybe(_normalize_wav)(z) x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t) if self.hp.lcfm_training_mode == "cfm": if self.training: lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device) lambd = lambd[:, None, None] x_mel_denoised = self.normalizer( self.to_mel(self._may_denoise(x, z)), update=False ) x_mel_denoised = x_mel_denoised.detach() x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original self._visualize(x_mel_original, x_mel_denoised) else: lambd = self._eval_lambd if lambd == 0: x_mel_denoised = x_mel_original else: x_mel_denoised = self.normalizer( self.to_mel(self._may_denoise(x, z)), update=False ) x_mel_denoised = x_mel_denoised.detach() x_mel_denoised = ( lambd * x_mel_denoised + (1 - lambd) * x_mel_original ) else: x_mel_denoised = x_mel_original y_mel = _maybe(self.to_mel)(y) # (b d t) y_mel = _maybe(self.normalizer)(y_mel) lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t) if lcfm_decoded is None: o = None else: o = self.vocoder(lcfm_decoded, y) return o