Spaces:
Build error
Build error
import os | |
from pathlib import Path | |
from typing import List, Type | |
from unittest import TestCase | |
from voicevox_engine.acoustic_feature_extractor import ( | |
BasePhoneme, | |
JvsPhoneme, | |
OjtPhoneme, | |
) | |
class TestBasePhoneme(TestCase): | |
def setUp(self): | |
super().setUp() | |
self.str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil" | |
self.base_hello_hiho = [ | |
BasePhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split()) | |
] | |
self.lab_str = """ | |
0.00 1.00 pau | |
1.00 2.00 k | |
2.00 3.00 o | |
3.00 4.00 N | |
4.00 5.00 n | |
5.00 6.00 i | |
6.00 7.00 ch | |
7.00 8.00 i | |
8.00 9.00 w | |
9.00 10.00 a | |
10.00 11.00 pau | |
11.00 12.00 h | |
12.00 13.00 i | |
13.00 14.00 h | |
14.00 15.00 o | |
15.00 16.00 d | |
16.00 17.00 e | |
17.00 18.00 s | |
18.00 19.00 U | |
19.00 20.00 pau | |
""".replace( | |
" ", "" | |
)[ | |
1:-1 | |
] # ダブルクオーテーションx3で囲われている部分で、空白をすべて置き換え、先頭と最後の"\n"を除外する | |
def test_repr_(self): | |
self.assertEqual( | |
self.base_hello_hiho[1].__repr__(), "Phoneme(phoneme='k', start=1, end=2)" | |
) | |
self.assertEqual( | |
self.base_hello_hiho[10].__repr__(), | |
"Phoneme(phoneme='pau', start=10, end=11)", | |
) | |
def test_convert(self): | |
with self.assertRaises(NotImplementedError): | |
BasePhoneme.convert(self.base_hello_hiho) | |
def test_duration(self): | |
self.assertEqual(self.base_hello_hiho[1].duration, 1) | |
def test_parse(self): | |
parse_str_1 = "0 1 pau" | |
parse_str_2 = "32.67543 33.48933 e" | |
parsed_base_1 = BasePhoneme.parse(parse_str_1) | |
parsed_base_2 = BasePhoneme.parse(parse_str_2) | |
self.assertEqual(parsed_base_1.phoneme, "pau") | |
self.assertEqual(parsed_base_1.start, 0.0) | |
self.assertEqual(parsed_base_1.end, 1.0) | |
self.assertEqual(parsed_base_2.phoneme, "e") | |
self.assertEqual(parsed_base_2.start, 32.68) | |
self.assertEqual(parsed_base_2.end, 33.49) | |
def lab_test_base( | |
self, | |
file_path: str, | |
phonemes: List["BasePhoneme"], | |
phoneme_class: Type["BasePhoneme"], | |
): | |
phoneme_class.save_lab_list(phonemes, Path(file_path)) | |
with open(file_path, mode="r") as f: | |
self.assertEqual(f.read(), self.lab_str) | |
result_phoneme = phoneme_class.load_lab_list(Path(file_path)) | |
self.assertEqual(result_phoneme, phonemes) | |
os.remove(file_path) | |
class TestJvsPhoneme(TestBasePhoneme): | |
def setUp(self): | |
super().setUp() | |
base_hello_hiho = [ | |
JvsPhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split()) | |
] | |
self.jvs_hello_hiho = JvsPhoneme.convert(base_hello_hiho) | |
def test_phoneme_list(self): | |
self.assertEqual(JvsPhoneme.phoneme_list[1], "I") | |
self.assertEqual(JvsPhoneme.phoneme_list[14], "gy") | |
self.assertEqual(JvsPhoneme.phoneme_list[26], "p") | |
self.assertEqual(JvsPhoneme.phoneme_list[38], "z") | |
def test_const(self): | |
self.assertEqual(JvsPhoneme.num_phoneme, 39) | |
self.assertEqual(JvsPhoneme.space_phoneme, "pau") | |
def test_convert(self): | |
converted_str_hello_hiho = " ".join([p.phoneme for p in self.jvs_hello_hiho]) | |
self.assertEqual( | |
converted_str_hello_hiho, "pau k o N n i ch i w a pau h i h o d e s U pau" | |
) | |
def test_equal(self): | |
# jvs_hello_hihoの2番目の"k"と比較 | |
true_jvs_phoneme = JvsPhoneme("k", 1, 2) | |
# OjtPhonemeと比べる、比較はBasePhoneme内で実装されているので、比較結果はTrue | |
true_ojt_phoneme = OjtPhoneme("k", 1, 2) | |
false_jvs_phoneme_1 = JvsPhoneme("a", 1, 2) | |
false_jvs_phoneme_2 = JvsPhoneme("k", 2, 3) | |
self.assertTrue(self.jvs_hello_hiho[1] == true_jvs_phoneme) | |
self.assertTrue(self.jvs_hello_hiho[1] == true_ojt_phoneme) | |
self.assertFalse(self.jvs_hello_hiho[1] == false_jvs_phoneme_1) | |
self.assertFalse(self.jvs_hello_hiho[1] == false_jvs_phoneme_2) | |
def test_verify(self): | |
for phoneme in self.jvs_hello_hiho: | |
phoneme.verify() | |
def test_phoneme_id(self): | |
jvs_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.jvs_hello_hiho]) | |
self.assertEqual( | |
jvs_str_hello_hiho, "0 19 25 2 23 17 7 17 36 4 0 15 17 15 25 9 11 30 3 0" | |
) | |
def test_onehot(self): | |
phoneme_id_list = [ | |
0, | |
19, | |
25, | |
2, | |
23, | |
17, | |
7, | |
17, | |
36, | |
4, | |
0, | |
15, | |
17, | |
15, | |
25, | |
9, | |
11, | |
30, | |
3, | |
0, | |
] | |
for i, phoneme in enumerate(self.jvs_hello_hiho): | |
for j in range(JvsPhoneme.num_phoneme): | |
if phoneme_id_list[i] == j: | |
self.assertEqual(phoneme.onehot[j], True) | |
else: | |
self.assertEqual(phoneme.onehot[j], False) | |
def test_parse(self): | |
parse_str_1 = "0 1 pau" | |
parse_str_2 = "15.32654 16.39454 a" | |
parsed_jvs_1 = JvsPhoneme.parse(parse_str_1) | |
parsed_jvs_2 = JvsPhoneme.parse(parse_str_2) | |
self.assertEqual(parsed_jvs_1.phoneme_id, 0) | |
self.assertEqual(parsed_jvs_2.phoneme_id, 4) | |
def test_lab_list(self): | |
self.lab_test_base("./jvs_lab_test", self.jvs_hello_hiho, JvsPhoneme) | |
class TestOjtPhoneme(TestBasePhoneme): | |
def setUp(self): | |
super().setUp() | |
self.str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil" | |
base_hello_hiho = [ | |
OjtPhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split()) | |
] | |
self.ojt_hello_hiho = OjtPhoneme.convert(base_hello_hiho) | |
def test_phoneme_list(self): | |
self.assertEqual(OjtPhoneme.phoneme_list[1], "A") | |
self.assertEqual(OjtPhoneme.phoneme_list[14], "e") | |
self.assertEqual(OjtPhoneme.phoneme_list[26], "m") | |
self.assertEqual(OjtPhoneme.phoneme_list[38], "ts") | |
self.assertEqual(OjtPhoneme.phoneme_list[41], "v") | |
def test_const(self): | |
self.assertEqual(OjtPhoneme.num_phoneme, 45) | |
self.assertEqual(OjtPhoneme.space_phoneme, "pau") | |
def test_convert(self): | |
ojt_str_hello_hiho = " ".join([p.phoneme for p in self.ojt_hello_hiho]) | |
self.assertEqual( | |
ojt_str_hello_hiho, "pau k o N n i ch i w a pau h i h o d e s U pau" | |
) | |
def test_equal(self): | |
# ojt_hello_hihoの10番目の"a"と比較 | |
true_ojt_phoneme = OjtPhoneme("a", 9, 10) | |
# JvsPhonemeと比べる、比較はBasePhoneme内で実装されているので、比較結果はTrue | |
true_jvs_phoneme = JvsPhoneme("a", 9, 10) | |
false_ojt_phoneme_1 = OjtPhoneme("k", 9, 10) | |
false_ojt_phoneme_2 = OjtPhoneme("a", 10, 11) | |
self.assertTrue(self.ojt_hello_hiho[9] == true_ojt_phoneme) | |
self.assertTrue(self.ojt_hello_hiho[9] == true_jvs_phoneme) | |
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_1) | |
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_2) | |
def test_verify(self): | |
for phoneme in self.ojt_hello_hiho: | |
phoneme.verify() | |
def test_phoneme_id(self): | |
ojt_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.ojt_hello_hiho]) | |
self.assertEqual( | |
ojt_str_hello_hiho, "0 23 30 4 28 21 10 21 42 7 0 19 21 19 30 12 14 35 6 0" | |
) | |
def test_onehot(self): | |
phoneme_id_list = [ | |
0, | |
23, | |
30, | |
4, | |
28, | |
21, | |
10, | |
21, | |
42, | |
7, | |
0, | |
19, | |
21, | |
19, | |
30, | |
12, | |
14, | |
35, | |
6, | |
0, | |
] | |
for i, phoneme in enumerate(self.ojt_hello_hiho): | |
for j in range(OjtPhoneme.num_phoneme): | |
if phoneme_id_list[i] == j: | |
self.assertEqual(phoneme.onehot[j], True) | |
else: | |
self.assertEqual(phoneme.onehot[j], False) | |
def test_parse(self): | |
parse_str_1 = "0 1 pau" | |
parse_str_2 = "32.67543 33.48933 e" | |
parsed_ojt_1 = OjtPhoneme.parse(parse_str_1) | |
parsed_ojt_2 = OjtPhoneme.parse(parse_str_2) | |
self.assertEqual(parsed_ojt_1.phoneme_id, 0) | |
self.assertEqual(parsed_ojt_2.phoneme_id, 14) | |
def tes_lab_list(self): | |
self.lab_test_base("./ojt_lab_test", self.ojt_hello_hiho, OjtPhoneme) | |