import torch from modules.GenerSpeech.model.glow_modules import Glow from modules.fastspeech.tts_modules import PitchPredictor import random from modules.GenerSpeech.model.prosody_util import ProsodyAligner, LocalStyleAdaptor from utils.pitch_utils import f0_to_coarse, denorm_f0 from modules.commons.common_layers import * import torch.distributions as dist from utils.hparams import hparams from modules.GenerSpeech.model.mixstyle import MixStyle from modules.fastspeech.fs2 import FastSpeech2 import json from modules.fastspeech.tts_modules import DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS class GenerSpeech(FastSpeech2): ''' GenerSpeech: Towards Style Transfer for Generalizable Out-Of-Domain Text-to-Speech https://arxiv.org/abs/2205.07211 ''' def __init__(self, dictionary, out_dims=None): super().__init__(dictionary, out_dims) # Mixstyle self.norm = MixStyle(p=0.5, alpha=0.1, eps=1e-6, hidden_size=self.hidden_size) # emotion embedding self.emo_embed_proj = Linear(256, self.hidden_size, bias=True) # build prosody extractor ## frame level self.prosody_extractor_utter = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) self.l1_utter = nn.Linear(self.hidden_size * 2, self.hidden_size) self.align_utter = ProsodyAligner(num_layers=2) ## phoneme level self.prosody_extractor_ph = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) self.l1_ph = nn.Linear(self.hidden_size * 2, self.hidden_size) self.align_ph = ProsodyAligner(num_layers=2) ## word level self.prosody_extractor_word = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) self.l1_word = nn.Linear(self.hidden_size * 2, self.hidden_size) self.align_word = ProsodyAligner(num_layers=2) self.pitch_inpainter_predictor = PitchPredictor( self.hidden_size, n_chans=self.hidden_size, n_layers=3, dropout_rate=0.1, odim=2, padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) # build attention layer self.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS self.embed_positions = SinusoidalPositionalEmbedding( self.hidden_size, self.padding_idx, init_size=self.max_source_positions + self.padding_idx + 1, ) # build post flow cond_hs = 80 if hparams.get('use_txt_cond', True): cond_hs = cond_hs + hparams['hidden_size'] cond_hs = cond_hs + hparams['hidden_size'] * 3 # for emo, spk embedding and prosody embedding self.post_flow = Glow( 80, hparams['post_glow_hidden'], hparams['post_glow_kernel_size'], 1, hparams['post_glow_n_blocks'], hparams['post_glow_n_block_layers'], n_split=4, n_sqz=2, gin_channels=cond_hs, share_cond_layers=hparams['post_share_cond_layers'], share_wn_layers=hparams['share_wn_layers'], sigmoid_scale=hparams['sigmoid_scale'] ) self.prior_dist = dist.Normal(0, 1) def forward(self, txt_tokens, mel2ph=None, ref_mel2ph=None, ref_mel2word=None, spk_embed=None, emo_embed=None, ref_mels=None, f0=None, uv=None, skip_decoder=False, global_steps=0, infer=False, **kwargs): ret = {} encoder_out = self.encoder(txt_tokens) # [B, T, C] src_nonpadding = (txt_tokens > 0).float()[:, :, None] # add spk/emo embed spk_embed = self.spk_embed_proj(spk_embed)[:, None, :] emo_embed = self.emo_embed_proj(emo_embed)[:, None, :] # add dur dur_inp = (encoder_out + spk_embed + emo_embed) * src_nonpadding mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret) tgt_nonpadding = (mel2ph > 0).float()[:, :, None] decoder_inp = self.expand_states(encoder_out, mel2ph) decoder_inp = self.norm(decoder_inp, spk_embed + emo_embed) # add prosody VQ ret['ref_mel2ph'] = ref_mel2ph ret['ref_mel2word'] = ref_mel2word prosody_utter_mel = self.get_prosody_utter(decoder_inp, ref_mels, ret, infer, global_steps) prosody_ph_mel = self.get_prosody_ph(decoder_inp, ref_mels, ret, infer, global_steps) prosody_word_mel = self.get_prosody_word(decoder_inp, ref_mels, ret, infer, global_steps) # add pitch embed pitch_inp_domain_agnostic = decoder_inp * tgt_nonpadding pitch_inp_domain_specific = (decoder_inp + spk_embed + emo_embed + prosody_utter_mel + prosody_ph_mel + prosody_word_mel) * tgt_nonpadding predicted_pitch = self.inpaint_pitch(pitch_inp_domain_agnostic, pitch_inp_domain_specific, f0, uv, mel2ph, ret) # decode decoder_inp = decoder_inp + spk_embed + emo_embed + predicted_pitch + prosody_utter_mel + prosody_ph_mel + prosody_word_mel ret['decoder_inp'] = decoder_inp = decoder_inp * tgt_nonpadding if skip_decoder: return ret ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) # postflow is_training = self.training ret['x_mask'] = tgt_nonpadding ret['spk_embed'] = spk_embed ret['emo_embed'] = emo_embed ret['ref_prosody'] = prosody_utter_mel + prosody_ph_mel + prosody_word_mel self.run_post_glow(ref_mels, infer, is_training, ret) return ret def get_prosody_ph(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): # get VQ prosody if global_steps > hparams['vq_start'] or infer: prosody_embedding, loss, ppl = self.prosody_extractor_ph(ref_mels, ret['ref_mel2ph'], no_vq=False) ret['vq_loss_ph'] = loss ret['ppl_ph'] = ppl else: prosody_embedding = self.prosody_extractor_ph(ref_mels, ret['ref_mel2ph'], no_vq=True) # add positional embedding positions = self.embed_positions(prosody_embedding[:, :, 0]) prosody_embedding = self.l1_ph(torch.cat([prosody_embedding, positions], dim=-1)) # style-to-content attention src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data if global_steps < hparams['forcing']: output, guided_loss, attn_emo = self.align_ph(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), src_key_padding_mask, prosody_key_padding_mask, forcing=True) else: output, guided_loss, attn_emo = self.align_ph(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), src_key_padding_mask, prosody_key_padding_mask, forcing=False) ret['gloss_ph'] = guided_loss ret['attn_ph'] = attn_emo return output.transpose(0, 1) def get_prosody_word(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): # get VQ prosody if global_steps > hparams['vq_start'] or infer: prosody_embedding, loss, ppl = self.prosody_extractor_word(ref_mels, ret['ref_mel2word'], no_vq=False) ret['vq_loss_word'] = loss ret['ppl_word'] = ppl else: prosody_embedding = self.prosody_extractor_word(ref_mels, ret['ref_mel2word'], no_vq=True) # add positional embedding positions = self.embed_positions(prosody_embedding[:, :, 0]) prosody_embedding = self.l1_word(torch.cat([prosody_embedding, positions], dim=-1)) # style-to-content attention src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data if global_steps < hparams['forcing']: output, guided_loss, attn_emo = self.align_word(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), src_key_padding_mask, prosody_key_padding_mask, forcing=True) else: output, guided_loss, attn_emo = self.align_word(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), src_key_padding_mask, prosody_key_padding_mask, forcing=False) ret['gloss_word'] = guided_loss ret['attn_word'] = attn_emo return output.transpose(0, 1) def get_prosody_utter(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): # get VQ prosody if global_steps > hparams['vq_start'] or infer: prosody_embedding, loss, ppl = self.prosody_extractor_utter(ref_mels, no_vq=False) ret['vq_loss_utter'] = loss ret['ppl_utter'] = ppl else: prosody_embedding = self.prosody_extractor_utter(ref_mels, no_vq=True) # add positional embedding positions = self.embed_positions(prosody_embedding[:, :, 0]) prosody_embedding = self.l1_utter(torch.cat([prosody_embedding, positions], dim=-1)) # style-to-content attention src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data if global_steps < hparams['forcing']: output, guided_loss, attn_emo = self.align_utter(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), src_key_padding_mask, prosody_key_padding_mask, forcing=True) else: output, guided_loss, attn_emo = self.align_utter(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), src_key_padding_mask, prosody_key_padding_mask, forcing=False) ret['gloss_utter'] = guided_loss ret['attn_utter'] = attn_emo return output.transpose(0, 1) def inpaint_pitch(self, pitch_inp_domain_agnostic, pitch_inp_domain_specific, f0, uv, mel2ph, ret): if hparams['pitch_type'] == 'frame': pitch_padding = mel2ph == 0 if hparams['predictor_grad'] != 1: pitch_inp_domain_agnostic = pitch_inp_domain_agnostic.detach() + hparams['predictor_grad'] * (pitch_inp_domain_agnostic - pitch_inp_domain_agnostic.detach()) pitch_inp_domain_specific = pitch_inp_domain_specific.detach() + hparams['predictor_grad'] * (pitch_inp_domain_specific - pitch_inp_domain_specific.detach()) pitch_domain_agnostic = self.pitch_predictor(pitch_inp_domain_agnostic) pitch_domain_specific = self.pitch_inpainter_predictor(pitch_inp_domain_specific) pitch_pred = pitch_domain_agnostic + pitch_domain_specific ret['pitch_pred'] = pitch_pred use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv'] if f0 is None: f0 = pitch_pred[:, :, 0] # [B, T] if use_uv: uv = pitch_pred[:, :, 1] > 0 # [B, T] f0_denorm = denorm_f0(f0, uv if use_uv else None, hparams, pitch_padding=pitch_padding) pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt] ret['f0_denorm'] = f0_denorm ret['f0_denorm_pred'] = denorm_f0(pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None, hparams, pitch_padding=pitch_padding) if hparams['pitch_type'] == 'ph': pitch = torch.gather(F.pad(pitch, [1, 0]), 1, mel2ph) ret['f0_denorm'] = torch.gather(F.pad(ret['f0_denorm'], [1, 0]), 1, mel2ph) ret['f0_denorm_pred'] = torch.gather(F.pad(ret['f0_denorm_pred'], [1, 0]), 1, mel2ph) pitch_embed = self.pitch_embed(pitch) return pitch_embed def run_post_glow(self, tgt_mels, infer, is_training, ret): x_recon = ret['mel_out'].transpose(1, 2) g = x_recon B, _, T = g.shape if hparams.get('use_txt_cond', True): g = torch.cat([g, ret['decoder_inp'].transpose(1, 2)], 1) g_spk_embed = ret['spk_embed'].repeat(1, T, 1).transpose(1, 2) g_emo_embed = ret['emo_embed'].repeat(1, T, 1).transpose(1, 2) l_ref_prosody = ret['ref_prosody'].transpose(1, 2) g = torch.cat([g, g_spk_embed, g_emo_embed, l_ref_prosody], dim=1) prior_dist = self.prior_dist if not infer: if is_training: self.train() x_mask = ret['x_mask'].transpose(1, 2) y_lengths = x_mask.sum(-1) g = g.detach() tgt_mels = tgt_mels.transpose(1, 2) z_postflow, ldj = self.post_flow(tgt_mels, x_mask, g=g) ldj = ldj / y_lengths / 80 ret['z_pf'], ret['ldj_pf'] = z_postflow, ldj ret['postflow'] = -prior_dist.log_prob(z_postflow).mean() - ldj.mean() else: x_mask = torch.ones_like(x_recon[:, :1, :]) z_post = prior_dist.sample(x_recon.shape).to(g.device) * hparams['noise_scale'] x_recon_, _ = self.post_flow(z_post, x_mask, g, reverse=True) x_recon = x_recon_ ret['mel_out'] = x_recon.transpose(1, 2)