Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import os.path as op | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator | |
from fairseq.tasks import register_task | |
from fairseq.tasks.speech_to_text import SpeechToTextTask | |
from fairseq.speech_generator import ( | |
AutoRegressiveSpeechGenerator, | |
NonAutoregressiveSpeechGenerator, | |
TeacherForcingAutoRegressiveSpeechGenerator, | |
) | |
logging.basicConfig( | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger = logging.getLogger(__name__) | |
try: | |
from tensorboardX import SummaryWriter | |
except ImportError: | |
logger.info("Please install tensorboardX: pip install tensorboardX") | |
SummaryWriter = None | |
class TextToSpeechTask(SpeechToTextTask): | |
def add_args(parser): | |
parser.add_argument("data", help="manifest root path") | |
parser.add_argument( | |
"--config-yaml", | |
type=str, | |
default="config.yaml", | |
help="Configuration YAML filename (under manifest root)", | |
) | |
parser.add_argument( | |
"--max-source-positions", | |
default=1024, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the source sequence", | |
) | |
parser.add_argument( | |
"--max-target-positions", | |
default=1200, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the target sequence", | |
) | |
parser.add_argument("--n-frames-per-step", type=int, default=1) | |
parser.add_argument("--eos-prob-threshold", type=float, default=0.5) | |
parser.add_argument("--eval-inference", action="store_true") | |
parser.add_argument("--eval-tb-nsample", type=int, default=8) | |
parser.add_argument("--vocoder", type=str, default="griffin_lim") | |
parser.add_argument("--spec-bwd-max-iter", type=int, default=8) | |
def __init__(self, args, src_dict): | |
super().__init__(args, src_dict) | |
self.src_dict = src_dict | |
self.sr = self.data_cfg.config.get("features").get("sample_rate") | |
self.tensorboard_writer = None | |
self.tensorboard_dir = "" | |
if args.tensorboard_logdir and SummaryWriter is not None: | |
self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra") | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
is_train_split = split.startswith("train") | |
pre_tokenizer = self.build_tokenizer(self.args) | |
bpe_tokenizer = self.build_bpe(self.args) | |
self.datasets[split] = TextToSpeechDatasetCreator.from_tsv( | |
self.args.data, | |
self.data_cfg, | |
split, | |
self.src_dict, | |
pre_tokenizer, | |
bpe_tokenizer, | |
is_train_split=is_train_split, | |
epoch=epoch, | |
seed=self.args.seed, | |
n_frames_per_step=self.args.n_frames_per_step, | |
speaker_to_id=self.speaker_to_id, | |
) | |
def target_dictionary(self): | |
return None | |
def source_dictionary(self): | |
return self.src_dict | |
def get_speaker_embeddings_path(self): | |
speaker_emb_path = None | |
if self.data_cfg.config.get("speaker_emb_filename") is not None: | |
speaker_emb_path = op.join( | |
self.args.data, self.data_cfg.config.get("speaker_emb_filename") | |
) | |
return speaker_emb_path | |
def get_speaker_embeddings(cls, args): | |
embed_speaker = None | |
if args.speaker_to_id is not None: | |
if args.speaker_emb_path is None: | |
embed_speaker = torch.nn.Embedding( | |
len(args.speaker_to_id), args.speaker_embed_dim | |
) | |
else: | |
speaker_emb_mat = np.load(args.speaker_emb_path) | |
assert speaker_emb_mat.shape[1] == args.speaker_embed_dim | |
embed_speaker = torch.nn.Embedding.from_pretrained( | |
torch.from_numpy(speaker_emb_mat), | |
freeze=True, | |
) | |
logger.info( | |
f"load speaker embeddings from {args.speaker_emb_path}. " | |
f"train embedding? {embed_speaker.weight.requires_grad}\n" | |
f"embeddings:\n{speaker_emb_mat}" | |
) | |
return embed_speaker | |
def build_model(self, cfg, from_checkpoint=False): | |
cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None) | |
cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None) | |
cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None) | |
cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None) | |
cfg.speaker_emb_path = self.get_speaker_embeddings_path() | |
model = super().build_model(cfg, from_checkpoint) | |
self.generator = None | |
if getattr(cfg, "eval_inference", False): | |
self.generator = self.build_generator([model], cfg) | |
return model | |
def build_generator(self, models, cfg, vocoder=None, **unused): | |
if vocoder is None: | |
vocoder = self.build_default_vocoder() | |
model = models[0] | |
if getattr(model, "NON_AUTOREGRESSIVE", False): | |
return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg) | |
else: | |
generator = AutoRegressiveSpeechGenerator | |
if getattr(cfg, "teacher_forcing", False): | |
generator = TeacherForcingAutoRegressiveSpeechGenerator | |
logger.info("Teacher forcing mode for generation") | |
return generator( | |
model, | |
vocoder, | |
self.data_cfg, | |
max_iter=self.args.max_target_positions, | |
eos_prob_threshold=self.args.eos_prob_threshold, | |
) | |
def build_default_vocoder(self): | |
from fairseq.models.text_to_speech.vocoder import get_vocoder | |
vocoder = get_vocoder(self.args, self.data_cfg) | |
if torch.cuda.is_available() and not self.args.cpu: | |
vocoder = vocoder.cuda() | |
else: | |
vocoder = vocoder.cpu() | |
return vocoder | |
def valid_step(self, sample, model, criterion): | |
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
if getattr(self.args, "eval_inference", False): | |
hypos, inference_losses = self.valid_step_with_inference( | |
sample, model, self.generator | |
) | |
for k, v in inference_losses.items(): | |
assert k not in logging_output | |
logging_output[k] = v | |
picked_id = 0 | |
if self.tensorboard_dir and (sample["id"] == picked_id).any(): | |
self.log_tensorboard( | |
sample, | |
hypos[: self.args.eval_tb_nsample], | |
model._num_updates, | |
is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False), | |
) | |
return loss, sample_size, logging_output | |
def valid_step_with_inference(self, sample, model, generator): | |
hypos = generator.generate(model, sample, has_targ=True) | |
losses = { | |
"mcd_loss": 0.0, | |
"targ_frames": 0.0, | |
"pred_frames": 0.0, | |
"nins": 0.0, | |
"ndel": 0.0, | |
} | |
rets = batch_mel_cepstral_distortion( | |
[hypo["targ_waveform"] for hypo in hypos], | |
[hypo["waveform"] for hypo in hypos], | |
self.sr, | |
normalize_type=None, | |
) | |
for d, extra in rets: | |
pathmap = extra[-1] | |
losses["mcd_loss"] += d.item() | |
losses["targ_frames"] += pathmap.size(0) | |
losses["pred_frames"] += pathmap.size(1) | |
losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() | |
losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() | |
return hypos, losses | |
def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False): | |
if self.tensorboard_writer is None: | |
self.tensorboard_writer = SummaryWriter(self.tensorboard_dir) | |
tb_writer = self.tensorboard_writer | |
for b in range(len(hypos)): | |
idx = sample["id"][b] | |
text = sample["src_texts"][b] | |
targ = hypos[b]["targ_feature"] | |
pred = hypos[b]["feature"] | |
attn = hypos[b]["attn"] | |
if is_na_model: | |
data = plot_tts_output( | |
[targ.transpose(0, 1), pred.transpose(0, 1)], | |
[f"target (idx={idx})", "output"], | |
attn, | |
"alignment", | |
ret_np=True, | |
suptitle=text, | |
) | |
else: | |
eos_prob = hypos[b]["eos_prob"] | |
data = plot_tts_output( | |
[targ.transpose(0, 1), pred.transpose(0, 1), attn], | |
[f"target (idx={idx})", "output", "alignment"], | |
eos_prob, | |
"eos prob", | |
ret_np=True, | |
suptitle=text, | |
) | |
tb_writer.add_image( | |
f"inference_sample_{b}", data, num_updates, dataformats="HWC" | |
) | |
if hypos[b]["waveform"] is not None: | |
targ_wave = hypos[b]["targ_waveform"].detach().cpu().float() | |
pred_wave = hypos[b]["waveform"].detach().cpu().float() | |
tb_writer.add_audio( | |
f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr | |
) | |
tb_writer.add_audio( | |
f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr | |
) | |
def save_figure_to_numpy(fig): | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
DEFAULT_V_MIN = np.log(1e-5) | |
def plot_tts_output( | |
data_2d, | |
title_2d, | |
data_1d, | |
title_1d, | |
figsize=(24, 4), | |
v_min=DEFAULT_V_MIN, | |
v_max=3, | |
ret_np=False, | |
suptitle="", | |
): | |
try: | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
except ImportError: | |
raise ImportError("Please install Matplotlib: pip install matplotlib") | |
data_2d = [ | |
x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x | |
for x in data_2d | |
] | |
fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize) | |
if suptitle: | |
fig.suptitle(suptitle[:400]) # capped at 400 chars | |
axes = [axes] if len(data_2d) == 0 else axes | |
for ax, x, name in zip(axes, data_2d, title_2d): | |
ax.set_title(name) | |
divider = make_axes_locatable(ax) | |
cax = divider.append_axes("right", size="5%", pad=0.05) | |
im = ax.imshow( | |
x, | |
origin="lower", | |
aspect="auto", | |
vmin=max(x.min(), v_min), | |
vmax=min(x.max(), v_max), | |
) | |
fig.colorbar(im, cax=cax, orientation="vertical") | |
if isinstance(data_1d, torch.Tensor): | |
data_1d = data_1d.detach().cpu().numpy() | |
axes[-1].plot(data_1d) | |
axes[-1].set_title(title_1d) | |
plt.tight_layout() | |
if ret_np: | |
fig.canvas.draw() | |
data = save_figure_to_numpy(fig) | |
plt.close(fig) | |
return data | |
def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None): | |
""" | |
for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs | |
offset=2 (1, 1), | |
offset=3 (2, 1), (1, 2) | |
offset=4 (2, 2), (1, 3) | |
offset=5 (2, 3) | |
constraints: | |
i + j = offset | |
min_j <= j < max_j | |
min_i <= offset - j < max_i | |
""" | |
if max_i is None: | |
max_i = offset + 1 | |
if max_j is None: | |
max_j = offset + 1 | |
min_j = max(min_j, offset - max_i + 1, 0) | |
max_j = min(max_j, offset - min_i + 1, offset + 1) | |
j = torch.arange(min_j, max_j) | |
i = offset - j | |
return torch.stack([i, j]) | |
def batch_dynamic_time_warping(distance, shapes=None): | |
"""full batched DTW without any constraints | |
distance: (batchsize, max_M, max_N) matrix | |
shapes: (batchsize,) vector specifying (M, N) for each entry | |
""" | |
# ptr: 0=left, 1=up-left, 2=up | |
ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)} | |
bsz, m, n = distance.size() | |
cumdist = torch.zeros_like(distance) | |
backptr = torch.zeros_like(distance).type(torch.int32) - 1 | |
# initialize | |
cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1) | |
cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1) | |
backptr[:, 0, :] = 0 | |
backptr[:, :, 0] = 2 | |
# DP with optimized anti-diagonal parallelization, O(M+N) steps | |
for offset in range(2, m + n - 1): | |
ind = antidiag_indices(offset, 1, m, 1, n) | |
c = torch.stack( | |
[ | |
cumdist[:, ind[0], ind[1] - 1], | |
cumdist[:, ind[0] - 1, ind[1] - 1], | |
cumdist[:, ind[0] - 1, ind[1]], | |
], | |
dim=2, | |
) | |
v, b = c.min(axis=-1) | |
backptr[:, ind[0], ind[1]] = b.int() | |
cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]] | |
# backtrace | |
pathmap = torch.zeros_like(backptr) | |
for b in range(bsz): | |
i = m - 1 if shapes is None else (shapes[b][0] - 1).item() | |
j = n - 1 if shapes is None else (shapes[b][1] - 1).item() | |
dtwpath = [(i, j)] | |
while (i != 0 or j != 0) and len(dtwpath) < 10000: | |
assert i >= 0 and j >= 0 | |
di, dj = ptr2dij[backptr[b, i, j].item()] | |
i, j = i + di, j + dj | |
dtwpath.append((i, j)) | |
dtwpath = dtwpath[::-1] | |
indices = torch.from_numpy(np.array(dtwpath)) | |
pathmap[b, indices[:, 0], indices[:, 1]] = 1 | |
return cumdist, backptr, pathmap | |
def compute_l2_dist(x1, x2): | |
"""compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices""" | |
return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2) | |
def compute_rms_dist(x1, x2): | |
l2_dist = compute_l2_dist(x1, x2) | |
return (l2_dist / x1.size(1)).pow(0.5) | |
def get_divisor(pathmap, normalize_type): | |
if normalize_type is None: | |
return 1 | |
elif normalize_type == "len1": | |
return pathmap.size(0) | |
elif normalize_type == "len2": | |
return pathmap.size(1) | |
elif normalize_type == "path": | |
return pathmap.sum().item() | |
else: | |
raise ValueError(f"normalize_type {normalize_type} not supported") | |
def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type): | |
d, s, x1, x2 = [], [], [], [] | |
for cur_y1, cur_y2 in zip(y1, y2): | |
assert cur_y1.ndim == 1 and cur_y2.ndim == 1 | |
cur_x1 = feat_fn(cur_y1) | |
cur_x2 = feat_fn(cur_y2) | |
x1.append(cur_x1) | |
x2.append(cur_x2) | |
cur_d = dist_fn(cur_x1, cur_x2) | |
d.append(cur_d) | |
s.append(d[-1].size()) | |
max_m = max(ss[0] for ss in s) | |
max_n = max(ss[1] for ss in s) | |
d = torch.stack( | |
[F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d] | |
) | |
s = torch.LongTensor(s).to(d.device) | |
cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s) | |
rets = [] | |
itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps) | |
for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr: | |
cumdist = cumdist[:m, :n] | |
backptr = backptr[:m, :n] | |
pathmap = pathmap[:m, :n] | |
divisor = get_divisor(pathmap, normalize_type) | |
distortion = cumdist[-1, -1] / divisor | |
ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap) | |
rets.append(ret) | |
return rets | |
def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None): | |
""" | |
https://arxiv.org/pdf/2011.03568.pdf | |
The root mean squared error computed on 13-dimensional MFCC using DTW for | |
alignment. MFCC features are computed from an 80-channel log-mel | |
spectrogram using a 50ms Hann window and hop of 12.5ms. | |
y1: list of waveforms | |
y2: list of waveforms | |
sr: sampling rate | |
""" | |
try: | |
import torchaudio | |
except ImportError: | |
raise ImportError("Please install torchaudio: pip install torchaudio") | |
if mfcc_fn is None or mfcc_fn.sample_rate != sr: | |
melkwargs = { | |
"n_fft": int(0.05 * sr), | |
"win_length": int(0.05 * sr), | |
"hop_length": int(0.0125 * sr), | |
"f_min": 20, | |
"n_mels": 80, | |
"window_fn": torch.hann_window, | |
} | |
mfcc_fn = torchaudio.transforms.MFCC( | |
sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs | |
).to(y1[0].device) | |
return batch_compute_distortion( | |
y1, | |
y2, | |
sr, | |
lambda y: mfcc_fn(y).transpose(-1, -2), | |
compute_rms_dist, | |
normalize_type, | |
) | |