lmzjms's picture
Upload 591 files
9206300
raw
history blame
3.5 kB
import torch
import utils
from modules.diff.diffusion import GaussianDiffusion
from modules.diff.net import DiffNet
from tasks.tts.fs2 import FastSpeech2Task
from utils.hparams import hparams
DIFF_DECODERS = {
'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
}
class DiffFsTask(FastSpeech2Task):
def build_tts_model(self):
mel_bins = hparams['audio_num_mel_bins']
self.model = GaussianDiffusion(
phone_encoder=self.phone_encoder,
out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
timesteps=hparams['timesteps'],
loss_type=hparams['diff_loss_type'],
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
)
def run_model(self, model, sample, return_output=False, infer=False):
txt_tokens = sample['txt_tokens'] # [B, T_t]
target = sample['mels'] # [B, T_s, 80]
mel2ph = sample['mel2ph'] # [B, T_s]
f0 = sample['f0']
uv = sample['uv']
energy = sample['energy']
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
if hparams['pitch_type'] == 'cwt':
cwt_spec = sample[f'cwt_spec']
f0_mean = sample['f0_mean']
f0_std = sample['f0_std']
sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
losses = {}
if 'diff_loss' in output:
losses['mel'] = output['diff_loss']
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
if hparams['use_pitch_embed']:
self.add_pitch_loss(output, sample, losses)
if hparams['use_energy_embed']:
self.add_energy_loss(output['energy_pred'], energy, losses)
if not return_output:
return losses
else:
return losses, output
def _training_step(self, sample, batch_idx, _):
log_outputs = self.run_model(self.model, sample)
total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad])
log_outputs['batch_size'] = sample['txt_tokens'].size()[0]
log_outputs['lr'] = self.scheduler.get_lr()[0]
return total_loss, log_outputs
def validation_step(self, sample, batch_idx):
outputs = {}
outputs['losses'] = {}
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
outputs['total_loss'] = sum(outputs['losses'].values())
outputs['nsamples'] = sample['nsamples']
outputs = utils.tensors_to_scalars(outputs)
if batch_idx < hparams['num_valid_plots']:
_, model_out = self.run_model(self.model, sample, return_output=True, infer=True)
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'])
return outputs
def build_scheduler(self, optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
if optimizer is None:
return
optimizer.step()
optimizer.zero_grad()
if self.scheduler is not None:
self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])