xVASynth-TTS / python /xvapitch /xvapitch_model.py
Pendrokar's picture
xVASynth v3 code for English
19c8b95
raw
history blame
39.3 kB
import math
import time
import traceback
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.modules.conv import Conv1d
from python.xvapitch.glow_tts import RelativePositionTransformer
from python.xvapitch.wavenet import WN
from python.xvapitch.hifigan import HifiganGenerator
from python.xvapitch.sdp import StochasticDurationPredictor#, StochasticPredictor
from python.xvapitch.util import maximum_path, rand_segments, segment, sequence_mask, generate_path
from python.xvapitch.text import get_text_preprocessor, ALL_SYMBOLS, lang_names
class xVAPitch(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.args.init_discriminator = True
self.args.speaker_embedding_channels = 512
self.args.use_spectral_norm_disriminator = False
self.args.d_vector_dim = 512
self.args.use_language_embedding = True
self.args.detach_dp_input = True
self.END2END = True
self.embedded_language_dim = 12
self.latent_size = 256
num_languages = len(list(lang_names.keys()))
self.emb_l = nn.Embedding(num_languages, self.embedded_language_dim)
self.length_scale = 1.0
self.noise_scale = 1.0
self.inference_noise_scale = 0.333
self.inference_noise_scale_dp = 0.333
self.noise_scale_dp = 1.0
self.max_inference_len = None
self.spec_segment_size = 32
self.text_encoder = TextEncoder(
# 165,
len(ALL_SYMBOLS),
self.latent_size,#192,
self.latent_size,#192,
768,
2,
10,
3,
0.1,
# language_emb_dim=4,
language_emb_dim=self.embedded_language_dim,
)
self.posterior_encoder = PosteriorEncoder(
513,
self.latent_size,#+self.embedded_language_dim if self.args.flc else self.latent_size,#192,
self.latent_size,#+self.embedded_language_dim if self.args.flc else self.latent_size,#192,
kernel_size=5,
dilation_rate=1,
num_layers=16,
cond_channels=self.args.d_vector_dim,
)
self.flow = ResidualCouplingBlocks(
self.latent_size,#192,
self.latent_size,#192,
kernel_size=5,
dilation_rate=1,
num_layers=4,
cond_channels=self.args.d_vector_dim,
args=self.args
)
self.duration_predictor = StochasticDurationPredictor(
self.latent_size,#192,
self.latent_size,#192,
3,
0.5,
4,
cond_channels=self.args.d_vector_dim,
language_emb_dim=self.embedded_language_dim,
)
self.waveform_decoder = HifiganGenerator(
self.latent_size,#192,
1,
"1",
[[1,3,5],[1,3,5],[1,3,5]],
[3,7,11],
[16,16,4,4],
512,
[8,8,2,2],
inference_padding=0,
# cond_channels=self.args.d_vector_dim+self.embedded_language_dim if self.args.flc else self.args.d_vector_dim,
cond_channels=self.args.d_vector_dim,
conv_pre_weight_norm=False,
conv_post_weight_norm=False,
conv_post_bias=False,
)
self.USE_PITCH_COND = False
# self.USE_PITCH_COND = True
if self.USE_PITCH_COND:
self.pitch_predictor = RelativePositioningPitchEnergyEncoder(
# 165,
# len(ALL_SYMBOLS),
out_channels=1,
hidden_channels=self.latent_size+self.embedded_language_dim,#196,
hidden_channels_ffn=768,
num_heads=2,
# num_layers=10,
num_layers=3,
kernel_size=3,
dropout_p=0.1,
# language_emb_dim=4,
conditioning_emb_dim=self.args.d_vector_dim,
)
self.pitch_emb = nn.Conv1d(
# 1, 384,
# 1, 196,
1,
self.args.expanded_flow_dim if args.expanded_flow else self.latent_size,
# pitch_conditioning_formants, symbols_embedding_dim,
kernel_size=3,
padding=int((3 - 1) / 2))
self.TEMP_timing = []
def infer_get_lang_emb (self, language_id):
aux_input = {
# "d_vectors": embedding.unsqueeze(dim=0),
"language_ids": language_id
}
sid, g, lid = self._set_cond_input(aux_input)
lang_emb = self.emb_l(lid).unsqueeze(-1)
return lang_emb
def infer_advanced (self, logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace=1.0, editor_data=None, old_sequence=None, pitch_amp=None):
if (editor_data is not None) and ((editor_data[0] is not None and len(editor_data[0])) or (editor_data[1] is not None and len(editor_data[1]))):
pitch_pred, dur_pred, energy_pred, em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred, _ = editor_data
# TODO, use energy_pred
dur_pred = torch.tensor(dur_pred)
dur_pred = dur_pred.view((1, dur_pred.shape[0])).float().to(self.device)
pitch_pred = torch.tensor(pitch_pred)
pitch_pred = pitch_pred.view((1, pitch_pred.shape[0])).float().to(self.device)
energy_pred = torch.tensor(energy_pred)
energy_pred = energy_pred.view((1, energy_pred.shape[0])).float().to(self.device)
em_angry_pred = em_angry_pred.clone().detach() if (type(em_angry_pred) == 'torch.Tensor') else torch.tensor(em_angry_pred)
em_angry_pred = em_angry_pred.view((1, em_angry_pred.shape[0])).float().to(self.device)
em_happy_pred = em_happy_pred.clone().detach() if (type(em_happy_pred) == 'torch.Tensor') else torch.tensor(em_happy_pred)
em_happy_pred = em_happy_pred.view((1, em_happy_pred.shape[0])).float().to(self.device)
em_sad_pred = em_sad_pred.clone().detach() if (type(em_sad_pred) == 'torch.Tensor') else torch.tensor(em_sad_pred)
em_sad_pred = em_sad_pred.view((1, em_sad_pred.shape[0])).float().to(self.device)
em_surprise_pred = em_surprise_pred.clone().detach() if (type(em_surprise_pred) == 'torch.Tensor') else torch.tensor(em_surprise_pred)
em_surprise_pred = em_surprise_pred.view((1, em_surprise_pred.shape[0])).float().to(self.device)
# Pitch speaker embedding deltas
if not self.USE_PITCH_COND and pitch_pred.shape[1]==speaker_embs.shape[2]:
pitch_delta = self.pitch_emb_values.to(pitch_pred.device) * pitch_pred
speaker_embs = speaker_embs + pitch_delta.float()
# Emotion speaker embedding deltas
emotions_strength = 0.00003 # Global scaling
if em_angry_pred.shape[1]==speaker_embs.shape[2]:
em_angry_delta = self.angry_emb_values.to(em_angry_pred.device) * em_angry_pred * emotions_strength
speaker_embs = speaker_embs + em_angry_delta.float()
if em_happy_pred.shape[1]==speaker_embs.shape[2]:
em_happy_delta = self.happy_emb_values.to(em_happy_pred.device) * em_happy_pred * emotions_strength
speaker_embs = speaker_embs + em_happy_delta.float()
if em_sad_pred.shape[1]==speaker_embs.shape[2]:
em_sad_delta = self.sad_emb_values.to(em_sad_pred.device) * em_sad_pred * emotions_strength
speaker_embs = speaker_embs + em_sad_delta.float()
if em_surprise_pred.shape[1]==speaker_embs.shape[2]:
em_surprise_delta = self.surprise_emb_values.to(em_surprise_pred.device) * em_surprise_pred * emotions_strength
speaker_embs = speaker_embs + em_surprise_delta.float()
try:
logger.info("editor data infer_using_vals")
wav, dur_pred, pitch_pred_out, energy_pred, em_pred_out, start_index, end_index, wav_mult = self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, \
speaker_embs, pace, dur_pred_existing=dur_pred, pitch_pred_existing=pitch_pred, energy_pred_existing=energy_pred, em_pred_existing=[em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred], old_sequence=old_sequence, new_sequence=text, pitch_amp=pitch_amp)
[em_angry_pred_out, em_happy_pred_out, em_sad_pred_out, em_surprise_pred_out] = em_pred_out
pitch_pred_out = pitch_pred
em_angry_pred_out = em_angry_pred
em_happy_pred_out = em_happy_pred
em_sad_pred_out = em_sad_pred
em_surprise_pred_out = em_surprise_pred
return wav, dur_pred, pitch_pred_out, energy_pred, [em_angry_pred_out, em_happy_pred_out, em_sad_pred_out, em_surprise_pred_out], start_index, end_index, wav_mult
except:
print(traceback.format_exc())
logger.info(traceback.format_exc())
# return traceback.format_exc()
logger.info("editor data corrupt; fallback to infer_using_vals")
return self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace, None, None, None, None, None, None, pitch_amp=pitch_amp)
else:
logger.info("no editor infer_using_vals")
return self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace, None, None, None, None, None, None, pitch_amp=pitch_amp)
def infer_using_vals (self, logger, plugin_manager, cleaned_text, sequence, lang_embs, speaker_embs, pace, dur_pred_existing, pitch_pred_existing, energy_pred_existing, em_pred_existing, old_sequence, new_sequence, pitch_amp=None):
start_index = None
end_index = None
[em_angry_pred_existing, em_happy_pred_existing, em_sad_pred_existing, em_surprise_pred_existing] = em_pred_existing if em_pred_existing is not None else [None, None, None, None]
# Calculate text splicing bounds, if needed
if old_sequence is not None:
old_sequence_np = old_sequence.cpu().detach().numpy()
old_sequence_np = list(old_sequence_np[0])
new_sequence_np = new_sequence.cpu().detach().numpy()
new_sequence_np = list(new_sequence_np[0])
# Get the index of the first changed value
if old_sequence_np[0]==new_sequence_np[0]: # If the start of both sequences is the same, then the change is not at the start
for i in range(len(old_sequence_np)):
if i<len(new_sequence_np):
if old_sequence_np[i]!=new_sequence_np[i]:
start_index = i-1
break
else:
start_index = i-1
break
if start_index is None:
start_index = len(old_sequence_np)-1
# Get the index of the last changed value
old_sequence_np.reverse()
new_sequence_np.reverse()
if old_sequence_np[0]==new_sequence_np[0]: # If the end of both reversed sequences is the same, then the change is not at the end
for i in range(len(old_sequence_np)):
if i<len(new_sequence_np):
if old_sequence_np[i]!=new_sequence_np[i]:
end_index = len(old_sequence_np)-1-i+1
break
else:
end_index = len(old_sequence_np)-1-i+1
break
old_sequence_np.reverse()
new_sequence_np.reverse()
# cleaned_text is the actual text phonemes
input_symbols = sequence
x_lengths = torch.where(input_symbols > 0, torch.ones_like(input_symbols), torch.zeros_like(input_symbols)).sum(dim=1)
lang_emb_full = None # TODO
self.text_encoder.logger = logger
# TODO, store a bank of trained 31 language embeds, to use for interpolating
lang_emb = self.emb_l(lang_embs).unsqueeze(-1)
if len(lang_embs.shape)>1: # Batch mode
lang_emb_full = lang_emb.squeeze(1).squeeze(-1)
else: # Individual line from the UI
lang_emb_full = lang_emb.transpose(2, 1).squeeze(1).unsqueeze(0)
x, x_emb, x_mask = self.text_encoder(input_symbols, x_lengths, lang_emb=None, stats=False, lang_emb_full=lang_emb_full)
m_p, logs_p = self.text_encoder(x, x_lengths, lang_emb=None, lang_emb_full=lang_emb_full, stats=True, x_mask=x_mask)
lang_emb_full = lang_emb_full.reshape(lang_emb_full.shape[0],lang_emb_full.shape[2],lang_emb_full.shape[1])
self.inference_noise_scale_dp = 0 # TEMP DEBUGGING. REMOVE - or should I? It seems to make it worse, the higher it is
# Calculate its own pitch, and duration vals if these were not already provided
if (dur_pred_existing is None or dur_pred_existing.shape[1]==0) or old_sequence is not None:
# Predict durations
self.duration_predictor.logger = logger
logw = self.duration_predictor(x, x_mask, g=speaker_embs, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb_full)
w = torch.exp(logw) * x_mask * self.length_scale
# w = w * 1.3 # The model seems to generate quite fast speech, so I'm gonna just globally adjust that
w = w * (pace.unsqueeze(2) if torch.is_tensor(pace) else pace)
w_ceil = w
w_ceil = torch.ceil(w)
dur_pred = w_ceil
else:
dur_pred = dur_pred_existing.unsqueeze(dim=0)
dur_pred = dur_pred * pace
y_lengths = torch.clamp_min(torch.sum(torch.round(dur_pred), [1, 2]), 1).long()
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype)
if dur_pred.shape[0]>1:
attn_all = []
m_p_all = []
logs_p_all = []
for b in range(dur_pred.shape[0]):
attn_mask = torch.unsqueeze(x_mask[b,:].unsqueeze(0), 2) * torch.unsqueeze(y_mask[b,:].unsqueeze(0), -1)
attn_all.append(generate_path(dur_pred.squeeze(1)[b,:].unsqueeze(0), attn_mask.squeeze(0).transpose(1, 2)))
m_p_all.append(torch.matmul(attn_all[-1].transpose(1, 2), m_p[b,:].unsqueeze(0).transpose(1, 2)).transpose(1, 2))
logs_p_all.append(torch.matmul(attn_all[-1].transpose(1, 2), logs_p[b,:].unsqueeze(0).transpose(1, 2)).transpose(1, 2))
del attn_all
m_p = torch.stack(m_p_all, dim=1).squeeze(dim=0)
logs_p = torch.stack(logs_p_all, dim=1).squeeze(dim=0)
pitch_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x)
else:
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = generate_path(dur_pred.squeeze(1), attn_mask.squeeze(0).transpose(1, 2))
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
pitch_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x)
emAngry_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x)
emHappy_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x)
emSad_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x)
emSurprise_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x)
# Splice/replace pitch/duration values from the old input if simulating only a partial re-generation
if start_index is not None or end_index is not None:
dur_pred_np = list(dur_pred.cpu().detach().numpy())[0][0]
pitch_pred_np = list(pitch_pred.cpu().detach().numpy())[0][0]
emAngry_pred_np = list(emAngry_pred.cpu().detach().numpy())[0][0]
emHappy_pred_np = list(emHappy_pred.cpu().detach().numpy())[0][0]
emSad_pred_np = list(emSad_pred.cpu().detach().numpy())[0][0]
emSurprise_pred_np = list(emSurprise_pred.cpu().detach().numpy())[0][0]
dur_pred_existing_np = list(dur_pred_existing.cpu().detach().numpy())[0]
pitch_pred_existing_np = list(pitch_pred_existing.cpu().detach().numpy())[0]
emAngry_pred_existing_np = list(em_angry_pred_existing.cpu().detach().numpy())[0]
emHappy_pred_existing_np = list(em_happy_pred_existing.cpu().detach().numpy())[0]
emSad_pred_existing_np = list(em_sad_pred_existing.cpu().detach().numpy())[0]
emSurprise_pred_existing_np = list(em_surprise_pred_existing.cpu().detach().numpy())[0]
if start_index is not None: # Replace starting values
for i in range(start_index+1):
dur_pred_np[i] = dur_pred_existing_np[i]
pitch_pred_np[i] = pitch_pred_existing_np[i]
emAngry_pred_np[i] = emAngry_pred_existing_np[i]
emHappy_pred_np[i] = emHappy_pred_existing_np[i]
emSad_pred_np[i] = emSad_pred_existing_np[i]
emSurprise_pred_np[i] = emSurprise_pred_existing_np[i]
if end_index is not None: # Replace end values
for i in range(len(old_sequence_np)-end_index):
dur_pred_np[-i-1] = dur_pred_existing_np[-i-1]
pitch_pred_np[-i-1] = pitch_pred_existing_np[-i-1]
emAngry_pred_np[-i-1] = emAngry_pred_existing_np[-i-1]
emHappy_pred_np[-i-1] = emHappy_pred_existing_np[-i-1]
emSad_pred_np[-i-1] = emSad_pred_existing_np[-i-1]
emSurprise_pred_np[-i-1] = emSurprise_pred_existing_np[-i-1]
dur_pred = torch.tensor(dur_pred_np).to(self.device).unsqueeze(0)
pitch_pred = torch.tensor(pitch_pred_np).to(self.device).unsqueeze(0).unsqueeze(0)
emAngry_pred = torch.tensor(emAngry_pred_np).to(self.device).unsqueeze(0).unsqueeze(0)
emHappy_pred = torch.tensor(emHappy_pred_np).to(self.device).unsqueeze(0).unsqueeze(0)
emSad_pred = torch.tensor(emSad_pred_np).to(self.device).unsqueeze(0).unsqueeze(0)
emSurprise_pred = torch.tensor(emSurprise_pred_np).to(self.device).unsqueeze(0).unsqueeze(0)
if pitch_amp is not None:
pitch_pred = pitch_pred * pitch_amp.unsqueeze(dim=-1)
if plugin_manager is not None and len(plugin_manager.plugins["synth-line"]["mid"]):
pitch_pred_numpy = pitch_pred.cpu().detach().numpy()
plugin_data = {
"pace": pace,
"duration": dur_pred.cpu().detach().numpy(),
"pitch": pitch_pred_numpy.reshape((pitch_pred_numpy.shape[0],pitch_pred_numpy.shape[2])),
"emAngry": emAngry_pred.reshape((emAngry_pred.shape[0],emAngry_pred.shape[2])),
"emHappy": emHappy_pred.reshape((emHappy_pred.shape[0],emHappy_pred.shape[2])),
"emSad": emSad_pred.reshape((emSad_pred.shape[0],emSad_pred.shape[2])),
"emSurprise": emSurprise_pred.reshape((emSurprise_pred.shape[0],emSurprise_pred.shape[2])),
"sequence": sequence,
"is_fresh_synth": pitch_pred_existing is None and dur_pred_existing is None,
"pluginsContext": plugin_manager.context,
"hasDataChanged": False
}
plugin_manager.run_plugins(plist=plugin_manager.plugins["synth-line"]["mid"], event="mid synth-line", data=plugin_data)
if (
pace != plugin_data["pace"]
or plugin_data["hasDataChanged"]
):
logger.info("Inference data has been changed by plugins, rerunning infer_advanced")
pace = plugin_data["pace"]
editor_data = [
plugin_data["pitch"][0],
plugin_data["duration"][0][0],
[1.0 for _ in range(pitch_pred_numpy.shape[-1])],
plugin_data["emAngry"][0],
plugin_data["emHappy"][0],
plugin_data["emSad"][0],
plugin_data["emSurprise"][0],
None
]
# rerun infer_advanced so that emValues take effect
# second argument ensures no loop
return self.infer_advanced (logger, None, cleaned_text, sequence, lang_embs, speaker_embs, pace=pace, editor_data=editor_data, old_sequence=sequence, pitch_amp=None)
else:
# skip rerunning infer_advanced
logger.info("Inference data unchanged by plugins")
# TODO, incorporate some sort of control for this
# self.inference_noise_scale = 0
# for flow in self.flow.flows:
# flow.logger = logger
# flow.enc.logger = logger
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
z = self.flow(z_p, y_mask, g=speaker_embs, reverse=True)
self.waveform_decoder.logger = logger
wav = self.waveform_decoder((z * y_mask.unsqueeze(1))[:, :, : self.max_inference_len], g=speaker_embs)
# In batch mode, trim the shorter audio waves in the batch. The masking doesn't seem to work, so have to do it manually
if dur_pred.shape[0]>1:
wav_all = []
for b in range(dur_pred.shape[0]):
percent_to_mask = torch.sum(y_mask[b])/y_mask.shape[1]
wav_all.append(wav[b,0,0:int((wav.shape[2]*percent_to_mask).item())])
wav = wav_all
start_index = -1 if start_index is None else start_index
end_index = -1 if end_index is None else end_index
# Apply volume adjustments
stretched_energy_mult = None
if energy_pred_existing is not None and pitch_pred_existing is not None:
energy_mult = self.expand_vals_by_durations(energy_pred_existing.unsqueeze(0), dur_pred, logger=logger)
stretched_energy_mult = torch.nn.functional.interpolate(energy_mult.unsqueeze(0).unsqueeze(0), (1,1,wav.shape[2])).squeeze()
stretched_energy_mult = stretched_energy_mult.cpu().detach().numpy()
energy_pred = energy_pred_existing.squeeze()
else:
energy_pred = [1.0 for _ in range(pitch_pred.shape[-1])]
energy_pred = torch.tensor(energy_pred)
# energy_pred = energy_pred.squeeze()
em_pred_out = [emAngry_pred, emHappy_pred, emSad_pred, emSurprise_pred]
return wav, dur_pred, pitch_pred, energy_pred, em_pred_out, start_index, end_index, stretched_energy_mult
def voice_conversion(self, y, y_lengths=None, spk1_emb=None, spk2_emb=None):
if y_lengths is None:
y_lengths = self.y_lengths_default
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=spk1_emb)
# z_hat = z
y_mask = y_mask.squeeze(0)
z_p = self.flow(z, y_mask, g=spk1_emb)
z_hat = self.flow(z_p, y_mask, g=spk2_emb, reverse=True)
o_hat = self.waveform_decoder(z_hat * y_mask, g=spk2_emb)
return o_hat
def _set_cond_input (self, aux_input):
"""Set the speaker conditioning input based on the multi-speaker mode."""
sid, g, lid = None, None, None
# if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
# sid = aux_input["speaker_ids"]
# if sid.ndim == 0:
# sid = sid.unsqueeze_(0)
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
if g.ndim == 2:
g = g.unsqueeze_(0)
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
lid = aux_input["language_ids"]
if lid.ndim == 0:
lid = lid.unsqueeze_(0)
return sid, g, lid
# Opposite of average_pitch; Repeat per-symbol values by durations, to get sequence-wide values
def expand_vals_by_durations (self, vals, durations, logger=None):
vals = vals.view((vals.shape[0], vals.shape[2]))
if len(durations.shape)>2:
durations = durations.view((durations.shape[0], durations.shape[2]))
max_dur = int(torch.round(durations).sum().item())
max_dur = int(torch.max(torch.sum(torch.round(durations), dim=1)).item())
expanded = torch.zeros((vals.shape[0], 1, max_dur)).to(vals)
for b in range(vals.shape[0]):
b_vals = vals[b]
b_durs = durations[b]
expanded_vals = []
for vi in range(b_vals.shape[0]):
for dur_i in range(round(b_durs[vi].item())):
if len(durations.shape)>2:
expanded_vals.append(b_vals[vi])
else:
expanded_vals.append(b_vals[vi].unsqueeze(dim=0))
expanded_vals = torch.tensor(expanded_vals).to(expanded)
expanded[b,:,0:expanded_vals.shape[0]] += expanded_vals
return expanded
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab: int, # len(ALL_SYMBOLS)
out_channels: int, # 192
hidden_channels: int, # 192
hidden_channels_ffn: int, # 768
num_heads: int, # 2
num_layers: int, # 10
kernel_size: int, # 3
dropout_p: float, # 0.1
language_emb_dim: int = None,
):
"""Text Encoder for VITS model.
Args:
n_vocab (int): Number of characters for the embedding layer.
out_channels (int): Number of channels for the output.
hidden_channels (int): Number of channels for the hidden layers.
hidden_channels_ffn (int): Number of channels for the convolutional layers.
num_heads (int): Number of attention heads for the Transformer layers.
num_layers (int): Number of Transformer layers.
kernel_size (int): Kernel size for the FFN layers in Transformer network.
dropout_p (float): Dropout rate for the Transformer layers.
"""
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
if language_emb_dim:
hidden_channels += language_emb_dim
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
hidden_channels_ffn=hidden_channels_ffn,
num_heads=num_heads,
num_layers=num_layers,
kernel_size=kernel_size,
dropout_p=dropout_p,
layer_norm_type="2",
rel_attn_window_size=4,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, lang_emb=None, stats=False, x_mask=None, lang_emb_full=None):
"""
Shapes:
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
if stats:
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return m, logs
else:
x_emb = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
# concat the lang emb in embedding chars
if lang_emb is not None or lang_emb_full is not None:
# x = torch.cat((x_emb, lang_emb.transpose(2, 1).expand(x_emb.size(0), x_emb.size(1), -1)), dim=-1)
if lang_emb_full is None:
lang_emb_full = lang_emb.transpose(2, 1).expand(x_emb.size(0), x_emb.size(1), -1)
x = torch.cat((x_emb, lang_emb_full), dim=-1)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
# stats = self.proj(x) * x_mask
# m, logs = torch.split(stats, self.out_channels, dim=1)
return x, x_emb, x_mask
class RelativePositioningPitchEnergyEncoder(nn.Module):
def __init__(
self,
# n_vocab: int, # len(ALL_SYMBOLS)
out_channels: int, # 192
hidden_channels: int, # 192
hidden_channels_ffn: int, # 768
num_heads: int, # 2
num_layers: int, # 10
kernel_size: int, # 3
dropout_p: float, # 0.1
conditioning_emb_dim: int = None,
):
super().__init__()
self.out_channels = out_channels
self.hidden_channels = hidden_channels
# self.emb = nn.Embedding(n_vocab, hidden_channels)
# nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
if conditioning_emb_dim:
hidden_channels += conditioning_emb_dim
self.encoder = RelativePositionTransformer(
in_channels=hidden_channels,
# out_channels=hidden_channels,
out_channels=1,
# out_channels=196,
hidden_channels=hidden_channels,
hidden_channels_ffn=hidden_channels_ffn,
num_heads=num_heads,
num_layers=num_layers,
kernel_size=kernel_size,
dropout_p=dropout_p,
layer_norm_type="2",
rel_attn_window_size=4,
)
# self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
# self.proj = nn.Conv1d(196, out_channels * 2, 1)
def forward(self, x, x_lengths=None, speaker_emb=None, stats=False, x_mask=None):
"""
Shapes:
- x: :math:`[B, T]`
- x_length: :math:`[B]`
"""
# concat the lang emb in embedding chars
if speaker_emb is not None:
x = torch.cat((x, speaker_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1)
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.encoder(x * x_mask, x_mask)
return x#, x_mask
class ResidualCouplingBlocks(nn.Module):
def __init__(
self,
channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
num_layers: int,
num_flows=4,
cond_channels=0,
args=None
):
"""Redisual Coupling blocks for VITS flow layers.
Args:
channels (int): Number of input and output tensor channels.
hidden_channels (int): Number of hidden network channels.
kernel_size (int): Kernel size of the WaveNet layers.
dilation_rate (int): Dilation rate of the WaveNet layers.
num_layers (int): Number of the WaveNet layers.
num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4.
cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0.
"""
super().__init__()
self.args = args
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.num_flows = num_flows
self.cond_channels = cond_channels
self.flows = nn.ModuleList()
for flow_i in range(num_flows):
self.flows.append(
ResidualCouplingBlock(
(192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else channels,
(192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else hidden_channels,
kernel_size,
dilation_rate,
num_layers,
cond_channels=cond_channels,
out_channels_override=(192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else None,
mean_only=True,
)
)
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
if not reverse:
for fi, flow in enumerate(self.flows):
x, _ = flow(x, x_mask, g=g, reverse=reverse)
x = torch.flip(x, [1])
else:
for flow in reversed(self.flows):
x = torch.flip(x, [1])
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int,
kernel_size: int,
dilation_rate: int,
num_layers: int,
cond_channels=0,
):
"""Posterior Encoder of VITS model.
::
x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
Args:
in_channels (int): Number of input tensor channels.
out_channels (int): Number of output tensor channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size of the WaveNet convolution layers.
dilation_rate (int): Dilation rate of the WaveNet layers.
num_layers (int): Number of the WaveNet layers.
cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_layers = num_layers
self.cond_channels = cond_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = WN(
hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_lengths: :math:`[B, 1]`
- g: :math:`[B, C, 1]`
"""
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
dropout_p=0,
cond_channels=0,
out_channels_override=None,
mean_only=False,
):
assert channels % 2 == 0, "channels should be divisible by 2"
super().__init__()
self.half_channels = channels // 2
self.mean_only = mean_only
# input layer
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
# coupling layers
self.enc = WN(
hidden_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
dropout_p=dropout_p,
c_in_channels=cond_channels,
)
# output layer
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
self.conv1d_projector = None
if out_channels_override:
self.conv1d_projector = nn.Conv1d(192, out_channels_override, 1)
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
- g: :math:`[B, C, 1]`
"""
if self.conv1d_projector is not None and not reverse:
x = self.conv1d_projector(x)
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
h = self.pre(x0) * x_mask.unsqueeze(1)
h = self.enc(h, x_mask.unsqueeze(1), g=g)
stats = self.post(h) * x_mask.unsqueeze(1)
if not self.mean_only:
m, log_scale = torch.split(stats, [self.half_channels] * 2, 1)
else:
m = stats
log_scale = torch.zeros_like(m)
if not reverse:
x1 = m + x1 * torch.exp(log_scale) * x_mask.unsqueeze(1)
x = torch.cat([x0, x1], 1)
logdet = torch.sum(log_scale, [1, 2])
return x, logdet
else:
x1 = (x1 - m) * torch.exp(-log_scale) * x_mask.unsqueeze(1)
x = torch.cat([x0, x1], 1)
return x
def mask_from_lens(lens, max_len= None):
if max_len is None:
max_len = lens.max()
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
mask = torch.lt(ids, lens.unsqueeze(1))
return mask
class TemporalPredictor(nn.Module):
"""Predicts a single float per each temporal location"""
def __init__(self, input_size, filter_size, kernel_size, dropout,
n_layers=2, n_predictions=1):
super(TemporalPredictor, self).__init__()
self.layers = nn.Sequential(*[
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size,
kernel_size=kernel_size, dropout=dropout)
for i in range(n_layers)]
)
self.n_predictions = n_predictions
self.fc = nn.Linear(filter_size, self.n_predictions, bias=True)
def forward(self, enc_out, enc_out_mask):
out = enc_out * enc_out_mask
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
out = self.fc(out) * enc_out_mask
return out
class ConvReLUNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0):
super(ConvReLUNorm, self).__init__()
self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size,
padding=(kernel_size // 2))
self.norm = torch.nn.LayerNorm(out_channels)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, signal):
out = F.relu(self.conv(signal))
out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype)
return self.dropout(out)