Spaces:
Build error
Build error
File size: 6,954 Bytes
9206300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
from modules.portaspeech.portaspeech_flow import PortaSpeechFlow
from tasks.tts.fs2 import FastSpeech2Task
from tasks.tts.ps import PortaSpeechTask
from utils.pitch_utils import denorm_f0
from utils.hparams import hparams
class PortaSpeechFlowTask(PortaSpeechTask):
def __init__(self):
super().__init__()
self.training_post_glow = False
def build_tts_model(self):
ph_dict_size = len(self.token_encoder)
word_dict_size = len(self.word_encoder)
self.model = PortaSpeechFlow(ph_dict_size, word_dict_size, hparams)
def _training_step(self, sample, batch_idx, opt_idx):
self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \
and hparams['use_post_flow']
if hparams['two_stage'] and \
((opt_idx == 0 and self.training_post_glow) or (opt_idx == 1 and not self.training_post_glow)):
return None
loss_output, _ = self.run_model(sample)
total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
loss_output['batch_size'] = sample['txt_tokens'].size()[0]
if 'postflow' in loss_output and loss_output['postflow'] is None:
return None
return total_loss, loss_output
def run_model(self, sample, infer=False, *args, **kwargs):
if not infer:
training_post_glow = self.training_post_glow
spk_embed = sample.get('spk_embed')
spk_id = sample.get('spk_ids')
output = self.model(sample['txt_tokens'],
sample['word_tokens'],
ph2word=sample['ph2word'],
mel2word=sample['mel2word'],
mel2ph=sample['mel2ph'],
word_len=sample['word_lengths'].max(),
tgt_mels=sample['mels'],
pitch=sample.get('pitch'),
spk_embed=spk_embed,
spk_id=spk_id,
infer=False,
forward_post_glow=training_post_glow,
two_stage=hparams['two_stage'],
global_step=self.global_step,
bert_feats=sample.get('bert_feats'))
losses = {}
self.add_mel_loss(output['mel_out'], sample['mels'], losses)
if (training_post_glow or not hparams['two_stage']) and hparams['use_post_flow']:
losses['postflow'] = output['postflow']
losses['l1'] = losses['l1'].detach()
losses['ssim'] = losses['ssim'].detach()
if not training_post_glow or not hparams['two_stage'] or not self.training:
losses['kl'] = output['kl']
if self.global_step < hparams['kl_start_steps']:
losses['kl'] = losses['kl'].detach()
else:
losses['kl'] = torch.clamp(losses['kl'], min=hparams['kl_min'])
losses['kl'] = losses['kl'] * hparams['lambda_kl']
if hparams['dur_level'] == 'word':
self.add_dur_loss(
output['dur'], sample['mel2word'], sample['word_lengths'], sample['txt_tokens'], losses)
self.get_attn_stats(output['attn'], sample, losses)
else:
super().add_dur_loss(output['dur'], sample['mel2ph'], sample['txt_tokens'], losses)
return losses, output
else:
use_gt_dur = kwargs.get('infer_use_gt_dur', hparams['use_gt_dur'])
forward_post_glow = self.global_step >= hparams['post_glow_training_start'] + 1000 \
and hparams['use_post_flow']
spk_embed = sample.get('spk_embed')
spk_id = sample.get('spk_ids')
output = self.model(
sample['txt_tokens'],
sample['word_tokens'],
ph2word=sample['ph2word'],
word_len=sample['word_lengths'].max(),
pitch=sample.get('pitch'),
mel2ph=sample['mel2ph'] if use_gt_dur else None,
mel2word=sample['mel2word'] if hparams['profile_infer'] or hparams['use_gt_dur'] else None,
infer=True,
forward_post_glow=forward_post_glow,
spk_embed=spk_embed,
spk_id=spk_id,
two_stage=hparams['two_stage'],
bert_feats=sample.get('bert_feats'))
return output
def validation_step(self, sample, batch_idx):
self.training_post_glow = self.global_step >= hparams['post_glow_training_start'] \
and hparams['use_post_flow']
return super().validation_step(sample, batch_idx)
def save_valid_result(self, sample, batch_idx, model_out):
super(PortaSpeechFlowTask, self).save_valid_result(sample, batch_idx, model_out)
sr = hparams['audio_sample_rate']
f0_gt = None
if sample.get('f0') is not None:
f0_gt = denorm_f0(sample['f0'][0].cpu(), sample['uv'][0].cpu())
if self.global_step > 0:
# save FVAE result
if hparams['use_post_flow']:
wav_pred = self.vocoder.spec2wav(model_out['mel_out_fvae'][0].cpu(), f0=f0_gt)
self.logger.add_audio(f'wav_fvae_{batch_idx}', wav_pred, self.global_step, sr)
self.plot_mel(batch_idx, sample['mels'], model_out['mel_out_fvae'][0],
f'mel_fvae_{batch_idx}', f0s=f0_gt)
def build_optimizer(self, model):
if hparams['two_stage'] and hparams['use_post_flow']:
self.optimizer = torch.optim.AdamW(
[p for name, p in self.model.named_parameters() if 'post_flow' not in name],
lr=hparams['lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
weight_decay=hparams['weight_decay'])
self.post_flow_optimizer = torch.optim.AdamW(
self.model.post_flow.parameters(),
lr=hparams['post_flow_lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
weight_decay=hparams['weight_decay'])
return [self.optimizer, self.post_flow_optimizer]
else:
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=hparams['lr'],
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
weight_decay=hparams['weight_decay'])
return [self.optimizer]
def build_scheduler(self, optimizer):
return FastSpeech2Task.build_scheduler(self, optimizer[0]) |