import os import torch import torch.distributed as dist from torch import nn from torch.utils.data import DistributedSampler from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler from text_to_speech.utils.audio.io import save_wav from text_to_speech.utils.commons.base_task import BaseTask from text_to_speech.utils.commons.dataset_utils import data_loader from text_to_speech.utils.commons.hparams import hparams from text_to_speech.utils.commons.tensor_utils import tensors_to_scalars class VocoderBaseTask(BaseTask): def __init__(self): super(VocoderBaseTask, self).__init__() self.max_sentences = hparams['max_sentences'] self.max_valid_sentences = hparams['max_valid_sentences'] if self.max_valid_sentences == -1: hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences self.dataset_cls = VocoderDataset @data_loader def train_dataloader(self): train_dataset = self.dataset_cls('train', shuffle=True) return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds']) @data_loader def val_dataloader(self): valid_dataset = self.dataset_cls('test', shuffle=False) return self.build_dataloader(valid_dataset, False, self.max_valid_sentences) @data_loader def test_dataloader(self): test_dataset = self.dataset_cls('test', shuffle=False) return self.build_dataloader(test_dataset, False, self.max_valid_sentences) def build_dataloader(self, dataset, shuffle, max_sentences, endless=False): world_size = 1 rank = 0 if dist.is_initialized(): world_size = dist.get_world_size() rank = dist.get_rank() sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler train_sampler = sampler_cls( dataset=dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, ) return torch.utils.data.DataLoader( dataset=dataset, shuffle=False, collate_fn=dataset.collater, batch_size=max_sentences, num_workers=dataset.num_workers, sampler=train_sampler, pin_memory=True, ) def build_optimizer(self, model): optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'], betas=[hparams['adam_b1'], hparams['adam_b2']]) optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(), lr=hparams['lr'], betas=[hparams['adam_b1'], hparams['adam_b2']]) return [optimizer_gen, optimizer_disc] def build_scheduler(self, optimizer): return { "gen": torch.optim.lr_scheduler.StepLR( optimizer=optimizer[0], **hparams["generator_scheduler_params"]), "disc": torch.optim.lr_scheduler.StepLR( optimizer=optimizer[1], **hparams["discriminator_scheduler_params"]), } def validation_step(self, sample, batch_idx): outputs = {} total_loss, loss_output = self._training_step(sample, batch_idx, 0) outputs['losses'] = tensors_to_scalars(loss_output) outputs['total_loss'] = tensors_to_scalars(total_loss) if self.global_step % hparams['valid_infer_interval'] == 0 and \ batch_idx < 10: mels = sample['mels'] y = sample['wavs'] f0 = sample['f0'] y_ = self.model_gen(mels, f0) for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): wav_pred = wav_pred / wav_pred.abs().max() if self.global_step == 0: wav_gt = wav_gt / wav_gt.abs().max() self.logger.add_audio(f'wav_{batch_idx}_{idx}_gt', wav_gt, self.global_step, hparams['audio_sample_rate']) self.logger.add_audio(f'wav_{batch_idx}_{idx}_pred', wav_pred, self.global_step, hparams['audio_sample_rate']) return outputs def test_start(self): self.gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') os.makedirs(self.gen_dir, exist_ok=True) def test_step(self, sample, batch_idx): mels = sample['mels'] y = sample['wavs'] f0 = sample['f0'] loss_output = {} y_ = self.model_gen(mels, f0) gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') os.makedirs(gen_dir, exist_ok=True) for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): wav_gt = wav_gt.clamp(-1, 1) wav_pred = wav_pred.clamp(-1, 1) save_wav( wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav', hparams['audio_sample_rate']) save_wav( wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', hparams['audio_sample_rate']) return loss_output def test_end(self, outputs): return {} def on_before_optimization(self, opt_idx): if opt_idx == 0: nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm']) else: nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"]) def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): if optimizer_idx == 0: self.scheduler['gen'].step(self.global_step // hparams['accumulate_grad_batches']) else: self.scheduler['disc'].step(self.global_step // hparams['accumulate_grad_batches'])