|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import os |
|
from utils.io import save_audio |
|
from tqdm import tqdm |
|
from models.tts.base import TTSTrainer |
|
from models.tts.jets.jets import Jets |
|
from models.tts.jets.jets_loss import GeneratorLoss, DiscriminatorLoss |
|
from models.tts.jets.jets_dataset import JetsDataset, JetsCollator |
|
from optimizer.optimizers import NoamLR |
|
from torch.optim.lr_scheduler import ExponentialLR |
|
from models.vocoders.gan.discriminator.mpd import MultiScaleMultiPeriodDiscriminator |
|
|
|
|
|
def get_segments( |
|
x: torch.Tensor, |
|
start_idxs: torch.Tensor, |
|
segment_size: int, |
|
) -> torch.Tensor: |
|
"""Get segments. |
|
|
|
Args: |
|
x (Tensor): Input tensor (B, C, T). |
|
start_idxs (Tensor): Start index tensor (B,). |
|
segment_size (int): Segment size. |
|
|
|
Returns: |
|
Tensor: Segmented tensor (B, C, segment_size). |
|
|
|
""" |
|
b, c, t = x.size() |
|
segments = x.new_zeros(b, c, segment_size) |
|
for i, start_idx in enumerate(start_idxs): |
|
segments[i] = x[i, :, start_idx : start_idx + segment_size] |
|
return segments |
|
|
|
|
|
class JetsTrainer(TTSTrainer): |
|
def __init__(self, args, cfg): |
|
TTSTrainer.__init__(self, args, cfg) |
|
self.cfg = cfg |
|
|
|
def _build_dataset(self): |
|
return JetsDataset, JetsCollator |
|
|
|
def __build_scheduler(self): |
|
return NoamLR(self.optimizer, **self.cfg.train.lr_scheduler) |
|
|
|
def _write_summary( |
|
self, |
|
losses, |
|
stats, |
|
images={}, |
|
audios={}, |
|
audio_sampling_rate=24000, |
|
tag="train", |
|
): |
|
for key, value in losses.items(): |
|
self.sw.add_scalar(tag + "/" + key, value, self.step) |
|
self.sw.add_scalar( |
|
"learning_rate", |
|
self.optimizer["optimizer_g"].param_groups[0]["lr"], |
|
self.step, |
|
) |
|
|
|
if len(images) != 0: |
|
for key, value in images.items(): |
|
self.sw.add_image(key, value, self.global_step, batchformats="HWC") |
|
if len(audios) != 0: |
|
for key, value in audios.items(): |
|
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) |
|
|
|
for key, value in losses.items(): |
|
self.sw.add_scalar("train/" + key, value, self.step) |
|
lr = self.optimizer.state_dict()["param_groups"][0]["lr"] |
|
self.sw.add_scalar("learning_rate", lr, self.step) |
|
|
|
def _write_valid_summary( |
|
self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val" |
|
): |
|
for key, value in losses.items(): |
|
self.sw.add_scalar(tag + "/" + key, value, self.step) |
|
|
|
if len(images) != 0: |
|
for key, value in images.items(): |
|
self.sw.add_image(key, value, self.global_step, batchformats="HWC") |
|
if len(audios) != 0: |
|
for key, value in audios.items(): |
|
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) |
|
|
|
def _build_criterion(self): |
|
criterion = { |
|
"generator": GeneratorLoss(self.cfg), |
|
"discriminator": DiscriminatorLoss(self.cfg), |
|
} |
|
return criterion |
|
|
|
def get_state_dict(self): |
|
state_dict = { |
|
"generator": self.model["generator"].state_dict(), |
|
"discriminator": self.model["discriminator"].state_dict(), |
|
"optimizer_g": self.optimizer["optimizer_g"].state_dict(), |
|
"optimizer_d": self.optimizer["optimizer_d"].state_dict(), |
|
"scheduler_g": self.scheduler["scheduler_g"].state_dict(), |
|
"scheduler_d": self.scheduler["scheduler_d"].state_dict(), |
|
"step": self.step, |
|
"epoch": self.epoch, |
|
"batch_size": self.cfg.train.batch_size, |
|
} |
|
return state_dict |
|
|
|
def _build_optimizer(self): |
|
optimizer_g = torch.optim.AdamW( |
|
self.model["generator"].parameters(), |
|
self.cfg.train.learning_rate, |
|
betas=self.cfg.train.AdamW.betas, |
|
eps=self.cfg.train.AdamW.eps, |
|
) |
|
optimizer_d = torch.optim.AdamW( |
|
self.model["discriminator"].parameters(), |
|
self.cfg.train.learning_rate, |
|
betas=self.cfg.train.AdamW.betas, |
|
eps=self.cfg.train.AdamW.eps, |
|
) |
|
optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d} |
|
|
|
return optimizer |
|
|
|
def _build_scheduler(self): |
|
scheduler_g = ExponentialLR( |
|
self.optimizer["optimizer_g"], |
|
gamma=self.cfg.train.lr_decay, |
|
last_epoch=self.epoch - 1, |
|
) |
|
scheduler_d = ExponentialLR( |
|
self.optimizer["optimizer_d"], |
|
gamma=self.cfg.train.lr_decay, |
|
last_epoch=self.epoch - 1, |
|
) |
|
|
|
scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d} |
|
return scheduler |
|
|
|
def _build_model(self): |
|
net_g = Jets(self.cfg) |
|
net_d = MultiScaleMultiPeriodDiscriminator() |
|
self.model = {"generator": net_g, "discriminator": net_d} |
|
return self.model |
|
|
|
def _train_epoch(self): |
|
r"""Training epoch. Should return average loss of a batch (sample) over |
|
one epoch. See ``train_loop`` for usage. |
|
""" |
|
self.model["generator"].train() |
|
self.model["discriminator"].train() |
|
epoch_sum_loss: float = 0.0 |
|
epoch_losses: dict = {} |
|
epoch_step: int = 0 |
|
for batch in tqdm( |
|
self.train_dataloader, |
|
desc=f"Training Epoch {self.epoch}", |
|
unit="batch", |
|
colour="GREEN", |
|
leave=False, |
|
dynamic_ncols=True, |
|
smoothing=0.04, |
|
disable=not self.accelerator.is_main_process, |
|
): |
|
with self.accelerator.accumulate(self.model): |
|
if batch["target_len"].min() < self.cfg.train.segment_size: |
|
continue |
|
total_loss, train_losses, training_stats = self._train_step(batch) |
|
self.batch_count += 1 |
|
|
|
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: |
|
epoch_sum_loss += total_loss |
|
for key, value in train_losses.items(): |
|
if key not in epoch_losses.keys(): |
|
epoch_losses[key] = value |
|
else: |
|
epoch_losses[key] += value |
|
|
|
self.accelerator.log( |
|
{ |
|
"Step/Train {} Loss".format(key): value, |
|
}, |
|
step=self.step, |
|
) |
|
|
|
self.step += 1 |
|
epoch_step += 1 |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
epoch_sum_loss = ( |
|
epoch_sum_loss |
|
/ len(self.train_dataloader) |
|
* self.cfg.train.gradient_accumulation_step |
|
) |
|
|
|
for key in epoch_losses.keys(): |
|
epoch_losses[key] = ( |
|
epoch_losses[key] |
|
/ len(self.train_dataloader) |
|
* self.cfg.train.gradient_accumulation_step |
|
) |
|
|
|
return epoch_sum_loss, epoch_losses |
|
|
|
def _train_step(self, batch): |
|
train_losses = {} |
|
total_loss = 0 |
|
training_stats = {} |
|
|
|
|
|
|
|
outputs_g = self.model["generator"](batch) |
|
speech_hat_, _, _, start_idxs, *_ = outputs_g |
|
|
|
|
|
speech = batch["audio"].unsqueeze(1) |
|
upsample_factor = self.cfg.train.upsample_factor |
|
speech_ = get_segments( |
|
x=speech, |
|
start_idxs=start_idxs * upsample_factor, |
|
segment_size=self.cfg.train.segment_size * upsample_factor, |
|
) |
|
p_hat = self.model["discriminator"](speech_hat_.detach()) |
|
p = self.model["discriminator"](speech_) |
|
|
|
|
|
loss_d = self.criterion["discriminator"](p, p_hat) |
|
train_losses.update(loss_d) |
|
|
|
|
|
self.optimizer["optimizer_d"].zero_grad() |
|
self.accelerator.backward(loss_d["loss_disc_all"]) |
|
self.optimizer["optimizer_d"].step() |
|
|
|
|
|
p_hat = self.model["discriminator"](speech_hat_) |
|
with torch.no_grad(): |
|
p = self.model["discriminator"](speech_) |
|
|
|
outputs_d = (p_hat, p) |
|
loss_g = self.criterion["generator"](outputs_g, outputs_d, speech_) |
|
train_losses.update(loss_g) |
|
|
|
|
|
self.optimizer["optimizer_g"].zero_grad() |
|
self.accelerator.backward(loss_g["g_total_loss"]) |
|
self.optimizer["optimizer_g"].step() |
|
|
|
for item in train_losses: |
|
train_losses[item] = train_losses[item].item() |
|
|
|
total_loss = loss_g["g_total_loss"] + loss_d["loss_disc_all"] |
|
|
|
return ( |
|
total_loss.item(), |
|
train_losses, |
|
training_stats, |
|
) |
|
|
|
@torch.inference_mode() |
|
def _valid_step(self, batch): |
|
valid_losses = {} |
|
total_loss = 0 |
|
valid_stats = {} |
|
|
|
|
|
|
|
outputs_g = self.model["generator"](batch) |
|
speech_hat_, _, _, start_idxs, *_ = outputs_g |
|
|
|
|
|
speech = batch["audio"].unsqueeze(1) |
|
upsample_factor = self.cfg.train.upsample_factor |
|
speech_ = get_segments( |
|
x=speech, |
|
start_idxs=start_idxs * upsample_factor, |
|
segment_size=self.cfg.train.segment_size * upsample_factor, |
|
) |
|
p_hat = self.model["discriminator"](speech_hat_.detach()) |
|
p = self.model["discriminator"](speech_) |
|
|
|
|
|
loss_d = self.criterion["discriminator"](p, p_hat) |
|
valid_losses.update(loss_d) |
|
|
|
|
|
p_hat = self.model["discriminator"](speech_hat_) |
|
with torch.no_grad(): |
|
p = self.model["discriminator"](speech_) |
|
|
|
outputs_d = (p_hat, p) |
|
loss_g = self.criterion["generator"](outputs_g, outputs_d, speech_) |
|
valid_losses.update(loss_g) |
|
|
|
for item in valid_losses: |
|
valid_losses[item] = valid_losses[item].item() |
|
|
|
total_loss = loss_g["g_total_loss"] + loss_d["loss_disc_all"] |
|
|
|
return ( |
|
total_loss.item(), |
|
valid_losses, |
|
valid_stats, |
|
) |
|
|
|
@torch.inference_mode() |
|
def _valid_epoch(self): |
|
r"""Testing epoch. Should return average loss of a batch (sample) over |
|
one epoch. See ``train_loop`` for usage. |
|
""" |
|
if isinstance(self.model, dict): |
|
for key in self.model.keys(): |
|
self.model[key].eval() |
|
else: |
|
self.model.eval() |
|
|
|
epoch_sum_loss = 0.0 |
|
epoch_losses = dict() |
|
for batch in tqdm( |
|
self.valid_dataloader, |
|
desc=f"Validating Epoch {self.epoch}", |
|
unit="batch", |
|
colour="GREEN", |
|
leave=False, |
|
dynamic_ncols=True, |
|
smoothing=0.04, |
|
disable=not self.accelerator.is_main_process, |
|
): |
|
total_loss, valid_losses, valid_stats = self._valid_step(batch) |
|
epoch_sum_loss += total_loss |
|
if isinstance(valid_losses, dict): |
|
for key, value in valid_losses.items(): |
|
if key not in epoch_losses.keys(): |
|
epoch_losses[key] = value |
|
else: |
|
epoch_losses[key] += value |
|
|
|
self.accelerator.log( |
|
{ |
|
"Step/Valid {} Loss".format(key): value, |
|
}, |
|
step=self.step, |
|
) |
|
|
|
epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader) |
|
for key in epoch_losses.keys(): |
|
epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader) |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
return epoch_sum_loss, epoch_losses |
|
|