Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2024 Amphion.
#
# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/generator.py
# Licensed under Apache License 2.0
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from modules.transformer.Models import Encoder, Decoder
from modules.transformer.Layers import PostNet
from collections import OrderedDict
from models.tts.jets.alignments import (
AlignmentModule,
viterbi_decode,
average_by_duration,
make_pad_mask,
make_non_pad_mask,
get_random_segments,
)
from models.tts.jets.length_regulator import GaussianUpsampling
from models.vocoders.gan.generator.hifigan import HiFiGAN
import os
import json
from utils.util import load_config
def get_mask_from_lengths(lengths, max_len=None):
device = lengths.device
batch_size = lengths.shape[0]
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
return mask
def pad(input_ele, mel_max_length=None):
if mel_max_length:
max_len = mel_max_length
else:
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
out_list = list()
for i, batch in enumerate(input_ele):
if len(batch.shape) == 1:
one_batch_padded = F.pad(
batch, (0, max_len - batch.size(0)), "constant", 0.0
)
elif len(batch.shape) == 2:
one_batch_padded = F.pad(
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
)
out_list.append(one_batch_padded)
out_padded = torch.stack(out_list)
return out_padded
class VarianceAdaptor(nn.Module):
"""Variance Adaptor"""
def __init__(self, cfg):
super(VarianceAdaptor, self).__init__()
self.duration_predictor = VariancePredictor(cfg)
self.length_regulator = LengthRegulator()
self.pitch_predictor = VariancePredictor(cfg)
self.energy_predictor = VariancePredictor(cfg)
# assign the pitch/energy feature level
if cfg.preprocess.use_frame_pitch:
self.pitch_feature_level = "frame_level"
self.pitch_dir = cfg.preprocess.pitch_dir
else:
self.pitch_feature_level = "phoneme_level"
self.pitch_dir = cfg.preprocess.phone_pitch_dir
if cfg.preprocess.use_frame_energy:
self.energy_feature_level = "frame_level"
self.energy_dir = cfg.preprocess.energy_dir
else:
self.energy_feature_level = "phoneme_level"
self.energy_dir = cfg.preprocess.phone_energy_dir
assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
assert self.energy_feature_level in ["phoneme_level", "frame_level"]
pitch_quantization = cfg.model.variance_embedding.pitch_quantization
energy_quantization = cfg.model.variance_embedding.energy_quantization
n_bins = cfg.model.variance_embedding.n_bins
assert pitch_quantization in ["linear", "log"]
assert energy_quantization in ["linear", "log"]
with open(
os.path.join(
cfg.preprocess.processed_dir,
cfg.dataset[0],
self.energy_dir,
"statistics.json",
)
) as f:
stats = json.load(f)
stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
mean, std = (
stats["voiced_positions"]["mean"],
stats["voiced_positions"]["std"],
)
energy_min = (stats["total_positions"]["min"] - mean) / std
energy_max = (stats["total_positions"]["max"] - mean) / std
with open(
os.path.join(
cfg.preprocess.processed_dir,
cfg.dataset[0],
self.pitch_dir,
"statistics.json",
)
) as f:
stats = json.load(f)
stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
mean, std = (
stats["voiced_positions"]["mean"],
stats["voiced_positions"]["std"],
)
pitch_min = (stats["total_positions"]["min"] - mean) / std
pitch_max = (stats["total_positions"]["max"] - mean) / std
if pitch_quantization == "log":
self.pitch_bins = nn.Parameter(
torch.exp(
torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
),
requires_grad=False,
)
else:
self.pitch_bins = nn.Parameter(
torch.linspace(pitch_min, pitch_max, n_bins - 1),
requires_grad=False,
)
if energy_quantization == "log":
self.energy_bins = nn.Parameter(
torch.exp(
torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
),
requires_grad=False,
)
else:
self.energy_bins = nn.Parameter(
torch.linspace(energy_min, energy_max, n_bins - 1),
requires_grad=False,
)
self.pitch_embedding = nn.Embedding(
n_bins, cfg.model.transformer.encoder_hidden
)
self.energy_embedding = nn.Embedding(
n_bins, cfg.model.transformer.encoder_hidden
)
def get_pitch_embedding(self, x, target, mask, control):
prediction = self.pitch_predictor(x, mask)
if target is not None:
embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
else:
prediction = prediction * control
embedding = self.pitch_embedding(
torch.bucketize(prediction, self.pitch_bins)
)
return prediction, embedding
def get_energy_embedding(self, x, target, mask, control):
prediction = self.energy_predictor(x, mask)
if target is not None:
embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
else:
prediction = prediction * control
embedding = self.energy_embedding(
torch.bucketize(prediction, self.energy_bins)
)
return prediction, embedding
def forward(
self,
x,
src_mask,
mel_mask=None,
max_len=None,
pitch_target=None,
energy_target=None,
duration_target=None,
p_control=1.0,
e_control=1.0,
d_control=1.0,
pitch_embedding=None,
energy_embedding=None,
):
log_duration_prediction = self.duration_predictor(x, src_mask)
x = x + pitch_embedding
x = x + energy_embedding
pitch_prediction = self.pitch_predictor(x, src_mask)
energy_prediction = self.energy_predictor(x, src_mask)
if duration_target is not None:
x, mel_len = self.length_regulator(x, duration_target, max_len)
duration_rounded = duration_target
else:
duration_rounded = torch.clamp(
(torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
min=0,
)
x, mel_len = self.length_regulator(x, duration_rounded, max_len)
mel_mask = get_mask_from_lengths(mel_len)
return (
x,
pitch_prediction,
energy_prediction,
log_duration_prediction,
duration_rounded,
mel_len,
mel_mask,
)
def inference(
self,
x,
src_mask,
mel_mask=None,
max_len=None,
pitch_target=None,
energy_target=None,
duration_target=None,
p_control=1.0,
e_control=1.0,
d_control=1.0,
pitch_embedding=None,
energy_embedding=None,
):
p_outs = self.pitch_predictor(x, src_mask)
e_outs = self.energy_predictor(x, src_mask)
d_outs = self.duration_predictor(x, src_mask)
return p_outs, e_outs, d_outs
class LengthRegulator(nn.Module):
"""Length Regulator"""
def __init__(self):
super(LengthRegulator, self).__init__()
def LR(self, x, duration, max_len):
device = x.device
output = list()
mel_len = list()
for batch, expand_target in zip(x, duration):
expanded = self.expand(batch, expand_target)
output.append(expanded)
mel_len.append(expanded.shape[0])
if max_len is not None:
output = pad(output, max_len)
else:
output = pad(output)
return output, torch.LongTensor(mel_len).to(device)
def expand(self, batch, predicted):
out = list()
for i, vec in enumerate(batch):
expand_size = predicted[i].item()
out.append(vec.expand(max(int(expand_size), 0), -1))
out = torch.cat(out, 0)
return out
def forward(self, x, duration, max_len):
output, mel_len = self.LR(x, duration, max_len)
return output, mel_len
class VariancePredictor(nn.Module):
"""Duration, Pitch and Energy Predictor"""
def __init__(self, cfg):
super(VariancePredictor, self).__init__()
self.input_size = cfg.model.transformer.encoder_hidden
self.filter_size = cfg.model.variance_predictor.filter_size
self.kernel = cfg.model.variance_predictor.kernel_size
self.conv_output_size = cfg.model.variance_predictor.filter_size
self.dropout = cfg.model.variance_predictor.dropout
self.conv_layer = nn.Sequential(
OrderedDict(
[
(
"conv1d_1",
Conv(
self.input_size,
self.filter_size,
kernel_size=self.kernel,
padding=(self.kernel - 1) // 2,
),
),
("relu_1", nn.ReLU()),
("layer_norm_1", nn.LayerNorm(self.filter_size)),
("dropout_1", nn.Dropout(self.dropout)),
(
"conv1d_2",
Conv(
self.filter_size,
self.filter_size,
kernel_size=self.kernel,
padding=1,
),
),
("relu_2", nn.ReLU()),
("layer_norm_2", nn.LayerNorm(self.filter_size)),
("dropout_2", nn.Dropout(self.dropout)),
]
)
)
self.linear_layer = nn.Linear(self.conv_output_size, 1)
def forward(self, encoder_output, mask):
out = self.conv_layer(encoder_output)
out = self.linear_layer(out)
out = out.squeeze(-1)
if mask is not None:
out = out.masked_fill(mask, 0.0)
return out
class Conv(nn.Module):
"""
Convolution Module
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
w_init="linear",
):
"""
:param in_channels: dimension of input
:param out_channels: dimension of output
:param kernel_size: size of kernel
:param stride: size of stride
:param padding: size of padding
:param dilation: dilation rate
:param bias: boolean. if True, bias is included.
:param w_init: str. weight inits with xavier initialization.
"""
super(Conv, self).__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
def forward(self, x):
x = x.contiguous().transpose(1, 2)
x = self.conv(x)
x = x.contiguous().transpose(1, 2)
return x
class Jets(nn.Module):
def __init__(self, cfg) -> None:
super(Jets, self).__init__()
self.cfg = cfg
self.encoder = Encoder(cfg.model)
self.variance_adaptor = VarianceAdaptor(cfg)
self.decoder = Decoder(cfg.model)
self.length_regulator_infer = LengthRegulator()
self.mel_linear = nn.Linear(
cfg.model.transformer.decoder_hidden,
cfg.preprocess.n_mel,
)
self.postnet = PostNet(n_mel_channels=cfg.preprocess.n_mel)
self.speaker_emb = None
if cfg.train.multi_speaker_training:
with open(
os.path.join(
cfg.preprocess.processed_dir, cfg.dataset[0], "spk2id.json"
),
"r",
) as f:
n_speaker = len(json.load(f))
self.speaker_emb = nn.Embedding(
n_speaker,
cfg.model.transformer.encoder_hidden,
)
output_dim = cfg.preprocess.n_mel
attention_dim = 256
self.alignment_module = AlignmentModule(attention_dim, output_dim)
# NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg
pitch_embed_kernel_size: int = 9
pitch_embed_dropout: float = 0.5
self.pitch_embed = torch.nn.Sequential(
torch.nn.Conv1d(
in_channels=1,
out_channels=attention_dim,
kernel_size=pitch_embed_kernel_size,
padding=(pitch_embed_kernel_size - 1) // 2,
),
torch.nn.Dropout(pitch_embed_dropout),
)
# NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg
energy_embed_kernel_size: int = 9
energy_embed_dropout: float = 0.5
self.energy_embed = torch.nn.Sequential(
torch.nn.Conv1d(
in_channels=1,
out_channels=attention_dim,
kernel_size=energy_embed_kernel_size,
padding=(energy_embed_kernel_size - 1) // 2,
),
torch.nn.Dropout(energy_embed_dropout),
)
# define length regulator
self.length_regulator = GaussianUpsampling()
self.segment_size = cfg.train.segment_size
# Define HiFiGAN generator
hifi_cfg = load_config("egs/vocoder/gan/hifigan/exp_config.json")
# hifi_cfg.model.hifigan.resblock_kernel_sizes = [3, 7, 11]
hifi_cfg.preprocess.n_mel = attention_dim
self.generator = HiFiGAN(hifi_cfg)
def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
"""Make masks for self-attention.
Args:
ilens (LongTensor): Batch of lengths (B,).
Returns:
Tensor: Mask tensor for self-attention.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> ilens = [5, 3]
>>> self._source_mask(ilens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
return x_masks.unsqueeze(-2)
def forward(self, data, p_control=1.0, e_control=1.0, d_control=1.0):
speakers = data["spk_id"]
texts = data["texts"]
src_lens = data["text_len"]
max_src_len = max(src_lens)
feats = data["mel"]
mel_lens = data["target_len"] if "target_len" in data else None
feats_lengths = mel_lens
max_mel_len = max(mel_lens) if "target_len" in data else None
p_targets = data["pitch"] if "pitch" in data else None
e_targets = data["energy"] if "energy" in data else None
src_masks = get_mask_from_lengths(src_lens, max_src_len)
mel_masks = (
get_mask_from_lengths(mel_lens, max_mel_len)
if mel_lens is not None
else None
)
output = self.encoder(texts, src_masks)
if self.speaker_emb is not None:
output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
-1, max_src_len, -1
)
# Forward alignment module and obtain duration, averaged pitch, energy
h_masks = make_pad_mask(src_lens).to(output.device)
log_p_attn = self.alignment_module(
output, feats, src_lens, feats_lengths, h_masks
)
ds, bin_loss = viterbi_decode(log_p_attn, src_lens, feats_lengths)
ps = average_by_duration(
ds, p_targets.squeeze(-1), src_lens, feats_lengths
).unsqueeze(-1)
es = average_by_duration(
ds, e_targets.squeeze(-1), src_lens, feats_lengths
).unsqueeze(-1)
p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2)
e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2)
# FastSpeech2 variance adaptor
(
output,
p_predictions,
e_predictions,
log_d_predictions,
d_rounded,
mel_lens,
mel_masks,
) = self.variance_adaptor(
output,
src_masks,
mel_masks,
max_mel_len,
p_targets,
e_targets,
ds,
p_control,
e_control,
d_control,
ps,
es,
)
# forward decoder
zs, _ = self.decoder(output, mel_masks) # (B, T_feats, adim)
# get random segments
z_segments, z_start_idxs = get_random_segments(
zs.transpose(1, 2),
feats_lengths,
self.segment_size,
)
# forward generator
wav = self.generator(z_segments)
return (
wav,
bin_loss,
log_p_attn,
z_start_idxs,
log_d_predictions,
ds,
p_predictions,
ps,
e_predictions,
es,
src_lens,
feats_lengths,
)
def inference(self, data, p_control=1.0, e_control=1.0, d_control=1.0):
speakers = data["spk_id"]
texts = data["texts"]
src_lens = data["text_len"]
max_src_len = max(src_lens)
mel_lens = data["target_len"] if "target_len" in data else None
feats_lengths = mel_lens
max_mel_len = max(mel_lens) if "target_len" in data else None
p_targets = data["pitch"] if "pitch" in data else None
e_targets = data["energy"] if "energy" in data else None
d_targets = data["durations"] if "durations" in data else None
src_masks = get_mask_from_lengths(src_lens, max_src_len)
mel_masks = (
get_mask_from_lengths(mel_lens, max_mel_len)
if mel_lens is not None
else None
)
x_masks = self._source_mask(src_lens)
hs = self.encoder(texts, src_masks)
(
p_outs,
e_outs,
d_outs,
) = self.variance_adaptor.inference(
hs,
src_masks,
)
p_embs = self.pitch_embed(p_outs.unsqueeze(-1).transpose(1, 2)).transpose(1, 2)
e_embs = self.energy_embed(e_outs.unsqueeze(-1).transpose(1, 2)).transpose(1, 2)
hs = hs + e_embs + p_embs
# Duration predictor inference mode: log_d_pred to d_pred
offset = 1.0
d_predictions = torch.clamp(
torch.round(d_outs.exp() - offset), min=0
).long() # avoid negative value
# forward decoder
hs, mel_len = self.length_regulator_infer(hs, d_predictions, max_mel_len)
mel_mask = get_mask_from_lengths(mel_len)
zs, _ = self.decoder(hs, mel_mask) # (B, T_feats, adim)
# forward generator
wav = self.generator(zs.transpose(1, 2))
return wav, d_predictions