Spaces:
Build error
Build error
import json | |
from copy import deepcopy | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
from typing import Dict | |
from unittest import TestCase | |
from fastapi import HTTPException | |
from pyopenjtalk import g2p, unset_user_dict | |
from voicevox_engine.model import UserDictWord, WordTypes | |
from voicevox_engine.part_of_speech_data import MAX_PRIORITY, part_of_speech_data | |
from voicevox_engine.user_dict import ( | |
apply_word, | |
create_word, | |
delete_word, | |
import_user_dict, | |
read_dict, | |
rewrite_word, | |
update_dict, | |
) | |
# jsonとして保存される正しい形式の辞書データ | |
valid_dict_dict_json = { | |
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": { | |
"surface": "test", | |
"cost": part_of_speech_data[WordTypes.PROPER_NOUN].cost_candidates[5], | |
"part_of_speech": "名詞", | |
"part_of_speech_detail_1": "固有名詞", | |
"part_of_speech_detail_2": "一般", | |
"part_of_speech_detail_3": "*", | |
"inflectional_type": "*", | |
"inflectional_form": "*", | |
"stem": "*", | |
"yomi": "テスト", | |
"pronunciation": "テスト", | |
"accent_type": 1, | |
"accent_associative_rule": "*", | |
}, | |
} | |
# APIでやり取りされる正しい形式の辞書データ | |
valid_dict_dict_api = deepcopy(valid_dict_dict_json) | |
del valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]["cost"] | |
valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]["priority"] = 5 | |
import_word = UserDictWord( | |
surface="test2", | |
priority=5, | |
part_of_speech="名詞", | |
part_of_speech_detail_1="固有名詞", | |
part_of_speech_detail_2="一般", | |
part_of_speech_detail_3="*", | |
inflectional_type="*", | |
inflectional_form="*", | |
stem="*", | |
yomi="テストツー", | |
pronunciation="テストツー", | |
accent_type=1, | |
accent_associative_rule="*", | |
) | |
def get_new_word(user_dict: Dict[str, UserDictWord]): | |
assert len(user_dict) == 2 or ( | |
len(user_dict) == 1 and "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e" not in user_dict | |
) | |
for word_uuid in user_dict.keys(): | |
if word_uuid == "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": | |
continue | |
return user_dict[word_uuid] | |
raise AssertionError | |
class TestUserDict(TestCase): | |
def setUp(self): | |
self.tmp_dir = TemporaryDirectory() | |
self.tmp_dir_path = Path(self.tmp_dir.name) | |
def tearDown(self): | |
unset_user_dict() | |
self.tmp_dir.cleanup() | |
def test_read_not_exist_json(self): | |
self.assertEqual( | |
read_dict(user_dict_path=(self.tmp_dir_path / "not_exist.json")), | |
{}, | |
) | |
def test_create_word(self): | |
# 将来的に品詞などが追加された時にテストを増やす | |
self.assertEqual( | |
create_word(surface="test", pronunciation="テスト", accent_type=1), | |
UserDictWord( | |
surface="test", | |
priority=5, | |
part_of_speech="名詞", | |
part_of_speech_detail_1="固有名詞", | |
part_of_speech_detail_2="一般", | |
part_of_speech_detail_3="*", | |
inflectional_type="*", | |
inflectional_form="*", | |
stem="*", | |
yomi="テスト", | |
pronunciation="テスト", | |
accent_type=1, | |
accent_associative_rule="*", | |
), | |
) | |
def test_apply_word_without_json(self): | |
user_dict_path = self.tmp_dir_path / "test_apply_word_without_json.json" | |
apply_word( | |
surface="test", | |
pronunciation="テスト", | |
accent_type=1, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_without_json.dic"), | |
) | |
res = read_dict(user_dict_path=user_dict_path) | |
self.assertEqual(len(res), 1) | |
new_word = get_new_word(res) | |
self.assertEqual( | |
( | |
new_word.surface, | |
new_word.pronunciation, | |
new_word.accent_type, | |
), | |
("test", "テスト", 1), | |
) | |
def test_apply_word_with_json(self): | |
user_dict_path = self.tmp_dir_path / "test_apply_word_with_json.json" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
apply_word( | |
surface="test2", | |
pronunciation="テストツー", | |
accent_type=3, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_with_json.dic"), | |
) | |
res = read_dict(user_dict_path=user_dict_path) | |
self.assertEqual(len(res), 2) | |
new_word = get_new_word(res) | |
self.assertEqual( | |
( | |
new_word.surface, | |
new_word.pronunciation, | |
new_word.accent_type, | |
), | |
("test2", "テストツー", 3), | |
) | |
def test_rewrite_word_invalid_id(self): | |
user_dict_path = self.tmp_dir_path / "test_rewrite_word_invalid_id.json" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
self.assertRaises( | |
HTTPException, | |
rewrite_word, | |
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9", | |
surface="test2", | |
pronunciation="テストツー", | |
accent_type=2, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_invalid_id.dic"), | |
) | |
def test_rewrite_word_valid_id(self): | |
user_dict_path = self.tmp_dir_path / "test_rewrite_word_valid_id.json" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
rewrite_word( | |
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e", | |
surface="test2", | |
pronunciation="テストツー", | |
accent_type=2, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_valid_id.dic"), | |
) | |
new_word = read_dict(user_dict_path=user_dict_path)[ | |
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e" | |
] | |
self.assertEqual( | |
(new_word.surface, new_word.pronunciation, new_word.accent_type), | |
("test2", "テストツー", 2), | |
) | |
def test_delete_word_invalid_id(self): | |
user_dict_path = self.tmp_dir_path / "test_delete_word_invalid_id.json" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
self.assertRaises( | |
HTTPException, | |
delete_word, | |
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9", | |
user_dict_path=user_dict_path, | |
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_invalid_id.dic"), | |
) | |
def test_delete_word_valid_id(self): | |
user_dict_path = self.tmp_dir_path / "test_delete_word_valid_id.json" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
delete_word( | |
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e", | |
user_dict_path=user_dict_path, | |
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_valid_id.dic"), | |
) | |
self.assertEqual(len(read_dict(user_dict_path=user_dict_path)), 0) | |
def test_priority(self): | |
for pos in part_of_speech_data: | |
for i in range(MAX_PRIORITY + 1): | |
self.assertEqual( | |
create_word( | |
surface="test", | |
pronunciation="テスト", | |
accent_type=1, | |
word_type=pos, | |
priority=i, | |
).priority, | |
i, | |
) | |
def test_import_dict(self): | |
user_dict_path = self.tmp_dir_path / "test_import_dict.json" | |
compiled_dict_path = self.tmp_dir_path / "test_import_dict.dic" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
import_user_dict( | |
{"b1affe2a-d5f0-4050-926c-f28e0c1d9a98": import_word}, | |
override=False, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=compiled_dict_path, | |
) | |
self.assertEqual( | |
read_dict(user_dict_path)["b1affe2a-d5f0-4050-926c-f28e0c1d9a98"], | |
import_word, | |
) | |
self.assertEqual( | |
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"], | |
UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]), | |
) | |
def test_import_dict_no_override(self): | |
user_dict_path = self.tmp_dir_path / "test_import_dict_no_override.json" | |
compiled_dict_path = self.tmp_dir_path / "test_import_dict_no_override.dic" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
import_user_dict( | |
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word}, | |
override=False, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=compiled_dict_path, | |
) | |
self.assertEqual( | |
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"], | |
UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]), | |
) | |
def test_import_dict_override(self): | |
user_dict_path = self.tmp_dir_path / "test_import_dict_override.json" | |
compiled_dict_path = self.tmp_dir_path / "test_import_dict_override.dic" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
import_user_dict( | |
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word}, | |
override=True, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=compiled_dict_path, | |
) | |
self.assertEqual( | |
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"], | |
import_word, | |
) | |
def test_import_invalid_word(self): | |
user_dict_path = self.tmp_dir_path / "test_import_invalid_dict.json" | |
compiled_dict_path = self.tmp_dir_path / "test_import_invalid_dict.dic" | |
invalid_accent_associative_rule_word = deepcopy(import_word) | |
invalid_accent_associative_rule_word.accent_associative_rule = "invalid" | |
user_dict_path.write_text( | |
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
) | |
self.assertRaises( | |
AssertionError, | |
import_user_dict, | |
{ | |
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_accent_associative_rule_word | |
}, | |
override=True, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=compiled_dict_path, | |
) | |
invalid_pos_word = deepcopy(import_word) | |
invalid_pos_word.context_id = 2 | |
invalid_pos_word.part_of_speech = "フィラー" | |
invalid_pos_word.part_of_speech_detail_1 = "*" | |
invalid_pos_word.part_of_speech_detail_2 = "*" | |
invalid_pos_word.part_of_speech_detail_3 = "*" | |
self.assertRaises( | |
ValueError, | |
import_user_dict, | |
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_pos_word}, | |
override=True, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=compiled_dict_path, | |
) | |
def test_update_dict(self): | |
user_dict_path = self.tmp_dir_path / "test_update_dict.json" | |
compiled_dict_path = self.tmp_dir_path / "test_update_dict.dic" | |
update_dict( | |
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path | |
) | |
test_text = "テスト用の文字列" | |
success_pronunciation = "デフォルトノジショデハゼッタイニセイセイサレナイヨミ" | |
# 既に辞書に登録されていないか確認する | |
self.assertNotEqual(g2p(text=test_text, kana=True), success_pronunciation) | |
apply_word( | |
surface=test_text, | |
pronunciation=success_pronunciation, | |
accent_type=1, | |
priority=10, | |
user_dict_path=user_dict_path, | |
compiled_dict_path=compiled_dict_path, | |
) | |
self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation) | |
# 疑似的にエンジンを再起動する | |
unset_user_dict() | |
update_dict( | |
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path | |
) | |
self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation) | |