Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import math | |
import time | |
import traceback | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.nn.modules.conv import Conv1d | |
from python.xvapitch.glow_tts import RelativePositionTransformer | |
from python.xvapitch.wavenet import WN | |
from python.xvapitch.hifigan import HifiganGenerator | |
from python.xvapitch.sdp import StochasticDurationPredictor#, StochasticPredictor | |
from python.xvapitch.util import maximum_path, rand_segments, segment, sequence_mask, generate_path | |
from python.xvapitch.text import get_text_preprocessor, ALL_SYMBOLS, lang_names | |
class xVAPitch(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.args.init_discriminator = True | |
self.args.speaker_embedding_channels = 512 | |
self.args.use_spectral_norm_disriminator = False | |
self.args.d_vector_dim = 512 | |
self.args.use_language_embedding = True | |
self.args.detach_dp_input = True | |
self.END2END = True | |
self.embedded_language_dim = 12 | |
self.latent_size = 256 | |
num_languages = len(list(lang_names.keys())) | |
self.emb_l = nn.Embedding(num_languages, self.embedded_language_dim) | |
self.length_scale = 1.0 | |
self.noise_scale = 1.0 | |
self.inference_noise_scale = 0.333 | |
self.inference_noise_scale_dp = 0.333 | |
self.noise_scale_dp = 1.0 | |
self.max_inference_len = None | |
self.spec_segment_size = 32 | |
self.text_encoder = TextEncoder( | |
# 165, | |
len(ALL_SYMBOLS), | |
self.latent_size,#192, | |
self.latent_size,#192, | |
768, | |
2, | |
10, | |
3, | |
0.1, | |
# language_emb_dim=4, | |
language_emb_dim=self.embedded_language_dim, | |
) | |
self.posterior_encoder = PosteriorEncoder( | |
513, | |
self.latent_size,#+self.embedded_language_dim if self.args.flc else self.latent_size,#192, | |
self.latent_size,#+self.embedded_language_dim if self.args.flc else self.latent_size,#192, | |
kernel_size=5, | |
dilation_rate=1, | |
num_layers=16, | |
cond_channels=self.args.d_vector_dim, | |
) | |
self.flow = ResidualCouplingBlocks( | |
self.latent_size,#192, | |
self.latent_size,#192, | |
kernel_size=5, | |
dilation_rate=1, | |
num_layers=4, | |
cond_channels=self.args.d_vector_dim, | |
args=self.args | |
) | |
self.duration_predictor = StochasticDurationPredictor( | |
self.latent_size,#192, | |
self.latent_size,#192, | |
3, | |
0.5, | |
4, | |
cond_channels=self.args.d_vector_dim, | |
language_emb_dim=self.embedded_language_dim, | |
) | |
self.waveform_decoder = HifiganGenerator( | |
self.latent_size,#192, | |
1, | |
"1", | |
[[1,3,5],[1,3,5],[1,3,5]], | |
[3,7,11], | |
[16,16,4,4], | |
512, | |
[8,8,2,2], | |
inference_padding=0, | |
# cond_channels=self.args.d_vector_dim+self.embedded_language_dim if self.args.flc else self.args.d_vector_dim, | |
cond_channels=self.args.d_vector_dim, | |
conv_pre_weight_norm=False, | |
conv_post_weight_norm=False, | |
conv_post_bias=False, | |
) | |
self.USE_PITCH_COND = False | |
# self.USE_PITCH_COND = True | |
if self.USE_PITCH_COND: | |
self.pitch_predictor = RelativePositioningPitchEnergyEncoder( | |
# 165, | |
# len(ALL_SYMBOLS), | |
out_channels=1, | |
hidden_channels=self.latent_size+self.embedded_language_dim,#196, | |
hidden_channels_ffn=768, | |
num_heads=2, | |
# num_layers=10, | |
num_layers=3, | |
kernel_size=3, | |
dropout_p=0.1, | |
# language_emb_dim=4, | |
conditioning_emb_dim=self.args.d_vector_dim, | |
) | |
self.pitch_emb = nn.Conv1d( | |
# 1, 384, | |
# 1, 196, | |
1, | |
self.args.expanded_flow_dim if args.expanded_flow else self.latent_size, | |
# pitch_conditioning_formants, symbols_embedding_dim, | |
kernel_size=3, | |
padding=int((3 - 1) / 2)) | |
self.TEMP_timing = [] | |
def infer_get_lang_emb (self, language_id): | |
aux_input = { | |
# "d_vectors": embedding.unsqueeze(dim=0), | |
"language_ids": language_id | |
} | |
sid, g, lid = self._set_cond_input(aux_input) | |
lang_emb = self.emb_l(lid).unsqueeze(-1) | |
return lang_emb | |
def infer_advanced (self, logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace=1.0, editor_data=None, old_sequence=None, pitch_amp=None): | |
if (editor_data is not None) and ((editor_data[0] is not None and len(editor_data[0])) or (editor_data[1] is not None and len(editor_data[1]))): | |
pitch_pred, dur_pred, energy_pred, em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred, _ = editor_data | |
# TODO, use energy_pred | |
dur_pred = torch.tensor(dur_pred) | |
dur_pred = dur_pred.view((1, dur_pred.shape[0])).float().to(self.device) | |
pitch_pred = torch.tensor(pitch_pred) | |
pitch_pred = pitch_pred.view((1, pitch_pred.shape[0])).float().to(self.device) | |
energy_pred = torch.tensor(energy_pred) | |
energy_pred = energy_pred.view((1, energy_pred.shape[0])).float().to(self.device) | |
em_angry_pred = em_angry_pred.clone().detach() if (type(em_angry_pred) == 'torch.Tensor') else torch.tensor(em_angry_pred) | |
em_angry_pred = em_angry_pred.view((1, em_angry_pred.shape[0])).float().to(self.device) | |
em_happy_pred = em_happy_pred.clone().detach() if (type(em_happy_pred) == 'torch.Tensor') else torch.tensor(em_happy_pred) | |
em_happy_pred = em_happy_pred.view((1, em_happy_pred.shape[0])).float().to(self.device) | |
em_sad_pred = em_sad_pred.clone().detach() if (type(em_sad_pred) == 'torch.Tensor') else torch.tensor(em_sad_pred) | |
em_sad_pred = em_sad_pred.view((1, em_sad_pred.shape[0])).float().to(self.device) | |
em_surprise_pred = em_surprise_pred.clone().detach() if (type(em_surprise_pred) == 'torch.Tensor') else torch.tensor(em_surprise_pred) | |
em_surprise_pred = em_surprise_pred.view((1, em_surprise_pred.shape[0])).float().to(self.device) | |
# Pitch speaker embedding deltas | |
if not self.USE_PITCH_COND and pitch_pred.shape[1]==speaker_embs.shape[2]: | |
pitch_delta = self.pitch_emb_values.to(pitch_pred.device) * pitch_pred | |
speaker_embs = speaker_embs + pitch_delta.float() | |
# Emotion speaker embedding deltas | |
emotions_strength = 0.00003 # Global scaling | |
if em_angry_pred.shape[1]==speaker_embs.shape[2]: | |
em_angry_delta = self.angry_emb_values.to(em_angry_pred.device) * em_angry_pred * emotions_strength | |
speaker_embs = speaker_embs + em_angry_delta.float() | |
if em_happy_pred.shape[1]==speaker_embs.shape[2]: | |
em_happy_delta = self.happy_emb_values.to(em_happy_pred.device) * em_happy_pred * emotions_strength | |
speaker_embs = speaker_embs + em_happy_delta.float() | |
if em_sad_pred.shape[1]==speaker_embs.shape[2]: | |
em_sad_delta = self.sad_emb_values.to(em_sad_pred.device) * em_sad_pred * emotions_strength | |
speaker_embs = speaker_embs + em_sad_delta.float() | |
if em_surprise_pred.shape[1]==speaker_embs.shape[2]: | |
em_surprise_delta = self.surprise_emb_values.to(em_surprise_pred.device) * em_surprise_pred * emotions_strength | |
speaker_embs = speaker_embs + em_surprise_delta.float() | |
try: | |
logger.info("editor data infer_using_vals") | |
wav, dur_pred, pitch_pred_out, energy_pred, em_pred_out, start_index, end_index, wav_mult = self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, \ | |
speaker_embs, pace, dur_pred_existing=dur_pred, pitch_pred_existing=pitch_pred, energy_pred_existing=energy_pred, em_pred_existing=[em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred], old_sequence=old_sequence, new_sequence=text, pitch_amp=pitch_amp) | |
[em_angry_pred_out, em_happy_pred_out, em_sad_pred_out, em_surprise_pred_out] = em_pred_out | |
pitch_pred_out = pitch_pred | |
em_angry_pred_out = em_angry_pred | |
em_happy_pred_out = em_happy_pred | |
em_sad_pred_out = em_sad_pred | |
em_surprise_pred_out = em_surprise_pred | |
return wav, dur_pred, pitch_pred_out, energy_pred, [em_angry_pred_out, em_happy_pred_out, em_sad_pred_out, em_surprise_pred_out], start_index, end_index, wav_mult | |
except: | |
print(traceback.format_exc()) | |
logger.info(traceback.format_exc()) | |
# return traceback.format_exc() | |
logger.info("editor data corrupt; fallback to infer_using_vals") | |
return self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace, None, None, None, None, None, None, pitch_amp=pitch_amp) | |
else: | |
logger.info("no editor infer_using_vals") | |
return self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace, None, None, None, None, None, None, pitch_amp=pitch_amp) | |
def infer_using_vals (self, logger, plugin_manager, cleaned_text, sequence, lang_embs, speaker_embs, pace, dur_pred_existing, pitch_pred_existing, energy_pred_existing, em_pred_existing, old_sequence, new_sequence, pitch_amp=None): | |
start_index = None | |
end_index = None | |
[em_angry_pred_existing, em_happy_pred_existing, em_sad_pred_existing, em_surprise_pred_existing] = em_pred_existing if em_pred_existing is not None else [None, None, None, None] | |
# Calculate text splicing bounds, if needed | |
if old_sequence is not None: | |
old_sequence_np = old_sequence.cpu().detach().numpy() | |
old_sequence_np = list(old_sequence_np[0]) | |
new_sequence_np = new_sequence.cpu().detach().numpy() | |
new_sequence_np = list(new_sequence_np[0]) | |
# Get the index of the first changed value | |
if old_sequence_np[0]==new_sequence_np[0]: # If the start of both sequences is the same, then the change is not at the start | |
for i in range(len(old_sequence_np)): | |
if i<len(new_sequence_np): | |
if old_sequence_np[i]!=new_sequence_np[i]: | |
start_index = i-1 | |
break | |
else: | |
start_index = i-1 | |
break | |
if start_index is None: | |
start_index = len(old_sequence_np)-1 | |
# Get the index of the last changed value | |
old_sequence_np.reverse() | |
new_sequence_np.reverse() | |
if old_sequence_np[0]==new_sequence_np[0]: # If the end of both reversed sequences is the same, then the change is not at the end | |
for i in range(len(old_sequence_np)): | |
if i<len(new_sequence_np): | |
if old_sequence_np[i]!=new_sequence_np[i]: | |
end_index = len(old_sequence_np)-1-i+1 | |
break | |
else: | |
end_index = len(old_sequence_np)-1-i+1 | |
break | |
old_sequence_np.reverse() | |
new_sequence_np.reverse() | |
# cleaned_text is the actual text phonemes | |
input_symbols = sequence | |
x_lengths = torch.where(input_symbols > 0, torch.ones_like(input_symbols), torch.zeros_like(input_symbols)).sum(dim=1) | |
lang_emb_full = None # TODO | |
self.text_encoder.logger = logger | |
# TODO, store a bank of trained 31 language embeds, to use for interpolating | |
lang_emb = self.emb_l(lang_embs).unsqueeze(-1) | |
if len(lang_embs.shape)>1: # Batch mode | |
lang_emb_full = lang_emb.squeeze(1).squeeze(-1) | |
else: # Individual line from the UI | |
lang_emb_full = lang_emb.transpose(2, 1).squeeze(1).unsqueeze(0) | |
x, x_emb, x_mask = self.text_encoder(input_symbols, x_lengths, lang_emb=None, stats=False, lang_emb_full=lang_emb_full) | |
m_p, logs_p = self.text_encoder(x, x_lengths, lang_emb=None, lang_emb_full=lang_emb_full, stats=True, x_mask=x_mask) | |
lang_emb_full = lang_emb_full.reshape(lang_emb_full.shape[0],lang_emb_full.shape[2],lang_emb_full.shape[1]) | |
self.inference_noise_scale_dp = 0 # TEMP DEBUGGING. REMOVE - or should I? It seems to make it worse, the higher it is | |
# Calculate its own pitch, and duration vals if these were not already provided | |
if (dur_pred_existing is None or dur_pred_existing.shape[1]==0) or old_sequence is not None: | |
# Predict durations | |
self.duration_predictor.logger = logger | |
logw = self.duration_predictor(x, x_mask, g=speaker_embs, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb_full) | |
w = torch.exp(logw) * x_mask * self.length_scale | |
# w = w * 1.3 # The model seems to generate quite fast speech, so I'm gonna just globally adjust that | |
w = w * (pace.unsqueeze(2) if torch.is_tensor(pace) else pace) | |
w_ceil = w | |
w_ceil = torch.ceil(w) | |
dur_pred = w_ceil | |
else: | |
dur_pred = dur_pred_existing.unsqueeze(dim=0) | |
dur_pred = dur_pred * pace | |
y_lengths = torch.clamp_min(torch.sum(torch.round(dur_pred), [1, 2]), 1).long() | |
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) | |
if dur_pred.shape[0]>1: | |
attn_all = [] | |
m_p_all = [] | |
logs_p_all = [] | |
for b in range(dur_pred.shape[0]): | |
attn_mask = torch.unsqueeze(x_mask[b,:].unsqueeze(0), 2) * torch.unsqueeze(y_mask[b,:].unsqueeze(0), -1) | |
attn_all.append(generate_path(dur_pred.squeeze(1)[b,:].unsqueeze(0), attn_mask.squeeze(0).transpose(1, 2))) | |
m_p_all.append(torch.matmul(attn_all[-1].transpose(1, 2), m_p[b,:].unsqueeze(0).transpose(1, 2)).transpose(1, 2)) | |
logs_p_all.append(torch.matmul(attn_all[-1].transpose(1, 2), logs_p[b,:].unsqueeze(0).transpose(1, 2)).transpose(1, 2)) | |
del attn_all | |
m_p = torch.stack(m_p_all, dim=1).squeeze(dim=0) | |
logs_p = torch.stack(logs_p_all, dim=1).squeeze(dim=0) | |
pitch_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
else: | |
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) | |
attn = generate_path(dur_pred.squeeze(1), attn_mask.squeeze(0).transpose(1, 2)) | |
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) | |
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) | |
pitch_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
emAngry_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
emHappy_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
emSad_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
emSurprise_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) | |
# Splice/replace pitch/duration values from the old input if simulating only a partial re-generation | |
if start_index is not None or end_index is not None: | |
dur_pred_np = list(dur_pred.cpu().detach().numpy())[0][0] | |
pitch_pred_np = list(pitch_pred.cpu().detach().numpy())[0][0] | |
emAngry_pred_np = list(emAngry_pred.cpu().detach().numpy())[0][0] | |
emHappy_pred_np = list(emHappy_pred.cpu().detach().numpy())[0][0] | |
emSad_pred_np = list(emSad_pred.cpu().detach().numpy())[0][0] | |
emSurprise_pred_np = list(emSurprise_pred.cpu().detach().numpy())[0][0] | |
dur_pred_existing_np = list(dur_pred_existing.cpu().detach().numpy())[0] | |
pitch_pred_existing_np = list(pitch_pred_existing.cpu().detach().numpy())[0] | |
emAngry_pred_existing_np = list(em_angry_pred_existing.cpu().detach().numpy())[0] | |
emHappy_pred_existing_np = list(em_happy_pred_existing.cpu().detach().numpy())[0] | |
emSad_pred_existing_np = list(em_sad_pred_existing.cpu().detach().numpy())[0] | |
emSurprise_pred_existing_np = list(em_surprise_pred_existing.cpu().detach().numpy())[0] | |
if start_index is not None: # Replace starting values | |
for i in range(start_index+1): | |
dur_pred_np[i] = dur_pred_existing_np[i] | |
pitch_pred_np[i] = pitch_pred_existing_np[i] | |
emAngry_pred_np[i] = emAngry_pred_existing_np[i] | |
emHappy_pred_np[i] = emHappy_pred_existing_np[i] | |
emSad_pred_np[i] = emSad_pred_existing_np[i] | |
emSurprise_pred_np[i] = emSurprise_pred_existing_np[i] | |
if end_index is not None: # Replace end values | |
for i in range(len(old_sequence_np)-end_index): | |
dur_pred_np[-i-1] = dur_pred_existing_np[-i-1] | |
pitch_pred_np[-i-1] = pitch_pred_existing_np[-i-1] | |
emAngry_pred_np[-i-1] = emAngry_pred_existing_np[-i-1] | |
emHappy_pred_np[-i-1] = emHappy_pred_existing_np[-i-1] | |
emSad_pred_np[-i-1] = emSad_pred_existing_np[-i-1] | |
emSurprise_pred_np[-i-1] = emSurprise_pred_existing_np[-i-1] | |
dur_pred = torch.tensor(dur_pred_np).to(self.device).unsqueeze(0) | |
pitch_pred = torch.tensor(pitch_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
emAngry_pred = torch.tensor(emAngry_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
emHappy_pred = torch.tensor(emHappy_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
emSad_pred = torch.tensor(emSad_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
emSurprise_pred = torch.tensor(emSurprise_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) | |
if pitch_amp is not None: | |
pitch_pred = pitch_pred * pitch_amp.unsqueeze(dim=-1) | |
if plugin_manager is not None and len(plugin_manager.plugins["synth-line"]["mid"]): | |
pitch_pred_numpy = pitch_pred.cpu().detach().numpy() | |
plugin_data = { | |
"pace": pace, | |
"duration": dur_pred.cpu().detach().numpy(), | |
"pitch": pitch_pred_numpy.reshape((pitch_pred_numpy.shape[0],pitch_pred_numpy.shape[2])), | |
"emAngry": emAngry_pred.reshape((emAngry_pred.shape[0],emAngry_pred.shape[2])), | |
"emHappy": emHappy_pred.reshape((emHappy_pred.shape[0],emHappy_pred.shape[2])), | |
"emSad": emSad_pred.reshape((emSad_pred.shape[0],emSad_pred.shape[2])), | |
"emSurprise": emSurprise_pred.reshape((emSurprise_pred.shape[0],emSurprise_pred.shape[2])), | |
"sequence": sequence, | |
"is_fresh_synth": pitch_pred_existing is None and dur_pred_existing is None, | |
"pluginsContext": plugin_manager.context, | |
"hasDataChanged": False | |
} | |
plugin_manager.run_plugins(plist=plugin_manager.plugins["synth-line"]["mid"], event="mid synth-line", data=plugin_data) | |
if ( | |
pace != plugin_data["pace"] | |
or plugin_data["hasDataChanged"] | |
): | |
logger.info("Inference data has been changed by plugins, rerunning infer_advanced") | |
pace = plugin_data["pace"] | |
editor_data = [ | |
plugin_data["pitch"][0], | |
plugin_data["duration"][0][0], | |
[1.0 for _ in range(pitch_pred_numpy.shape[-1])], | |
plugin_data["emAngry"][0], | |
plugin_data["emHappy"][0], | |
plugin_data["emSad"][0], | |
plugin_data["emSurprise"][0], | |
None | |
] | |
# rerun infer_advanced so that emValues take effect | |
# second argument ensures no loop | |
return self.infer_advanced (logger, None, cleaned_text, sequence, lang_embs, speaker_embs, pace=pace, editor_data=editor_data, old_sequence=sequence, pitch_amp=None) | |
else: | |
# skip rerunning infer_advanced | |
logger.info("Inference data unchanged by plugins") | |
# TODO, incorporate some sort of control for this | |
# self.inference_noise_scale = 0 | |
# for flow in self.flow.flows: | |
# flow.logger = logger | |
# flow.enc.logger = logger | |
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale | |
z = self.flow(z_p, y_mask, g=speaker_embs, reverse=True) | |
self.waveform_decoder.logger = logger | |
wav = self.waveform_decoder((z * y_mask.unsqueeze(1))[:, :, : self.max_inference_len], g=speaker_embs) | |
# In batch mode, trim the shorter audio waves in the batch. The masking doesn't seem to work, so have to do it manually | |
if dur_pred.shape[0]>1: | |
wav_all = [] | |
for b in range(dur_pred.shape[0]): | |
percent_to_mask = torch.sum(y_mask[b])/y_mask.shape[1] | |
wav_all.append(wav[b,0,0:int((wav.shape[2]*percent_to_mask).item())]) | |
wav = wav_all | |
start_index = -1 if start_index is None else start_index | |
end_index = -1 if end_index is None else end_index | |
# Apply volume adjustments | |
stretched_energy_mult = None | |
if energy_pred_existing is not None and pitch_pred_existing is not None: | |
energy_mult = self.expand_vals_by_durations(energy_pred_existing.unsqueeze(0), dur_pred, logger=logger) | |
stretched_energy_mult = torch.nn.functional.interpolate(energy_mult.unsqueeze(0).unsqueeze(0), (1,1,wav.shape[2])).squeeze() | |
stretched_energy_mult = stretched_energy_mult.cpu().detach().numpy() | |
energy_pred = energy_pred_existing.squeeze() | |
else: | |
energy_pred = [1.0 for _ in range(pitch_pred.shape[-1])] | |
energy_pred = torch.tensor(energy_pred) | |
# energy_pred = energy_pred.squeeze() | |
em_pred_out = [emAngry_pred, emHappy_pred, emSad_pred, emSurprise_pred] | |
return wav, dur_pred, pitch_pred, energy_pred, em_pred_out, start_index, end_index, stretched_energy_mult | |
def voice_conversion(self, y, y_lengths=None, spk1_emb=None, spk2_emb=None): | |
if y_lengths is None: | |
y_lengths = self.y_lengths_default | |
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=spk1_emb) | |
# z_hat = z | |
y_mask = y_mask.squeeze(0) | |
z_p = self.flow(z, y_mask, g=spk1_emb) | |
z_hat = self.flow(z_p, y_mask, g=spk2_emb, reverse=True) | |
o_hat = self.waveform_decoder(z_hat * y_mask, g=spk2_emb) | |
return o_hat | |
def _set_cond_input (self, aux_input): | |
"""Set the speaker conditioning input based on the multi-speaker mode.""" | |
sid, g, lid = None, None, None | |
# if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: | |
# sid = aux_input["speaker_ids"] | |
# if sid.ndim == 0: | |
# sid = sid.unsqueeze_(0) | |
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: | |
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) | |
if g.ndim == 2: | |
g = g.unsqueeze_(0) | |
if "language_ids" in aux_input and aux_input["language_ids"] is not None: | |
lid = aux_input["language_ids"] | |
if lid.ndim == 0: | |
lid = lid.unsqueeze_(0) | |
return sid, g, lid | |
# Opposite of average_pitch; Repeat per-symbol values by durations, to get sequence-wide values | |
def expand_vals_by_durations (self, vals, durations, logger=None): | |
vals = vals.view((vals.shape[0], vals.shape[2])) | |
if len(durations.shape)>2: | |
durations = durations.view((durations.shape[0], durations.shape[2])) | |
max_dur = int(torch.round(durations).sum().item()) | |
max_dur = int(torch.max(torch.sum(torch.round(durations), dim=1)).item()) | |
expanded = torch.zeros((vals.shape[0], 1, max_dur)).to(vals) | |
for b in range(vals.shape[0]): | |
b_vals = vals[b] | |
b_durs = durations[b] | |
expanded_vals = [] | |
for vi in range(b_vals.shape[0]): | |
for dur_i in range(round(b_durs[vi].item())): | |
if len(durations.shape)>2: | |
expanded_vals.append(b_vals[vi]) | |
else: | |
expanded_vals.append(b_vals[vi].unsqueeze(dim=0)) | |
expanded_vals = torch.tensor(expanded_vals).to(expanded) | |
expanded[b,:,0:expanded_vals.shape[0]] += expanded_vals | |
return expanded | |
class TextEncoder(nn.Module): | |
def __init__( | |
self, | |
n_vocab: int, # len(ALL_SYMBOLS) | |
out_channels: int, # 192 | |
hidden_channels: int, # 192 | |
hidden_channels_ffn: int, # 768 | |
num_heads: int, # 2 | |
num_layers: int, # 10 | |
kernel_size: int, # 3 | |
dropout_p: float, # 0.1 | |
language_emb_dim: int = None, | |
): | |
"""Text Encoder for VITS model. | |
Args: | |
n_vocab (int): Number of characters for the embedding layer. | |
out_channels (int): Number of channels for the output. | |
hidden_channels (int): Number of channels for the hidden layers. | |
hidden_channels_ffn (int): Number of channels for the convolutional layers. | |
num_heads (int): Number of attention heads for the Transformer layers. | |
num_layers (int): Number of Transformer layers. | |
kernel_size (int): Kernel size for the FFN layers in Transformer network. | |
dropout_p (float): Dropout rate for the Transformer layers. | |
""" | |
super().__init__() | |
self.out_channels = out_channels | |
self.hidden_channels = hidden_channels | |
self.emb = nn.Embedding(n_vocab, hidden_channels) | |
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) | |
if language_emb_dim: | |
hidden_channels += language_emb_dim | |
self.encoder = RelativePositionTransformer( | |
in_channels=hidden_channels, | |
out_channels=hidden_channels, | |
hidden_channels=hidden_channels, | |
hidden_channels_ffn=hidden_channels_ffn, | |
num_heads=num_heads, | |
num_layers=num_layers, | |
kernel_size=kernel_size, | |
dropout_p=dropout_p, | |
layer_norm_type="2", | |
rel_attn_window_size=4, | |
) | |
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
def forward(self, x, x_lengths, lang_emb=None, stats=False, x_mask=None, lang_emb_full=None): | |
""" | |
Shapes: | |
- x: :math:`[B, T]` | |
- x_length: :math:`[B]` | |
""" | |
if stats: | |
stats = self.proj(x) * x_mask | |
m, logs = torch.split(stats, self.out_channels, dim=1) | |
return m, logs | |
else: | |
x_emb = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] | |
# concat the lang emb in embedding chars | |
if lang_emb is not None or lang_emb_full is not None: | |
# x = torch.cat((x_emb, lang_emb.transpose(2, 1).expand(x_emb.size(0), x_emb.size(1), -1)), dim=-1) | |
if lang_emb_full is None: | |
lang_emb_full = lang_emb.transpose(2, 1).expand(x_emb.size(0), x_emb.size(1), -1) | |
x = torch.cat((x_emb, lang_emb_full), dim=-1) | |
x = torch.transpose(x, 1, -1) # [b, h, t] | |
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) | |
x = self.encoder(x * x_mask, x_mask) | |
# stats = self.proj(x) * x_mask | |
# m, logs = torch.split(stats, self.out_channels, dim=1) | |
return x, x_emb, x_mask | |
class RelativePositioningPitchEnergyEncoder(nn.Module): | |
def __init__( | |
self, | |
# n_vocab: int, # len(ALL_SYMBOLS) | |
out_channels: int, # 192 | |
hidden_channels: int, # 192 | |
hidden_channels_ffn: int, # 768 | |
num_heads: int, # 2 | |
num_layers: int, # 10 | |
kernel_size: int, # 3 | |
dropout_p: float, # 0.1 | |
conditioning_emb_dim: int = None, | |
): | |
super().__init__() | |
self.out_channels = out_channels | |
self.hidden_channels = hidden_channels | |
# self.emb = nn.Embedding(n_vocab, hidden_channels) | |
# nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) | |
if conditioning_emb_dim: | |
hidden_channels += conditioning_emb_dim | |
self.encoder = RelativePositionTransformer( | |
in_channels=hidden_channels, | |
# out_channels=hidden_channels, | |
out_channels=1, | |
# out_channels=196, | |
hidden_channels=hidden_channels, | |
hidden_channels_ffn=hidden_channels_ffn, | |
num_heads=num_heads, | |
num_layers=num_layers, | |
kernel_size=kernel_size, | |
dropout_p=dropout_p, | |
layer_norm_type="2", | |
rel_attn_window_size=4, | |
) | |
# self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
# self.proj = nn.Conv1d(196, out_channels * 2, 1) | |
def forward(self, x, x_lengths=None, speaker_emb=None, stats=False, x_mask=None): | |
""" | |
Shapes: | |
- x: :math:`[B, T]` | |
- x_length: :math:`[B]` | |
""" | |
# concat the lang emb in embedding chars | |
if speaker_emb is not None: | |
x = torch.cat((x, speaker_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) | |
x = torch.transpose(x, 1, -1) # [b, h, t] | |
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) | |
x = self.encoder(x * x_mask, x_mask) | |
return x#, x_mask | |
class ResidualCouplingBlocks(nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
hidden_channels: int, | |
kernel_size: int, | |
dilation_rate: int, | |
num_layers: int, | |
num_flows=4, | |
cond_channels=0, | |
args=None | |
): | |
"""Redisual Coupling blocks for VITS flow layers. | |
Args: | |
channels (int): Number of input and output tensor channels. | |
hidden_channels (int): Number of hidden network channels. | |
kernel_size (int): Kernel size of the WaveNet layers. | |
dilation_rate (int): Dilation rate of the WaveNet layers. | |
num_layers (int): Number of the WaveNet layers. | |
num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4. | |
cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0. | |
""" | |
super().__init__() | |
self.args = args | |
self.channels = channels | |
self.hidden_channels = hidden_channels | |
self.kernel_size = kernel_size | |
self.dilation_rate = dilation_rate | |
self.num_layers = num_layers | |
self.num_flows = num_flows | |
self.cond_channels = cond_channels | |
self.flows = nn.ModuleList() | |
for flow_i in range(num_flows): | |
self.flows.append( | |
ResidualCouplingBlock( | |
(192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else channels, | |
(192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else hidden_channels, | |
kernel_size, | |
dilation_rate, | |
num_layers, | |
cond_channels=cond_channels, | |
out_channels_override=(192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else None, | |
mean_only=True, | |
) | |
) | |
def forward(self, x, x_mask, g=None, reverse=False): | |
""" | |
Shapes: | |
- x: :math:`[B, C, T]` | |
- x_mask: :math:`[B, 1, T]` | |
- g: :math:`[B, C, 1]` | |
""" | |
if not reverse: | |
for fi, flow in enumerate(self.flows): | |
x, _ = flow(x, x_mask, g=g, reverse=reverse) | |
x = torch.flip(x, [1]) | |
else: | |
for flow in reversed(self.flows): | |
x = torch.flip(x, [1]) | |
x = flow(x, x_mask, g=g, reverse=reverse) | |
return x | |
class PosteriorEncoder(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
hidden_channels: int, | |
kernel_size: int, | |
dilation_rate: int, | |
num_layers: int, | |
cond_channels=0, | |
): | |
"""Posterior Encoder of VITS model. | |
:: | |
x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z | |
Args: | |
in_channels (int): Number of input tensor channels. | |
out_channels (int): Number of output tensor channels. | |
hidden_channels (int): Number of hidden channels. | |
kernel_size (int): Kernel size of the WaveNet convolution layers. | |
dilation_rate (int): Dilation rate of the WaveNet layers. | |
num_layers (int): Number of the WaveNet layers. | |
cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. | |
""" | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.hidden_channels = hidden_channels | |
self.kernel_size = kernel_size | |
self.dilation_rate = dilation_rate | |
self.num_layers = num_layers | |
self.cond_channels = cond_channels | |
self.pre = nn.Conv1d(in_channels, hidden_channels, 1) | |
self.enc = WN( | |
hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels | |
) | |
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) | |
def forward(self, x, x_lengths, g=None): | |
""" | |
Shapes: | |
- x: :math:`[B, C, T]` | |
- x_lengths: :math:`[B, 1]` | |
- g: :math:`[B, C, 1]` | |
""" | |
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) | |
x = self.pre(x) * x_mask | |
x = self.enc(x, x_mask, g=g) | |
stats = self.proj(x) * x_mask | |
mean, log_scale = torch.split(stats, self.out_channels, dim=1) | |
z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask | |
return z, mean, log_scale, x_mask | |
class ResidualCouplingBlock(nn.Module): | |
def __init__( | |
self, | |
channels, | |
hidden_channels, | |
kernel_size, | |
dilation_rate, | |
num_layers, | |
dropout_p=0, | |
cond_channels=0, | |
out_channels_override=None, | |
mean_only=False, | |
): | |
assert channels % 2 == 0, "channels should be divisible by 2" | |
super().__init__() | |
self.half_channels = channels // 2 | |
self.mean_only = mean_only | |
# input layer | |
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) | |
# coupling layers | |
self.enc = WN( | |
hidden_channels, | |
hidden_channels, | |
kernel_size, | |
dilation_rate, | |
num_layers, | |
dropout_p=dropout_p, | |
c_in_channels=cond_channels, | |
) | |
# output layer | |
# Initializing last layer to 0 makes the affine coupling layers | |
# do nothing at first. This helps with training stability | |
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) | |
self.conv1d_projector = None | |
if out_channels_override: | |
self.conv1d_projector = nn.Conv1d(192, out_channels_override, 1) | |
def forward(self, x, x_mask, g=None, reverse=False): | |
""" | |
Shapes: | |
- x: :math:`[B, C, T]` | |
- x_mask: :math:`[B, 1, T]` | |
- g: :math:`[B, C, 1]` | |
""" | |
if self.conv1d_projector is not None and not reverse: | |
x = self.conv1d_projector(x) | |
x0, x1 = torch.split(x, [self.half_channels] * 2, 1) | |
h = self.pre(x0) * x_mask.unsqueeze(1) | |
h = self.enc(h, x_mask.unsqueeze(1), g=g) | |
stats = self.post(h) * x_mask.unsqueeze(1) | |
if not self.mean_only: | |
m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) | |
else: | |
m = stats | |
log_scale = torch.zeros_like(m) | |
if not reverse: | |
x1 = m + x1 * torch.exp(log_scale) * x_mask.unsqueeze(1) | |
x = torch.cat([x0, x1], 1) | |
logdet = torch.sum(log_scale, [1, 2]) | |
return x, logdet | |
else: | |
x1 = (x1 - m) * torch.exp(-log_scale) * x_mask.unsqueeze(1) | |
x = torch.cat([x0, x1], 1) | |
return x | |
def mask_from_lens(lens, max_len= None): | |
if max_len is None: | |
max_len = lens.max() | |
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype) | |
mask = torch.lt(ids, lens.unsqueeze(1)) | |
return mask | |
class TemporalPredictor(nn.Module): | |
"""Predicts a single float per each temporal location""" | |
def __init__(self, input_size, filter_size, kernel_size, dropout, | |
n_layers=2, n_predictions=1): | |
super(TemporalPredictor, self).__init__() | |
self.layers = nn.Sequential(*[ | |
ConvReLUNorm(input_size if i == 0 else filter_size, filter_size, | |
kernel_size=kernel_size, dropout=dropout) | |
for i in range(n_layers)] | |
) | |
self.n_predictions = n_predictions | |
self.fc = nn.Linear(filter_size, self.n_predictions, bias=True) | |
def forward(self, enc_out, enc_out_mask): | |
out = enc_out * enc_out_mask | |
out = self.layers(out.transpose(1, 2)).transpose(1, 2) | |
out = self.fc(out) * enc_out_mask | |
return out | |
class ConvReLUNorm(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0): | |
super(ConvReLUNorm, self).__init__() | |
self.conv = torch.nn.Conv1d(in_channels, out_channels, | |
kernel_size=kernel_size, | |
padding=(kernel_size // 2)) | |
self.norm = torch.nn.LayerNorm(out_channels) | |
self.dropout = torch.nn.Dropout(dropout) | |
def forward(self, signal): | |
out = F.relu(self.conv(signal)) | |
out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype) | |
return self.dropout(out) | |