DiffSpeech / tasks /vocoder /vocoder_base.py
RayeRen's picture
init
d1b91e7
raw
history blame
5.96 kB
import os
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DistributedSampler
from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler
from utils.audio.io import save_wav
from utils.commons.base_task import BaseTask
from utils.commons.dataset_utils import data_loader
from utils.commons.hparams import hparams
from utils.commons.tensor_utils import tensors_to_scalars
class VocoderBaseTask(BaseTask):
def __init__(self):
super(VocoderBaseTask, self).__init__()
self.max_sentences = hparams['max_sentences']
self.max_valid_sentences = hparams['max_valid_sentences']
if self.max_valid_sentences == -1:
hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences
self.dataset_cls = VocoderDataset
@data_loader
def train_dataloader(self):
train_dataset = self.dataset_cls('train', shuffle=True)
return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds'])
@data_loader
def val_dataloader(self):
valid_dataset = self.dataset_cls('test', shuffle=False)
return self.build_dataloader(valid_dataset, False, self.max_valid_sentences)
@data_loader
def test_dataloader(self):
test_dataset = self.dataset_cls('test', shuffle=False)
return self.build_dataloader(test_dataset, False, self.max_valid_sentences)
def build_dataloader(self, dataset, shuffle, max_sentences, endless=False):
world_size = 1
rank = 0
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler
train_sampler = sampler_cls(
dataset=dataset,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
)
return torch.utils.data.DataLoader(
dataset=dataset,
shuffle=False,
collate_fn=dataset.collater,
batch_size=max_sentences,
num_workers=dataset.num_workers,
sampler=train_sampler,
pin_memory=True,
)
def build_optimizer(self, model):
optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'],
betas=[hparams['adam_b1'], hparams['adam_b2']])
optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(), lr=hparams['lr'],
betas=[hparams['adam_b1'], hparams['adam_b2']])
return [optimizer_gen, optimizer_disc]
def build_scheduler(self, optimizer):
return {
"gen": torch.optim.lr_scheduler.StepLR(
optimizer=optimizer[0],
**hparams["generator_scheduler_params"]),
"disc": torch.optim.lr_scheduler.StepLR(
optimizer=optimizer[1],
**hparams["discriminator_scheduler_params"]),
}
def validation_step(self, sample, batch_idx):
outputs = {}
total_loss, loss_output = self._training_step(sample, batch_idx, 0)
outputs['losses'] = tensors_to_scalars(loss_output)
outputs['total_loss'] = tensors_to_scalars(total_loss)
if self.global_step % hparams['valid_infer_interval'] == 0 and \
batch_idx < 10:
mels = sample['mels']
y = sample['wavs']
f0 = sample['f0']
y_ = self.model_gen(mels, f0)
for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])):
wav_pred = wav_pred / wav_pred.abs().max()
if self.global_step == 0:
wav_gt = wav_gt / wav_gt.abs().max()
self.logger.add_audio(f'wav_{batch_idx}_{idx}_gt', wav_gt, self.global_step,
hparams['audio_sample_rate'])
self.logger.add_audio(f'wav_{batch_idx}_{idx}_pred', wav_pred, self.global_step,
hparams['audio_sample_rate'])
return outputs
def test_start(self):
self.gen_dir = os.path.join(hparams['work_dir'],
f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
os.makedirs(self.gen_dir, exist_ok=True)
def test_step(self, sample, batch_idx):
mels = sample['mels']
y = sample['wavs']
f0 = sample['f0']
loss_output = {}
y_ = self.model_gen(mels, f0)
gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
os.makedirs(gen_dir, exist_ok=True)
for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])):
wav_gt = wav_gt.clamp(-1, 1)
wav_pred = wav_pred.clamp(-1, 1)
save_wav(
wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav',
hparams['audio_sample_rate'])
save_wav(
wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav',
hparams['audio_sample_rate'])
return loss_output
def test_end(self, outputs):
return {}
def on_before_optimization(self, opt_idx):
if opt_idx == 0:
nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm'])
else:
nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"])
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
if optimizer_idx == 0:
self.scheduler['gen'].step(self.global_step // hparams['accumulate_grad_batches'])
else:
self.scheduler['disc'].step(self.global_step // hparams['accumulate_grad_batches'])