# Copyright (c) 2024 Amphion. # # This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/loss.py # Licensed under Apache License 2.0 from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import librosa from models.vocoders.gan.discriminator.mpd import MultiScaleMultiPeriodDiscriminator from models.tts.jets.alignments import make_non_pad_mask, make_pad_mask class GeneratorAdversarialLoss(torch.nn.Module): """Generator adversarial loss module.""" def __init__(self): super().__init__() def forward(self, outputs) -> torch.Tensor: if isinstance(outputs, (tuple, list)): adv_loss = 0.0 for i, outputs_ in enumerate(outputs): if isinstance(outputs_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps outputs_ = outputs_[-1] adv_loss += F.mse_loss(outputs_, outputs_.new_ones(outputs_.size())) else: adv_loss = F.mse_loss(outputs, outputs.new_ones(outputs.size())) return adv_loss class FeatureMatchLoss(torch.nn.Module): """Feature matching loss module.""" def __init__( self, average_by_layers: bool = False, average_by_discriminators: bool = False, include_final_outputs: bool = True, ): """Initialize FeatureMatchLoss module. Args: average_by_layers (bool): Whether to average the loss by the number of layers. average_by_discriminators (bool): Whether to average the loss by the number of discriminators. include_final_outputs (bool): Whether to include the final output of each discriminator for loss calculation. """ super().__init__() self.average_by_layers = average_by_layers self.average_by_discriminators = average_by_discriminators self.include_final_outputs = include_final_outputs def forward( self, feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], ) -> torch.Tensor: """Calculate feature matching loss. Args: feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of discriminator outputs or list of discriminator outputs calcuated from generator's outputs. feats (Union[List[List[Tensor]], List[Tensor]]): List of list of discriminator outputs or list of discriminator outputs calcuated from groundtruth.. Returns: Tensor: Feature matching loss value. """ feat_match_loss = 0.0 for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): feat_match_loss_ = 0.0 if not self.include_final_outputs: feats_hat_ = feats_hat_[:-1] feats_ = feats_[:-1] for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) if self.average_by_layers: feat_match_loss_ /= j + 1 feat_match_loss += feat_match_loss_ if self.average_by_discriminators: feat_match_loss /= i + 1 return feat_match_loss class DurationPredictorLoss(torch.nn.Module): """Loss function module for duration predictor. The loss value is Calculated in log domain to make it Gaussian. """ def __init__(self, offset=1.0, reduction="mean"): """Initilize duration predictor loss module. Args: offset (float, optional): Offset value to avoid nan in log domain. reduction (str): Reduction type in loss calculation. """ super().__init__() self.criterion = torch.nn.MSELoss(reduction=reduction) self.offset = offset def forward(self, outputs, targets): targets = torch.log(targets.float() + self.offset) loss = self.criterion(outputs, targets) return loss class VarianceLoss(torch.nn.Module): def __init__(self): """Initialize JETS variance loss module.""" super().__init__() # define criterions reduction = "mean" self.mse_criterion = torch.nn.MSELoss(reduction=reduction) self.duration_criterion = DurationPredictorLoss(reduction=reduction) def forward( self, d_outs: torch.Tensor, ds: torch.Tensor, p_outs: torch.Tensor, ps: torch.Tensor, e_outs: torch.Tensor, es: torch.Tensor, ilens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate forward propagation. Args: d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text). ds (LongTensor): Batch of durations (B, T_text). p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1). ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1). e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1). es (Tensor): Batch of target token-averaged energy (B, T_text, 1). ilens (LongTensor): Batch of the lengths of each input (B,). Returns: Tensor: Duration predictor loss value. Tensor: Pitch predictor loss value. Tensor: Energy predictor loss value. """ # apply mask to remove padded part duration_masks = make_non_pad_mask(ilens).to(ds.device) d_outs = d_outs.masked_select(duration_masks) ds = ds.masked_select(duration_masks) pitch_masks = make_non_pad_mask(ilens).to(ds.device) pitch_masks_ = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device) p_outs = p_outs.masked_select(pitch_masks) e_outs = e_outs.masked_select(pitch_masks) ps = ps.masked_select(pitch_masks_) es = es.masked_select(pitch_masks_) # calculate loss duration_loss = self.duration_criterion(d_outs, ds) pitch_loss = self.mse_criterion(p_outs, ps) energy_loss = self.mse_criterion(e_outs, es) return duration_loss, pitch_loss, energy_loss class ForwardSumLoss(torch.nn.Module): """Forwardsum loss described at https://openreview.net/forum?id=0NQwnnwAORi""" def __init__(self): """Initialize forwardsum loss module.""" super().__init__() def forward( self, log_p_attn: torch.Tensor, ilens: torch.Tensor, olens: torch.Tensor, blank_prob: float = np.e**-1, ) -> torch.Tensor: """Calculate forward propagation. Args: log_p_attn (Tensor): Batch of log probability of attention matrix (B, T_feats, T_text). ilens (Tensor): Batch of the lengths of each input (B,). olens (Tensor): Batch of the lengths of each target (B,). blank_prob (float): Blank symbol probability. Returns: Tensor: forwardsum loss value. """ B = log_p_attn.size(0) # a row must be added to the attention matrix to account for # blank token of CTC loss # (B,T_feats,T_text+1) log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob)) loss = 0 for bidx in range(B): # construct target sequnece. # Every text token is mapped to a unique sequnece number. target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0) cur_log_p_attn_pd = log_p_attn_pd[ bidx, : olens[bidx], : ilens[bidx] + 1 ].unsqueeze( 1 ) # (T_feats,1,T_text+1) cur_log_p_attn_pd = F.log_softmax(cur_log_p_attn_pd, dim=-1) loss += F.ctc_loss( log_probs=cur_log_p_attn_pd, targets=target_seq, input_lengths=olens[bidx : bidx + 1], target_lengths=ilens[bidx : bidx + 1], zero_infinity=True, ) loss = loss / B return loss class MelSpectrogramLoss(torch.nn.Module): """Mel-spectrogram loss.""" def __init__( self, fs: int = 22050, n_fft: int = 1024, hop_length: int = 256, win_length: Optional[int] = None, window: str = "hann", n_mels: int = 80, fmin: Optional[int] = 0, fmax: Optional[int] = None, center: bool = True, normalized: bool = False, onesided: bool = True, htk: bool = False, ): """Initialize Mel-spectrogram loss. Args: fs (int): Sampling rate. n_fft (int): FFT points. hop_length (int): Hop length. win_length (Optional[int]): Window length. window (str): Window type. n_mels (int): Number of Mel basis. fmin (Optional[int]): Minimum frequency for Mel. fmax (Optional[int]): Maximum frequency for Mel. center (bool): Whether to use center window. normalized (bool): Whether to use normalized one. onesided (bool): Whether to use oneseded one. """ super().__init__() self.fs = fs self.n_fft = n_fft self.hop_length = hop_length self.win_length = n_fft self.window = window self.n_mels = n_mels self.fmin = 0 if fmin is None else fmin self.fmax = fs / 2 if fmax is None else fmax self.center = center self.normalized = normalized self.onesided = onesided self.htk = htk def logmel(self, feat, ilens): mel_options = dict( sr=self.fs, n_fft=self.n_fft, n_mels=self.n_mels, fmin=self.fmin, fmax=self.fmax, htk=self.htk, ) melmat = librosa.filters.mel(**mel_options) melmat = torch.from_numpy(melmat.T).float().to(feat.device) mel_feat = torch.matmul(feat, melmat) mel_feat = torch.clamp(mel_feat, min=1e-10) logmel_feat = mel_feat.log10() # Zero padding if ilens is not None: logmel_feat = logmel_feat.masked_fill( make_pad_mask(ilens, logmel_feat, 1), 0.0 ) else: ilens = feat.new_full( [feat.size(0)], fill_value=feat.size(1), dtype=torch.long ) return logmel_feat def wav_to_mel(self, input, input_lengths=None): if self.window is not None: window_func = getattr(torch, f"{self.window}_window") window = window_func( self.win_length, dtype=input.dtype, device=input.device ) stft_kwargs = dict( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, center=self.center, window=window, normalized=self.normalized, onesided=self.onesided, return_complex=True, ) bs = input.size(0) if input.dim() == 3: multi_channel = True # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) input = input.transpose(1, 2).reshape(-1, input.size(1)) else: multi_channel = False input_stft = torch.stft(input, **stft_kwargs) input_stft = torch.view_as_real(input_stft) input_stft = input_stft.transpose(1, 2) if multi_channel: input_stft = input_stft.view( bs, -1, input_stft.size(1), input_stft.size(2), 2 ).transpose(1, 2) if input_lengths is not None: if self.center: pad = self.n_fft // 2 input_lengths = input_lengths + 2 * pad feats_lens = (input_lengths - self.n_fft) // self.hop_length + 1 input_stft.masked_fill_(make_pad_mask(feats_lens, input_stft, 1), 0.0) else: feats_lens = None input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) input_feats = self.logmel(input_amp, feats_lens) return input_feats, feats_lens def forward( self, y_hat: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: mel_hat, _ = self.wav_to_mel(y_hat.squeeze(1)) mel, _ = self.wav_to_mel(y.squeeze(1)) mel_loss = F.l1_loss(mel_hat, mel) return mel_loss class GeneratorLoss(nn.Module): """The total loss of the generator""" def __init__(self, cfg): super().__init__() self.cfg = cfg self.mel_loss = MelSpectrogramLoss() self.generator_adv_loss = GeneratorAdversarialLoss() self.feat_match_loss = FeatureMatchLoss() self.var_loss = VarianceLoss() self.forwardsum_loss = ForwardSumLoss() self.lambda_adv = 1.0 self.lambda_mel = 45.0 self.lambda_feat_match = 2.0 self.lambda_var = 1.0 self.lambda_align = 2.0 def forward(self, outputs_g, outputs_d, speech_): loss_g = {} # parse generator output ( speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es, text_lengths, feats_lengths, ) = outputs_g # parse discriminator output (p_hat, p) = outputs_d # calculate losses mel_loss = self.mel_loss(speech_hat_, speech_) adv_loss = self.generator_adv_loss(p_hat) feat_match_loss = self.feat_match_loss(p_hat, p) dur_loss, pitch_loss, energy_loss = self.var_loss( d_outs, ds, p_outs, ps, e_outs, es, text_lengths ) forwardsum_loss = self.forwardsum_loss(log_p_attn, text_lengths, feats_lengths) # calculate total loss mel_loss = mel_loss * self.lambda_mel loss_g["mel_loss"] = mel_loss adv_loss = adv_loss * self.lambda_adv loss_g["adv_loss"] = adv_loss feat_match_loss = feat_match_loss * self.lambda_feat_match loss_g["feat_match_loss"] = feat_match_loss g_loss = mel_loss + adv_loss + feat_match_loss loss_g["g_loss"] = g_loss var_loss = (dur_loss + pitch_loss + energy_loss) * self.lambda_var loss_g["var_loss"] = var_loss align_loss = (forwardsum_loss + bin_loss) * self.lambda_align loss_g["align_loss"] = align_loss g_total_loss = g_loss + var_loss + align_loss loss_g["g_total_loss"] = g_total_loss return loss_g class DiscriminatorAdversarialLoss(torch.nn.Module): """Discriminator adversarial loss module.""" def __init__( self, average_by_discriminators: bool = True, loss_type: str = "mse", ): """Initialize DiscriminatorAversarialLoss module. Args: average_by_discriminators (bool): Whether to average the loss by the number of discriminators. loss_type (str): Loss type, "mse" or "hinge". """ super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." if loss_type == "mse": self.fake_criterion = self._mse_fake_loss self.real_criterion = self._mse_real_loss else: self.fake_criterion = self._hinge_fake_loss self.real_criterion = self._hinge_real_loss def forward( self, outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """Calcualate discriminator adversarial loss. Args: outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from generator. outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from groundtruth. Returns: Tensor: Discriminator real loss value. Tensor: Discriminator fake loss value. """ if isinstance(outputs, (tuple, list)): real_loss = 0.0 fake_loss = 0.0 for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): if isinstance(outputs_hat_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps outputs_hat_ = outputs_hat_[-1] outputs_ = outputs_[-1] real_loss += self.real_criterion(outputs_) fake_loss += self.fake_criterion(outputs_hat_) if self.average_by_discriminators: fake_loss /= i + 1 real_loss /= i + 1 else: real_loss = self.real_criterion(outputs) fake_loss = self.fake_criterion(outputs_hat) return real_loss, fake_loss def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_ones(x.size())) def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_zeros(x.size())) def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) class DiscriminatorLoss(torch.nn.Module): """The total loss of the discriminator""" def __init__(self, cfg): super(DiscriminatorLoss, self).__init__() self.cfg = cfg self.discriminator = MultiScaleMultiPeriodDiscriminator() self.discriminator_adv_loss = DiscriminatorAdversarialLoss() def forward(self, speech_real, speech_generated): loss_d = {} real_loss, fake_loss = self.discriminator_adv_loss( speech_generated, speech_real ) loss_d["loss_disc_all"] = real_loss + fake_loss return loss_d