DiffSpeech / modules /tts /fs2_orig.py
RayeRen's picture
init
d1b91e7
raw
history blame
5.36 kB
import torch
from torch import nn
from modules.commons.layers import Embedding
from modules.commons.nar_tts_modules import EnergyPredictor, PitchPredictor
from modules.tts.commons.align_ops import expand_states
from modules.tts.fs import FastSpeech
from utils.audio.cwt import cwt2f0, get_lf0_cwt
from utils.audio.pitch.utils import denorm_f0, f0_to_coarse, norm_f0
import numpy as np
class FastSpeech2Orig(FastSpeech):
def __init__(self, dict_size, hparams, out_dims=None):
super().__init__(dict_size, hparams, out_dims)
predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
if hparams['use_energy_embed']:
self.energy_embed = Embedding(300, self.hidden_size, 0)
self.energy_predictor = EnergyPredictor(
self.hidden_size, n_chans=predictor_hidden,
n_layers=5, dropout_rate=0.1, odim=2,
kernel_size=hparams['predictor_kernel'])
if hparams['pitch_type'] == 'cwt' and hparams['use_pitch_embed']:
self.pitch_predictor = PitchPredictor(
self.hidden_size, n_chans=predictor_hidden,
n_layers=5, dropout_rate=0.1, odim=11,
kernel_size=hparams['predictor_kernel'])
self.cwt_stats_layers = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(),
nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU(), nn.Linear(self.hidden_size, 2))
def forward(self, txt_tokens, mel2ph=None, spk_embed=None, spk_id=None,
f0=None, uv=None, energy=None, infer=False, **kwargs):
ret = {}
encoder_out = self.encoder(txt_tokens) # [B, T, C]
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
style_embed = self.forward_style_embed(spk_embed, spk_id)
# add dur
dur_inp = (encoder_out + style_embed) * src_nonpadding
mel2ph = self.forward_dur(dur_inp, mel2ph, txt_tokens, ret)
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
decoder_inp = decoder_inp_ = expand_states(encoder_out, mel2ph)
# add pitch and energy embed
if self.hparams['use_pitch_embed']:
pitch_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
decoder_inp = decoder_inp + self.forward_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out)
# add pitch and energy embed
if self.hparams['use_energy_embed']:
energy_inp = (decoder_inp_ + style_embed) * tgt_nonpadding
decoder_inp = decoder_inp + self.forward_energy(energy_inp, energy, ret)
# decoder input
ret['decoder_inp'] = decoder_inp = (decoder_inp + style_embed) * tgt_nonpadding
if self.hparams['dec_inp_add_noise']:
B, T, _ = decoder_inp.shape
z = kwargs.get('adv_z', torch.randn([B, T, self.z_channels])).to(decoder_inp.device)
ret['adv_z'] = z
decoder_inp = torch.cat([decoder_inp, z], -1)
decoder_inp = self.dec_inp_noise_proj(decoder_inp) * tgt_nonpadding
ret['mel_out'] = self.forward_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
return ret
def forward_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
if self.hparams['pitch_type'] == 'cwt':
decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
pitch_padding = mel2ph == 0
ret['cwt'] = cwt_out = self.pitch_predictor(decoder_inp)
stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
mean = ret['f0_mean'] = stats_out[:, 0]
std = ret['f0_std'] = stats_out[:, 1]
cwt_spec = cwt_out[:, :, :10]
if f0 is None:
std = std * self.hparams['cwt_std_scale']
f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
if self.hparams['use_uv']:
assert cwt_out.shape[-1] == 11
uv = cwt_out[:, :, -1] > 0
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv if self.hparams['use_uv'] else None,
pitch_padding=pitch_padding)
pitch = f0_to_coarse(f0_denorm) # start from 0
pitch_embed = self.pitch_embed(pitch)
return pitch_embed
else:
return super(FastSpeech2Orig, self).forward_pitch(decoder_inp, f0, uv, mel2ph, ret, encoder_out)
def forward_energy(self, decoder_inp, energy, ret):
decoder_inp = decoder_inp.detach() + self.hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
energy_embed_inp = energy_pred if energy is None else energy
energy_embed_inp = torch.clamp(energy_embed_inp * 256 // 4, min=0, max=255).long()
energy_embed = self.energy_embed(energy_embed_inp)
return energy_embed
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
_, cwt_scales = get_lf0_cwt(np.ones(10))
f0 = cwt2f0(cwt_spec, mean, std, cwt_scales)
f0 = torch.cat(
[f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
f0_norm = norm_f0(f0, None)
return f0_norm