ddd
Add application file
b93970c
raw
history blame
3.48 kB
import torch
import utils
from .diff.diffusion import GaussianDiffusion
from .diff.net import DiffNet
from tasks.tts.fs2 import FastSpeech2Task
from utils.hparams import hparams
DIFF_DECODERS = {
'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
}
class DiffFsTask(FastSpeech2Task):
def build_tts_model(self):
mel_bins = hparams['audio_num_mel_bins']
self.model = GaussianDiffusion(
phone_encoder=self.phone_encoder,
out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
timesteps=hparams['timesteps'],
loss_type=hparams['diff_loss_type'],
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
)
def run_model(self, model, sample, return_output=False, infer=False):
txt_tokens = sample['txt_tokens'] # [B, T_t]
target = sample['mels'] # [B, T_s, 80]
mel2ph = sample['mel2ph'] # [B, T_s]
f0 = sample['f0']
uv = sample['uv']
energy = sample['energy']
spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
if hparams['pitch_type'] == 'cwt':
cwt_spec = sample[f'cwt_spec']
f0_mean = sample['f0_mean']
f0_std = sample['f0_std']
sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
losses = {}
if 'diff_loss' in output:
losses['mel'] = output['diff_loss']
self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
if hparams['use_pitch_embed']:
self.add_pitch_loss(output, sample, losses)
if hparams['use_energy_embed']:
self.add_energy_loss(output['energy_pred'], energy, losses)
if not return_output:
return losses
else:
return losses, output
def _training_step(self, sample, batch_idx, _):
log_outputs = self.run_model(self.model, sample)
total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad])
log_outputs['batch_size'] = sample['txt_tokens'].size()[0]
log_outputs['lr'] = self.scheduler.get_lr()[0]
return total_loss, log_outputs
def validation_step(self, sample, batch_idx):
outputs = {}
outputs['losses'] = {}
outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
outputs['total_loss'] = sum(outputs['losses'].values())
outputs['nsamples'] = sample['nsamples']
outputs = utils.tensors_to_scalars(outputs)
if batch_idx < hparams['num_valid_plots']:
_, model_out = self.run_model(self.model, sample, return_output=True, infer=True)
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'])
return outputs
def build_scheduler(self, optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
if optimizer is None:
return
optimizer.step()
optimizer.zero_grad()
if self.scheduler is not None:
self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])