File size: 6,800 Bytes
ecca75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# coding=utf-8
from copy import deepcopy
from .get_dataset import CreateDataset
from .logger import LoggerFactory
from .retrieve_dialog import RetrieveDialog
from .utils import load_json, load_txt, save_to_json

import logging
import os

logger = LoggerFactory.create_logger(name="test", level=logging.INFO)


class GetManualTestSamples:
    def __init__(
        self,
        role_name,
        role_data_path,
        save_samples_dir,
        save_samples_path=None,
        prompt_path="dataset_character.txt",
        max_seq_len=4000,
        retrieve_num=20,
    ):
        self.role_name = role_name.strip()
        self.role_data = load_json(role_data_path)
        self.role_info = self.role_data[0]["role_info"].strip()

        self.prompt = load_txt(prompt_path)
        self.prompt = self.prompt.replace("${role_name}", self.role_name)
        self.prompt = self.prompt.replace("${role_info}",
                                          f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()

        self.retrieve_num = retrieve_num
        self.retrieve = RetrieveDialog(role_name=self.role_name,
                                       raw_dialog_list=[d["dialog"] for d in self.role_data],
                                       retrieve_num=retrieve_num)

        self.max_seq_len = max_seq_len
        if not save_samples_path:
            save_samples_path = f"{self.role_name}.json"
        self.save_samples_path = os.path.join(save_samples_dir, save_samples_path)

    def _add_simi_dialog(self, history: list, content_length):
        retrieve_results = self.retrieve.get_retrieve_res(history, self.retrieve_num)
        simi_dialogs = deepcopy(retrieve_results)

        if simi_dialogs:
            simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
                                                         max_length=self.max_seq_len - content_length,
                                                         train_flag=False)
        logger.debug(f"retrieve_results: {retrieve_results}\nsimi_dialogs: {simi_dialogs}.")
        return simi_dialogs, retrieve_results

    def get_qa_samples_by_file(self,
                               questions_path,
                               user_name="user",
                               keep_retrieve_results_flag=False
                               ):
        questions = load_txt(questions_path).splitlines()
        samples = []
        for question in questions:
            question = question.replace('\\n', "\n")
            query = f"{user_name}:{question}" if ":" not in question else question
            content = self.prompt.replace("${dialog}", query)
            content = content.replace("${user_name}", user_name).strip()

            history = [query]
            simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))

            sample = {
                "role_name": self.role_name,
                "role_info": self.role_info,
                "user_name": user_name,
                "dialog": history,
                "simi_dialogs": simi_dialogs,
            }
            if keep_retrieve_results_flag and retrieve_results:
                sample["retrieve_results"] = retrieve_results
            samples.append(sample)
        self._save_samples(samples)

    def get_qa_samples_by_query(self,
                                questions_query,
                                user_name="user",
                                keep_retrieve_results_flag=False
                                ):
        question = questions_query
        samples = []
        question = question.replace('\\n', "\n")
        query = f"{user_name}: {question}" if ":" not in question else question
        content = self.prompt.replace("${dialog}", query)
        content = content.replace("${user_name}", user_name).strip()

        history = [query]
        simi_dialogs, retrieve_results = self._add_simi_dialog(history, len(content))

        sample = {
            "role_name": self.role_name,
            "role_info": self.role_info,
            "user_name": user_name,
            "dialog": history,
            "simi_dialogs": simi_dialogs,
        }
        if keep_retrieve_results_flag and retrieve_results:
            sample["retrieve_results"] = retrieve_results
        samples.append(sample)
        self._save_samples(samples)

    def _save_samples(self, samples):
        data = samples
        save_to_json(data, self.save_samples_path)


class CreateTestDataset:
    def __init__(self,
                 role_name,
                 role_samples_path=None,
                 role_data_path=None,
                 prompt_path="dataset_character.txt",
                 max_seq_len=4000):
        self.max_seq_len = max_seq_len
        self.role_name = role_name

        self.prompt = load_txt(prompt_path)
        self.prompt = self.prompt.replace("${role_name}", role_name).strip()

        if not role_data_path:
            print("need role_data_path, check please!")
        self.default_simi_dialogs = None
        if os.path.exists(role_data_path):
            data = load_json(role_data_path)
            role_info = data[0]["role_info"]
        else:
            raise ValueError(f"{self.role_name} didn't find role_info.")
        self.role_info = role_info
        self.prompt = self.prompt.replace("${role_info}", f"以下是{self.role_name}的人设:\n{self.role_info}\n").strip()

        if role_samples_path:
            self.role_samples_path = role_samples_path
        else:
            print("check role_samples_path please!")

    def load_samples(self):
        samples = load_json(self.role_samples_path)
        results = []
        for sample in samples:
            input_text = self.prompt

            simi_dialogs = sample.get("simi_dialogs", None)
            if not simi_dialogs:
                simi_dialogs = self.default_simi_dialogs
            if not simi_dialogs:
                raise ValueError(f"didn't find simi_dialogs.")
            simi_dialogs = CreateDataset.choose_examples(simi_dialogs,
                                                         max_length=self.max_seq_len - len(input_text),
                                                         train_flag=False)

            input_text = input_text.replace("${simi_dialog}", simi_dialogs)
            user_name = sample.get("user_name", "user")
            input_text = input_text.replace("${user_name}", user_name)

            dialog = "\n".join(sample["dialog"]) if isinstance(sample["dialog"], list) else sample["dialog"]
            input_text = input_text.replace("${dialog}", dialog)

            assert len(input_text) < self.max_seq_len
            results.append({
                "input_text": input_text,
            })
        return results