voicevox / test /test_user_dict.py
2ndelement's picture
init
f1f433f
import json
from copy import deepcopy
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict
from unittest import TestCase
from fastapi import HTTPException
from pyopenjtalk import g2p, unset_user_dict
from voicevox_engine.model import UserDictWord, WordTypes
from voicevox_engine.part_of_speech_data import MAX_PRIORITY, part_of_speech_data
from voicevox_engine.user_dict import (
apply_word,
create_word,
delete_word,
import_user_dict,
read_dict,
rewrite_word,
update_dict,
)
# jsonとして保存される正しい形式の辞書データ
valid_dict_dict_json = {
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": {
"surface": "test",
"cost": part_of_speech_data[WordTypes.PROPER_NOUN].cost_candidates[5],
"part_of_speech": "名詞",
"part_of_speech_detail_1": "固有名詞",
"part_of_speech_detail_2": "一般",
"part_of_speech_detail_3": "*",
"inflectional_type": "*",
"inflectional_form": "*",
"stem": "*",
"yomi": "テスト",
"pronunciation": "テスト",
"accent_type": 1,
"accent_associative_rule": "*",
},
}
# APIでやり取りされる正しい形式の辞書データ
valid_dict_dict_api = deepcopy(valid_dict_dict_json)
del valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]["cost"]
valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]["priority"] = 5
import_word = UserDictWord(
surface="test2",
priority=5,
part_of_speech="名詞",
part_of_speech_detail_1="固有名詞",
part_of_speech_detail_2="一般",
part_of_speech_detail_3="*",
inflectional_type="*",
inflectional_form="*",
stem="*",
yomi="テストツー",
pronunciation="テストツー",
accent_type=1,
accent_associative_rule="*",
)
def get_new_word(user_dict: Dict[str, UserDictWord]):
assert len(user_dict) == 2 or (
len(user_dict) == 1 and "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e" not in user_dict
)
for word_uuid in user_dict.keys():
if word_uuid == "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e":
continue
return user_dict[word_uuid]
raise AssertionError
class TestUserDict(TestCase):
def setUp(self):
self.tmp_dir = TemporaryDirectory()
self.tmp_dir_path = Path(self.tmp_dir.name)
def tearDown(self):
unset_user_dict()
self.tmp_dir.cleanup()
def test_read_not_exist_json(self):
self.assertEqual(
read_dict(user_dict_path=(self.tmp_dir_path / "not_exist.json")),
{},
)
def test_create_word(self):
# 将来的に品詞などが追加された時にテストを増やす
self.assertEqual(
create_word(surface="test", pronunciation="テスト", accent_type=1),
UserDictWord(
surface="test",
priority=5,
part_of_speech="名詞",
part_of_speech_detail_1="固有名詞",
part_of_speech_detail_2="一般",
part_of_speech_detail_3="*",
inflectional_type="*",
inflectional_form="*",
stem="*",
yomi="テスト",
pronunciation="テスト",
accent_type=1,
accent_associative_rule="*",
),
)
def test_apply_word_without_json(self):
user_dict_path = self.tmp_dir_path / "test_apply_word_without_json.json"
apply_word(
surface="test",
pronunciation="テスト",
accent_type=1,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_without_json.dic"),
)
res = read_dict(user_dict_path=user_dict_path)
self.assertEqual(len(res), 1)
new_word = get_new_word(res)
self.assertEqual(
(
new_word.surface,
new_word.pronunciation,
new_word.accent_type,
),
("test", "テスト", 1),
)
def test_apply_word_with_json(self):
user_dict_path = self.tmp_dir_path / "test_apply_word_with_json.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
apply_word(
surface="test2",
pronunciation="テストツー",
accent_type=3,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_apply_word_with_json.dic"),
)
res = read_dict(user_dict_path=user_dict_path)
self.assertEqual(len(res), 2)
new_word = get_new_word(res)
self.assertEqual(
(
new_word.surface,
new_word.pronunciation,
new_word.accent_type,
),
("test2", "テストツー", 3),
)
def test_rewrite_word_invalid_id(self):
user_dict_path = self.tmp_dir_path / "test_rewrite_word_invalid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
self.assertRaises(
HTTPException,
rewrite_word,
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9",
surface="test2",
pronunciation="テストツー",
accent_type=2,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_invalid_id.dic"),
)
def test_rewrite_word_valid_id(self):
user_dict_path = self.tmp_dir_path / "test_rewrite_word_valid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
rewrite_word(
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e",
surface="test2",
pronunciation="テストツー",
accent_type=2,
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_valid_id.dic"),
)
new_word = read_dict(user_dict_path=user_dict_path)[
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"
]
self.assertEqual(
(new_word.surface, new_word.pronunciation, new_word.accent_type),
("test2", "テストツー", 2),
)
def test_delete_word_invalid_id(self):
user_dict_path = self.tmp_dir_path / "test_delete_word_invalid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
self.assertRaises(
HTTPException,
delete_word,
word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9",
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_invalid_id.dic"),
)
def test_delete_word_valid_id(self):
user_dict_path = self.tmp_dir_path / "test_delete_word_valid_id.json"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
delete_word(
word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e",
user_dict_path=user_dict_path,
compiled_dict_path=(self.tmp_dir_path / "test_delete_word_valid_id.dic"),
)
self.assertEqual(len(read_dict(user_dict_path=user_dict_path)), 0)
def test_priority(self):
for pos in part_of_speech_data:
for i in range(MAX_PRIORITY + 1):
self.assertEqual(
create_word(
surface="test",
pronunciation="テスト",
accent_type=1,
word_type=pos,
priority=i,
).priority,
i,
)
def test_import_dict(self):
user_dict_path = self.tmp_dir_path / "test_import_dict.json"
compiled_dict_path = self.tmp_dir_path / "test_import_dict.dic"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"b1affe2a-d5f0-4050-926c-f28e0c1d9a98": import_word},
override=False,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
self.assertEqual(
read_dict(user_dict_path)["b1affe2a-d5f0-4050-926c-f28e0c1d9a98"],
import_word,
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]),
)
def test_import_dict_no_override(self):
user_dict_path = self.tmp_dir_path / "test_import_dict_no_override.json"
compiled_dict_path = self.tmp_dir_path / "test_import_dict_no_override.dic"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word},
override=False,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]),
)
def test_import_dict_override(self):
user_dict_path = self.tmp_dir_path / "test_import_dict_override.json"
compiled_dict_path = self.tmp_dir_path / "test_import_dict_override.dic"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
import_user_dict(
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
self.assertEqual(
read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"],
import_word,
)
def test_import_invalid_word(self):
user_dict_path = self.tmp_dir_path / "test_import_invalid_dict.json"
compiled_dict_path = self.tmp_dir_path / "test_import_invalid_dict.dic"
invalid_accent_associative_rule_word = deepcopy(import_word)
invalid_accent_associative_rule_word.accent_associative_rule = "invalid"
user_dict_path.write_text(
json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8"
)
self.assertRaises(
AssertionError,
import_user_dict,
{
"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_accent_associative_rule_word
},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
invalid_pos_word = deepcopy(import_word)
invalid_pos_word.context_id = 2
invalid_pos_word.part_of_speech = "フィラー"
invalid_pos_word.part_of_speech_detail_1 = "*"
invalid_pos_word.part_of_speech_detail_2 = "*"
invalid_pos_word.part_of_speech_detail_3 = "*"
self.assertRaises(
ValueError,
import_user_dict,
{"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_pos_word},
override=True,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
def test_update_dict(self):
user_dict_path = self.tmp_dir_path / "test_update_dict.json"
compiled_dict_path = self.tmp_dir_path / "test_update_dict.dic"
update_dict(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
test_text = "テスト用の文字列"
success_pronunciation = "デフォルトノジショデハゼッタイニセイセイサレナイヨミ"
# 既に辞書に登録されていないか確認する
self.assertNotEqual(g2p(text=test_text, kana=True), success_pronunciation)
apply_word(
surface=test_text,
pronunciation=success_pronunciation,
accent_type=1,
priority=10,
user_dict_path=user_dict_path,
compiled_dict_path=compiled_dict_path,
)
self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation)
# 疑似的にエンジンを再起動する
unset_user_dict()
update_dict(
user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path
)
self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation)