import math import time import traceback import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from torch.nn.modules.conv import Conv1d from python.xvapitch.glow_tts import RelativePositionTransformer from python.xvapitch.wavenet import WN from python.xvapitch.hifigan import HifiganGenerator from python.xvapitch.sdp import StochasticDurationPredictor#, StochasticPredictor from python.xvapitch.util import maximum_path, rand_segments, segment, sequence_mask, generate_path from python.xvapitch.text import get_text_preprocessor, ALL_SYMBOLS, lang_names class xVAPitch(nn.Module): def __init__(self, args): super().__init__() self.args = args self.args.init_discriminator = True self.args.speaker_embedding_channels = 512 self.args.use_spectral_norm_disriminator = False self.args.d_vector_dim = 512 self.args.use_language_embedding = True self.args.detach_dp_input = True self.END2END = True self.embedded_language_dim = 12 self.latent_size = 256 num_languages = len(list(lang_names.keys())) self.emb_l = nn.Embedding(num_languages, self.embedded_language_dim) self.length_scale = 1.0 self.noise_scale = 1.0 self.inference_noise_scale = 0.333 self.inference_noise_scale_dp = 0.333 self.noise_scale_dp = 1.0 self.max_inference_len = None self.spec_segment_size = 32 self.text_encoder = TextEncoder( # 165, len(ALL_SYMBOLS), self.latent_size,#192, self.latent_size,#192, 768, 2, 10, 3, 0.1, # language_emb_dim=4, language_emb_dim=self.embedded_language_dim, ) self.posterior_encoder = PosteriorEncoder( 513, self.latent_size,#+self.embedded_language_dim if self.args.flc else self.latent_size,#192, self.latent_size,#+self.embedded_language_dim if self.args.flc else self.latent_size,#192, kernel_size=5, dilation_rate=1, num_layers=16, cond_channels=self.args.d_vector_dim, ) self.flow = ResidualCouplingBlocks( self.latent_size,#192, self.latent_size,#192, kernel_size=5, dilation_rate=1, num_layers=4, cond_channels=self.args.d_vector_dim, args=self.args ) self.duration_predictor = StochasticDurationPredictor( self.latent_size,#192, self.latent_size,#192, 3, 0.5, 4, cond_channels=self.args.d_vector_dim, language_emb_dim=self.embedded_language_dim, ) self.waveform_decoder = HifiganGenerator( self.latent_size,#192, 1, "1", [[1,3,5],[1,3,5],[1,3,5]], [3,7,11], [16,16,4,4], 512, [8,8,2,2], inference_padding=0, # cond_channels=self.args.d_vector_dim+self.embedded_language_dim if self.args.flc else self.args.d_vector_dim, cond_channels=self.args.d_vector_dim, conv_pre_weight_norm=False, conv_post_weight_norm=False, conv_post_bias=False, ) self.USE_PITCH_COND = False # self.USE_PITCH_COND = True if self.USE_PITCH_COND: self.pitch_predictor = RelativePositioningPitchEnergyEncoder( # 165, # len(ALL_SYMBOLS), out_channels=1, hidden_channels=self.latent_size+self.embedded_language_dim,#196, hidden_channels_ffn=768, num_heads=2, # num_layers=10, num_layers=3, kernel_size=3, dropout_p=0.1, # language_emb_dim=4, conditioning_emb_dim=self.args.d_vector_dim, ) self.pitch_emb = nn.Conv1d( # 1, 384, # 1, 196, 1, self.args.expanded_flow_dim if args.expanded_flow else self.latent_size, # pitch_conditioning_formants, symbols_embedding_dim, kernel_size=3, padding=int((3 - 1) / 2)) self.TEMP_timing = [] def infer_get_lang_emb (self, language_id): aux_input = { # "d_vectors": embedding.unsqueeze(dim=0), "language_ids": language_id } sid, g, lid = self._set_cond_input(aux_input) lang_emb = self.emb_l(lid).unsqueeze(-1) return lang_emb def infer_advanced (self, logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace=1.0, editor_data=None, old_sequence=None, pitch_amp=None): if (editor_data is not None) and ((editor_data[0] is not None and len(editor_data[0])) or (editor_data[1] is not None and len(editor_data[1]))): pitch_pred, dur_pred, energy_pred, em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred, _ = editor_data # TODO, use energy_pred dur_pred = torch.tensor(dur_pred) dur_pred = dur_pred.view((1, dur_pred.shape[0])).float().to(self.device) pitch_pred = torch.tensor(pitch_pred) pitch_pred = pitch_pred.view((1, pitch_pred.shape[0])).float().to(self.device) energy_pred = torch.tensor(energy_pred) energy_pred = energy_pred.view((1, energy_pred.shape[0])).float().to(self.device) em_angry_pred = em_angry_pred.clone().detach() if (type(em_angry_pred) == 'torch.Tensor') else torch.tensor(em_angry_pred) em_angry_pred = em_angry_pred.view((1, em_angry_pred.shape[0])).float().to(self.device) em_happy_pred = em_happy_pred.clone().detach() if (type(em_happy_pred) == 'torch.Tensor') else torch.tensor(em_happy_pred) em_happy_pred = em_happy_pred.view((1, em_happy_pred.shape[0])).float().to(self.device) em_sad_pred = em_sad_pred.clone().detach() if (type(em_sad_pred) == 'torch.Tensor') else torch.tensor(em_sad_pred) em_sad_pred = em_sad_pred.view((1, em_sad_pred.shape[0])).float().to(self.device) em_surprise_pred = em_surprise_pred.clone().detach() if (type(em_surprise_pred) == 'torch.Tensor') else torch.tensor(em_surprise_pred) em_surprise_pred = em_surprise_pred.view((1, em_surprise_pred.shape[0])).float().to(self.device) # Pitch speaker embedding deltas if not self.USE_PITCH_COND and pitch_pred.shape[1]==speaker_embs.shape[2]: pitch_delta = self.pitch_emb_values.to(pitch_pred.device) * pitch_pred speaker_embs = speaker_embs + pitch_delta.float() # Emotion speaker embedding deltas emotions_strength = 0.00003 # Global scaling if em_angry_pred.shape[1]==speaker_embs.shape[2]: em_angry_delta = self.angry_emb_values.to(em_angry_pred.device) * em_angry_pred * emotions_strength speaker_embs = speaker_embs + em_angry_delta.float() if em_happy_pred.shape[1]==speaker_embs.shape[2]: em_happy_delta = self.happy_emb_values.to(em_happy_pred.device) * em_happy_pred * emotions_strength speaker_embs = speaker_embs + em_happy_delta.float() if em_sad_pred.shape[1]==speaker_embs.shape[2]: em_sad_delta = self.sad_emb_values.to(em_sad_pred.device) * em_sad_pred * emotions_strength speaker_embs = speaker_embs + em_sad_delta.float() if em_surprise_pred.shape[1]==speaker_embs.shape[2]: em_surprise_delta = self.surprise_emb_values.to(em_surprise_pred.device) * em_surprise_pred * emotions_strength speaker_embs = speaker_embs + em_surprise_delta.float() try: logger.info("editor data infer_using_vals") wav, dur_pred, pitch_pred_out, energy_pred, em_pred_out, start_index, end_index, wav_mult = self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, \ speaker_embs, pace, dur_pred_existing=dur_pred, pitch_pred_existing=pitch_pred, energy_pred_existing=energy_pred, em_pred_existing=[em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred], old_sequence=old_sequence, new_sequence=text, pitch_amp=pitch_amp) [em_angry_pred_out, em_happy_pred_out, em_sad_pred_out, em_surprise_pred_out] = em_pred_out pitch_pred_out = pitch_pred em_angry_pred_out = em_angry_pred em_happy_pred_out = em_happy_pred em_sad_pred_out = em_sad_pred em_surprise_pred_out = em_surprise_pred return wav, dur_pred, pitch_pred_out, energy_pred, [em_angry_pred_out, em_happy_pred_out, em_sad_pred_out, em_surprise_pred_out], start_index, end_index, wav_mult except: print(traceback.format_exc()) logger.info(traceback.format_exc()) # return traceback.format_exc() logger.info("editor data corrupt; fallback to infer_using_vals") return self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace, None, None, None, None, None, None, pitch_amp=pitch_amp) else: logger.info("no editor infer_using_vals") return self.infer_using_vals(logger, plugin_manager, cleaned_text, text, lang_embs, speaker_embs, pace, None, None, None, None, None, None, pitch_amp=pitch_amp) def infer_using_vals (self, logger, plugin_manager, cleaned_text, sequence, lang_embs, speaker_embs, pace, dur_pred_existing, pitch_pred_existing, energy_pred_existing, em_pred_existing, old_sequence, new_sequence, pitch_amp=None): start_index = None end_index = None [em_angry_pred_existing, em_happy_pred_existing, em_sad_pred_existing, em_surprise_pred_existing] = em_pred_existing if em_pred_existing is not None else [None, None, None, None] # Calculate text splicing bounds, if needed if old_sequence is not None: old_sequence_np = old_sequence.cpu().detach().numpy() old_sequence_np = list(old_sequence_np[0]) new_sequence_np = new_sequence.cpu().detach().numpy() new_sequence_np = list(new_sequence_np[0]) # Get the index of the first changed value if old_sequence_np[0]==new_sequence_np[0]: # If the start of both sequences is the same, then the change is not at the start for i in range(len(old_sequence_np)): if i 0, torch.ones_like(input_symbols), torch.zeros_like(input_symbols)).sum(dim=1) lang_emb_full = None # TODO self.text_encoder.logger = logger # TODO, store a bank of trained 31 language embeds, to use for interpolating lang_emb = self.emb_l(lang_embs).unsqueeze(-1) if len(lang_embs.shape)>1: # Batch mode lang_emb_full = lang_emb.squeeze(1).squeeze(-1) else: # Individual line from the UI lang_emb_full = lang_emb.transpose(2, 1).squeeze(1).unsqueeze(0) x, x_emb, x_mask = self.text_encoder(input_symbols, x_lengths, lang_emb=None, stats=False, lang_emb_full=lang_emb_full) m_p, logs_p = self.text_encoder(x, x_lengths, lang_emb=None, lang_emb_full=lang_emb_full, stats=True, x_mask=x_mask) lang_emb_full = lang_emb_full.reshape(lang_emb_full.shape[0],lang_emb_full.shape[2],lang_emb_full.shape[1]) self.inference_noise_scale_dp = 0 # TEMP DEBUGGING. REMOVE - or should I? It seems to make it worse, the higher it is # Calculate its own pitch, and duration vals if these were not already provided if (dur_pred_existing is None or dur_pred_existing.shape[1]==0) or old_sequence is not None: # Predict durations self.duration_predictor.logger = logger logw = self.duration_predictor(x, x_mask, g=speaker_embs, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb_full) w = torch.exp(logw) * x_mask * self.length_scale # w = w * 1.3 # The model seems to generate quite fast speech, so I'm gonna just globally adjust that w = w * (pace.unsqueeze(2) if torch.is_tensor(pace) else pace) w_ceil = w w_ceil = torch.ceil(w) dur_pred = w_ceil else: dur_pred = dur_pred_existing.unsqueeze(dim=0) dur_pred = dur_pred * pace y_lengths = torch.clamp_min(torch.sum(torch.round(dur_pred), [1, 2]), 1).long() y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) if dur_pred.shape[0]>1: attn_all = [] m_p_all = [] logs_p_all = [] for b in range(dur_pred.shape[0]): attn_mask = torch.unsqueeze(x_mask[b,:].unsqueeze(0), 2) * torch.unsqueeze(y_mask[b,:].unsqueeze(0), -1) attn_all.append(generate_path(dur_pred.squeeze(1)[b,:].unsqueeze(0), attn_mask.squeeze(0).transpose(1, 2))) m_p_all.append(torch.matmul(attn_all[-1].transpose(1, 2), m_p[b,:].unsqueeze(0).transpose(1, 2)).transpose(1, 2)) logs_p_all.append(torch.matmul(attn_all[-1].transpose(1, 2), logs_p[b,:].unsqueeze(0).transpose(1, 2)).transpose(1, 2)) del attn_all m_p = torch.stack(m_p_all, dim=1).squeeze(dim=0) logs_p = torch.stack(logs_p_all, dim=1).squeeze(dim=0) pitch_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) else: attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn = generate_path(dur_pred.squeeze(1), attn_mask.squeeze(0).transpose(1, 2)) m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) pitch_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) emAngry_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) emHappy_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) emSad_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) emSurprise_pred = torch.zeros((x.shape[0], x.shape[0], x.shape[2])).to(x) # Splice/replace pitch/duration values from the old input if simulating only a partial re-generation if start_index is not None or end_index is not None: dur_pred_np = list(dur_pred.cpu().detach().numpy())[0][0] pitch_pred_np = list(pitch_pred.cpu().detach().numpy())[0][0] emAngry_pred_np = list(emAngry_pred.cpu().detach().numpy())[0][0] emHappy_pred_np = list(emHappy_pred.cpu().detach().numpy())[0][0] emSad_pred_np = list(emSad_pred.cpu().detach().numpy())[0][0] emSurprise_pred_np = list(emSurprise_pred.cpu().detach().numpy())[0][0] dur_pred_existing_np = list(dur_pred_existing.cpu().detach().numpy())[0] pitch_pred_existing_np = list(pitch_pred_existing.cpu().detach().numpy())[0] emAngry_pred_existing_np = list(em_angry_pred_existing.cpu().detach().numpy())[0] emHappy_pred_existing_np = list(em_happy_pred_existing.cpu().detach().numpy())[0] emSad_pred_existing_np = list(em_sad_pred_existing.cpu().detach().numpy())[0] emSurprise_pred_existing_np = list(em_surprise_pred_existing.cpu().detach().numpy())[0] if start_index is not None: # Replace starting values for i in range(start_index+1): dur_pred_np[i] = dur_pred_existing_np[i] pitch_pred_np[i] = pitch_pred_existing_np[i] emAngry_pred_np[i] = emAngry_pred_existing_np[i] emHappy_pred_np[i] = emHappy_pred_existing_np[i] emSad_pred_np[i] = emSad_pred_existing_np[i] emSurprise_pred_np[i] = emSurprise_pred_existing_np[i] if end_index is not None: # Replace end values for i in range(len(old_sequence_np)-end_index): dur_pred_np[-i-1] = dur_pred_existing_np[-i-1] pitch_pred_np[-i-1] = pitch_pred_existing_np[-i-1] emAngry_pred_np[-i-1] = emAngry_pred_existing_np[-i-1] emHappy_pred_np[-i-1] = emHappy_pred_existing_np[-i-1] emSad_pred_np[-i-1] = emSad_pred_existing_np[-i-1] emSurprise_pred_np[-i-1] = emSurprise_pred_existing_np[-i-1] dur_pred = torch.tensor(dur_pred_np).to(self.device).unsqueeze(0) pitch_pred = torch.tensor(pitch_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) emAngry_pred = torch.tensor(emAngry_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) emHappy_pred = torch.tensor(emHappy_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) emSad_pred = torch.tensor(emSad_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) emSurprise_pred = torch.tensor(emSurprise_pred_np).to(self.device).unsqueeze(0).unsqueeze(0) if pitch_amp is not None: pitch_pred = pitch_pred * pitch_amp.unsqueeze(dim=-1) if plugin_manager is not None and len(plugin_manager.plugins["synth-line"]["mid"]): pitch_pred_numpy = pitch_pred.cpu().detach().numpy() plugin_data = { "pace": pace, "duration": dur_pred.cpu().detach().numpy(), "pitch": pitch_pred_numpy.reshape((pitch_pred_numpy.shape[0],pitch_pred_numpy.shape[2])), "emAngry": emAngry_pred.reshape((emAngry_pred.shape[0],emAngry_pred.shape[2])), "emHappy": emHappy_pred.reshape((emHappy_pred.shape[0],emHappy_pred.shape[2])), "emSad": emSad_pred.reshape((emSad_pred.shape[0],emSad_pred.shape[2])), "emSurprise": emSurprise_pred.reshape((emSurprise_pred.shape[0],emSurprise_pred.shape[2])), "sequence": sequence, "is_fresh_synth": pitch_pred_existing is None and dur_pred_existing is None, "pluginsContext": plugin_manager.context, "hasDataChanged": False } plugin_manager.run_plugins(plist=plugin_manager.plugins["synth-line"]["mid"], event="mid synth-line", data=plugin_data) if ( pace != plugin_data["pace"] or plugin_data["hasDataChanged"] ): logger.info("Inference data has been changed by plugins, rerunning infer_advanced") pace = plugin_data["pace"] editor_data = [ plugin_data["pitch"][0], plugin_data["duration"][0][0], [1.0 for _ in range(pitch_pred_numpy.shape[-1])], plugin_data["emAngry"][0], plugin_data["emHappy"][0], plugin_data["emSad"][0], plugin_data["emSurprise"][0], None ] # rerun infer_advanced so that emValues take effect # second argument ensures no loop return self.infer_advanced (logger, None, cleaned_text, sequence, lang_embs, speaker_embs, pace=pace, editor_data=editor_data, old_sequence=sequence, pitch_amp=None) else: # skip rerunning infer_advanced logger.info("Inference data unchanged by plugins") # TODO, incorporate some sort of control for this # self.inference_noise_scale = 0 # for flow in self.flow.flows: # flow.logger = logger # flow.enc.logger = logger z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale z = self.flow(z_p, y_mask, g=speaker_embs, reverse=True) self.waveform_decoder.logger = logger wav = self.waveform_decoder((z * y_mask.unsqueeze(1))[:, :, : self.max_inference_len], g=speaker_embs) # In batch mode, trim the shorter audio waves in the batch. The masking doesn't seem to work, so have to do it manually if dur_pred.shape[0]>1: wav_all = [] for b in range(dur_pred.shape[0]): percent_to_mask = torch.sum(y_mask[b])/y_mask.shape[1] wav_all.append(wav[b,0,0:int((wav.shape[2]*percent_to_mask).item())]) wav = wav_all start_index = -1 if start_index is None else start_index end_index = -1 if end_index is None else end_index # Apply volume adjustments stretched_energy_mult = None if energy_pred_existing is not None and pitch_pred_existing is not None: energy_mult = self.expand_vals_by_durations(energy_pred_existing.unsqueeze(0), dur_pred, logger=logger) stretched_energy_mult = torch.nn.functional.interpolate(energy_mult.unsqueeze(0).unsqueeze(0), (1,1,wav.shape[2])).squeeze() stretched_energy_mult = stretched_energy_mult.cpu().detach().numpy() energy_pred = energy_pred_existing.squeeze() else: energy_pred = [1.0 for _ in range(pitch_pred.shape[-1])] energy_pred = torch.tensor(energy_pred) # energy_pred = energy_pred.squeeze() em_pred_out = [emAngry_pred, emHappy_pred, emSad_pred, emSurprise_pred] return wav, dur_pred, pitch_pred, energy_pred, em_pred_out, start_index, end_index, stretched_energy_mult def voice_conversion(self, y, y_lengths=None, spk1_emb=None, spk2_emb=None): if y_lengths is None: y_lengths = self.y_lengths_default z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=spk1_emb) # z_hat = z y_mask = y_mask.squeeze(0) z_p = self.flow(z, y_mask, g=spk1_emb) z_hat = self.flow(z_p, y_mask, g=spk2_emb, reverse=True) o_hat = self.waveform_decoder(z_hat * y_mask, g=spk2_emb) return o_hat def _set_cond_input (self, aux_input): """Set the speaker conditioning input based on the multi-speaker mode.""" sid, g, lid = None, None, None # if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: # sid = aux_input["speaker_ids"] # if sid.ndim == 0: # sid = sid.unsqueeze_(0) if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) if g.ndim == 2: g = g.unsqueeze_(0) if "language_ids" in aux_input and aux_input["language_ids"] is not None: lid = aux_input["language_ids"] if lid.ndim == 0: lid = lid.unsqueeze_(0) return sid, g, lid # Opposite of average_pitch; Repeat per-symbol values by durations, to get sequence-wide values def expand_vals_by_durations (self, vals, durations, logger=None): vals = vals.view((vals.shape[0], vals.shape[2])) if len(durations.shape)>2: durations = durations.view((durations.shape[0], durations.shape[2])) max_dur = int(torch.round(durations).sum().item()) max_dur = int(torch.max(torch.sum(torch.round(durations), dim=1)).item()) expanded = torch.zeros((vals.shape[0], 1, max_dur)).to(vals) for b in range(vals.shape[0]): b_vals = vals[b] b_durs = durations[b] expanded_vals = [] for vi in range(b_vals.shape[0]): for dur_i in range(round(b_durs[vi].item())): if len(durations.shape)>2: expanded_vals.append(b_vals[vi]) else: expanded_vals.append(b_vals[vi].unsqueeze(dim=0)) expanded_vals = torch.tensor(expanded_vals).to(expanded) expanded[b,:,0:expanded_vals.shape[0]] += expanded_vals return expanded class TextEncoder(nn.Module): def __init__( self, n_vocab: int, # len(ALL_SYMBOLS) out_channels: int, # 192 hidden_channels: int, # 192 hidden_channels_ffn: int, # 768 num_heads: int, # 2 num_layers: int, # 10 kernel_size: int, # 3 dropout_p: float, # 0.1 language_emb_dim: int = None, ): """Text Encoder for VITS model. Args: n_vocab (int): Number of characters for the embedding layer. out_channels (int): Number of channels for the output. hidden_channels (int): Number of channels for the hidden layers. hidden_channels_ffn (int): Number of channels for the convolutional layers. num_heads (int): Number of attention heads for the Transformer layers. num_layers (int): Number of Transformer layers. kernel_size (int): Kernel size for the FFN layers in Transformer network. dropout_p (float): Dropout rate for the Transformer layers. """ super().__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels self.emb = nn.Embedding(n_vocab, hidden_channels) nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) if language_emb_dim: hidden_channels += language_emb_dim self.encoder = RelativePositionTransformer( in_channels=hidden_channels, out_channels=hidden_channels, hidden_channels=hidden_channels, hidden_channels_ffn=hidden_channels_ffn, num_heads=num_heads, num_layers=num_layers, kernel_size=kernel_size, dropout_p=dropout_p, layer_norm_type="2", rel_attn_window_size=4, ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, lang_emb=None, stats=False, x_mask=None, lang_emb_full=None): """ Shapes: - x: :math:`[B, T]` - x_length: :math:`[B]` """ if stats: stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) return m, logs else: x_emb = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] # concat the lang emb in embedding chars if lang_emb is not None or lang_emb_full is not None: # x = torch.cat((x_emb, lang_emb.transpose(2, 1).expand(x_emb.size(0), x_emb.size(1), -1)), dim=-1) if lang_emb_full is None: lang_emb_full = lang_emb.transpose(2, 1).expand(x_emb.size(0), x_emb.size(1), -1) x = torch.cat((x_emb, lang_emb_full), dim=-1) x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.encoder(x * x_mask, x_mask) # stats = self.proj(x) * x_mask # m, logs = torch.split(stats, self.out_channels, dim=1) return x, x_emb, x_mask class RelativePositioningPitchEnergyEncoder(nn.Module): def __init__( self, # n_vocab: int, # len(ALL_SYMBOLS) out_channels: int, # 192 hidden_channels: int, # 192 hidden_channels_ffn: int, # 768 num_heads: int, # 2 num_layers: int, # 10 kernel_size: int, # 3 dropout_p: float, # 0.1 conditioning_emb_dim: int = None, ): super().__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels # self.emb = nn.Embedding(n_vocab, hidden_channels) # nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) if conditioning_emb_dim: hidden_channels += conditioning_emb_dim self.encoder = RelativePositionTransformer( in_channels=hidden_channels, # out_channels=hidden_channels, out_channels=1, # out_channels=196, hidden_channels=hidden_channels, hidden_channels_ffn=hidden_channels_ffn, num_heads=num_heads, num_layers=num_layers, kernel_size=kernel_size, dropout_p=dropout_p, layer_norm_type="2", rel_attn_window_size=4, ) # self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) # self.proj = nn.Conv1d(196, out_channels * 2, 1) def forward(self, x, x_lengths=None, speaker_emb=None, stats=False, x_mask=None): """ Shapes: - x: :math:`[B, T]` - x_length: :math:`[B]` """ # concat the lang emb in embedding chars if speaker_emb is not None: x = torch.cat((x, speaker_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.encoder(x * x_mask, x_mask) return x#, x_mask class ResidualCouplingBlocks(nn.Module): def __init__( self, channels: int, hidden_channels: int, kernel_size: int, dilation_rate: int, num_layers: int, num_flows=4, cond_channels=0, args=None ): """Redisual Coupling blocks for VITS flow layers. Args: channels (int): Number of input and output tensor channels. hidden_channels (int): Number of hidden network channels. kernel_size (int): Kernel size of the WaveNet layers. dilation_rate (int): Dilation rate of the WaveNet layers. num_layers (int): Number of the WaveNet layers. num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4. cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0. """ super().__init__() self.args = args self.channels = channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.num_layers = num_layers self.num_flows = num_flows self.cond_channels = cond_channels self.flows = nn.ModuleList() for flow_i in range(num_flows): self.flows.append( ResidualCouplingBlock( (192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else channels, (192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else hidden_channels, kernel_size, dilation_rate, num_layers, cond_channels=cond_channels, out_channels_override=(192+self.args.expanded_flow_dim+self.args.expanded_flow_dim) if flow_i==(num_flows-1) and self.args.expanded_flow else None, mean_only=True, ) ) def forward(self, x, x_mask, g=None, reverse=False): """ Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` - g: :math:`[B, C, 1]` """ if not reverse: for fi, flow in enumerate(self.flows): x, _ = flow(x, x_mask, g=g, reverse=reverse) x = torch.flip(x, [1]) else: for flow in reversed(self.flows): x = torch.flip(x, [1]) x = flow(x, x_mask, g=g, reverse=reverse) return x class PosteriorEncoder(nn.Module): def __init__( self, in_channels: int, out_channels: int, hidden_channels: int, kernel_size: int, dilation_rate: int, num_layers: int, cond_channels=0, ): """Posterior Encoder of VITS model. :: x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z Args: in_channels (int): Number of input tensor channels. out_channels (int): Number of output tensor channels. hidden_channels (int): Number of hidden channels. kernel_size (int): Kernel size of the WaveNet convolution layers. dilation_rate (int): Dilation rate of the WaveNet layers. num_layers (int): Number of the WaveNet layers. cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.num_layers = num_layers self.cond_channels = cond_channels self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.enc = WN( hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, g=None): """ Shapes: - x: :math:`[B, C, T]` - x_lengths: :math:`[B, 1]` - g: :math:`[B, C, 1]` """ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask mean, log_scale = torch.split(stats, self.out_channels, dim=1) z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask return z, mean, log_scale, x_mask class ResidualCouplingBlock(nn.Module): def __init__( self, channels, hidden_channels, kernel_size, dilation_rate, num_layers, dropout_p=0, cond_channels=0, out_channels_override=None, mean_only=False, ): assert channels % 2 == 0, "channels should be divisible by 2" super().__init__() self.half_channels = channels // 2 self.mean_only = mean_only # input layer self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) # coupling layers self.enc = WN( hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, dropout_p=dropout_p, c_in_channels=cond_channels, ) # output layer # Initializing last layer to 0 makes the affine coupling layers # do nothing at first. This helps with training stability self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) self.conv1d_projector = None if out_channels_override: self.conv1d_projector = nn.Conv1d(192, out_channels_override, 1) def forward(self, x, x_mask, g=None, reverse=False): """ Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1, T]` - g: :math:`[B, C, 1]` """ if self.conv1d_projector is not None and not reverse: x = self.conv1d_projector(x) x0, x1 = torch.split(x, [self.half_channels] * 2, 1) h = self.pre(x0) * x_mask.unsqueeze(1) h = self.enc(h, x_mask.unsqueeze(1), g=g) stats = self.post(h) * x_mask.unsqueeze(1) if not self.mean_only: m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) else: m = stats log_scale = torch.zeros_like(m) if not reverse: x1 = m + x1 * torch.exp(log_scale) * x_mask.unsqueeze(1) x = torch.cat([x0, x1], 1) logdet = torch.sum(log_scale, [1, 2]) return x, logdet else: x1 = (x1 - m) * torch.exp(-log_scale) * x_mask.unsqueeze(1) x = torch.cat([x0, x1], 1) return x def mask_from_lens(lens, max_len= None): if max_len is None: max_len = lens.max() ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype) mask = torch.lt(ids, lens.unsqueeze(1)) return mask class TemporalPredictor(nn.Module): """Predicts a single float per each temporal location""" def __init__(self, input_size, filter_size, kernel_size, dropout, n_layers=2, n_predictions=1): super(TemporalPredictor, self).__init__() self.layers = nn.Sequential(*[ ConvReLUNorm(input_size if i == 0 else filter_size, filter_size, kernel_size=kernel_size, dropout=dropout) for i in range(n_layers)] ) self.n_predictions = n_predictions self.fc = nn.Linear(filter_size, self.n_predictions, bias=True) def forward(self, enc_out, enc_out_mask): out = enc_out * enc_out_mask out = self.layers(out.transpose(1, 2)).transpose(1, 2) out = self.fc(out) * enc_out_mask return out class ConvReLUNorm(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0): super(ConvReLUNorm, self).__init__() self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size // 2)) self.norm = torch.nn.LayerNorm(out_channels) self.dropout = torch.nn.Dropout(dropout) def forward(self, signal): out = F.relu(self.conv(signal)) out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype) return self.dropout(out)