Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import torch | |
from torch import nn | |
from fairseq import utils | |
from fairseq.data.data_utils import lengths_to_padding_mask | |
from fairseq.models import ( | |
FairseqEncoder, | |
FairseqEncoderModel, | |
register_model, | |
register_model_architecture, | |
) | |
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface | |
from fairseq.models.text_to_speech.tacotron2 import Postnet | |
from fairseq.modules import ( | |
FairseqDropout, | |
LayerNorm, | |
MultiheadAttention, | |
PositionalEmbedding, | |
) | |
logger = logging.getLogger(__name__) | |
def model_init(m): | |
if isinstance(m, nn.Conv1d): | |
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu")) | |
def Embedding(num_embeddings, embedding_dim, padding_idx=None): | |
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) | |
return m | |
class PositionwiseFeedForward(nn.Module): | |
def __init__(self, in_dim, hidden_dim, kernel_size, dropout): | |
super().__init__() | |
self.ffn = nn.Sequential( | |
nn.Conv1d( | |
in_dim, | |
hidden_dim, | |
kernel_size=kernel_size, | |
padding=(kernel_size - 1) // 2, | |
), | |
nn.ReLU(), | |
nn.Conv1d( | |
hidden_dim, | |
in_dim, | |
kernel_size=kernel_size, | |
padding=(kernel_size - 1) // 2, | |
), | |
) | |
self.layer_norm = LayerNorm(in_dim) | |
self.dropout = self.dropout_module = FairseqDropout( | |
p=dropout, module_name=self.__class__.__name__ | |
) | |
def forward(self, x): | |
# B x T x C | |
residual = x | |
x = self.ffn(x.transpose(1, 2)).transpose(1, 2) | |
x = self.dropout(x) | |
return self.layer_norm(x + residual) | |
class FFTLayer(torch.nn.Module): | |
def __init__( | |
self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, attention_dropout | |
): | |
super().__init__() | |
self.self_attn = MultiheadAttention( | |
embed_dim, n_heads, dropout=attention_dropout, self_attention=True | |
) | |
self.layer_norm = LayerNorm(embed_dim) | |
self.ffn = PositionwiseFeedForward( | |
embed_dim, hidden_dim, kernel_size, dropout=dropout | |
) | |
def forward(self, x, padding_mask=None): | |
# B x T x C | |
residual = x | |
x = x.transpose(0, 1) | |
x, _ = self.self_attn( | |
query=x, key=x, value=x, key_padding_mask=padding_mask, need_weights=False | |
) | |
x = x.transpose(0, 1) | |
x = self.layer_norm(x + residual) | |
return self.ffn(x) | |
class LengthRegulator(nn.Module): | |
def forward(self, x, durations): | |
# x: B x T x C | |
out_lens = durations.sum(dim=1) | |
max_len = out_lens.max() | |
bsz, seq_len, dim = x.size() | |
out = x.new_zeros((bsz, max_len, dim)) | |
for b in range(bsz): | |
indices = [] | |
for t in range(seq_len): | |
indices.extend([t] * utils.item(durations[b, t])) | |
indices = torch.tensor(indices, dtype=torch.long).to(x.device) | |
out_len = utils.item(out_lens[b]) | |
out[b, :out_len] = x[b].index_select(0, indices) | |
return out, out_lens | |
class VariancePredictor(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.conv1 = nn.Sequential( | |
nn.Conv1d( | |
args.encoder_embed_dim, | |
args.var_pred_hidden_dim, | |
kernel_size=args.var_pred_kernel_size, | |
padding=(args.var_pred_kernel_size - 1) // 2, | |
), | |
nn.ReLU(), | |
) | |
self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim) | |
self.dropout_module = FairseqDropout( | |
p=args.var_pred_dropout, module_name=self.__class__.__name__ | |
) | |
self.conv2 = nn.Sequential( | |
nn.Conv1d( | |
args.var_pred_hidden_dim, | |
args.var_pred_hidden_dim, | |
kernel_size=args.var_pred_kernel_size, | |
padding=1, | |
), | |
nn.ReLU(), | |
) | |
self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim) | |
self.proj = nn.Linear(args.var_pred_hidden_dim, 1) | |
def forward(self, x): | |
# Input: B x T x C; Output: B x T | |
x = self.conv1(x.transpose(1, 2)).transpose(1, 2) | |
x = self.dropout_module(self.ln1(x)) | |
x = self.conv2(x.transpose(1, 2)).transpose(1, 2) | |
x = self.dropout_module(self.ln2(x)) | |
return self.proj(x).squeeze(dim=2) | |
class VarianceAdaptor(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.length_regulator = LengthRegulator() | |
self.duration_predictor = VariancePredictor(args) | |
self.pitch_predictor = VariancePredictor(args) | |
self.energy_predictor = VariancePredictor(args) | |
n_bins, steps = self.args.var_pred_n_bins, self.args.var_pred_n_bins - 1 | |
self.pitch_bins = torch.linspace(args.pitch_min, args.pitch_max, steps) | |
self.embed_pitch = Embedding(n_bins, args.encoder_embed_dim) | |
self.energy_bins = torch.linspace(args.energy_min, args.energy_max, steps) | |
self.embed_energy = Embedding(n_bins, args.encoder_embed_dim) | |
def get_pitch_emb(self, x, tgt=None, factor=1.0): | |
out = self.pitch_predictor(x) | |
bins = self.pitch_bins.to(x.device) | |
if tgt is None: | |
out = out * factor | |
emb = self.embed_pitch(torch.bucketize(out, bins)) | |
else: | |
emb = self.embed_pitch(torch.bucketize(tgt, bins)) | |
return out, emb | |
def get_energy_emb(self, x, tgt=None, factor=1.0): | |
out = self.energy_predictor(x) | |
bins = self.energy_bins.to(x.device) | |
if tgt is None: | |
out = out * factor | |
emb = self.embed_energy(torch.bucketize(out, bins)) | |
else: | |
emb = self.embed_energy(torch.bucketize(tgt, bins)) | |
return out, emb | |
def forward( | |
self, | |
x, | |
padding_mask, | |
durations=None, | |
pitches=None, | |
energies=None, | |
d_factor=1.0, | |
p_factor=1.0, | |
e_factor=1.0, | |
): | |
# x: B x T x C | |
log_dur_out = self.duration_predictor(x) | |
dur_out = torch.clamp( | |
torch.round((torch.exp(log_dur_out) - 1) * d_factor).long(), min=0 | |
) | |
dur_out.masked_fill_(padding_mask, 0) | |
pitch_out, pitch_emb = self.get_pitch_emb(x, pitches, p_factor) | |
x = x + pitch_emb | |
energy_out, energy_emb = self.get_energy_emb(x, energies, e_factor) | |
x = x + energy_emb | |
x, out_lens = self.length_regulator( | |
x, dur_out if durations is None else durations | |
) | |
return x, out_lens, log_dur_out, pitch_out, energy_out | |
class FastSpeech2Encoder(FairseqEncoder): | |
def __init__(self, args, src_dict, embed_speaker): | |
super().__init__(src_dict) | |
self.args = args | |
self.padding_idx = src_dict.pad() | |
self.n_frames_per_step = args.n_frames_per_step | |
self.out_dim = args.output_frame_dim * args.n_frames_per_step | |
self.embed_speaker = embed_speaker | |
self.spk_emb_proj = None | |
if embed_speaker is not None: | |
self.spk_emb_proj = nn.Linear( | |
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim | |
) | |
self.dropout_module = FairseqDropout( | |
p=args.dropout, module_name=self.__class__.__name__ | |
) | |
self.embed_tokens = Embedding( | |
len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx | |
) | |
self.embed_positions = PositionalEmbedding( | |
args.max_source_positions, args.encoder_embed_dim, self.padding_idx | |
) | |
self.pos_emb_alpha = nn.Parameter(torch.ones(1)) | |
self.dec_pos_emb_alpha = nn.Parameter(torch.ones(1)) | |
self.encoder_fft_layers = nn.ModuleList( | |
FFTLayer( | |
args.encoder_embed_dim, | |
args.encoder_attention_heads, | |
args.fft_hidden_dim, | |
args.fft_kernel_size, | |
dropout=args.dropout, | |
attention_dropout=args.attention_dropout, | |
) | |
for _ in range(args.encoder_layers) | |
) | |
self.var_adaptor = VarianceAdaptor(args) | |
self.decoder_fft_layers = nn.ModuleList( | |
FFTLayer( | |
args.decoder_embed_dim, | |
args.decoder_attention_heads, | |
args.fft_hidden_dim, | |
args.fft_kernel_size, | |
dropout=args.dropout, | |
attention_dropout=args.attention_dropout, | |
) | |
for _ in range(args.decoder_layers) | |
) | |
self.out_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) | |
self.postnet = None | |
if args.add_postnet: | |
self.postnet = Postnet( | |
self.out_dim, | |
args.postnet_conv_dim, | |
args.postnet_conv_kernel_size, | |
args.postnet_layers, | |
args.postnet_dropout, | |
) | |
self.apply(model_init) | |
def forward( | |
self, | |
src_tokens, | |
src_lengths=None, | |
speaker=None, | |
durations=None, | |
pitches=None, | |
energies=None, | |
**kwargs, | |
): | |
x = self.embed_tokens(src_tokens) | |
enc_padding_mask = src_tokens.eq(self.padding_idx) | |
x += self.pos_emb_alpha * self.embed_positions(enc_padding_mask) | |
x = self.dropout_module(x) | |
for layer in self.encoder_fft_layers: | |
x = layer(x, enc_padding_mask) | |
if self.embed_speaker is not None: | |
bsz, seq_len, _ = x.size() | |
emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1) | |
x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) | |
x, out_lens, log_dur_out, pitch_out, energy_out = self.var_adaptor( | |
x, enc_padding_mask, durations, pitches, energies | |
) | |
dec_padding_mask = lengths_to_padding_mask(out_lens) | |
x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask) | |
for layer in self.decoder_fft_layers: | |
x = layer(x, dec_padding_mask) | |
x = self.out_proj(x) | |
x_post = None | |
if self.postnet is not None: | |
x_post = x + self.postnet(x) | |
return x, x_post, out_lens, log_dur_out, pitch_out, energy_out | |
class FastSpeech2Model(FairseqEncoderModel): | |
""" | |
Implementation for https://arxiv.org/abs/2006.04558 | |
""" | |
NON_AUTOREGRESSIVE = True | |
def hub_models(cls): | |
base_url = "http://dl.fbaipublicfiles.com/fairseq/s2" | |
model_ids = [ | |
"fastspeech2-en-ljspeech", | |
"fastspeech2-en-200_speaker-cv4", | |
] | |
return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} | |
def from_pretrained( | |
cls, | |
model_name_or_path, | |
checkpoint_file="model.pt", | |
data_name_or_path=".", | |
config_yaml="config.yaml", | |
vocoder: str = "griffin_lim", | |
fp16: bool = False, | |
**kwargs, | |
): | |
from fairseq import hub_utils | |
x = hub_utils.from_pretrained( | |
model_name_or_path, | |
checkpoint_file, | |
data_name_or_path, | |
archive_map=cls.hub_models(), | |
config_yaml=config_yaml, | |
vocoder=vocoder, | |
fp16=fp16, | |
**kwargs, | |
) | |
return TTSHubInterface(x["args"], x["task"], x["models"][0]) | |
def add_args(parser): | |
parser.add_argument("--dropout", type=float) | |
parser.add_argument("--output-frame-dim", type=int) | |
parser.add_argument("--speaker-embed-dim", type=int) | |
# FFT blocks | |
parser.add_argument("--fft-hidden-dim", type=int) | |
parser.add_argument("--fft-kernel-size", type=int) | |
parser.add_argument("--attention-dropout", type=float) | |
parser.add_argument("--encoder-layers", type=int) | |
parser.add_argument("--encoder-embed-dim", type=int) | |
parser.add_argument("--encoder-attention-heads", type=int) | |
parser.add_argument("--decoder-layers", type=int) | |
parser.add_argument("--decoder-embed-dim", type=int) | |
parser.add_argument("--decoder-attention-heads", type=int) | |
# variance predictor | |
parser.add_argument("--var-pred-n-bins", type=int) | |
parser.add_argument("--var-pred-hidden-dim", type=int) | |
parser.add_argument("--var-pred-kernel-size", type=int) | |
parser.add_argument("--var-pred-dropout", type=float) | |
# postnet | |
parser.add_argument("--add-postnet", action="store_true") | |
parser.add_argument("--postnet-dropout", type=float) | |
parser.add_argument("--postnet-layers", type=int) | |
parser.add_argument("--postnet-conv-dim", type=int) | |
parser.add_argument("--postnet-conv-kernel-size", type=int) | |
def __init__(self, encoder, args, src_dict): | |
super().__init__(encoder) | |
self._num_updates = 0 | |
out_dim = args.output_frame_dim * args.n_frames_per_step | |
self.ctc_proj = None | |
if getattr(args, "ctc_weight", 0.0) > 0.0: | |
self.ctc_proj = nn.Linear(out_dim, len(src_dict)) | |
def build_model(cls, args, task): | |
embed_speaker = task.get_speaker_embeddings(args) | |
encoder = FastSpeech2Encoder(args, task.src_dict, embed_speaker) | |
return cls(encoder, args, task.src_dict) | |
def set_num_updates(self, num_updates): | |
super().set_num_updates(num_updates) | |
self._num_updates = num_updates | |
def get_normalized_probs(self, net_output, log_probs, sample=None): | |
logits = self.ctc_proj(net_output[0]) | |
if log_probs: | |
return utils.log_softmax(logits.float(), dim=-1) | |
else: | |
return utils.softmax(logits.float(), dim=-1) | |
def base_architecture(args): | |
args.dropout = getattr(args, "dropout", 0.2) | |
args.output_frame_dim = getattr(args, "output_frame_dim", 80) | |
args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64) | |
# FFT blocks | |
args.fft_hidden_dim = getattr(args, "fft_hidden_dim", 1024) | |
args.fft_kernel_size = getattr(args, "fft_kernel_size", 9) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.0) | |
args.encoder_layers = getattr(args, "encoder_layers", 4) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2) | |
args.decoder_layers = getattr(args, "decoder_layers", 4) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2) | |
# variance predictor | |
args.var_pred_n_bins = getattr(args, "var_pred_n_bins", 256) | |
args.var_pred_hidden_dim = getattr(args, "var_pred_hidden_dim", 256) | |
args.var_pred_kernel_size = getattr(args, "var_pred_kernel_size", 3) | |
args.var_pred_dropout = getattr(args, "var_pred_dropout", 0.5) | |
# postnet | |
args.add_postnet = getattr(args, "add_postnet", False) | |
args.postnet_dropout = getattr(args, "postnet_dropout", 0.5) | |
args.postnet_layers = getattr(args, "postnet_layers", 5) | |
args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512) | |
args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5) | |