Spaces:
Runtime error
Runtime error
# Copyright (c) 2024 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
import os | |
from utils.io import save_audio | |
from tqdm import tqdm | |
from models.tts.base import TTSTrainer | |
from models.tts.jets.jets import Jets | |
from models.tts.jets.jets_loss import GeneratorLoss, DiscriminatorLoss | |
from models.tts.jets.jets_dataset import JetsDataset, JetsCollator | |
from optimizer.optimizers import NoamLR | |
from torch.optim.lr_scheduler import ExponentialLR | |
from models.vocoders.gan.discriminator.mpd import MultiScaleMultiPeriodDiscriminator | |
def get_segments( | |
x: torch.Tensor, | |
start_idxs: torch.Tensor, | |
segment_size: int, | |
) -> torch.Tensor: | |
"""Get segments. | |
Args: | |
x (Tensor): Input tensor (B, C, T). | |
start_idxs (Tensor): Start index tensor (B,). | |
segment_size (int): Segment size. | |
Returns: | |
Tensor: Segmented tensor (B, C, segment_size). | |
""" | |
b, c, t = x.size() | |
segments = x.new_zeros(b, c, segment_size) | |
for i, start_idx in enumerate(start_idxs): | |
segments[i] = x[i, :, start_idx : start_idx + segment_size] | |
return segments | |
class JetsTrainer(TTSTrainer): | |
def __init__(self, args, cfg): | |
TTSTrainer.__init__(self, args, cfg) | |
self.cfg = cfg | |
def _build_dataset(self): | |
return JetsDataset, JetsCollator | |
def __build_scheduler(self): | |
return NoamLR(self.optimizer, **self.cfg.train.lr_scheduler) | |
def _write_summary( | |
self, | |
losses, | |
stats, | |
images={}, | |
audios={}, | |
audio_sampling_rate=24000, | |
tag="train", | |
): | |
for key, value in losses.items(): | |
self.sw.add_scalar(tag + "/" + key, value, self.step) | |
self.sw.add_scalar( | |
"learning_rate", | |
self.optimizer["optimizer_g"].param_groups[0]["lr"], | |
self.step, | |
) | |
if len(images) != 0: | |
for key, value in images.items(): | |
self.sw.add_image(key, value, self.global_step, batchformats="HWC") | |
if len(audios) != 0: | |
for key, value in audios.items(): | |
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) | |
for key, value in losses.items(): | |
self.sw.add_scalar("train/" + key, value, self.step) | |
lr = self.optimizer.state_dict()["param_groups"][0]["lr"] | |
self.sw.add_scalar("learning_rate", lr, self.step) | |
def _write_valid_summary( | |
self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val" | |
): | |
for key, value in losses.items(): | |
self.sw.add_scalar(tag + "/" + key, value, self.step) | |
if len(images) != 0: | |
for key, value in images.items(): | |
self.sw.add_image(key, value, self.global_step, batchformats="HWC") | |
if len(audios) != 0: | |
for key, value in audios.items(): | |
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) | |
def _build_criterion(self): | |
criterion = { | |
"generator": GeneratorLoss(self.cfg), | |
"discriminator": DiscriminatorLoss(self.cfg), | |
} | |
return criterion | |
def get_state_dict(self): | |
state_dict = { | |
"generator": self.model["generator"].state_dict(), | |
"discriminator": self.model["discriminator"].state_dict(), | |
"optimizer_g": self.optimizer["optimizer_g"].state_dict(), | |
"optimizer_d": self.optimizer["optimizer_d"].state_dict(), | |
"scheduler_g": self.scheduler["scheduler_g"].state_dict(), | |
"scheduler_d": self.scheduler["scheduler_d"].state_dict(), | |
"step": self.step, | |
"epoch": self.epoch, | |
"batch_size": self.cfg.train.batch_size, | |
} | |
return state_dict | |
def _build_optimizer(self): | |
optimizer_g = torch.optim.AdamW( | |
self.model["generator"].parameters(), | |
self.cfg.train.learning_rate, | |
betas=self.cfg.train.AdamW.betas, | |
eps=self.cfg.train.AdamW.eps, | |
) | |
optimizer_d = torch.optim.AdamW( | |
self.model["discriminator"].parameters(), | |
self.cfg.train.learning_rate, | |
betas=self.cfg.train.AdamW.betas, | |
eps=self.cfg.train.AdamW.eps, | |
) | |
optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d} | |
return optimizer | |
def _build_scheduler(self): | |
scheduler_g = ExponentialLR( | |
self.optimizer["optimizer_g"], | |
gamma=self.cfg.train.lr_decay, | |
last_epoch=self.epoch - 1, | |
) | |
scheduler_d = ExponentialLR( | |
self.optimizer["optimizer_d"], | |
gamma=self.cfg.train.lr_decay, | |
last_epoch=self.epoch - 1, | |
) | |
scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d} | |
return scheduler | |
def _build_model(self): | |
net_g = Jets(self.cfg) | |
net_d = MultiScaleMultiPeriodDiscriminator() | |
self.model = {"generator": net_g, "discriminator": net_d} | |
return self.model | |
def _train_epoch(self): | |
r"""Training epoch. Should return average loss of a batch (sample) over | |
one epoch. See ``train_loop`` for usage. | |
""" | |
self.model["generator"].train() | |
self.model["discriminator"].train() | |
epoch_sum_loss: float = 0.0 | |
epoch_losses: dict = {} | |
epoch_step: int = 0 | |
for batch in tqdm( | |
self.train_dataloader, | |
desc=f"Training Epoch {self.epoch}", | |
unit="batch", | |
colour="GREEN", | |
leave=False, | |
dynamic_ncols=True, | |
smoothing=0.04, | |
disable=not self.accelerator.is_main_process, | |
): | |
with self.accelerator.accumulate(self.model): | |
if batch["target_len"].min() < self.cfg.train.segment_size: | |
continue | |
total_loss, train_losses, training_stats = self._train_step(batch) | |
self.batch_count += 1 | |
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: | |
epoch_sum_loss += total_loss | |
for key, value in train_losses.items(): | |
if key not in epoch_losses.keys(): | |
epoch_losses[key] = value | |
else: | |
epoch_losses[key] += value | |
self.accelerator.log( | |
{ | |
"Step/Train {} Loss".format(key): value, | |
}, | |
step=self.step, | |
) | |
self.step += 1 | |
epoch_step += 1 | |
self.accelerator.wait_for_everyone() | |
epoch_sum_loss = ( | |
epoch_sum_loss | |
/ len(self.train_dataloader) | |
* self.cfg.train.gradient_accumulation_step | |
) | |
for key in epoch_losses.keys(): | |
epoch_losses[key] = ( | |
epoch_losses[key] | |
/ len(self.train_dataloader) | |
* self.cfg.train.gradient_accumulation_step | |
) | |
return epoch_sum_loss, epoch_losses | |
def _train_step(self, batch): | |
train_losses = {} | |
total_loss = 0 | |
training_stats = {} | |
# Train Discriminator | |
# Generator output | |
outputs_g = self.model["generator"](batch) | |
speech_hat_, _, _, start_idxs, *_ = outputs_g | |
# Discriminator output | |
speech = batch["audio"].unsqueeze(1) | |
upsample_factor = self.cfg.train.upsample_factor | |
speech_ = get_segments( | |
x=speech, | |
start_idxs=start_idxs * upsample_factor, | |
segment_size=self.cfg.train.segment_size * upsample_factor, | |
) | |
p_hat = self.model["discriminator"](speech_hat_.detach()) | |
p = self.model["discriminator"](speech_) | |
# Discriminator loss | |
loss_d = self.criterion["discriminator"](p, p_hat) | |
train_losses.update(loss_d) | |
# BP and Grad Updated | |
self.optimizer["optimizer_d"].zero_grad() | |
self.accelerator.backward(loss_d["loss_disc_all"]) | |
self.optimizer["optimizer_d"].step() | |
# Train Generator | |
p_hat = self.model["discriminator"](speech_hat_) | |
with torch.no_grad(): | |
p = self.model["discriminator"](speech_) | |
outputs_d = (p_hat, p) | |
loss_g = self.criterion["generator"](outputs_g, outputs_d, speech_) | |
train_losses.update(loss_g) | |
# BP and Grad Updated | |
self.optimizer["optimizer_g"].zero_grad() | |
self.accelerator.backward(loss_g["g_total_loss"]) | |
self.optimizer["optimizer_g"].step() | |
for item in train_losses: | |
train_losses[item] = train_losses[item].item() | |
total_loss = loss_g["g_total_loss"] + loss_d["loss_disc_all"] | |
return ( | |
total_loss.item(), | |
train_losses, | |
training_stats, | |
) | |
def _valid_step(self, batch): | |
valid_losses = {} | |
total_loss = 0 | |
valid_stats = {} | |
# Discriminator | |
# Generator output | |
outputs_g = self.model["generator"](batch) | |
speech_hat_, _, _, start_idxs, *_ = outputs_g | |
# Discriminator output | |
speech = batch["audio"].unsqueeze(1) | |
upsample_factor = self.cfg.train.upsample_factor | |
speech_ = get_segments( | |
x=speech, | |
start_idxs=start_idxs * upsample_factor, | |
segment_size=self.cfg.train.segment_size * upsample_factor, | |
) | |
p_hat = self.model["discriminator"](speech_hat_.detach()) | |
p = self.model["discriminator"](speech_) | |
# Discriminator loss | |
loss_d = self.criterion["discriminator"](p, p_hat) | |
valid_losses.update(loss_d) | |
# Generator loss | |
p_hat = self.model["discriminator"](speech_hat_) | |
with torch.no_grad(): | |
p = self.model["discriminator"](speech_) | |
outputs_d = (p_hat, p) | |
loss_g = self.criterion["generator"](outputs_g, outputs_d, speech_) | |
valid_losses.update(loss_g) | |
for item in valid_losses: | |
valid_losses[item] = valid_losses[item].item() | |
total_loss = loss_g["g_total_loss"] + loss_d["loss_disc_all"] | |
return ( | |
total_loss.item(), | |
valid_losses, | |
valid_stats, | |
) | |
def _valid_epoch(self): | |
r"""Testing epoch. Should return average loss of a batch (sample) over | |
one epoch. See ``train_loop`` for usage. | |
""" | |
if isinstance(self.model, dict): | |
for key in self.model.keys(): | |
self.model[key].eval() | |
else: | |
self.model.eval() | |
epoch_sum_loss = 0.0 | |
epoch_losses = dict() | |
for batch in tqdm( | |
self.valid_dataloader, | |
desc=f"Validating Epoch {self.epoch}", | |
unit="batch", | |
colour="GREEN", | |
leave=False, | |
dynamic_ncols=True, | |
smoothing=0.04, | |
disable=not self.accelerator.is_main_process, | |
): | |
total_loss, valid_losses, valid_stats = self._valid_step(batch) | |
epoch_sum_loss += total_loss | |
if isinstance(valid_losses, dict): | |
for key, value in valid_losses.items(): | |
if key not in epoch_losses.keys(): | |
epoch_losses[key] = value | |
else: | |
epoch_losses[key] += value | |
self.accelerator.log( | |
{ | |
"Step/Valid {} Loss".format(key): value, | |
}, | |
step=self.step, | |
) | |
epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader) | |
for key in epoch_losses.keys(): | |
epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader) | |
self.accelerator.wait_for_everyone() | |
return epoch_sum_loss, epoch_losses | |