File size: 2,404 Bytes
8273cb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from tqdm import tqdm
from fairseq import utils
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
from tests.speech import TestFairseqSpeech
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestFastSpeech2(TestFairseqSpeech):
def setUp(self):
self.set_up_ljspeech()
@torch.no_grad()
def test_ljspeech_fastspeech2_checkpoint(self):
checkpoint_filename = "ljspeech_fastspeech2_g2p.pt"
path = self.download(self.base_url, self.root, checkpoint_filename)
models, cfg, task = load_model_ensemble_and_task(
[path.as_posix()],
arg_overrides={
"data": self.root.as_posix(),
"config_yaml": "cfg_ljspeech_g2p.yaml",
"vocoder": "griffin_lim",
"fp16": False,
},
)
if self.use_cuda:
for model in models:
model.cuda()
test_split = "ljspeech_test"
task.load_dataset(test_split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(test_split),
max_tokens=65_536,
max_positions=4_096,
num_workers=1,
).next_epoch_itr(shuffle=False)
progress = tqdm(batch_iterator, total=len(batch_iterator))
generator = task.build_generator(models, cfg)
mcd, n_samples = 0.0, 0
for sample in progress:
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
hypos = generator.generate(models[0], sample, has_targ=True)
rets = batch_mel_cepstral_distortion(
[hypo["targ_waveform"] for hypo in hypos],
[hypo["waveform"] for hypo in hypos],
sr=task.sr,
)
mcd += sum(d.item() for d, _ in rets)
n_samples += len(sample["id"].tolist())
mcd = round(mcd / n_samples, 1)
reference_mcd = 3.2
print(f"MCD: {mcd} (reference: {reference_mcd})")
self.assertAlmostEqual(mcd, reference_mcd, delta=0.1)
if __name__ == "__main__":
unittest.main()
|