import os from pathlib import Path from typing import List, Type from unittest import TestCase from voicevox_engine.acoustic_feature_extractor import ( BasePhoneme, JvsPhoneme, OjtPhoneme, ) class TestBasePhoneme(TestCase): def setUp(self): super().setUp() self.str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil" self.base_hello_hiho = [ BasePhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split()) ] self.lab_str = """ 0.00 1.00 pau 1.00 2.00 k 2.00 3.00 o 3.00 4.00 N 4.00 5.00 n 5.00 6.00 i 6.00 7.00 ch 7.00 8.00 i 8.00 9.00 w 9.00 10.00 a 10.00 11.00 pau 11.00 12.00 h 12.00 13.00 i 13.00 14.00 h 14.00 15.00 o 15.00 16.00 d 16.00 17.00 e 17.00 18.00 s 18.00 19.00 U 19.00 20.00 pau """.replace( " ", "" )[ 1:-1 ] # ダブルクオーテーションx3で囲われている部分で、空白をすべて置き換え、先頭と最後の"\n"を除外する def test_repr_(self): self.assertEqual( self.base_hello_hiho[1].__repr__(), "Phoneme(phoneme='k', start=1, end=2)" ) self.assertEqual( self.base_hello_hiho[10].__repr__(), "Phoneme(phoneme='pau', start=10, end=11)", ) def test_convert(self): with self.assertRaises(NotImplementedError): BasePhoneme.convert(self.base_hello_hiho) def test_duration(self): self.assertEqual(self.base_hello_hiho[1].duration, 1) def test_parse(self): parse_str_1 = "0 1 pau" parse_str_2 = "32.67543 33.48933 e" parsed_base_1 = BasePhoneme.parse(parse_str_1) parsed_base_2 = BasePhoneme.parse(parse_str_2) self.assertEqual(parsed_base_1.phoneme, "pau") self.assertEqual(parsed_base_1.start, 0.0) self.assertEqual(parsed_base_1.end, 1.0) self.assertEqual(parsed_base_2.phoneme, "e") self.assertEqual(parsed_base_2.start, 32.68) self.assertEqual(parsed_base_2.end, 33.49) def lab_test_base( self, file_path: str, phonemes: List["BasePhoneme"], phoneme_class: Type["BasePhoneme"], ): phoneme_class.save_lab_list(phonemes, Path(file_path)) with open(file_path, mode="r") as f: self.assertEqual(f.read(), self.lab_str) result_phoneme = phoneme_class.load_lab_list(Path(file_path)) self.assertEqual(result_phoneme, phonemes) os.remove(file_path) class TestJvsPhoneme(TestBasePhoneme): def setUp(self): super().setUp() base_hello_hiho = [ JvsPhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split()) ] self.jvs_hello_hiho = JvsPhoneme.convert(base_hello_hiho) def test_phoneme_list(self): self.assertEqual(JvsPhoneme.phoneme_list[1], "I") self.assertEqual(JvsPhoneme.phoneme_list[14], "gy") self.assertEqual(JvsPhoneme.phoneme_list[26], "p") self.assertEqual(JvsPhoneme.phoneme_list[38], "z") def test_const(self): self.assertEqual(JvsPhoneme.num_phoneme, 39) self.assertEqual(JvsPhoneme.space_phoneme, "pau") def test_convert(self): converted_str_hello_hiho = " ".join([p.phoneme for p in self.jvs_hello_hiho]) self.assertEqual( converted_str_hello_hiho, "pau k o N n i ch i w a pau h i h o d e s U pau" ) def test_equal(self): # jvs_hello_hihoの2番目の"k"と比較 true_jvs_phoneme = JvsPhoneme("k", 1, 2) # OjtPhonemeと比べる、比較はBasePhoneme内で実装されているので、比較結果はTrue true_ojt_phoneme = OjtPhoneme("k", 1, 2) false_jvs_phoneme_1 = JvsPhoneme("a", 1, 2) false_jvs_phoneme_2 = JvsPhoneme("k", 2, 3) self.assertTrue(self.jvs_hello_hiho[1] == true_jvs_phoneme) self.assertTrue(self.jvs_hello_hiho[1] == true_ojt_phoneme) self.assertFalse(self.jvs_hello_hiho[1] == false_jvs_phoneme_1) self.assertFalse(self.jvs_hello_hiho[1] == false_jvs_phoneme_2) def test_verify(self): for phoneme in self.jvs_hello_hiho: phoneme.verify() def test_phoneme_id(self): jvs_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.jvs_hello_hiho]) self.assertEqual( jvs_str_hello_hiho, "0 19 25 2 23 17 7 17 36 4 0 15 17 15 25 9 11 30 3 0" ) def test_onehot(self): phoneme_id_list = [ 0, 19, 25, 2, 23, 17, 7, 17, 36, 4, 0, 15, 17, 15, 25, 9, 11, 30, 3, 0, ] for i, phoneme in enumerate(self.jvs_hello_hiho): for j in range(JvsPhoneme.num_phoneme): if phoneme_id_list[i] == j: self.assertEqual(phoneme.onehot[j], True) else: self.assertEqual(phoneme.onehot[j], False) def test_parse(self): parse_str_1 = "0 1 pau" parse_str_2 = "15.32654 16.39454 a" parsed_jvs_1 = JvsPhoneme.parse(parse_str_1) parsed_jvs_2 = JvsPhoneme.parse(parse_str_2) self.assertEqual(parsed_jvs_1.phoneme_id, 0) self.assertEqual(parsed_jvs_2.phoneme_id, 4) def test_lab_list(self): self.lab_test_base("./jvs_lab_test", self.jvs_hello_hiho, JvsPhoneme) class TestOjtPhoneme(TestBasePhoneme): def setUp(self): super().setUp() self.str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil" base_hello_hiho = [ OjtPhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split()) ] self.ojt_hello_hiho = OjtPhoneme.convert(base_hello_hiho) def test_phoneme_list(self): self.assertEqual(OjtPhoneme.phoneme_list[1], "A") self.assertEqual(OjtPhoneme.phoneme_list[14], "e") self.assertEqual(OjtPhoneme.phoneme_list[26], "m") self.assertEqual(OjtPhoneme.phoneme_list[38], "ts") self.assertEqual(OjtPhoneme.phoneme_list[41], "v") def test_const(self): self.assertEqual(OjtPhoneme.num_phoneme, 45) self.assertEqual(OjtPhoneme.space_phoneme, "pau") def test_convert(self): ojt_str_hello_hiho = " ".join([p.phoneme for p in self.ojt_hello_hiho]) self.assertEqual( ojt_str_hello_hiho, "pau k o N n i ch i w a pau h i h o d e s U pau" ) def test_equal(self): # ojt_hello_hihoの10番目の"a"と比較 true_ojt_phoneme = OjtPhoneme("a", 9, 10) # JvsPhonemeと比べる、比較はBasePhoneme内で実装されているので、比較結果はTrue true_jvs_phoneme = JvsPhoneme("a", 9, 10) false_ojt_phoneme_1 = OjtPhoneme("k", 9, 10) false_ojt_phoneme_2 = OjtPhoneme("a", 10, 11) self.assertTrue(self.ojt_hello_hiho[9] == true_ojt_phoneme) self.assertTrue(self.ojt_hello_hiho[9] == true_jvs_phoneme) self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_1) self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_2) def test_verify(self): for phoneme in self.ojt_hello_hiho: phoneme.verify() def test_phoneme_id(self): ojt_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.ojt_hello_hiho]) self.assertEqual( ojt_str_hello_hiho, "0 23 30 4 28 21 10 21 42 7 0 19 21 19 30 12 14 35 6 0" ) def test_onehot(self): phoneme_id_list = [ 0, 23, 30, 4, 28, 21, 10, 21, 42, 7, 0, 19, 21, 19, 30, 12, 14, 35, 6, 0, ] for i, phoneme in enumerate(self.ojt_hello_hiho): for j in range(OjtPhoneme.num_phoneme): if phoneme_id_list[i] == j: self.assertEqual(phoneme.onehot[j], True) else: self.assertEqual(phoneme.onehot[j], False) def test_parse(self): parse_str_1 = "0 1 pau" parse_str_2 = "32.67543 33.48933 e" parsed_ojt_1 = OjtPhoneme.parse(parse_str_1) parsed_ojt_2 = OjtPhoneme.parse(parse_str_2) self.assertEqual(parsed_ojt_1.phoneme_id, 0) self.assertEqual(parsed_ojt_2.phoneme_id, 14) def tes_lab_list(self): self.lab_test_base("./ojt_lab_test", self.ojt_hello_hiho, OjtPhoneme)