Spaces:
Running
Running
File size: 6,800 Bytes
ecca75f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# coding=utf-8
from copy import deepcopy
from .get_dataset import CreateDataset
from .logger import LoggerFactory
from .retrieve_dialog import RetrieveDialog
from .utils import load_json, load_txt, save_to_json
import logging
import os
logger = LoggerFactory.create_logger(name="test", level=logging.INFO)
class GetManualTestSamples:
def __init__(
self,
role_name,
role_data_path,
save_samples_dir,
save_samples_path=None,
prompt_path="dataset_character.txt",
max_seq_len=4000,
retrieve_num=20,
):
self.role_name = role_name.strip()
self.role_data = load_json(role_data_path)
self.role_info = self.role_data[0]["role_info"].strip()
self.prompt = load_txt(prompt_path)
self.prompt = self.prompt.replace("${role_name}", self.role_name)
self.prompt = self.prompt.replace("${role_info}",
f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
self.retrieve_num = retrieve_num
self.retrieve = RetrieveDialog(role_name=self.role_name,
raw_dialog_list=[d["dialog"] for d in self.role_data],
retrieve_num=retrieve_num)
self.max_seq_len = max_seq_len
if not save_samples_path:
save_samples_path = f"{self.role_name}.json"
self.save_samples_path = os.path.join(save_samples_dir, save_samples_path)
def _add_simi_dialog(self, history: list, content_length):
retrieve_results = self.retrieve.get_retrieve_res(history, self.retrieve_num)
simi_dialogs = deepcopy(retrieve_results)
if simi_dialogs:
simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
max_length=self.max_seq_len - content_length,
train_flag=False)
logger.debug(f"retrieve_results: {retrieve_results}\nsimi_dialogs: {simi_dialogs}.")
return simi_dialogs, retrieve_results
def get_qa_samples_by_file(self,
questions_path,
user_name="user",
keep_retrieve_results_flag=False
):
questions = load_txt(questions_path).splitlines()
samples = []
for question in questions:
question = question.replace('\\n', "\n")
query = f"{user_name}:{question}" if ":" not in question else question
content = self.prompt.replace("${dialog}", query)
content = content.replace("${user_name}", user_name).strip()
history = [query]
simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
sample = {
"role_name": self.role_name,
"role_info": self.role_info,
"user_name": user_name,
"dialog": history,
"simi_dialogs": simi_dialogs,
}
if keep_retrieve_results_flag and retrieve_results:
sample["retrieve_results"] = retrieve_results
samples.append(sample)
self._save_samples(samples)
def get_qa_samples_by_query(self,
questions_query,
user_name="user",
keep_retrieve_results_flag=False
):
question = questions_query
samples = []
question = question.replace('\\n', "\n")
query = f"{user_name}: {question}" if ":" not in question else question
content = self.prompt.replace("${dialog}", query)
content = content.replace("${user_name}", user_name).strip()
history = [query]
simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))
sample = {
"role_name": self.role_name,
"role_info": self.role_info,
"user_name": user_name,
"dialog": history,
"simi_dialogs": simi_dialogs,
}
if keep_retrieve_results_flag and retrieve_results:
sample["retrieve_results"] = retrieve_results
samples.append(sample)
self._save_samples(samples)
def _save_samples(self, samples):
data = samples
save_to_json(data, self.save_samples_path)
class CreateTestDataset:
def __init__(self,
role_name,
role_samples_path=None,
role_data_path=None,
prompt_path="dataset_character.txt",
max_seq_len=4000):
self.max_seq_len = max_seq_len
self.role_name = role_name
self.prompt = load_txt(prompt_path)
self.prompt = self.prompt.replace("${role_name}", role_name).strip()
if not role_data_path:
print("need role_data_path, check please!")
self.default_simi_dialogs = None
if os.path.exists(role_data_path):
data = load_json(role_data_path)
role_info = data[0]["role_info"]
else:
raise ValueError(f"{self.role_name} didn't find role_info.")
self.role_info = role_info
self.prompt = self.prompt.replace("${role_info}", f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()
if role_samples_path:
self.role_samples_path = role_samples_path
else:
print("check role_samples_path please!")
def load_samples(self):
samples = load_json(self.role_samples_path)
results = []
for sample in samples:
input_text = self.prompt
simi_dialogs = sample.get("simi_dialogs", None)
if not simi_dialogs:
simi_dialogs = self.default_simi_dialogs
if not simi_dialogs:
raise ValueError(f"didn't find simi_dialogs.")
simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
max_length=self.max_seq_len - len(input_text),
train_flag=False)
input_text = input_text.replace("${simi_dialog}", simi_dialogs)
user_name = sample.get("user_name", "user")
input_text = input_text.replace("${user_name}", user_name)
dialog = "\n".join(sample["dialog"]) if isinstance(sample["dialog"], list) else sample["dialog"]
input_text = input_text.replace("${dialog}", dialog)
assert len(input_text) < self.max_seq_len
results.append({
"input_text": input_text,
})
return results
|