Upload interact_mmi.py
#9
by
chenmingxuan
- opened
- interact_mmi.py +238 -0
interact_mmi.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import argparse
|
8 |
+
# from torch.utils.tensorboard import SummaryWriter
|
9 |
+
from datetime import datetime
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch.nn import DataParallel
|
12 |
+
import logging
|
13 |
+
from transformers.modeling_gpt2 import GPT2Config, GPT2LMHeadModel
|
14 |
+
from transformers import BertTokenizer
|
15 |
+
from os.path import join, exists
|
16 |
+
from itertools import zip_longest, chain
|
17 |
+
from dataset import MyDataset
|
18 |
+
from torch.utils.data import Dataset, DataLoader
|
19 |
+
from torch.nn import CrossEntropyLoss
|
20 |
+
from sklearn.model_selection import train_test_split
|
21 |
+
from train import create_model
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import copy
|
24 |
+
|
25 |
+
PAD = '[PAD]'
|
26 |
+
pad_id = 0
|
27 |
+
|
28 |
+
|
29 |
+
def set_interact_args():
|
30 |
+
"""
|
31 |
+
Sets up the training arguments.
|
32 |
+
"""
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument('--device', default='1', type=str, required=False, help='生成设备')
|
35 |
+
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature')
|
36 |
+
parser.add_argument('--topk', default=8, type=int, required=False, help='最高k选1')
|
37 |
+
parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
|
38 |
+
parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
|
39 |
+
help='模型参数')
|
40 |
+
parser.add_argument('--log_path', default='data/interacting_mmi.log', type=str, required=False,
|
41 |
+
help='interact_mmi日志存放位置')
|
42 |
+
parser.add_argument('--voca_path', default='vocabulary/vocab_small.txt', type=str, required=False, help='选择词库')
|
43 |
+
parser.add_argument('--dialogue_model_path', default='dialogue_model/', type=str, required=False,
|
44 |
+
help='dialogue_model路径')
|
45 |
+
parser.add_argument('--mmi_model_path', default='mmi_model/', type=str, required=False,
|
46 |
+
help='互信息mmi_model路径')
|
47 |
+
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
|
48 |
+
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
|
49 |
+
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
|
50 |
+
parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
|
51 |
+
parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
|
52 |
+
parser.add_argument('--max_history_len', type=int, default=5, help="dialogue history的最大长度")
|
53 |
+
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
|
54 |
+
parser.add_argument('--batch_size', type=int, default=5, help='批量生成response,然后经过MMI模型进行筛选')
|
55 |
+
parser.add_argument('--debug', action='store_true', help='指定该参数,可以查看生成的所有候选的reponse,及其loss')
|
56 |
+
return parser.parse_args()
|
57 |
+
|
58 |
+
|
59 |
+
def create_logger(args):
|
60 |
+
"""
|
61 |
+
将日志输出到日志文件和控制台
|
62 |
+
"""
|
63 |
+
logger = logging.getLogger(__name__)
|
64 |
+
logger.setLevel(logging.INFO)
|
65 |
+
|
66 |
+
formatter = logging.Formatter(
|
67 |
+
'%(asctime)s - %(levelname)s - %(message)s')
|
68 |
+
|
69 |
+
# 创建一个handler,用于写入日志文件
|
70 |
+
file_handler = logging.FileHandler(
|
71 |
+
filename=args.log_path)
|
72 |
+
file_handler.setFormatter(formatter)
|
73 |
+
file_handler.setLevel(logging.INFO)
|
74 |
+
logger.addHandler(file_handler)
|
75 |
+
|
76 |
+
# 创建一个handler,用于将日志输出到控制台
|
77 |
+
console = logging.StreamHandler()
|
78 |
+
console.setLevel(logging.DEBUG)
|
79 |
+
console.setFormatter(formatter)
|
80 |
+
logger.addHandler(console)
|
81 |
+
|
82 |
+
return logger
|
83 |
+
|
84 |
+
|
85 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
86 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
87 |
+
Args:
|
88 |
+
logits: logits distribution shape (vocabulary size)
|
89 |
+
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
90 |
+
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
91 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
92 |
+
"""
|
93 |
+
assert logits.dim() == 2
|
94 |
+
top_k = min(top_k, logits[0].size(-1)) # Safety check
|
95 |
+
if top_k > 0:
|
96 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
97 |
+
# torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices)
|
98 |
+
# ...表示其他维度由计算机自行推断
|
99 |
+
for logit in logits:
|
100 |
+
indices_to_remove = logit < torch.topk(logit, top_k)[0][..., -1, None]
|
101 |
+
logit[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷
|
102 |
+
|
103 |
+
if top_p > 0.0:
|
104 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) # 对logits进行递减排序
|
105 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
106 |
+
|
107 |
+
# Remove tokens with cumulative probability above the threshold
|
108 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
109 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
110 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
111 |
+
sorted_indices_to_remove[..., 0] = 0
|
112 |
+
for index, logit in enumerate(logits):
|
113 |
+
indices_to_remove = sorted_indices[index][sorted_indices_to_remove[index]]
|
114 |
+
logit[indices_to_remove] = filter_value
|
115 |
+
return logits
|
116 |
+
|
117 |
+
|
118 |
+
def main():
|
119 |
+
args = set_interact_args()
|
120 |
+
logger = create_logger(args)
|
121 |
+
# 当用户使用GPU,并且GPU可用时
|
122 |
+
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
123 |
+
device = 'cuda' if args.cuda else 'cpu'
|
124 |
+
logger.info('using device:{}'.format(device))
|
125 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
126 |
+
|
127 |
+
|
128 |
+
tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
129 |
+
# 对话model
|
130 |
+
dialogue_model = GPT2LMHeadModel.from_pretrained(args.dialogue_model_path)
|
131 |
+
dialogue_model.to(device)
|
132 |
+
dialogue_model.eval()
|
133 |
+
# 互信息mmi model
|
134 |
+
mmi_model = GPT2LMHeadModel.from_pretrained(args.mmi_model_path)
|
135 |
+
mmi_model.to(device)
|
136 |
+
mmi_model.eval()
|
137 |
+
if args.save_samples_path:
|
138 |
+
if not os.path.exists(args.save_samples_path):
|
139 |
+
os.makedirs(args.save_samples_path)
|
140 |
+
samples_file = open(args.save_samples_path + '/mmi_samples.txt', 'a', encoding='utf8')
|
141 |
+
samples_file.write("聊天记录{}:\n".format(datetime.now()))
|
142 |
+
# 存储聊天记录,每个utterance以token的id的形式进行存储
|
143 |
+
history = []
|
144 |
+
print('开始和chatbot聊天,输入CTRL + Z以退出')
|
145 |
+
|
146 |
+
while True:
|
147 |
+
try:
|
148 |
+
text = input("user:")
|
149 |
+
if args.save_samples_path:
|
150 |
+
samples_file.write("user:{}\n".format(text))
|
151 |
+
history.append(tokenizer.encode(text))
|
152 |
+
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
|
153 |
+
for history_id, history_utr in enumerate(history[-args.max_history_len:]):
|
154 |
+
input_ids.extend(history_utr)
|
155 |
+
input_ids.append(tokenizer.sep_token_id)
|
156 |
+
# 用于批量生成response,维度为(batch_size,token_len)
|
157 |
+
input_ids = [copy.deepcopy(input_ids) for _ in range(args.batch_size)]
|
158 |
+
|
159 |
+
curr_input_tensors = torch.tensor(input_ids).long().to(device)
|
160 |
+
generated = [] # 二维数组,维度为(生成的response的最大长度,batch_size),generated[i,j]表示第j个response的第i个token的id
|
161 |
+
finish_set = set() # 标记是否所有response均已生成结束,若第i个response生成结束,即生成了sep_token_id,则将i放入finish_set
|
162 |
+
# 最多生成max_len个token
|
163 |
+
for _ in range(args.max_len):
|
164 |
+
outputs = dialogue_model(input_ids=curr_input_tensors)
|
165 |
+
# print ("outputs",outputs)
|
166 |
+
next_token_logits = outputs[0][:, -1, :]
|
167 |
+
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
|
168 |
+
for index in range(args.batch_size):
|
169 |
+
for token_id in set([token_ids[index] for token_ids in generated]):
|
170 |
+
next_token_logits[index][token_id] /= args.repetition_penalty
|
171 |
+
next_token_logits = next_token_logits / args.temperature
|
172 |
+
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
|
173 |
+
for next_token_logit in next_token_logits:
|
174 |
+
next_token_logit[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
|
175 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
|
176 |
+
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
|
177 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
178 |
+
# 判断是否有response生成了[SEP],将已生成了[SEP]的resposne进行标记
|
179 |
+
for index, token_id in enumerate(next_token[:, 0]):
|
180 |
+
if token_id == tokenizer.sep_token_id:
|
181 |
+
finish_set.add(index)
|
182 |
+
# 检验是否所有的response均已生成[SEP]
|
183 |
+
finish_flag = True # 是否所有的response均已生成[SEP]的token
|
184 |
+
for index in range(args.batch_size):
|
185 |
+
if index not in finish_set: # response批量生成未完成
|
186 |
+
finish_flag = False
|
187 |
+
break
|
188 |
+
if finish_flag:
|
189 |
+
break
|
190 |
+
generated.append([token.item() for token in next_token[:, 0]])
|
191 |
+
# 将新生成的token与原来的token进行拼接
|
192 |
+
curr_input_tensors = torch.cat((curr_input_tensors, next_token), dim=-1)
|
193 |
+
candidate_responses = [] # 生成的所有候选response
|
194 |
+
for batch_index in range(args.batch_size):
|
195 |
+
response = []
|
196 |
+
for token_index in range(len(generated)):
|
197 |
+
if generated[token_index][batch_index] != tokenizer.sep_token_id:
|
198 |
+
response.append(generated[token_index][batch_index])
|
199 |
+
else:
|
200 |
+
break
|
201 |
+
candidate_responses.append(response)
|
202 |
+
|
203 |
+
# mmi模型的输入
|
204 |
+
if args.debug:
|
205 |
+
print("candidate response:")
|
206 |
+
samples_file.write("candidate response:\n")
|
207 |
+
min_loss = float('Inf')
|
208 |
+
best_response = ""
|
209 |
+
for response in candidate_responses:
|
210 |
+
mmi_input_id = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
|
211 |
+
mmi_input_id.extend(response)
|
212 |
+
mmi_input_id.append(tokenizer.sep_token_id)
|
213 |
+
for history_utr in reversed(history[-args.max_history_len:]):
|
214 |
+
mmi_input_id.extend(history_utr)
|
215 |
+
mmi_input_id.append(tokenizer.sep_token_id)
|
216 |
+
mmi_input_tensor = torch.tensor(mmi_input_id).long().to(device)
|
217 |
+
out = mmi_model(input_ids=mmi_input_tensor, labels=mmi_input_tensor)
|
218 |
+
loss = out[0].item()
|
219 |
+
if args.debug:
|
220 |
+
text = tokenizer.convert_ids_to_tokens(response)
|
221 |
+
print("{} loss:{}".format("".join(text), loss))
|
222 |
+
samples_file.write("{} loss:{}\n".format("".join(text), loss))
|
223 |
+
if loss < min_loss:
|
224 |
+
best_response = response
|
225 |
+
min_loss = loss
|
226 |
+
history.append(best_response)
|
227 |
+
text = tokenizer.convert_ids_to_tokens(best_response)
|
228 |
+
print("chatbot:" + "".join(text))
|
229 |
+
if args.save_samples_path:
|
230 |
+
samples_file.write("chatbot:{}\n".format("".join(text)))
|
231 |
+
except KeyboardInterrupt:
|
232 |
+
if args.save_samples_path:
|
233 |
+
samples_file.close()
|
234 |
+
break
|
235 |
+
|
236 |
+
|
237 |
+
if __name__ == '__main__':
|
238 |
+
main()
|