voicevox / test /test_acoustic_feature_extractor.py
2ndelement's picture
init
f1f433f
raw
history blame
8.97 kB
import os
from pathlib import Path
from typing import List, Type
from unittest import TestCase
from voicevox_engine.acoustic_feature_extractor import (
BasePhoneme,
JvsPhoneme,
OjtPhoneme,
)
class TestBasePhoneme(TestCase):
def setUp(self):
super().setUp()
self.str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil"
self.base_hello_hiho = [
BasePhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split())
]
self.lab_str = """
0.00 1.00 pau
1.00 2.00 k
2.00 3.00 o
3.00 4.00 N
4.00 5.00 n
5.00 6.00 i
6.00 7.00 ch
7.00 8.00 i
8.00 9.00 w
9.00 10.00 a
10.00 11.00 pau
11.00 12.00 h
12.00 13.00 i
13.00 14.00 h
14.00 15.00 o
15.00 16.00 d
16.00 17.00 e
17.00 18.00 s
18.00 19.00 U
19.00 20.00 pau
""".replace(
" ", ""
)[
1:-1
] # ダブルクオーテーションx3で囲われている部分で、空白をすべて置き換え、先頭と最後の"\n"を除外する
def test_repr_(self):
self.assertEqual(
self.base_hello_hiho[1].__repr__(), "Phoneme(phoneme='k', start=1, end=2)"
)
self.assertEqual(
self.base_hello_hiho[10].__repr__(),
"Phoneme(phoneme='pau', start=10, end=11)",
)
def test_convert(self):
with self.assertRaises(NotImplementedError):
BasePhoneme.convert(self.base_hello_hiho)
def test_duration(self):
self.assertEqual(self.base_hello_hiho[1].duration, 1)
def test_parse(self):
parse_str_1 = "0 1 pau"
parse_str_2 = "32.67543 33.48933 e"
parsed_base_1 = BasePhoneme.parse(parse_str_1)
parsed_base_2 = BasePhoneme.parse(parse_str_2)
self.assertEqual(parsed_base_1.phoneme, "pau")
self.assertEqual(parsed_base_1.start, 0.0)
self.assertEqual(parsed_base_1.end, 1.0)
self.assertEqual(parsed_base_2.phoneme, "e")
self.assertEqual(parsed_base_2.start, 32.68)
self.assertEqual(parsed_base_2.end, 33.49)
def lab_test_base(
self,
file_path: str,
phonemes: List["BasePhoneme"],
phoneme_class: Type["BasePhoneme"],
):
phoneme_class.save_lab_list(phonemes, Path(file_path))
with open(file_path, mode="r") as f:
self.assertEqual(f.read(), self.lab_str)
result_phoneme = phoneme_class.load_lab_list(Path(file_path))
self.assertEqual(result_phoneme, phonemes)
os.remove(file_path)
class TestJvsPhoneme(TestBasePhoneme):
def setUp(self):
super().setUp()
base_hello_hiho = [
JvsPhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split())
]
self.jvs_hello_hiho = JvsPhoneme.convert(base_hello_hiho)
def test_phoneme_list(self):
self.assertEqual(JvsPhoneme.phoneme_list[1], "I")
self.assertEqual(JvsPhoneme.phoneme_list[14], "gy")
self.assertEqual(JvsPhoneme.phoneme_list[26], "p")
self.assertEqual(JvsPhoneme.phoneme_list[38], "z")
def test_const(self):
self.assertEqual(JvsPhoneme.num_phoneme, 39)
self.assertEqual(JvsPhoneme.space_phoneme, "pau")
def test_convert(self):
converted_str_hello_hiho = " ".join([p.phoneme for p in self.jvs_hello_hiho])
self.assertEqual(
converted_str_hello_hiho, "pau k o N n i ch i w a pau h i h o d e s U pau"
)
def test_equal(self):
# jvs_hello_hihoの2番目の"k"と比較
true_jvs_phoneme = JvsPhoneme("k", 1, 2)
# OjtPhonemeと比べる、比較はBasePhoneme内で実装されているので、比較結果はTrue
true_ojt_phoneme = OjtPhoneme("k", 1, 2)
false_jvs_phoneme_1 = JvsPhoneme("a", 1, 2)
false_jvs_phoneme_2 = JvsPhoneme("k", 2, 3)
self.assertTrue(self.jvs_hello_hiho[1] == true_jvs_phoneme)
self.assertTrue(self.jvs_hello_hiho[1] == true_ojt_phoneme)
self.assertFalse(self.jvs_hello_hiho[1] == false_jvs_phoneme_1)
self.assertFalse(self.jvs_hello_hiho[1] == false_jvs_phoneme_2)
def test_verify(self):
for phoneme in self.jvs_hello_hiho:
phoneme.verify()
def test_phoneme_id(self):
jvs_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.jvs_hello_hiho])
self.assertEqual(
jvs_str_hello_hiho, "0 19 25 2 23 17 7 17 36 4 0 15 17 15 25 9 11 30 3 0"
)
def test_onehot(self):
phoneme_id_list = [
0,
19,
25,
2,
23,
17,
7,
17,
36,
4,
0,
15,
17,
15,
25,
9,
11,
30,
3,
0,
]
for i, phoneme in enumerate(self.jvs_hello_hiho):
for j in range(JvsPhoneme.num_phoneme):
if phoneme_id_list[i] == j:
self.assertEqual(phoneme.onehot[j], True)
else:
self.assertEqual(phoneme.onehot[j], False)
def test_parse(self):
parse_str_1 = "0 1 pau"
parse_str_2 = "15.32654 16.39454 a"
parsed_jvs_1 = JvsPhoneme.parse(parse_str_1)
parsed_jvs_2 = JvsPhoneme.parse(parse_str_2)
self.assertEqual(parsed_jvs_1.phoneme_id, 0)
self.assertEqual(parsed_jvs_2.phoneme_id, 4)
def test_lab_list(self):
self.lab_test_base("./jvs_lab_test", self.jvs_hello_hiho, JvsPhoneme)
class TestOjtPhoneme(TestBasePhoneme):
def setUp(self):
super().setUp()
self.str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil"
base_hello_hiho = [
OjtPhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split())
]
self.ojt_hello_hiho = OjtPhoneme.convert(base_hello_hiho)
def test_phoneme_list(self):
self.assertEqual(OjtPhoneme.phoneme_list[1], "A")
self.assertEqual(OjtPhoneme.phoneme_list[14], "e")
self.assertEqual(OjtPhoneme.phoneme_list[26], "m")
self.assertEqual(OjtPhoneme.phoneme_list[38], "ts")
self.assertEqual(OjtPhoneme.phoneme_list[41], "v")
def test_const(self):
self.assertEqual(OjtPhoneme.num_phoneme, 45)
self.assertEqual(OjtPhoneme.space_phoneme, "pau")
def test_convert(self):
ojt_str_hello_hiho = " ".join([p.phoneme for p in self.ojt_hello_hiho])
self.assertEqual(
ojt_str_hello_hiho, "pau k o N n i ch i w a pau h i h o d e s U pau"
)
def test_equal(self):
# ojt_hello_hihoの10番目の"a"と比較
true_ojt_phoneme = OjtPhoneme("a", 9, 10)
# JvsPhonemeと比べる、比較はBasePhoneme内で実装されているので、比較結果はTrue
true_jvs_phoneme = JvsPhoneme("a", 9, 10)
false_ojt_phoneme_1 = OjtPhoneme("k", 9, 10)
false_ojt_phoneme_2 = OjtPhoneme("a", 10, 11)
self.assertTrue(self.ojt_hello_hiho[9] == true_ojt_phoneme)
self.assertTrue(self.ojt_hello_hiho[9] == true_jvs_phoneme)
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_1)
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_2)
def test_verify(self):
for phoneme in self.ojt_hello_hiho:
phoneme.verify()
def test_phoneme_id(self):
ojt_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.ojt_hello_hiho])
self.assertEqual(
ojt_str_hello_hiho, "0 23 30 4 28 21 10 21 42 7 0 19 21 19 30 12 14 35 6 0"
)
def test_onehot(self):
phoneme_id_list = [
0,
23,
30,
4,
28,
21,
10,
21,
42,
7,
0,
19,
21,
19,
30,
12,
14,
35,
6,
0,
]
for i, phoneme in enumerate(self.ojt_hello_hiho):
for j in range(OjtPhoneme.num_phoneme):
if phoneme_id_list[i] == j:
self.assertEqual(phoneme.onehot[j], True)
else:
self.assertEqual(phoneme.onehot[j], False)
def test_parse(self):
parse_str_1 = "0 1 pau"
parse_str_2 = "32.67543 33.48933 e"
parsed_ojt_1 = OjtPhoneme.parse(parse_str_1)
parsed_ojt_2 = OjtPhoneme.parse(parse_str_2)
self.assertEqual(parsed_ojt_1.phoneme_id, 0)
self.assertEqual(parsed_ojt_2.phoneme_id, 14)
def tes_lab_list(self):
self.lab_test_base("./ojt_lab_test", self.ojt_hello_hiho, OjtPhoneme)