maskgct / models /tts /naturalspeech2 /prior_encoder.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from modules.naturalpseech2.transformers import (
TransformerEncoder,
DurationPredictor,
PitchPredictor,
LengthRegulator,
)
class PriorEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.enc_emb_tokens = nn.Embedding(
cfg.vocab_size, cfg.encoder.encoder_hidden, padding_idx=0
)
self.enc_emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
self.encoder = TransformerEncoder(
enc_emb_tokens=self.enc_emb_tokens, cfg=cfg.encoder
)
self.duration_predictor = DurationPredictor(cfg.duration_predictor)
self.pitch_predictor = PitchPredictor(cfg.pitch_predictor)
self.length_regulator = LengthRegulator()
self.pitch_min = cfg.pitch_min
self.pitch_max = cfg.pitch_max
self.pitch_bins_num = cfg.pitch_bins_num
pitch_bins = torch.exp(
torch.linspace(
np.log(self.pitch_min), np.log(self.pitch_max), self.pitch_bins_num - 1
)
)
self.register_buffer("pitch_bins", pitch_bins)
self.pitch_embedding = nn.Embedding(
self.pitch_bins_num, cfg.encoder.encoder_hidden
)
def forward(
self,
phone_id,
duration=None,
pitch=None,
phone_mask=None,
mask=None,
ref_emb=None,
ref_mask=None,
is_inference=False,
):
"""
input:
phone_id: (B, N)
duration: (B, N)
pitch: (B, T)
phone_mask: (B, N); mask is 0
mask: (B, T); mask is 0
ref_emb: (B, d, T')
ref_mask: (B, T'); mask is 0
output:
prior_embedding: (B, d, T)
pred_dur: (B, N)
pred_pitch: (B, T)
"""
x = self.encoder(phone_id, phone_mask, ref_emb.transpose(1, 2))
# print(torch.min(x), torch.max(x))
dur_pred_out = self.duration_predictor(x, phone_mask, ref_emb, ref_mask)
# dur_pred_out: {dur_pred_log, dur_pred, dur_pred_round}
if is_inference or duration is None:
x, mel_len = self.length_regulator(
x,
dur_pred_out["dur_pred_round"],
max_len=torch.max(torch.sum(dur_pred_out["dur_pred_round"], dim=1)),
)
else:
x, mel_len = self.length_regulator(x, duration, max_len=pitch.shape[1])
pitch_pred_log = self.pitch_predictor(x, mask, ref_emb, ref_mask)
if is_inference or pitch is None:
pitch_tokens = torch.bucketize(pitch_pred_log.exp(), self.pitch_bins)
pitch_embedding = self.pitch_embedding(pitch_tokens)
else:
pitch_tokens = torch.bucketize(pitch, self.pitch_bins)
pitch_embedding = self.pitch_embedding(pitch_tokens)
x = x + pitch_embedding
if (not is_inference) and (mask is not None):
x = x * mask.to(x.dtype)[:, :, None]
prior_out = {
"dur_pred_round": dur_pred_out["dur_pred_round"],
"dur_pred_log": dur_pred_out["dur_pred_log"],
"dur_pred": dur_pred_out["dur_pred"],
"pitch_pred_log": pitch_pred_log,
"pitch_token": pitch_tokens,
"mel_len": mel_len,
"prior_out": x,
}
return prior_out