maskgct / models /tts /jets /jets_loss.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2024 Amphion.
#
# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/loss.py
# Licensed under Apache License 2.0
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
from models.vocoders.gan.discriminator.mpd import MultiScaleMultiPeriodDiscriminator
from models.tts.jets.alignments import make_non_pad_mask, make_pad_mask
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(self):
super().__init__()
def forward(self, outputs) -> torch.Tensor:
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_ = outputs_[-1]
adv_loss += F.mse_loss(outputs_, outputs_.new_ones(outputs_.size()))
else:
adv_loss = F.mse_loss(outputs, outputs.new_ones(outputs.size()))
return adv_loss
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers: bool = False,
average_by_discriminators: bool = False,
include_final_outputs: bool = True,
):
"""Initialize FeatureMatchLoss module.
Args:
average_by_layers (bool): Whether to average the loss by the number
of layers.
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
include_final_outputs (bool): Whether to include the final output of
each discriminator for loss calculation.
"""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(
self,
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
) -> torch.Tensor:
"""Calculate feature matching loss.
Args:
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from generator's outputs.
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from groundtruth..
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class DurationPredictorLoss(torch.nn.Module):
"""Loss function module for duration predictor.
The loss value is Calculated in log domain to make it Gaussian.
"""
def __init__(self, offset=1.0, reduction="mean"):
"""Initilize duration predictor loss module.
Args:
offset (float, optional): Offset value to avoid nan in log domain.
reduction (str): Reduction type in loss calculation.
"""
super().__init__()
self.criterion = torch.nn.MSELoss(reduction=reduction)
self.offset = offset
def forward(self, outputs, targets):
targets = torch.log(targets.float() + self.offset)
loss = self.criterion(outputs, targets)
return loss
class VarianceLoss(torch.nn.Module):
def __init__(self):
"""Initialize JETS variance loss module."""
super().__init__()
# define criterions
reduction = "mean"
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
def forward(
self,
d_outs: torch.Tensor,
ds: torch.Tensor,
p_outs: torch.Tensor,
ps: torch.Tensor,
e_outs: torch.Tensor,
es: torch.Tensor,
ilens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate forward propagation.
Args:
d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text).
ds (LongTensor): Batch of durations (B, T_text).
p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1).
ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1).
e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1).
es (Tensor): Batch of target token-averaged energy (B, T_text, 1).
ilens (LongTensor): Batch of the lengths of each input (B,).
Returns:
Tensor: Duration predictor loss value.
Tensor: Pitch predictor loss value.
Tensor: Energy predictor loss value.
"""
# apply mask to remove padded part
duration_masks = make_non_pad_mask(ilens).to(ds.device)
d_outs = d_outs.masked_select(duration_masks)
ds = ds.masked_select(duration_masks)
pitch_masks = make_non_pad_mask(ilens).to(ds.device)
pitch_masks_ = make_non_pad_mask(ilens).unsqueeze(-1).to(ds.device)
p_outs = p_outs.masked_select(pitch_masks)
e_outs = e_outs.masked_select(pitch_masks)
ps = ps.masked_select(pitch_masks_)
es = es.masked_select(pitch_masks_)
# calculate loss
duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.mse_criterion(p_outs, ps)
energy_loss = self.mse_criterion(e_outs, es)
return duration_loss, pitch_loss, energy_loss
class ForwardSumLoss(torch.nn.Module):
"""Forwardsum loss described at https://openreview.net/forum?id=0NQwnnwAORi"""
def __init__(self):
"""Initialize forwardsum loss module."""
super().__init__()
def forward(
self,
log_p_attn: torch.Tensor,
ilens: torch.Tensor,
olens: torch.Tensor,
blank_prob: float = np.e**-1,
) -> torch.Tensor:
"""Calculate forward propagation.
Args:
log_p_attn (Tensor): Batch of log probability of attention matrix
(B, T_feats, T_text).
ilens (Tensor): Batch of the lengths of each input (B,).
olens (Tensor): Batch of the lengths of each target (B,).
blank_prob (float): Blank symbol probability.
Returns:
Tensor: forwardsum loss value.
"""
B = log_p_attn.size(0)
# a row must be added to the attention matrix to account for
# blank token of CTC loss
# (B,T_feats,T_text+1)
log_p_attn_pd = F.pad(log_p_attn, (1, 0, 0, 0, 0, 0), value=np.log(blank_prob))
loss = 0
for bidx in range(B):
# construct target sequnece.
# Every text token is mapped to a unique sequnece number.
target_seq = torch.arange(1, ilens[bidx] + 1).unsqueeze(0)
cur_log_p_attn_pd = log_p_attn_pd[
bidx, : olens[bidx], : ilens[bidx] + 1
].unsqueeze(
1
) # (T_feats,1,T_text+1)
cur_log_p_attn_pd = F.log_softmax(cur_log_p_attn_pd, dim=-1)
loss += F.ctc_loss(
log_probs=cur_log_p_attn_pd,
targets=target_seq,
input_lengths=olens[bidx : bidx + 1],
target_lengths=ilens[bidx : bidx + 1],
zero_infinity=True,
)
loss = loss / B
return loss
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
fs: int = 22050,
n_fft: int = 1024,
hop_length: int = 256,
win_length: Optional[int] = None,
window: str = "hann",
n_mels: int = 80,
fmin: Optional[int] = 0,
fmax: Optional[int] = None,
center: bool = True,
normalized: bool = False,
onesided: bool = True,
htk: bool = False,
):
"""Initialize Mel-spectrogram loss.
Args:
fs (int): Sampling rate.
n_fft (int): FFT points.
hop_length (int): Hop length.
win_length (Optional[int]): Window length.
window (str): Window type.
n_mels (int): Number of Mel basis.
fmin (Optional[int]): Minimum frequency for Mel.
fmax (Optional[int]): Maximum frequency for Mel.
center (bool): Whether to use center window.
normalized (bool): Whether to use normalized one.
onesided (bool): Whether to use oneseded one.
"""
super().__init__()
self.fs = fs
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = n_fft
self.window = window
self.n_mels = n_mels
self.fmin = 0 if fmin is None else fmin
self.fmax = fs / 2 if fmax is None else fmax
self.center = center
self.normalized = normalized
self.onesided = onesided
self.htk = htk
def logmel(self, feat, ilens):
mel_options = dict(
sr=self.fs,
n_fft=self.n_fft,
n_mels=self.n_mels,
fmin=self.fmin,
fmax=self.fmax,
htk=self.htk,
)
melmat = librosa.filters.mel(**mel_options)
melmat = torch.from_numpy(melmat.T).float().to(feat.device)
mel_feat = torch.matmul(feat, melmat)
mel_feat = torch.clamp(mel_feat, min=1e-10)
logmel_feat = mel_feat.log10()
# Zero padding
if ilens is not None:
logmel_feat = logmel_feat.masked_fill(
make_pad_mask(ilens, logmel_feat, 1), 0.0
)
else:
ilens = feat.new_full(
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
)
return logmel_feat
def wav_to_mel(self, input, input_lengths=None):
if self.window is not None:
window_func = getattr(torch, f"{self.window}_window")
window = window_func(
self.win_length, dtype=input.dtype, device=input.device
)
stft_kwargs = dict(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
center=self.center,
window=window,
normalized=self.normalized,
onesided=self.onesided,
return_complex=True,
)
bs = input.size(0)
if input.dim() == 3:
multi_channel = True
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
input = input.transpose(1, 2).reshape(-1, input.size(1))
else:
multi_channel = False
input_stft = torch.stft(input, **stft_kwargs)
input_stft = torch.view_as_real(input_stft)
input_stft = input_stft.transpose(1, 2)
if multi_channel:
input_stft = input_stft.view(
bs, -1, input_stft.size(1), input_stft.size(2), 2
).transpose(1, 2)
if input_lengths is not None:
if self.center:
pad = self.n_fft // 2
input_lengths = input_lengths + 2 * pad
feats_lens = (input_lengths - self.n_fft) // self.hop_length + 1
input_stft.masked_fill_(make_pad_mask(feats_lens, input_stft, 1), 0.0)
else:
feats_lens = None
input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10))
input_feats = self.logmel(input_amp, feats_lens)
return input_feats, feats_lens
def forward(
self,
y_hat: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
mel_hat, _ = self.wav_to_mel(y_hat.squeeze(1))
mel, _ = self.wav_to_mel(y.squeeze(1))
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
class GeneratorLoss(nn.Module):
"""The total loss of the generator"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.mel_loss = MelSpectrogramLoss()
self.generator_adv_loss = GeneratorAdversarialLoss()
self.feat_match_loss = FeatureMatchLoss()
self.var_loss = VarianceLoss()
self.forwardsum_loss = ForwardSumLoss()
self.lambda_adv = 1.0
self.lambda_mel = 45.0
self.lambda_feat_match = 2.0
self.lambda_var = 1.0
self.lambda_align = 2.0
def forward(self, outputs_g, outputs_d, speech_):
loss_g = {}
# parse generator output
(
speech_hat_,
bin_loss,
log_p_attn,
start_idxs,
d_outs,
ds,
p_outs,
ps,
e_outs,
es,
text_lengths,
feats_lengths,
) = outputs_g
# parse discriminator output
(p_hat, p) = outputs_d
# calculate losses
mel_loss = self.mel_loss(speech_hat_, speech_)
adv_loss = self.generator_adv_loss(p_hat)
feat_match_loss = self.feat_match_loss(p_hat, p)
dur_loss, pitch_loss, energy_loss = self.var_loss(
d_outs, ds, p_outs, ps, e_outs, es, text_lengths
)
forwardsum_loss = self.forwardsum_loss(log_p_attn, text_lengths, feats_lengths)
# calculate total loss
mel_loss = mel_loss * self.lambda_mel
loss_g["mel_loss"] = mel_loss
adv_loss = adv_loss * self.lambda_adv
loss_g["adv_loss"] = adv_loss
feat_match_loss = feat_match_loss * self.lambda_feat_match
loss_g["feat_match_loss"] = feat_match_loss
g_loss = mel_loss + adv_loss + feat_match_loss
loss_g["g_loss"] = g_loss
var_loss = (dur_loss + pitch_loss + energy_loss) * self.lambda_var
loss_g["var_loss"] = var_loss
align_loss = (forwardsum_loss + bin_loss) * self.lambda_align
loss_g["align_loss"] = align_loss
g_total_loss = g_loss + var_loss + align_loss
loss_g["g_total_loss"] = g_total_loss
return loss_g
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "mse",
):
"""Initialize DiscriminatorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(
self,
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from generator.
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
class DiscriminatorLoss(torch.nn.Module):
"""The total loss of the discriminator"""
def __init__(self, cfg):
super(DiscriminatorLoss, self).__init__()
self.cfg = cfg
self.discriminator = MultiScaleMultiPeriodDiscriminator()
self.discriminator_adv_loss = DiscriminatorAdversarialLoss()
def forward(self, speech_real, speech_generated):
loss_d = {}
real_loss, fake_loss = self.discriminator_adv_loss(
speech_generated, speech_real
)
loss_d["loss_disc_all"] = real_loss + fake_loss
return loss_d