File size: 1,908 Bytes
8273cb9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from pathlib import Path
import torch
S3_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq"
class TestFairseqSpeech(unittest.TestCase):
@classmethod
def download(cls, base_url: str, out_root: Path, filename: str):
url = f"{base_url}/{filename}"
path = out_root / filename
if not path.exists():
torch.hub.download_url_to_file(url, path.as_posix(), progress=True)
return path
def set_up_librispeech(self):
self.use_cuda = torch.cuda.is_available()
self.root = Path.home() / ".cache" / "fairseq" / "librispeech"
self.root.mkdir(exist_ok=True, parents=True)
os.chdir(self.root)
self.data_filenames = [
"cfg_librispeech.yaml",
"spm_librispeech_unigram10000.model",
"spm_librispeech_unigram10000.txt",
"librispeech_test-other.tsv",
"librispeech_test-other.zip",
]
self.base_url = f"{S3_BASE_URL}/s2t/librispeech"
for filename in self.data_filenames:
self.download(self.base_url, self.root, filename)
def set_up_ljspeech(self):
self.use_cuda = torch.cuda.is_available()
self.root = Path.home() / ".cache" / "fairseq" / "ljspeech"
self.root.mkdir(exist_ok=True, parents=True)
os.chdir(self.root)
self.data_filenames = [
"cfg_ljspeech_g2p.yaml",
"ljspeech_g2p_gcmvn_stats.npz",
"ljspeech_g2p.txt",
"ljspeech_test.tsv",
"ljspeech_test.zip",
]
self.base_url = f"{S3_BASE_URL}/s2/ljspeech"
for filename in self.data_filenames:
self.download(self.base_url, self.root, filename)
|