LightChen2333 commited on
Commit
37b9e99
1 Parent(s): 0882d77

Upload 34 files

Browse files
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from common.config import Config
4
+ from common.model_manager import ModelManager
5
+
6
+ config = Config.load_from_yaml("config/app.yaml")
7
+ model_manager = ModelManager(config)
8
+ model_manager.load()
9
+
10
+
11
+ def text_analysis(text):
12
+ print(text)
13
+ data = model_manager.predict(text)
14
+ html = """<link href="https://cdn.staticfile.org/twitter-bootstrap/5.1.1/css/bootstrap.min.css" rel="stylesheet">
15
+ <script src="https://cdn.staticfile.org/twitter-bootstrap/5.1.1/js/bootstrap.bundle.min.js"></script>"""
16
+ html += """<div style="background: white; padding: 16px;"><b>Intent:</b>"""
17
+
18
+ for intent in data["intent"]:
19
+ html += """<button type="button" class="btn btn-white">
20
+ <span class="badge text-dark btn-light">""" + intent + """</span> </button>"""
21
+ html += """<br /> <b>Slot:</b>"""
22
+ for t, slot in zip(data["text"], data["slot"]):
23
+ html += """<button type="button" class="btn btn-white">"""+t+"""<span class="badge text-dark" style="background-color: rgb(255, 255, 255);
24
+ color: rgb(62 62 62);
25
+ box-shadow: 2px 2px 7px 1px rgba(210, 210, 210, 0.42);">"""+slot+\
26
+ """</span>
27
+ </button>"""
28
+ html+="</div>"
29
+ return html
30
+
31
+
32
+ demo = gr.Interface(
33
+ text_analysis,
34
+ gr.Textbox(placeholder="Enter sentence here..."),
35
+ ["html"],
36
+ examples=[
37
+ ["What a beautiful morning for a walk!"],
38
+ ["It was the best of times, it was the worst of times."],
39
+ ],
40
+ )
41
+
42
+ demo.launch()
common/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
common/config.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-26 10:55:43
6
+ Description: Configuration class to manage all process in OpenSLU like model construction, learning processing and so on.
7
+
8
+ '''
9
+ import re
10
+
11
+ from ruamel import yaml
12
+ import datetime
13
+
14
+ class Config(dict):
15
+ def __init__(self, *args, **kwargs):
16
+ """ init with dict as args
17
+ """
18
+ dict.__init__(self, *args, **kwargs)
19
+ self.__dict__ = self
20
+ self.start_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
21
+ self.__autowired()
22
+
23
+ @staticmethod
24
+ def load_from_yaml(file_path:str)->"Config":
25
+ """load config files with path
26
+
27
+ Args:
28
+ file_path (str): yaml configuration file path.
29
+
30
+ Returns:
31
+ Config: config object.
32
+ """
33
+ with open(file_path) as stream:
34
+ try:
35
+ return Config(yaml.safe_load(stream))
36
+ except yaml.YAMLError as exc:
37
+ print(exc)
38
+
39
+ @staticmethod
40
+ def load_from_args(args)->"Config":
41
+ """ load args to replace item value in config files assigned with '--config_path' or '--model'
42
+
43
+ Args:
44
+ args (Any): args with command line.
45
+
46
+ Returns:
47
+ Config: _description_
48
+ """
49
+ if args.model is not None:
50
+ args.config_path = "config/" + args.model + ".yaml"
51
+ config = Config.load_from_yaml(args.config_path)
52
+ if args.dataset is not None:
53
+ config.__update_dataset(args.dataset)
54
+ if args.device is not None:
55
+ config["base"]["device"] = args.device
56
+ if args.learning_rate is not None:
57
+ config["optimizer"]["lr"] = args.learning_rate
58
+ if args.epoch_num is not None:
59
+ config["base"]["epoch_num"] = args.epoch_num
60
+ return config
61
+
62
+ def autoload_template(self):
63
+ """ search '{*}' template to excute as python code, support replace variable as any configure item
64
+ """
65
+ self.__autoload_template(self.__dict__)
66
+
67
+ def __get_autoload_value(self, matched):
68
+ keys = matched.group()[1:-1].split(".")
69
+ temp = self.__dict__
70
+ for k in keys:
71
+ temp = temp[k]
72
+ return str(temp)
73
+
74
+ def __autoload_template(self, config:dict):
75
+ for k in config:
76
+ if isinstance(config, dict):
77
+ sub_config = config[k]
78
+ elif isinstance(config, list):
79
+ sub_config = k
80
+ else:
81
+ continue
82
+ if isinstance(sub_config, dict) or isinstance(sub_config, list):
83
+ self.__autoload_template(sub_config)
84
+ if isinstance(sub_config, str) and "{" in sub_config and "}" in sub_config:
85
+ res = re.sub(r'{.*?}', self.__get_autoload_value, config[k])
86
+ res_dict= {"res": None}
87
+ exec("res=" + res, res_dict)
88
+ config[k] = res_dict["res"]
89
+
90
+ def __update_dataset(self, dataset_name):
91
+ if dataset_name is not None and isinstance(dataset_name, str):
92
+ self.__dict__["dataset"]["dataset_name"] = dataset_name
93
+
94
+ def get_model_config(self):
95
+ return self.__dict__["model"]
96
+
97
+ def __autowired(self):
98
+ # Set encoder
99
+ encoder_config = self.__dict__["model"]["encoder"]
100
+ encoder_type = encoder_config["_model_target_"].split(".")[-1]
101
+
102
+ def get_output_dim(encoder_config):
103
+ encoder_type = encoder_config["_model_target_"].split(".")[-1]
104
+ if (encoder_type == "AutoEncoder" and encoder_config["encoder_name"] in ["lstm", "self-attention-lstm",
105
+ "bi-encoder"]) or encoder_type == "NoPretrainedEncoder":
106
+ output_dim = 0
107
+ if encoder_config.get("lstm"):
108
+ output_dim += encoder_config["lstm"]["output_dim"]
109
+ if encoder_config.get("attention"):
110
+ output_dim += encoder_config["attention"]["output_dim"]
111
+ return output_dim
112
+ else:
113
+ return encoder_config["output_dim"]
114
+
115
+ if encoder_type == "BiEncoder":
116
+ output_dim = get_output_dim(encoder_config["intent_encoder"]) + \
117
+ get_output_dim(encoder_config["slot_encoder"])
118
+ else:
119
+ output_dim = get_output_dim(encoder_config)
120
+ self.__dict__["model"]["encoder"]["output_dim"] = output_dim
121
+
122
+ # Set interaction
123
+ if "interaction" in self.__dict__["model"]["decoder"] and self.__dict__["model"]["decoder"]["interaction"].get(
124
+ "input_dim") is None:
125
+ self.__dict__["model"]["decoder"]["interaction"]["input_dim"] = output_dim
126
+ interaction_type = self.__dict__["model"]["decoder"]["interaction"]["_model_target_"].split(".")[-1]
127
+ if not ((encoder_type == "AutoEncoder" and encoder_config[
128
+ "encoder_name"] == "self-attention-lstm") or encoder_type == "SelfAttentionLSTMEncoder") and interaction_type != "BiModelWithoutDecoderInteraction":
129
+ output_dim = self.__dict__["model"]["decoder"]["interaction"]["output_dim"]
130
+
131
+ # Set classifier
132
+ if "slot_classifier" in self.__dict__["model"]["decoder"]:
133
+ if self.__dict__["model"]["decoder"]["slot_classifier"].get("input_dim") is None:
134
+ self.__dict__["model"]["decoder"]["slot_classifier"]["input_dim"] = output_dim
135
+ self.__dict__["model"]["decoder"]["slot_classifier"]["use_slot"] = True
136
+ if "intent_classifier" in self.__dict__["model"]["decoder"]:
137
+ if self.__dict__["model"]["decoder"]["intent_classifier"].get("input_dim") is None:
138
+ self.__dict__["model"]["decoder"]["intent_classifier"]["input_dim"] = output_dim
139
+ self.__dict__["model"]["decoder"]["intent_classifier"]["use_intent"] = True
140
+
141
+ def get_intent_label_num(self):
142
+ """ get the number of intent labels.
143
+ """
144
+ classifier_conf = self.__dict__["model"]["decoder"]["intent_classifier"]
145
+ return classifier_conf["intent_label_num"] if "intent_label_num" in classifier_conf else 0
146
+
147
+ def get_slot_label_num(self):
148
+ """ get the number of slot labels.
149
+ """
150
+ classifier_conf = self.__dict__["model"]["decoder"]["slot_classifier"]
151
+ return classifier_conf["slot_label_num"] if "slot_label_num" in classifier_conf else 0
152
+
153
+ def set_intent_label_num(self, intent_label_num):
154
+ """ set the number of intent labels.
155
+
156
+ Args:
157
+ slot_label_num (int): the number of intent label
158
+ """
159
+ self.__dict__["base"]["intent_label_num"] = intent_label_num
160
+ self.__dict__["model"]["decoder"]["intent_classifier"]["intent_label_num"] = intent_label_num
161
+ if "interaction" in self.__dict__["model"]["decoder"]:
162
+
163
+ self.__dict__["model"]["decoder"]["interaction"]["intent_label_num"] = intent_label_num
164
+ if self.__dict__["model"]["decoder"]["interaction"]["_model_target_"].split(".")[
165
+ -1] == "StackInteraction":
166
+ self.__dict__["model"]["decoder"]["slot_classifier"]["input_dim"] += intent_label_num
167
+
168
+
169
+ def set_slot_label_num(self, slot_label_num:int)->None:
170
+ """set the number of slot label
171
+
172
+ Args:
173
+ slot_label_num (int): the number of slot label
174
+ """
175
+ self.__dict__["base"]["slot_label_num"] = slot_label_num
176
+ self.__dict__["model"]["decoder"]["slot_classifier"]["slot_label_num"] = slot_label_num
177
+ if "interaction" in self.__dict__["model"]["decoder"]:
178
+ self.__dict__["model"]["decoder"]["interaction"]["slot_label_num"] = slot_label_num
179
+
180
+ def set_vocab_size(self, vocab_size):
181
+ """set the size of vocabulary in non-pretrained tokenizer
182
+ Args:
183
+ slot_label_num (int): the number of slot label
184
+ """
185
+ encoder_type = self.__dict__["model"]["encoder"]["_model_target_"].split(".")[-1]
186
+ encoder_name = self.__dict__["model"]["encoder"].get("encoder_name")
187
+ if encoder_type == "BiEncoder" or (encoder_type == "AutoEncoder" and encoder_name == "bi-encoder"):
188
+ self.__dict__["model"]["encoder"]["intent_encoder"]["embedding"]["vocab_size"] = vocab_size
189
+ self.__dict__["model"]["encoder"]["slot_encoder"]["embedding"]["vocab_size"] = vocab_size
190
+ elif self.__dict__["model"]["encoder"].get("embedding"):
191
+ self.__dict__["model"]["encoder"]["embedding"]["vocab_size"] = vocab_size
common/loader.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-02-07 19:26:06
6
+ Description: all class for load data.
7
+
8
+ '''
9
+ import os
10
+ import torch
11
+ import json
12
+ from datasets import load_dataset, Dataset
13
+ from torch.utils.data import DataLoader
14
+
15
+ from common.utils import InputData
16
+
17
+ ABS_PATH=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../")
18
+
19
+ class DataFactory(object):
20
+ def __init__(self, tokenizer,use_multi_intent=False, to_lower_case=True):
21
+ """_summary_
22
+
23
+ Args:
24
+ tokenizer (Tokenizer): _description_
25
+ use_multi_intent (bool, optional): _description_. Defaults to False.
26
+ """
27
+ self.tokenizer = tokenizer
28
+ self.slot_label_list = []
29
+ self.intent_label_list = []
30
+ self.use_multi = use_multi_intent
31
+ self.to_lower_case = to_lower_case
32
+ self.slot_label_dict = None
33
+ self.intent_label_dict = None
34
+
35
+ def __is_supported_datasets(self, dataset_name:str)->bool:
36
+ return dataset_name.lower() in ["atis", "snips", "mix-atis", "mix-atis"]
37
+
38
+ def load_dataset(self, dataset_config, split="train"):
39
+ # TODO: 关闭use_auth_token
40
+ dataset_name = None
41
+ if split not in dataset_config:
42
+ dataset_name = dataset_config.get("dataset_name")
43
+ elif self.__is_supported_datasets(dataset_config[split]):
44
+ dataset_name = dataset_config[split].lower()
45
+ if dataset_name is not None:
46
+ return load_dataset("LightChen2333/OpenSLU", dataset_name, split=split, use_auth_token=True)
47
+ else:
48
+ data_file = dataset_config[split]
49
+ data_dict = {"text": [], "slot": [], "intent":[]}
50
+ with open(data_file, encoding="utf-8") as f:
51
+ for line in f:
52
+ row = json.loads(line)
53
+ data_dict["text"].append(row["text"])
54
+ data_dict["slot"].append(row["slot"])
55
+ data_dict["intent"].append(row["intent"])
56
+ return Dataset.from_dict(data_dict)
57
+
58
+ def update_label_names(self, dataset):
59
+ for intent_labels in dataset["intent"]:
60
+ if self.use_multi:
61
+ intent_label = intent_labels.split("#")
62
+ else:
63
+ intent_label = [intent_labels]
64
+ for x in intent_label:
65
+ if x not in self.intent_label_list:
66
+ self.intent_label_list.append(x)
67
+ for slot_label in dataset["slot"]:
68
+ for x in slot_label:
69
+ if x not in self.slot_label_list:
70
+ self.slot_label_list.append(x)
71
+ self.intent_label_dict = {key: index for index,
72
+ key in enumerate(self.intent_label_list)}
73
+ self.slot_label_dict = {key: index for index,
74
+ key in enumerate(self.slot_label_list)}
75
+
76
+ def update_vocabulary(self, dataset):
77
+ if self.tokenizer.name_or_path in ["word_tokenizer"]:
78
+ for data in dataset:
79
+ self.tokenizer.add_instance(data["text"])
80
+
81
+ @staticmethod
82
+ def fast_align_data(text, padding_side="right"):
83
+ for i in range(len(text.input_ids)):
84
+ desired_output = []
85
+ for word_id in text.word_ids(i):
86
+ if word_id is not None:
87
+ start, end = text.word_to_tokens(
88
+ i, word_id, sequence_index=0 if padding_side == "right" else 1)
89
+ if start == end - 1:
90
+ tokens = [start]
91
+ else:
92
+ tokens = [start, end - 1]
93
+ if len(desired_output) == 0 or desired_output[-1] != tokens:
94
+ desired_output.append(tokens)
95
+ yield desired_output
96
+
97
+ def fast_align(self,
98
+ batch,
99
+ ignore_index=-100,
100
+ device="cuda",
101
+ config=None,
102
+ enable_label=True,
103
+ label2tensor=True):
104
+ if self.to_lower_case:
105
+ input_list = [[t.lower() for t in x["text"]] for x in batch]
106
+ else:
107
+ input_list = [x["text"] for x in batch]
108
+ text = self.tokenizer(input_list,
109
+ return_tensors="pt",
110
+ padding=True,
111
+ is_split_into_words=True,
112
+ truncation=True,
113
+ **config).to(device)
114
+ if enable_label:
115
+ if label2tensor:
116
+
117
+ slot_mask = torch.ones_like(text.input_ids) * ignore_index
118
+ for i, offsets in enumerate(
119
+ DataFactory.fast_align_data(text, padding_side=self.tokenizer.padding_side)):
120
+ num = 0
121
+ assert len(offsets) == len(batch[i]["text"])
122
+ assert len(offsets) == len(batch[i]["slot"])
123
+ for off in offsets:
124
+ slot_mask[i][off[0]
125
+ ] = self.slot_label_dict[batch[i]["slot"][num]]
126
+ num += 1
127
+ slot = slot_mask.clone()
128
+ attentin_id = 0 if self.tokenizer.padding_side == "right" else 1
129
+ for i, slot_batch in enumerate(slot):
130
+ for j, x in enumerate(slot_batch):
131
+ if x == ignore_index and text.attention_mask[i][j] == attentin_id and (text.input_ids[i][
132
+ j] not in self.tokenizer.all_special_ids or text.input_ids[i][j] == self.tokenizer.unk_token_id):
133
+ slot[i][j] = slot[i][j - 1]
134
+ slot = slot.to(device)
135
+ if not self.use_multi:
136
+ intent = torch.tensor(
137
+ [self.intent_label_dict[x["intent"]] for x in batch]).to(device)
138
+ else:
139
+ one_hot = torch.zeros(
140
+ (len(batch), len(self.intent_label_list)), dtype=torch.float)
141
+ for index, b in enumerate(batch):
142
+ for x in b["intent"].split("#"):
143
+ one_hot[index][self.intent_label_dict[x]] = 1.
144
+ intent = one_hot.to(device)
145
+ else:
146
+ slot_mask = None
147
+ slot = [['#' for _ in range(text.input_ids.shape[1])]
148
+ for _ in range(text.input_ids.shape[0])]
149
+ for i, offsets in enumerate(DataFactory.fast_align_data(text)):
150
+ num = 0
151
+ for off in offsets:
152
+ slot[i][off[0]] = batch[i]["slot"][num]
153
+ num += 1
154
+ if not self.use_multi:
155
+ intent = [x["intent"] for x in batch]
156
+ else:
157
+ intent = [
158
+ [x for x in b["intent"].split("#")] for b in batch]
159
+ return InputData((text, slot, intent))
160
+ else:
161
+ return InputData((text, None, None))
162
+
163
+ def general_align_data(self, split_text_list, raw_text_list, encoded_text):
164
+ for i in range(len(split_text_list)):
165
+ desired_output = []
166
+ jdx = 0
167
+ offset = encoded_text.offset_mapping[i].tolist()
168
+ split_texts = split_text_list[i]
169
+ raw_text = raw_text_list[i]
170
+ last = 0
171
+ temp_offset = []
172
+ for off in offset:
173
+ s, e = off
174
+ if len(temp_offset) > 0 and (e != 0 and last == s):
175
+ len_1 = off[1] - off[0]
176
+ len_2 = temp_offset[-1][1] - temp_offset[-1][0]
177
+ if len_1 > len_2:
178
+ temp_offset.pop(-1)
179
+ temp_offset.append([0, 0])
180
+ temp_offset.append(off)
181
+ continue
182
+ temp_offset.append(off)
183
+ last = s
184
+ offset = temp_offset
185
+ for split_text in split_texts:
186
+ while jdx < len(offset) and offset[jdx][0] == 0 and offset[jdx][1] == 0:
187
+ jdx += 1
188
+ if jdx == len(offset):
189
+ continue
190
+ start_, end_ = offset[jdx]
191
+ tokens = None
192
+ if split_text == raw_text[start_:end_].strip():
193
+ tokens = [jdx]
194
+ else:
195
+ # Compute "xxx" -> "xx" "#x"
196
+ temp_jdx = jdx
197
+ last_str = raw_text[start_:end_].strip()
198
+ while last_str != split_text and temp_jdx < len(offset) - 1:
199
+ temp_jdx += 1
200
+ last_str += raw_text[offset[temp_jdx]
201
+ [0]:offset[temp_jdx][1]].strip()
202
+
203
+ if temp_jdx == jdx:
204
+ raise ValueError("Illegal Input data")
205
+ elif last_str == split_text:
206
+ tokens = [jdx, temp_jdx]
207
+ jdx = temp_jdx
208
+ else:
209
+ jdx -= 1
210
+ jdx += 1
211
+ if tokens is not None:
212
+ desired_output.append(tokens)
213
+ yield desired_output
214
+
215
+ def general_align(self,
216
+ batch,
217
+ ignore_index=-100,
218
+ device="cuda",
219
+ config=None,
220
+ enable_label=True,
221
+ label2tensor=True,
222
+ locale="en-US"):
223
+ if self.to_lower_case:
224
+ raw_data = [" ".join(x["text"]).lower() if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in
225
+ batch]
226
+ input_list = [[t.lower() for t in x["text"]] for x in batch]
227
+ else:
228
+ input_list = [x["text"] for x in batch]
229
+ raw_data = [" ".join(x["text"]) if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in
230
+ batch]
231
+ text = self.tokenizer(raw_data,
232
+ return_tensors="pt",
233
+ padding=True,
234
+ truncation=True,
235
+ return_offsets_mapping=True,
236
+ **config).to(device)
237
+ if enable_label:
238
+ if label2tensor:
239
+ slot_mask = torch.ones_like(text.input_ids) * ignore_index
240
+ for i, offsets in enumerate(
241
+ self.general_align_data(input_list, raw_data, encoded_text=text)):
242
+ num = 0
243
+ # if len(offsets) != len(batch[i]["text"]) or len(offsets) != len(batch[i]["slot"]):
244
+ # if
245
+ for off in offsets:
246
+ slot_mask[i][off[0]
247
+ ] = self.slot_label_dict[batch[i]["slot"][num]]
248
+ num += 1
249
+ # slot = slot_mask.clone()
250
+ # attentin_id = 0 if self.tokenizer.padding_side == "right" else 1
251
+ # for i, slot_batch in enumerate(slot):
252
+ # for j, x in enumerate(slot_batch):
253
+ # if x == ignore_index and text.attention_mask[i][j] == attentin_id and text.input_ids[i][
254
+ # j] not in self.tokenizer.all_special_ids:
255
+ # slot[i][j] = slot[i][j - 1]
256
+ slot = slot_mask.to(device)
257
+ if not self.use_multi:
258
+ intent = torch.tensor(
259
+ [self.intent_label_dict[x["intent"]] for x in batch]).to(device)
260
+ else:
261
+ one_hot = torch.zeros(
262
+ (len(batch), len(self.intent_label_list)), dtype=torch.float)
263
+ for index, b in enumerate(batch):
264
+ for x in b["intent"].split("#"):
265
+ one_hot[index][self.intent_label_dict[x]] = 1.
266
+ intent = one_hot.to(device)
267
+ else:
268
+ slot_mask = None
269
+ slot = [['#' for _ in range(text.input_ids.shape[1])]
270
+ for _ in range(text.input_ids.shape[0])]
271
+ for i, offsets in enumerate(self.general_align_data(input_list, raw_data, encoded_text=text)):
272
+ num = 0
273
+ for off in offsets:
274
+ slot[i][off[0]] = batch[i]["slot"][num]
275
+ num += 1
276
+ if not self.use_multi:
277
+ intent = [x["intent"] for x in batch]
278
+ else:
279
+ intent = [
280
+ [x for x in b["intent"].split("#")] for b in batch]
281
+ return InputData((text, slot, intent))
282
+ else:
283
+ return InputData((text, None, None))
284
+
285
+ def batch_fn(self,
286
+ batch,
287
+ ignore_index=-100,
288
+ device="cuda",
289
+ config=None,
290
+ align_mode="fast",
291
+ enable_label=True,
292
+ label2tensor=True):
293
+ if align_mode == "fast":
294
+ # try:
295
+ return self.fast_align(batch,
296
+ ignore_index=ignore_index,
297
+ device=device,
298
+ config=config,
299
+ enable_label=enable_label,
300
+ label2tensor=label2tensor)
301
+ # except:
302
+ # return self.general_align(batch,
303
+ # ignore_index=ignore_index,
304
+ # device=device,
305
+ # config=config,
306
+ # enable_label=enable_label,
307
+ # label2tensor=label2tensor)
308
+ else:
309
+ return self.general_align(batch,
310
+ ignore_index=ignore_index,
311
+ device=device,
312
+ config=config,
313
+ enable_label=enable_label,
314
+ label2tensor=label2tensor)
315
+
316
+ def get_data_loader(self,
317
+ dataset,
318
+ batch_size,
319
+ shuffle=False,
320
+ device="cuda",
321
+ enable_label=True,
322
+ align_mode="fast",
323
+ label2tensor=True, **config):
324
+ data_loader = DataLoader(dataset,
325
+ shuffle=shuffle,
326
+ batch_size=batch_size,
327
+ collate_fn=lambda x: self.batch_fn(x,
328
+ device=device,
329
+ config=config,
330
+ enable_label=enable_label,
331
+ align_mode=align_mode,
332
+ label2tensor=label2tensor))
333
+ return data_loader
common/logger.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-02-02 16:29:13
6
+ Description: log manager
7
+
8
+ '''
9
+ import json
10
+ import os
11
+ import time
12
+ from common.config import Config
13
+
14
+ def mkdirs(dir_names):
15
+ for dir_name in dir_names:
16
+ if not os.path.exists(dir_name):
17
+ os.mkdir(dir_name)
18
+
19
+
20
+
21
+ class Logger():
22
+ """ logging infomation by [wandb, fitlog, local file]
23
+ """
24
+ def __init__(self,
25
+ logger_type: str,
26
+ logger_name: str,
27
+ logging_level="INFO",
28
+ start_time='',
29
+ accelerator=None):
30
+ """ create logger
31
+
32
+ Args:
33
+ logger_type (str): support type = ["wandb", "fitlog", "local"]
34
+ logger_name (str): logger name, means project name in wandb, and logging file name
35
+ logging_level (str, optional): logging level. Defaults to "INFO".
36
+ start_time (str, optional): start time string. Defaults to ''.
37
+ """
38
+ self.logger_type = logger_type
39
+ times = time.localtime()
40
+ self.output_dir = "logs/" + logger_name + "/" + str(times.tm_year) + start_time
41
+ self.accelerator = accelerator
42
+ self.logger_name = logger_name
43
+ if accelerator is not None:
44
+ from accelerate.logging import get_logger
45
+ self.logging = get_logger(logger_name)
46
+ else:
47
+ if self.logger_type == "wandb":
48
+ import wandb
49
+ self.logger = wandb
50
+ mkdirs(["logs", "logs/" + logger_name, self.output_dir])
51
+ self.logger.init(project=logger_name)
52
+ elif self.logger_type == "fitlog":
53
+ import fitlog
54
+ self.logger = fitlog
55
+ mkdirs(["logs", "logs/" + logger_name, self.output_dir])
56
+ self.logger.set_log_dir("logs/" + logger_name)
57
+ else:
58
+ mkdirs(["logs", "logs/" + logger_name, self.output_dir])
59
+ self.config_file = os.path.join(self.output_dir, "/config.jsonl")
60
+ with open(self.config_file, "w", encoding="utf8") as f:
61
+ print(f"Config will be written to {self.config_file}")
62
+
63
+ self.loss_file = os.path.join(self.output_dir, "/loss.jsonl")
64
+ with open(self.loss_file, "w", encoding="utf8") as f:
65
+ print(f"Loss Result will be written to {self.loss_file}")
66
+
67
+ self.metric_file = os.path.join(self.output_dir, "/metric.jsonl")
68
+ with open(self.metric_file, "w", encoding="utf8") as f:
69
+ print(f"Metric Result will be written to {self.metric_file}")
70
+
71
+ self.other_log_file = os.path.join(self.output_dir, "/other_log.jsonl")
72
+ with open(self.other_log_file, "w", encoding="utf8") as f:
73
+ print(f"Other Log Result will be written to {self.other_log_file}")
74
+ import logging
75
+ LOGGING_LEVEL_MAP = {
76
+ "CRITICAL": logging.CRITICAL,
77
+ "FATAL": logging.FATAL,
78
+ "ERROR": logging.ERROR,
79
+ "WARNING": logging.WARNING,
80
+ "WARN": logging.WARN,
81
+ "INFO": logging.INFO,
82
+ "DEBUG": logging.DEBUG,
83
+ "NOTSET": logging.NOTSET,
84
+ }
85
+ logging.basicConfig(format='[%(levelname)s - %(asctime)s]\t%(message)s', datefmt='%m/%d/%Y %I:%M:%S %p',
86
+ filename=os.path.join(self.output_dir, "log.log"), level=LOGGING_LEVEL_MAP[logging_level])
87
+ self.logging = logging
88
+
89
+ def set_config(self, config: Config):
90
+ """save config
91
+
92
+ Args:
93
+ config (Config): configuration object to save
94
+ """
95
+ if self.accelerator is not None:
96
+ self.accelerator.init_trackers(self.logger_name, config=config)
97
+ elif self.logger_type == "wandb":
98
+ self.logger.config.update(config)
99
+ elif self.logger_type == "fitlog":
100
+ self.logger.add_hyper(config)
101
+ else:
102
+ with open(self.config_file, "a", encoding="utf8") as f:
103
+ f.write(json.dumps(config) + "\n")
104
+
105
+ def log(self, data, step=0):
106
+ """log data and step
107
+
108
+ Args:
109
+ data (Any): data to log
110
+ step (int, optional): step num. Defaults to 0.
111
+ """
112
+ if self.accelerator is not None:
113
+ self.accelerator.log(data, step=0)
114
+ elif self.logger_type == "wandb":
115
+ self.logger.log(data, step=step)
116
+ elif self.logger_type == "fitlog":
117
+ self.logger.add_other({"data": data, "step": step})
118
+ else:
119
+ with open(self.other_log_file, "a", encoding="utf8") as f:
120
+ f.write(json.dumps({"data": data, "step": step}) + "\n")
121
+
122
+ def log_metric(self, metric, metric_split="dev", step=0):
123
+ """log metric
124
+
125
+ Args:
126
+ metric (Any): metric
127
+ metric_split (str, optional): dataset split. Defaults to 'dev'.
128
+ step (int, optional): step num. Defaults to 0.
129
+ """
130
+ if self.accelerator is not None:
131
+ self.accelerator.log({metric_split: metric}, step=step)
132
+ elif self.logger_type == "wandb":
133
+ self.logger.log({metric_split: metric}, step=step)
134
+ elif self.logger_type == "fitlog":
135
+ self.logger.add_metric({metric_split: metric}, step=step)
136
+ else:
137
+ with open(self.metric_file, "a", encoding="utf8") as f:
138
+ f.write(json.dumps({metric_split: metric, "step": step}) + "\n")
139
+
140
+ def log_loss(self, loss, loss_name="Loss", step=0):
141
+ """log loss
142
+
143
+ Args:
144
+ loss (Any): loss
145
+ loss_name (str, optional): loss description. Defaults to 'Loss'.
146
+ step (int, optional): step num. Defaults to 0.
147
+ """
148
+ if self.accelerator is not None:
149
+ self.accelerator.log({loss_name: loss}, step=step)
150
+ elif self.logger_type == "wandb":
151
+ self.logger.log({loss_name: loss}, step=step)
152
+ elif self.logger_type == "fitlog":
153
+ self.logger.add_loss(loss, name=loss_name, step=step)
154
+ else:
155
+ with open(self.loss_file, "a", encoding="utf8") as f:
156
+ f.write(json.dumps({loss_name: loss, "step": step}) + "\n")
157
+
158
+ def finish(self):
159
+ """finish logging
160
+ """
161
+ if self.logger_type == "fitlog":
162
+ self.logger.finish()
163
+
164
+ def info(self, message:str):
165
+ """ Log a message with severity 'INFO' in local file / console.
166
+
167
+ Args:
168
+ message (str): message to log
169
+ """
170
+ self.logging.info(message)
171
+
172
+ def warning(self, message):
173
+ """ Log a message with severity 'WARNING' in local file / console.
174
+
175
+ Args:
176
+ message (str): message to log
177
+ """
178
+ self.logging.warning(message)
179
+
180
+ def error(self, message):
181
+ """ Log a message with severity 'ERROR' in local file / console.
182
+
183
+ Args:
184
+ message (str): message to log
185
+ """
186
+ self.logging.error(message)
187
+
188
+ def debug(self, message):
189
+ """ Log a message with severity 'DEBUG' in local file / console.
190
+
191
+ Args:
192
+ message (str): message to log
193
+ """
194
+ self.logging.debug(message)
195
+
196
+ def critical(self, message):
197
+ self.logging.critical(message)
common/metric.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-26 12:12:55
6
+ Description: Metric calculation class
7
+
8
+ '''
9
+ from collections import Counter
10
+ from typing import List, Dict
11
+
12
+ import numpy as np
13
+ from sklearn.metrics import f1_score
14
+
15
+ from common.utils import InputData, OutputData
16
+
17
+
18
+ class Evaluator(object):
19
+ """Evaluation metric funtions library class
20
+ supported metric:
21
+ - slot_f1
22
+ - intent_acc
23
+ - exactly_match_accuracy
24
+ - intent_f1 (defult "macro_intent_f1")
25
+ - macro_intent_f1
26
+ - micro_intent_f1=
27
+ """
28
+ @staticmethod
29
+ def exactly_match_accuracy(pred_slot: List[List[str or int]],
30
+ real_slot: List[List[str or int]],
31
+ pred_intent: List[List[str or int] or str or int],
32
+ real_intent: List[List[str or int] or str or int]) -> float:
33
+ """Compute the accuracy based on the whole predictions of given sentence, including slot and intent.
34
+ (both support str or int index as the representation of slot and intent)
35
+ Args:
36
+ pred_slot (List[List[str or int]]): predicted sequence of slot list
37
+ real_slot (List[List[str or int]]): golden sequence of slot list.
38
+ pred_intent (List[List[str or int] or str or int]): golden intent list / golden multi intent list.
39
+ real_intent (List[List[str or int] or str or int]): predicted intent list / predicted multi intent list.
40
+
41
+ Returns:
42
+ float: exactly match accuracy score
43
+ """
44
+ total_count, correct_count = 0.0, 0.0
45
+ for p_slot, r_slot, p_intent, r_intent in zip(pred_slot, real_slot, pred_intent, real_intent):
46
+ if isinstance(p_intent, list):
47
+ p_intent, r_intent = set(p_intent), set(r_intent)
48
+ if p_slot == r_slot and p_intent == r_intent:
49
+ correct_count += 1.0
50
+ total_count += 1.0
51
+
52
+ return 1.0 * correct_count / total_count
53
+
54
+
55
+ @staticmethod
56
+ def intent_accuracy(pred_list: List, real_list: List) -> float:
57
+ """Get intent accuracy measured by predictions and ground-trues. Support both multi intent and single intent.
58
+
59
+ Args:
60
+ pred_list (List): predicted intent list
61
+ real_list (List): golden intent list
62
+
63
+ Returns:
64
+ float: intent accuracy score
65
+ """
66
+ total_count, correct_count = 0.0, 0.0
67
+ for p_intent, r_intent in zip(pred_list, real_list):
68
+ if isinstance(p_intent, list):
69
+ p_intent, r_intent = set(p_intent), set(r_intent)
70
+ if p_intent == r_intent:
71
+ correct_count += 1.0
72
+ total_count += 1.0
73
+
74
+ return 1.0 * correct_count / total_count
75
+
76
+ @staticmethod
77
+ def intent_f1(pred_list: List[List[int]], real_list: List[List[int]], num_intent: int, average='macro') -> float:
78
+ """Get intent accuracy measured by predictions and ground-trues. Support both multi intent and single intent.
79
+ (Only support multi intent now, but you can use [[intent1], [intent2], ...] to compute intent f1 in single intent)
80
+ Args:
81
+ pred_list (List[List[int]]): predicted multi intent list.
82
+ real_list (List[List[int]]): golden multi intent list.
83
+ num_intent (int)
84
+ average (str): support "micro" and "macro"
85
+
86
+ Returns:
87
+ float: intent accuracy score
88
+ """
89
+ return f1_score(Evaluator.__instance2onehot(num_intent, real_list),
90
+ Evaluator.__instance2onehot(num_intent, pred_list),
91
+ average=average,
92
+ zero_division=0)
93
+
94
+ @staticmethod
95
+ def __multilabel2one_hot(labels, nums):
96
+ res = [0.] * nums
97
+ if len(labels) == 0:
98
+ return res
99
+ if isinstance(labels[0], list):
100
+ for label in labels[0]:
101
+ res[label] = 1.
102
+ return res
103
+ for label in labels:
104
+ res[label] = 1.
105
+ return res
106
+
107
+ @staticmethod
108
+ def __instance2onehot(num_intent, data):
109
+ res = []
110
+ for intents in data:
111
+ res.append(Evaluator.__multilabel2one_hot(intents, num_intent))
112
+ return np.array(res)
113
+
114
+ @staticmethod
115
+ def __startOfChunk(prevTag, tag, prevTagType, tagType, chunkStart=False):
116
+ if prevTag == 'B' and tag == 'B':
117
+ chunkStart = True
118
+ if prevTag == 'I' and tag == 'B':
119
+ chunkStart = True
120
+ if prevTag == 'O' and tag == 'B':
121
+ chunkStart = True
122
+ if prevTag == 'O' and tag == 'I':
123
+ chunkStart = True
124
+
125
+ if prevTag == 'E' and tag == 'E':
126
+ chunkStart = True
127
+ if prevTag == 'E' and tag == 'I':
128
+ chunkStart = True
129
+ if prevTag == 'O' and tag == 'E':
130
+ chunkStart = True
131
+ if prevTag == 'O' and tag == 'I':
132
+ chunkStart = True
133
+
134
+ if tag != 'O' and tag != '.' and prevTagType != tagType:
135
+ chunkStart = True
136
+ return chunkStart
137
+
138
+ @staticmethod
139
+ def __endOfChunk(prevTag, tag, prevTagType, tagType, chunkEnd=False):
140
+ if prevTag == 'B' and tag == 'B':
141
+ chunkEnd = True
142
+ if prevTag == 'B' and tag == 'O':
143
+ chunkEnd = True
144
+ if prevTag == 'I' and tag == 'B':
145
+ chunkEnd = True
146
+ if prevTag == 'I' and tag == 'O':
147
+ chunkEnd = True
148
+
149
+ if prevTag == 'E' and tag == 'E':
150
+ chunkEnd = True
151
+ if prevTag == 'E' and tag == 'I':
152
+ chunkEnd = True
153
+ if prevTag == 'E' and tag == 'O':
154
+ chunkEnd = True
155
+ if prevTag == 'I' and tag == 'O':
156
+ chunkEnd = True
157
+
158
+ if prevTag != 'O' and prevTag != '.' and prevTagType != tagType:
159
+ chunkEnd = True
160
+ return chunkEnd
161
+
162
+ @staticmethod
163
+ def __splitTagType(tag):
164
+ s = tag.split('-')
165
+ if len(s) > 2 or len(s) == 0:
166
+ raise ValueError('tag format wrong. it must be B-xxx.xxx')
167
+ if len(s) == 1:
168
+ tag = s[0]
169
+ tagType = ""
170
+ else:
171
+ tag = s[0]
172
+ tagType = s[1]
173
+ return tag, tagType
174
+
175
+ @staticmethod
176
+ def computeF1Score(correct_slots: List[List[str]], pred_slots: List[List[str]]) -> float:
177
+ """compute f1 score is modified from conlleval.pl
178
+
179
+ Args:
180
+ correct_slots (List[List[str]]): golden slot string list
181
+ pred_slots (List[List[str]]): predicted slot string list
182
+
183
+ Returns:
184
+ float: slot f1 score
185
+ """
186
+ correctChunk = {}
187
+ correctChunkCnt = 0.0
188
+ foundCorrect = {}
189
+ foundCorrectCnt = 0.0
190
+ foundPred = {}
191
+ foundPredCnt = 0.0
192
+ correctTags = 0.0
193
+ tokenCount = 0.0
194
+ for correct_slot, pred_slot in zip(correct_slots, pred_slots):
195
+ inCorrect = False
196
+ lastCorrectTag = 'O'
197
+ lastCorrectType = ''
198
+ lastPredTag = 'O'
199
+ lastPredType = ''
200
+ for c, p in zip(correct_slot, pred_slot):
201
+ correctTag, correctType = Evaluator.__splitTagType(c)
202
+ predTag, predType = Evaluator.__splitTagType(p)
203
+
204
+ if inCorrect == True:
205
+ if Evaluator.__endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \
206
+ Evaluator.__endOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \
207
+ (lastCorrectType == lastPredType):
208
+ inCorrect = False
209
+ correctChunkCnt += 1.0
210
+ if lastCorrectType in correctChunk:
211
+ correctChunk[lastCorrectType] += 1.0
212
+ else:
213
+ correctChunk[lastCorrectType] = 1.0
214
+ elif Evaluator.__endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) != \
215
+ Evaluator.__endOfChunk(lastPredTag, predTag, lastPredType, predType) or \
216
+ (correctType != predType):
217
+ inCorrect = False
218
+
219
+ if Evaluator.__startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \
220
+ Evaluator.__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \
221
+ (correctType == predType):
222
+ inCorrect = True
223
+
224
+ if Evaluator.__startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True:
225
+ foundCorrectCnt += 1
226
+ if correctType in foundCorrect:
227
+ foundCorrect[correctType] += 1.0
228
+ else:
229
+ foundCorrect[correctType] = 1.0
230
+
231
+ if Evaluator.__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True:
232
+ foundPredCnt += 1.0
233
+ if predType in foundPred:
234
+ foundPred[predType] += 1.0
235
+ else:
236
+ foundPred[predType] = 1.0
237
+
238
+ if correctTag == predTag and correctType == predType:
239
+ correctTags += 1.0
240
+
241
+ tokenCount += 1.0
242
+
243
+ lastCorrectTag = correctTag
244
+ lastCorrectType = correctType
245
+ lastPredTag = predTag
246
+ lastPredType = predType
247
+
248
+ if inCorrect == True:
249
+ correctChunkCnt += 1.0
250
+ if lastCorrectType in correctChunk:
251
+ correctChunk[lastCorrectType] += 1.0
252
+ else:
253
+ correctChunk[lastCorrectType] = 1.0
254
+
255
+ if foundPredCnt > 0:
256
+ precision = 1.0 * correctChunkCnt / foundPredCnt
257
+ else:
258
+ precision = 0
259
+
260
+ if foundCorrectCnt > 0:
261
+ recall = 1.0 * correctChunkCnt / foundCorrectCnt
262
+ else:
263
+ recall = 0
264
+
265
+ if (precision + recall) > 0:
266
+ f1 = (2.0 * precision * recall) / (precision + recall)
267
+ else:
268
+ f1 = 0
269
+
270
+ return f1
271
+
272
+ @staticmethod
273
+ def max_freq_predict(sample):
274
+ """Max frequency prediction.
275
+ """
276
+ predict = []
277
+ for items in sample:
278
+ predict.append(Counter(items).most_common(1)[0][0])
279
+ return predict
280
+
281
+ @staticmethod
282
+ def __token_map(indexes, token_label_map):
283
+ return [[token_label_map[idx] if idx in token_label_map else -1 for idx in index] for index in indexes]
284
+
285
+ @staticmethod
286
+ def compute_all_metric(inps: InputData,
287
+ output: OutputData,
288
+ intent_label_map: dict = None,
289
+ metric_list: List=None)-> Dict:
290
+ """Auto compute all metric mentioned in 'metric_list'
291
+
292
+ Args:
293
+ inps (InputData): input golden slot and intent labels
294
+ output (OutputData): output predicted slot and intent labels
295
+ intent_label_map (dict, Optional): dict like {"intent1": 0, "intent2": 1, ...},which aims to map intent string to index
296
+ metric_list (List): support metrics in ["slot_f1", "intent_acc", "intent_f1", "macro_intent_f1", "micro_intent_f1", "EMA"]
297
+
298
+ Returns:
299
+ Dict: all metric mentioned in 'metric_list', like {'EMA': 0.7, ...}
300
+
301
+
302
+ Example:
303
+ if compute slot metric:
304
+
305
+ inps.slot = [["slot1", "slot2", ...], ...]; output.slot_ids=[["slot1", "slot2", ...], ...];
306
+
307
+ if compute intent metric:
308
+
309
+ [Multi Intent] inps.intent = [["intent1", "intent2", ...], ...]; output.intent_ids = [["intent1", "intent2", ...], ...]
310
+
311
+ [Single Intent] inps.intent = ["intent1", ...]; [Single Intent] output.intent_ids = ["intent1", ...]
312
+ """
313
+ if not metric_list:
314
+ metric_list = ["slot_f1", "intent_acc", "EMA"]
315
+ res_dict = {}
316
+ use_slot = output.slot_ids is not None and len(output.slot_ids) > 0
317
+ use_intent = output.intent_ids is not None and len(
318
+ output.intent_ids) > 0
319
+ if use_slot and "slot_f1" in metric_list:
320
+ res_dict["slot_f1"] = Evaluator.computeF1Score(
321
+ output.slot_ids, inps.slot)
322
+ if use_intent and "intent_acc" in metric_list:
323
+ res_dict["intent_acc"] = Evaluator.intent_accuracy(
324
+ output.intent_ids, inps.intent)
325
+ if isinstance(output.intent_ids[0], list):
326
+ if "intent_f1" in metric_list:
327
+ res_dict["intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
328
+ Evaluator.__token_map(
329
+ inps.intent, intent_label_map),
330
+ len(intent_label_map.keys()))
331
+ elif "macro_intent_f1" in metric_list:
332
+ res_dict["macro_intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
333
+ Evaluator.__token_map(inps.intent, intent_label_map),
334
+ len(intent_label_map.keys()), average="macro")
335
+ if "micro_intent_f1" in metric_list:
336
+ res_dict["micro_intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
337
+ Evaluator.__token_map(inps.intent, intent_label_map),
338
+ len(intent_label_map.keys()), average="micro")
339
+
340
+ if use_slot and use_intent and "EMA" in metric_list:
341
+ res_dict["EMA"] = Evaluator.exactly_match_accuracy(output.slot_ids, inps.slot, output.intent_ids,
342
+ inps.intent)
343
+ return res_dict
common/model_manager.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-02-07 21:36:06
6
+ Description: manage all process of model training and prediction.
7
+
8
+ '''
9
+ import os
10
+ import random
11
+
12
+ import numpy as np
13
+ import torch
14
+ from tqdm import tqdm
15
+
16
+
17
+ from common import utils
18
+ from common.loader import DataFactory
19
+ from common.logger import Logger
20
+ from common.metric import Evaluator
21
+ from common.tokenizer import get_tokenizer, get_tokenizer_class, load_embedding
22
+ from common.utils import InputData, instantiate
23
+ from common.utils import OutputData
24
+ from common.config import Config
25
+ import dill
26
+
27
+
28
+ class ModelManager(object):
29
+ def __init__(self, config: Config):
30
+ """create model manager by config
31
+
32
+ Args:
33
+ config (Config): configuration to manage all process in OpenSLU
34
+ """
35
+ # init config
36
+ self.config = config
37
+ self.__set_seed(self.config.base.get("seed"))
38
+ self.device = self.config.base.get("device")
39
+
40
+ # enable accelerator
41
+ if "accelerator" in self.config and self.config["accelerator"].get("use_accelerator"):
42
+ from accelerate import Accelerator
43
+ self.accelerator = Accelerator(log_with="wandb")
44
+ else:
45
+ self.accelerator = None
46
+ if self.config.base.get("train"):
47
+ self.tokenizer = get_tokenizer(
48
+ self.config.tokenizer.get("_tokenizer_name_"))
49
+ self.logger = Logger(
50
+ "wandb", self.config.base["name"], start_time=config.start_time, accelerator=self.accelerator)
51
+
52
+ # init dataloader & load data
53
+ if self.config.base.get("save_dir"):
54
+ self.model_save_dir = self.config.base["save_dir"]
55
+ else:
56
+ if not os.path.exists("save/"):
57
+ os.mkdir("save/")
58
+ self.model_save_dir = "save/" + config.start_time
59
+ if not os.path.exists(self.model_save_dir):
60
+ os.mkdir(self.model_save_dir)
61
+ batch_size = self.config.base["batch_size"]
62
+ df = DataFactory(tokenizer=self.tokenizer,
63
+ use_multi_intent=self.config.base.get("multi_intent"),
64
+ to_lower_case=self.config.base.get("_to_lower_case_"))
65
+ train_dataset = df.load_dataset(self.config.dataset, split="train")
66
+
67
+ # update label and vocabulary
68
+ df.update_label_names(train_dataset)
69
+ df.update_vocabulary(train_dataset)
70
+
71
+ # init tokenizer config and dataloaders
72
+ tokenizer_config = {key: self.config.tokenizer[key]
73
+ for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"}
74
+ self.train_dataloader = df.get_data_loader(train_dataset,
75
+ batch_size,
76
+ shuffle=True,
77
+ device=self.device,
78
+ enable_label=True,
79
+ align_mode=self.config.tokenizer.get(
80
+ "_align_mode_"),
81
+ label2tensor=True,
82
+ **tokenizer_config)
83
+ dev_dataset = df.load_dataset(
84
+ self.config.dataset, split="validation")
85
+ self.dev_dataloader = df.get_data_loader(dev_dataset,
86
+ batch_size,
87
+ shuffle=False,
88
+ device=self.device,
89
+ enable_label=True,
90
+ align_mode=self.config.tokenizer.get(
91
+ "_align_mode_"),
92
+ label2tensor=False,
93
+ **tokenizer_config)
94
+ df.update_vocabulary(dev_dataset)
95
+ # add intent label num and slot label num to config
96
+ if int(self.config.get_intent_label_num()) == 0 or int(self.config.get_slot_label_num()) == 0:
97
+ self.intent_list = df.intent_label_list
98
+ self.intent_dict = df.intent_label_dict
99
+ self.config.set_intent_label_num(len(self.intent_list))
100
+ self.slot_list = df.slot_label_list
101
+ self.slot_dict = df.slot_label_dict
102
+ self.config.set_slot_label_num(len(self.slot_list))
103
+ self.config.set_vocab_size(self.tokenizer.vocab_size)
104
+
105
+ # autoload embedding for non-pretrained encoder
106
+ if self.config["model"]["encoder"].get("embedding") and self.config["model"]["encoder"]["embedding"].get(
107
+ "load_embedding_name"):
108
+ self.config["model"]["encoder"]["embedding"]["embedding_matrix"] = load_embedding(self.tokenizer,
109
+ self.config["model"][
110
+ "encoder"][
111
+ "embedding"].get(
112
+ "load_embedding_name"))
113
+ # fill template in config
114
+ self.config.autoload_template()
115
+ # save config
116
+ self.logger.set_config(self.config)
117
+
118
+ self.model = None
119
+ self.optimizer = None
120
+ self.total_step = None
121
+ self.lr_scheduler = None
122
+ if self.config.tokenizer.get("_tokenizer_name_") == "word_tokenizer":
123
+ self.tokenizer.save(os.path.join(self.model_save_dir, "tokenizer.json"))
124
+ utils.save_json(os.path.join(
125
+ self.model_save_dir, "label.json"), {"intent": self.intent_list,"slot": self.slot_list})
126
+ if self.config.base.get("test"):
127
+ self.test_dataset = df.load_dataset(
128
+ self.config.dataset, split="test")
129
+ self.test_dataloader = df.get_data_loader(self.test_dataset,
130
+ batch_size,
131
+ shuffle=False,
132
+ device=self.device,
133
+ enable_label=True,
134
+ align_mode=self.config.tokenizer.get(
135
+ "_align_mode_"),
136
+ label2tensor=False,
137
+ **tokenizer_config)
138
+
139
+ def init_model(self, model):
140
+ """init model, optimizer, lr_scheduler
141
+
142
+ Args:
143
+ model (Any): pytorch model
144
+ """
145
+ self.model = model
146
+ self.model.to(self.device)
147
+ if self.config.base.get("train"):
148
+ self.optimizer = instantiate(
149
+ self.config["optimizer"])(self.model.parameters())
150
+ self.total_step = int(self.config.base.get(
151
+ "epoch_num")) * len(self.train_dataloader)
152
+ self.lr_scheduler = instantiate(self.config["scheduler"])(
153
+ optimizer=self.optimizer,
154
+ num_training_steps=self.total_step
155
+ )
156
+ if self.accelerator is not None:
157
+ self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
158
+ self.model, self.optimizer, self.train_dataloader, self.lr_scheduler)
159
+ if self.config.base.get("load_dir_path"):
160
+ self.accelerator.load_state(self.config.base.get("load_dir_path"))
161
+ # self.dev_dataloader = self.accelerator.prepare(self.dev_dataloader)
162
+
163
+ def eval(self, step: int, best_metric: float) -> float:
164
+ """ evaluation models.
165
+
166
+ Args:
167
+ step (int): which step the model has trained in
168
+ best_metric (float): last best metric value to judge whether to test or save model
169
+
170
+ Returns:
171
+ float: updated best metric value
172
+ """
173
+ # TODO: save dev
174
+ _, res = self.__evaluate(self.model, self.dev_dataloader)
175
+ self.logger.log_metric(res, metric_split="dev", step=step)
176
+ if res[self.config.base.get("best_key")] > best_metric:
177
+ best_metric = res[self.config.base.get("best_key")]
178
+ outputs, test_res = self.__evaluate(
179
+ self.model, self.test_dataloader)
180
+ if not os.path.exists(self.model_save_dir):
181
+ os.mkdir(self.model_save_dir)
182
+ if self.accelerator is None:
183
+ torch.save(self.model, os.path.join(
184
+ self.model_save_dir, "model.pkl"))
185
+ torch.save(self.optimizer, os.path.join(
186
+ self.model_save_dir, "optimizer.pkl"))
187
+ torch.save(self.lr_scheduler, os.path.join(
188
+ self.model_save_dir, "lr_scheduler.pkl"), pickle_module=dill)
189
+ torch.save(step, os.path.join(
190
+ self.model_save_dir, "step.pkl"))
191
+ else:
192
+ self.accelerator.wait_for_everyone()
193
+ unwrapped_model = self.accelerator.unwrap_model(self.model)
194
+ self.accelerator.save(unwrapped_model.state_dict(
195
+ ), os.path.join(self.model_save_dir, "model.pkl"))
196
+ self.accelerator.save_state(output_dir=self.model_save_dir)
197
+ outputs.save(self.model_save_dir, self.test_dataset)
198
+ self.logger.log_metric(test_res, metric_split="test", step=step)
199
+ return best_metric
200
+
201
+ def train(self) -> float:
202
+ """ train models.
203
+
204
+ Returns:
205
+ float: updated best metric value
206
+ """
207
+ step = 0
208
+ best_metric = 0
209
+ progress_bar = tqdm(range(self.total_step))
210
+ for _ in range(int(self.config.base.get("epoch_num"))):
211
+ for data in self.train_dataloader:
212
+ if step == 0:
213
+ self.logger.info(data.get_item(
214
+ 0, tokenizer=self.tokenizer, intent_map=self.intent_list, slot_map=self.slot_list))
215
+ output = self.model(data)
216
+ if self.accelerator is not None and hasattr(self.model, "module"):
217
+ loss, intent_loss, slot_loss = self.model.module.compute_loss(
218
+ pred=output, target=data)
219
+ else:
220
+ loss, intent_loss, slot_loss = self.model.compute_loss(
221
+ pred=output, target=data)
222
+ self.logger.log_loss(loss, "Loss", step=step)
223
+ self.logger.log_loss(intent_loss, "Intent Loss", step=step)
224
+ self.logger.log_loss(slot_loss, "Slot Loss", step=step)
225
+ self.optimizer.zero_grad()
226
+
227
+ if self.accelerator is not None:
228
+ self.accelerator.backward(loss)
229
+ else:
230
+ loss.backward()
231
+ self.optimizer.step()
232
+ self.lr_scheduler.step()
233
+ if not self.config.base.get("eval_by_epoch") and step % self.config.base.get(
234
+ "eval_step") == 0 and step != 0:
235
+ best_metric = self.eval(step, best_metric)
236
+ step += 1
237
+ progress_bar.update(1)
238
+ if self.config.base.get("eval_by_epoch"):
239
+ best_metric = self.eval(step, best_metric)
240
+ self.logger.finish()
241
+ return best_metric
242
+
243
+ def __set_seed(self, seed_value: int):
244
+ """Manually set random seeds.
245
+
246
+ Args:
247
+ seed_value (int): random seed
248
+ """
249
+ random.seed(seed_value)
250
+ np.random.seed(seed_value)
251
+ torch.manual_seed(seed_value)
252
+ torch.random.manual_seed(seed_value)
253
+ os.environ['PYTHONHASHSEED'] = str(seed_value)
254
+ if torch.cuda.is_available():
255
+ torch.cuda.manual_seed(seed_value)
256
+ torch.cuda.manual_seed_all(seed_value)
257
+ torch.backends.cudnn.deterministic = True
258
+ torch.backends.cudnn.benchmark = True
259
+ return
260
+
261
+ def __evaluate(self, model, dataloader):
262
+ model.eval()
263
+ inps = InputData()
264
+ outputs = OutputData()
265
+ for data in dataloader:
266
+ torch.cuda.empty_cache()
267
+ output = model(data)
268
+ if self.accelerator is not None and hasattr(self.model, "module"):
269
+ decode_output = model.module.decode(output, data)
270
+ else:
271
+ decode_output = model.decode(output, data)
272
+
273
+ decode_output.map_output(slot_map=self.slot_list,
274
+ intent_map=self.intent_list)
275
+ data, decode_output = utils.remove_slot_ignore_index(
276
+ data, decode_output, ignore_index="#")
277
+
278
+ inps.merge_input_data(data)
279
+ outputs.merge_output_data(decode_output)
280
+ if "metric" in self.config:
281
+ res = Evaluator.compute_all_metric(
282
+ inps, outputs, intent_label_map=self.intent_dict, metric_list=self.config.metric)
283
+ else:
284
+ res = Evaluator.compute_all_metric(
285
+ inps, outputs, intent_label_map=self.intent_dict)
286
+ model.train()
287
+ return outputs, res
288
+
289
+ def load(self):
290
+
291
+ self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"))
292
+ if self.config.tokenizer["_tokenizer_name_"] == "word_tokenizer":
293
+ self.tokenizer = get_tokenizer_class(self.config.tokenizer["_tokenizer_name_"]).from_file(
294
+ os.path.join(self.config.base["model_dir"], "tokenizer.json"))
295
+ else:
296
+ self.tokenizer = get_tokenizer(self.config.tokenizer["_tokenizer_name_"])
297
+ self.model.to(self.device)
298
+ label = utils.load_json(os.path.join(self.config.base["model_dir"], "label.json"))
299
+ self.intent_list = label["intent"]
300
+ self.slot_list = label["slot"]
301
+ self.data_factory=DataFactory(tokenizer=self.tokenizer,
302
+ use_multi_intent=self.config.base.get("multi_intent"),
303
+ to_lower_case=self.config.tokenizer.get("_to_lower_case_"))
304
+
305
+ def predict(self, text_data):
306
+ self.model.eval()
307
+ tokenizer_config = {key: self.config.tokenizer[key]
308
+ for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"}
309
+ align_mode = self.config.tokenizer.get("_align_mode_")
310
+ inputs = self.data_factory.batch_fn(batch=[{"text": text_data.split(" ")}],
311
+ device=self.device,
312
+ config=tokenizer_config,
313
+ enable_label=False,
314
+ align_mode= align_mode if align_mode is not None else "general",
315
+ label2tensor=False)
316
+ output = self.model(inputs)
317
+ decode_output = self.model.decode(output, inputs)
318
+ decode_output.map_output(slot_map=self.slot_list,
319
+ intent_map=self.intent_list)
320
+ if self.config.base.get("multi_intent"):
321
+ intent = decode_output.intent_ids[0]
322
+ else:
323
+ intent = [decode_output.intent_ids[0]]
324
+ return {"intent": intent, "slot": decode_output.slot_ids[0], "text": self.tokenizer.decode(inputs.input_ids[0])}
common/tokenizer.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import Counter
4
+ from collections import OrderedDict
5
+ from typing import List
6
+
7
+ import torch
8
+ from ordered_set import OrderedSet
9
+ from transformers import AutoTokenizer
10
+
11
+ from common.utils import download, unzip_file
12
+
13
+
14
+ def get_tokenizer(tokenizer_name:str):
15
+ """auto get tokenizer
16
+
17
+ Args:
18
+ tokenizer_name (str): support "word_tokenizer" and other pretrained tokenizer in hugging face.
19
+
20
+ Returns:
21
+ Any: Tokenizer Object
22
+ """
23
+ if tokenizer_name == "word_tokenizer":
24
+ return WordTokenizer(tokenizer_name)
25
+ else:
26
+ return AutoTokenizer.from_pretrained(tokenizer_name)
27
+
28
+ def get_tokenizer_class(tokenizer_name:str):
29
+ """auto get tokenizer class
30
+
31
+ Args:
32
+ tokenizer_name (str): support "word_tokenizer" and other pretrained tokenizer in hugging face.
33
+
34
+ Returns:
35
+ Any: Tokenizer Class
36
+ """
37
+ if tokenizer_name == "word_tokenizer":
38
+ return WordTokenizer
39
+ else:
40
+ return AutoTokenizer.from_pretrained
41
+
42
+ BATCH_STATE = 1
43
+ INSTANCE_STATE = 2
44
+
45
+
46
+ class WordTokenizer(object):
47
+
48
+ def __init__(self, name):
49
+ self.__name = name
50
+ self.index2instance = OrderedSet()
51
+ self.instance2index = OrderedDict()
52
+ # Counter Object record the frequency
53
+ # of element occurs in raw text.
54
+ self.counter = Counter()
55
+
56
+ self.__sign_pad = "[PAD]"
57
+ self.add_instance(self.__sign_pad)
58
+ self.__sign_unk = "[UNK]"
59
+ self.add_instance(self.__sign_unk)
60
+
61
+ @property
62
+ def padding_side(self):
63
+ return "right"
64
+ @property
65
+ def all_special_ids(self):
66
+ return [self.unk_token_id, self.pad_token_id]
67
+
68
+ @property
69
+ def name_or_path(self):
70
+ return self.__name
71
+
72
+ @property
73
+ def vocab_size(self):
74
+ return len(self.instance2index)
75
+
76
+ @property
77
+ def pad_token_id(self):
78
+ return self.instance2index[self.__sign_pad]
79
+
80
+ @property
81
+ def unk_token_id(self):
82
+ return self.instance2index[self.__sign_unk]
83
+
84
+ def add_instance(self, instance):
85
+ """ Add instances to alphabet.
86
+
87
+ 1, We support any iterative data structure which
88
+ contains elements of str type.
89
+
90
+ 2, We will count added instances that will influence
91
+ the serialization of unknown instance.
92
+
93
+ Args:
94
+ instance: is given instance or a list of it.
95
+ """
96
+
97
+ if isinstance(instance, (list, tuple)):
98
+ for element in instance:
99
+ self.add_instance(element)
100
+ return
101
+
102
+ # We only support elements of str type.
103
+ assert isinstance(instance, str)
104
+
105
+ # count the frequency of instances.
106
+ self.counter[instance] += 1
107
+
108
+ if instance not in self.index2instance:
109
+ self.instance2index[instance] = len(self.index2instance)
110
+ self.index2instance.append(instance)
111
+
112
+ def __call__(self, instance,
113
+ return_tensors="pt",
114
+ is_split_into_words=True,
115
+ padding=True,
116
+ add_special_tokens=False,
117
+ truncation=True,
118
+ max_length=512,
119
+ **config):
120
+ if isinstance(instance, (list, tuple)) and isinstance(instance[0], (str)) and is_split_into_words:
121
+ res = self.get_index(instance)
122
+ state = INSTANCE_STATE
123
+ elif isinstance(instance, str) and not is_split_into_words:
124
+ res = self.get_index(instance.split(" "))
125
+ state = INSTANCE_STATE
126
+ elif not is_split_into_words and isinstance(instance, (list, tuple)):
127
+ res = [self.get_index(ins.split(" ")) for ins in instance]
128
+ state = BATCH_STATE
129
+ else:
130
+ res = [self.get_index(ins) for ins in instance]
131
+ state = BATCH_STATE
132
+ res = [r[:max_length] if len(r) >= max_length else r for r in res]
133
+ pad_id = self.get_index(self.__sign_pad)
134
+ if padding and state == BATCH_STATE:
135
+ max_len = max([len(x) for x in instance])
136
+
137
+ for i in range(len(res)):
138
+ res[i] = res[i] + [pad_id] * (max_len - len(res[i]))
139
+ if return_tensors == "pt":
140
+ input_ids = torch.Tensor(res).long()
141
+ attention_mask = (input_ids != pad_id).long()
142
+ elif state == BATCH_STATE:
143
+ input_ids = res
144
+ attention_mask = [1 if r != pad_id else 0 for batch in res for r in batch]
145
+ else:
146
+ input_ids = res
147
+ attention_mask = [1 if r != pad_id else 0 for r in res]
148
+ return TokenizedData(input_ids, token_type_ids=attention_mask, attention_mask=attention_mask)
149
+
150
+ def get_index(self, instance):
151
+ """ Serialize given instance and return.
152
+
153
+ For unknown words, the return index of alphabet
154
+ depends on variable self.__use_unk:
155
+
156
+ 1, If True, then return the index of "<UNK>";
157
+ 2, If False, then return the index of the
158
+ element that hold max frequency in training data.
159
+
160
+ Args:
161
+ instance (Any): is given instance or a list of it.
162
+ Return:
163
+ Any: the serialization of query instance.
164
+ """
165
+
166
+ if isinstance(instance, (list, tuple)):
167
+ return [self.get_index(elem) for elem in instance]
168
+
169
+ assert isinstance(instance, str)
170
+
171
+ try:
172
+ return self.instance2index[instance]
173
+ except KeyError:
174
+ return self.instance2index[self.__sign_unk]
175
+
176
+ def decode(self, index):
177
+ """ Get corresponding instance of query index.
178
+
179
+ if index is invalid, then throws exception.
180
+
181
+ Args:
182
+ index (int): is query index, possibly iterable.
183
+ Returns:
184
+ is corresponding instance.
185
+ """
186
+
187
+ if isinstance(index, list):
188
+ return [self.decode(elem) for elem in index]
189
+ if isinstance(index, torch.Tensor):
190
+ index = index.tolist()
191
+ return self.decode(index)
192
+ return self.index2instance[index]
193
+
194
+ def save(self, path):
195
+ """ Save the content of alphabet to files.
196
+
197
+ There are two kinds of saved files:
198
+ 1, The first is a list file, elements are
199
+ sorted by the frequency of occurrence.
200
+
201
+ 2, The second is a dictionary file, elements
202
+ are sorted by it serialized index.
203
+
204
+ Args:
205
+ path (str): is the path to save object.
206
+ """
207
+
208
+ with open(path, 'w', encoding="utf8") as fw:
209
+ fw.write(json.dumps({"name": self.__name, "token_map": self.instance2index}))
210
+
211
+ @staticmethod
212
+ def from_file(path):
213
+ with open(path, 'r', encoding="utf8") as fw:
214
+ obj = json.load(fw)
215
+ tokenizer = WordTokenizer(obj["name"])
216
+ tokenizer.instance2index = OrderedDict(obj["token_map"])
217
+ tokenizer.counter = len(tokenizer.instance2index)
218
+ tokenizer.index2instance = OrderedSet(tokenizer.instance2index.keys())
219
+ return tokenizer
220
+
221
+ def __len__(self):
222
+ return len(self.index2instance)
223
+
224
+ def __str__(self):
225
+ return 'Alphabet {} contains about {} words: \n\t{}'.format(self.name_or_path, len(self), self.index2instance)
226
+
227
+ def convert_tokens_to_ids(self, tokens):
228
+ """convert token sequence to intput ids sequence
229
+
230
+ Args:
231
+ tokens (Any): token sequence
232
+
233
+ Returns:
234
+ Any: intput ids sequence
235
+ """
236
+ try:
237
+ if isinstance(tokens, (list, tuple)):
238
+ return [self.instance2index[x] for x in tokens]
239
+ return self.instance2index[tokens]
240
+
241
+ except KeyError:
242
+ return self.instance2index[self.__sign_unk]
243
+
244
+
245
+ class TokenizedData():
246
+ """tokenized output data with input_ids, token_type_ids, attention_mask
247
+ """
248
+ def __init__(self, input_ids, token_type_ids, attention_mask):
249
+ self.input_ids = input_ids
250
+ self.token_type_ids = token_type_ids
251
+ self.attention_mask = attention_mask
252
+
253
+ def word_ids(self, index: int) -> List[int or None]:
254
+ """ get word id list
255
+
256
+ Args:
257
+ index (int): word index in sequence
258
+
259
+ Returns:
260
+ List[int or None]: word id list
261
+ """
262
+ return [j if self.attention_mask[index][j] != 0 else None for j, x in enumerate(self.input_ids[index])]
263
+
264
+ def word_to_tokens(self, index, word_id, **kwargs):
265
+ """map word and tokens
266
+
267
+ Args:
268
+ index (int): unused
269
+ word_id (int): word index in sequence
270
+ """
271
+ return (word_id, word_id + 1)
272
+
273
+ def to(self, device):
274
+ """set device
275
+
276
+ Args:
277
+ device (str): support ["cpu", "cuda"]
278
+ """
279
+ self.input_ids = self.input_ids.to(device)
280
+ self.token_type_ids = self.token_type_ids.to(device)
281
+ self.attention_mask = self.attention_mask.to(device)
282
+ return self
283
+
284
+
285
+ def load_embedding(tokenizer: WordTokenizer, glove_name:str):
286
+ """ load embedding from standford server or local cache.
287
+
288
+ Args:
289
+ tokenizer (WordTokenizer): non-pretrained tokenizer
290
+ glove_name (str): _description_
291
+
292
+ Returns:
293
+ Any: word embedding
294
+ """
295
+ save_path = "save/" + glove_name + ".zip"
296
+ if not os.path.exists(save_path):
297
+ download("http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip#" + glove_name, save_path)
298
+ unzip_file(save_path, "save/" + glove_name)
299
+ dim = int(glove_name.split(".")[-2][:-1])
300
+ embedding_list = torch.rand((tokenizer.vocab_size, dim))
301
+ embedding_list[tokenizer.pad_token_id] = torch.zeros((1, dim))
302
+ with open("save/" + glove_name + "/" + glove_name, "r", encoding="utf8") as f:
303
+ for line in f.readlines():
304
+ datas = line.split(" ")
305
+ word = datas[0]
306
+ embedding = torch.Tensor([float(datas[i + 1]) for i in range(len(datas) - 1)])
307
+ tokenized = tokenizer.convert_tokens_to_ids(word)
308
+ if isinstance(tokenized, int) and tokenized != tokenizer.unk_token_id:
309
+ embedding_list[tokenized] = embedding
310
+
311
+ return embedding_list
common/utils.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import importlib
3
+ import json
4
+ import os
5
+ import tarfile
6
+ from typing import List, Tuple
7
+ import zipfile
8
+ from collections import Callable
9
+ from ruamel import yaml
10
+ import requests
11
+ import torch
12
+ from torch.nn.utils.rnn import pad_sequence
13
+ from tqdm import tqdm
14
+ from torch import Tensor
15
+
16
+ class InputData():
17
+ """input datas class
18
+ """
19
+ def __init__(self, inputs: List =None):
20
+ """init input datas class
21
+
22
+ if inputs is None:
23
+ this class can be used to save all InputData in the history by 'merge_input_data(X:InputData)'
24
+ else:
25
+ this class can be used for model input.
26
+
27
+ Args:
28
+ inputs (List, optional): inputs with [tokenized_data, slot, intent]. Defaults to None.
29
+ """
30
+ if inputs == None:
31
+ self.slot = []
32
+ self.intent = []
33
+ self.input_ids = None
34
+ self.token_type_ids = None
35
+ self.attention_mask = None
36
+ self.seq_lens = None
37
+ else:
38
+ self.input_ids = inputs[0].input_ids
39
+ self.token_type_ids = None
40
+ if hasattr(inputs[0], "token_type_ids"):
41
+ self.token_type_ids = inputs[0].token_type_ids
42
+ self.attention_mask = inputs[0].attention_mask
43
+ if len(inputs)>=2:
44
+ self.slot = inputs[1]
45
+ if len(inputs)>=3:
46
+ self.intent = inputs[2]
47
+ self.seq_lens = self.attention_mask.sum(-1)
48
+
49
+ def get_inputs(self):
50
+ """ get tokenized_data
51
+
52
+ Returns:
53
+ dict: tokenized data
54
+ """
55
+ res = {
56
+ "input_ids": self.input_ids,
57
+ "attention_mask": self.attention_mask
58
+ }
59
+ if self.token_type_ids is not None:
60
+ res["token_type_ids"] = self.token_type_ids
61
+ return res
62
+
63
+ def merge_input_data(self, inp: "InputData"):
64
+ """merge another InputData object with slot and intent
65
+
66
+ Args:
67
+ inp (InputData): another InputData object
68
+ """
69
+ self.slot += inp.slot
70
+ self.intent += inp.intent
71
+
72
+ def get_slot_mask(self, ignore_index:int)->Tensor:
73
+ """get slot mask
74
+
75
+ Args:
76
+ ignore_index (int): ignore index used in slot padding
77
+
78
+ Returns:
79
+ Tensor: mask tensor
80
+ """
81
+ mask = self.slot != ignore_index
82
+ mask[:, 0] = torch.ones_like(mask[:, 0]).to(self.slot.device)
83
+ return mask
84
+
85
+ def get_item(self, index, tokenizer=None, intent_map=None, slot_map=None, ignore_index = -100):
86
+ res = {"input_ids": self.input_ids[index]}
87
+ if tokenizer is not None:
88
+ res["tokens"] = [tokenizer.decode(x) for x in self.input_ids[index]]
89
+ if intent_map is not None:
90
+ intents = self.intent.tolist()
91
+ if isinstance(intents[index], list):
92
+ res["intent"] = [intent_map[int(x)] for x in intents[index]]
93
+ else:
94
+ res["intent"] = intent_map[intents[index]]
95
+ if slot_map is not None:
96
+ res["slot"] = [slot_map[x] if x != ignore_index else "#" for x in self.slot.tolist()[index]]
97
+ return res
98
+
99
+ class OutputData():
100
+ """output data class
101
+ """
102
+ def __init__(self, intent_ids=None, slot_ids=None):
103
+ """init output data class
104
+
105
+ if intent_ids is None and slot_ids is None:
106
+ this class can be used to save all OutputData in the history by 'merge_output_data(X:OutputData)'
107
+ else:
108
+ this class can be used to model output management.
109
+
110
+ Args:
111
+ intent_ids (Any, optional): list(Tensor) of intent ids / logits / strings. Defaults to None.
112
+ slot_ids (Any, optional): list(Tensor) of slot ids / ids / strings. Defaults to None.
113
+ """
114
+ if intent_ids is None and slot_ids is None:
115
+ self.intent_ids = []
116
+ self.slot_ids = []
117
+ else:
118
+ if isinstance(intent_ids, ClassifierOutputData):
119
+ self.intent_ids = intent_ids.classifier_output
120
+ else:
121
+ self.intent_ids = intent_ids
122
+ if isinstance(slot_ids, ClassifierOutputData):
123
+ self.slot_ids = slot_ids.classifier_output
124
+ else:
125
+ self.slot_ids = slot_ids
126
+
127
+ def map_output(self, slot_map=None, intent_map=None):
128
+ """ map intent or slot ids to intent or slot string.
129
+
130
+ Args:
131
+ slot_map (dict, optional): slot id-to-string map. Defaults to None.
132
+ intent_map (dict, optional): intent id-to-string map. Defaults to None.
133
+ """
134
+ if self.slot_ids is not None:
135
+ if slot_map:
136
+ self.slot_ids = [[slot_map[x] if x >= 0 else "#" for x in sid] for sid in self.slot_ids]
137
+ if self.intent_ids is not None:
138
+ if intent_map:
139
+ self.intent_ids = [[intent_map[x] for x in sid] if isinstance(sid, list) else intent_map[sid] for sid in
140
+ self.intent_ids]
141
+
142
+ def merge_output_data(self, output:"OutputData"):
143
+ """merge another OutData object with slot and intent
144
+
145
+ Args:
146
+ output (OutputData): another OutputData object
147
+ """
148
+ if output.slot_ids is not None:
149
+ self.slot_ids += output.slot_ids
150
+ if output.intent_ids is not None:
151
+ self.intent_ids += output.intent_ids
152
+
153
+ def save(self, path:str, original_dataset=None):
154
+ """ save all OutputData in the history
155
+
156
+ Args:
157
+ path (str): save dir path
158
+ original_dataset(Iterable): original dataset
159
+ """
160
+ # with open(f"{path}/intent.jsonl", "w") as f:
161
+ # for x in self.intent_ids:
162
+ # f.write(json.dumps(x) + "\n")
163
+ with open(f"{path}/outputs.jsonl", "w") as f:
164
+ if original_dataset is not None:
165
+ for i, s, d in zip(self.intent_ids, self.slot_ids, original_dataset):
166
+ f.write(json.dumps({"pred_intent": i, "pred_slot": s, "text": d["text"], "golden_intent":d["intent"], "golden_slot":d["slot"]}) + "\n")
167
+ else:
168
+ for i, s in zip(self.intent_ids, self.slot_ids):
169
+ f.write(json.dumps({"pred_intent": i, "pred_slot": s}) + "\n")
170
+
171
+
172
+ class HiddenData():
173
+ """Interactive data structure for all model components
174
+ """
175
+ def __init__(self, intent_hidden, slot_hidden):
176
+ """init hidden data structure
177
+
178
+ Args:
179
+ intent_hidden (Any): sentence-level or intent hidden state
180
+ slot_hidden (Any): token-level or slot hidden state
181
+ """
182
+ self.intent_hidden = intent_hidden
183
+ self.slot_hidden = slot_hidden
184
+ self.inputs = None
185
+ self.embedding = None
186
+
187
+ def get_intent_hidden_state(self):
188
+ """get intent hidden state
189
+
190
+ Returns:
191
+ Any: intent hidden state
192
+ """
193
+ return self.intent_hidden
194
+
195
+ def get_slot_hidden_state(self):
196
+ """get slot hidden state
197
+
198
+ Returns:
199
+ Any: slot hidden state
200
+ """
201
+ return self.slot_hidden
202
+
203
+ def update_slot_hidden_state(self, hidden_state):
204
+ """update slot hidden state
205
+
206
+ Args:
207
+ hidden_state (Any): slot hidden state to update
208
+ """
209
+ self.slot_hidden = hidden_state
210
+
211
+ def update_intent_hidden_state(self, hidden_state):
212
+ """update intent hidden state
213
+
214
+ Args:
215
+ hidden_state (Any): intent hidden state to update
216
+ """
217
+ self.intent_hidden = hidden_state
218
+
219
+ def add_input(self, inputs: InputData or "HiddenData"):
220
+ """add last model component input information to next model component
221
+
222
+ Args:
223
+ inputs (InputDataor or HiddenData): last model component input
224
+ """
225
+ self.inputs = inputs
226
+
227
+ def add_embedding(self, embedding):
228
+ self.embedding = embedding
229
+
230
+
231
+ class ClassifierOutputData():
232
+ """Classifier output data structure of all classifier components
233
+ """
234
+ def __init__(self, classifier_output):
235
+ self.classifier_output = classifier_output
236
+ self.output_embedding = None
237
+
238
+ def remove_slot_ignore_index(inputs:InputData, outputs:OutputData, ignore_index=-100):
239
+ """ remove padding or extra token in input id and output id
240
+
241
+ Args:
242
+ inputs (InputData): input data with input id
243
+ outputs (OutputData): output data with decoded output id
244
+ ignore_index (int, optional): ignore_index in input_ids. Defaults to -100.
245
+
246
+ Returns:
247
+ InputData: input data removed padding or extra token
248
+ OutputData: output data removed padding or extra token
249
+ """
250
+ for index, (inp_ss, out_ss) in enumerate(zip(inputs.slot, outputs.slot_ids)):
251
+ temp_inp = []
252
+ temp_out = []
253
+ for inp_s, out_s in zip(list(inp_ss), list(out_ss)):
254
+ if inp_s != ignore_index:
255
+ temp_inp.append(inp_s)
256
+ temp_out.append(out_s)
257
+
258
+ inputs.slot[index] = temp_inp
259
+ outputs.slot_ids[index] = temp_out
260
+ return inputs, outputs
261
+
262
+
263
+ def pack_sequence(inputs:Tensor, seq_len:Tensor or List) -> Tensor:
264
+ """pack sequence data to packed data without padding.
265
+
266
+ Args:
267
+ inputs (Tensor): list(Tensor) of packed sequence inputs
268
+ seq_len (Tensor or List): list(Tensor) of sequence length
269
+
270
+ Returns:
271
+ Tensor: packed inputs
272
+
273
+ Examples:
274
+ inputs = [[x, y, z, PAD, PAD], [x, y, PAD, PAD, PAD]]
275
+
276
+ seq_len = [3,2]
277
+
278
+ return -> [x, y, z, x, y]
279
+ """
280
+ output = []
281
+ for index, batch in enumerate(inputs):
282
+ output.append(batch[:seq_len[index]])
283
+ return torch.cat(output, dim=0)
284
+
285
+
286
+ def unpack_sequence(inputs:Tensor, seq_lens:Tensor or List, padding_value=0) -> Tensor:
287
+ """unpack sequence data.
288
+
289
+ Args:
290
+ inputs (Tensor): list(Tensor) of packed sequence inputs
291
+ seq_lens (Tensor or List): list(Tensor) of sequence length
292
+ padding_value (int, optional): padding value. Defaults to 0.
293
+
294
+ Returns:
295
+ Tensor: unpacked inputs
296
+
297
+ Examples:
298
+ inputs = [x, y, z, x, y]
299
+
300
+ seq_len = [3,2]
301
+
302
+ return -> [[x, y, z, PAD, PAD], [x, y, PAD, PAD, PAD]]
303
+ """
304
+ last_idx = 0
305
+ output = []
306
+ for _, seq_len in enumerate(seq_lens):
307
+ output.append(inputs[last_idx:last_idx + seq_len])
308
+ last_idx = last_idx + seq_len
309
+ return pad_sequence(output, batch_first=True, padding_value=padding_value)
310
+
311
+
312
+ def get_dict_with_key_prefix(input_dict: dict, prefix=""):
313
+ res = {}
314
+ for t in input_dict:
315
+ res[t + prefix] = input_dict[t]
316
+ return res
317
+
318
+
319
+ def download(url: str, fname: str):
320
+ """download file from url to fname
321
+
322
+ Args:
323
+ url (str): remote server url path
324
+ fname (str): local path to save
325
+ """
326
+ resp = requests.get(url, stream=True)
327
+ total = int(resp.headers.get('content-length', 0))
328
+ with open(fname, 'wb') as file, tqdm(
329
+ desc=fname,
330
+ total=total,
331
+ unit='iB',
332
+ unit_scale=True,
333
+ unit_divisor=1024,
334
+ ) as bar:
335
+ for data in resp.iter_content(chunk_size=1024):
336
+ size = file.write(data)
337
+ bar.update(size)
338
+
339
+
340
+ def tar_gz_data(file_name:str):
341
+ """use "tar.gz" format to compress data
342
+
343
+ Args:
344
+ file_name (str): file path to tar
345
+ """
346
+ t = tarfile.open(f"{file_name}.tar.gz", "w:gz")
347
+
348
+ for root, dir, files in os.walk(f"{file_name}"):
349
+ print(root, dir, files)
350
+ for file in files:
351
+ fullpath = os.path.join(root, file)
352
+ t.add(fullpath)
353
+ t.close()
354
+
355
+
356
+ def untar(fname:str, dirs:str):
357
+ """ uncompress "tar.gz" file
358
+
359
+ Args:
360
+ fname (str): file path to untar
361
+ dirs (str): target dir path
362
+ """
363
+ t = tarfile.open(fname)
364
+ t.extractall(path=dirs)
365
+
366
+
367
+ def unzip_file(zip_src:str, dst_dir:str):
368
+ """ uncompress "zip" file
369
+
370
+ Args:
371
+ fname (str): file path to unzip
372
+ dirs (str): target dir path
373
+ """
374
+ r = zipfile.is_zipfile(zip_src)
375
+ if r:
376
+ if not os.path.exists(dst_dir):
377
+ os.mkdir(dst_dir)
378
+ fz = zipfile.ZipFile(zip_src, 'r')
379
+ for file in fz.namelist():
380
+ fz.extract(file, dst_dir)
381
+ else:
382
+ print('This is not zip')
383
+
384
+
385
+ def find_callable(target: str) -> Callable:
386
+ """ find callable function / class to instantiate
387
+
388
+ Args:
389
+ target (str): class/module path
390
+
391
+ Raises:
392
+ e: can not import module
393
+
394
+ Returns:
395
+ Callable: return function / class
396
+ """
397
+ target_module_path, target_callable_path = target.rsplit(".", 1)
398
+ target_callable_paths = [target_callable_path]
399
+
400
+ target_module = None
401
+ while len(target_module_path):
402
+ try:
403
+ target_module = importlib.import_module(target_module_path)
404
+ break
405
+ except Exception as e:
406
+ raise e
407
+ target_callable = target_module
408
+ for attr in reversed(target_callable_paths):
409
+ target_callable = getattr(target_callable, attr)
410
+
411
+ return target_callable
412
+
413
+
414
+ def instantiate(config, target="_model_target_", partial="_model_partial_"):
415
+ """ instantiate object by config.
416
+
417
+ Modified from https://github.com/HIT-SCIR/ltp/blob/main/python/core/ltp_core/models/utils/instantiate.py.
418
+
419
+ Args:
420
+ config (Any): configuration
421
+ target (str, optional): key to assign the class to be instantiated. Defaults to "_model_target_".
422
+ partial (str, optional): key to judge object whether should be instantiated partially. Defaults to "_model_partial_".
423
+
424
+ Returns:
425
+ Any: instantiated object
426
+ """
427
+ if isinstance(config, dict) and target in config:
428
+ target_path = config.get(target)
429
+ target_callable = find_callable(target_path)
430
+
431
+ is_partial = config.get(partial, False)
432
+ target_args = {
433
+ key: instantiate(value)
434
+ for key, value in config.items()
435
+ if key not in [target, partial]
436
+ }
437
+
438
+ if is_partial:
439
+ return functools.partial(target_callable, **target_args)
440
+ else:
441
+ return target_callable(**target_args)
442
+ elif isinstance(config, dict):
443
+ return {key: instantiate(value) for key, value in config.items()}
444
+ else:
445
+ return config
446
+
447
+
448
+ def load_yaml(file):
449
+ """ load data from yaml files.
450
+
451
+ Args:
452
+ file (str): yaml file path.
453
+
454
+ Returns:
455
+ Any: data
456
+ """
457
+ with open(file, encoding="utf-8") as stream:
458
+ try:
459
+ return yaml.safe_load(stream)
460
+ except yaml.YAMLError as exc:
461
+ raise exc
462
+
463
+ def from_configured(configure_name_or_file:str, model_class:Callable, config_prefix="./config/", **input_config):
464
+ """load module from pre-configured data
465
+
466
+ Args:
467
+ configure_name_or_file (str): config path -> {config_prefix}/{configure_name_or_file}.yaml
468
+ model_class (Callable): module class
469
+ config_prefix (str, optional): configuration root path. Defaults to "./config/".
470
+
471
+ Returns:
472
+ Any: instantiated object.
473
+ """
474
+ if os.path.exists(configure_name_or_file):
475
+ configure_file=configure_name_or_file
476
+ else:
477
+ configure_file= os.path.join(config_prefix, configure_name_or_file+".yaml")
478
+ config = load_yaml(configure_file)
479
+ config.update(input_config)
480
+ return model_class(**config)
481
+
482
+ def save_json(file_path, obj):
483
+ with open(file_path, 'w', encoding="utf8") as fw:
484
+ fw.write(json.dumps(obj))
485
+
486
+ def load_json(file_path):
487
+ with open(file_path, 'r', encoding="utf8") as fw:
488
+ res =json.load(fw)
489
+ return res
model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from model.open_slu_model import OpenSLUModel
2
+
3
+ __all__ = ["OpenSLUModel"]
model/decoder/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from model.decoder.agif_decoder import AGIFDecoder
2
+ from model.decoder.base_decoder import StackPropagationDecoder, BaseDecoder, DCANetDecoder
3
+ from model.decoder.gl_gin_decoder import GLGINDecoder
4
+
5
+ __all__ = ["StackPropagationDecoder", "BaseDecoder", "DCANetDecoder", "AGIFDecoder", "GLGINDecoder"]
model/decoder/agif_decoder.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from common.utils import HiddenData, OutputData
2
+ from model.decoder.base_decoder import BaseDecoder
3
+
4
+
5
+ class AGIFDecoder(BaseDecoder):
6
+ def forward(self, hidden: HiddenData, **kwargs):
7
+ # hidden = self.interaction(hidden)
8
+ pred_intent = self.intent_classifier(hidden)
9
+ intent_index = self.intent_classifier.decode(OutputData(pred_intent, None),
10
+ return_list=False,
11
+ return_sentence_level=True)
12
+ interact_args = {"intent_index": intent_index,
13
+ "batch_size": pred_intent.classifier_output.shape[0],
14
+ "intent_label_num": self.intent_classifier.config["intent_label_num"]}
15
+ pred_slot = self.slot_classifier(hidden, internal_interaction=self.interaction, **interact_args)
16
+ return OutputData(pred_intent, pred_slot)
model/decoder/base_decoder.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-31 18:22:36
6
+ Description:
7
+
8
+ '''
9
+ from torch import nn
10
+
11
+ from common.utils import HiddenData, OutputData, InputData
12
+
13
+
14
+ class BaseDecoder(nn.Module):
15
+ """Base class for all decoder module.
16
+
17
+ Notice: t is often only necessary to change this module and its sub-modules
18
+ """
19
+ def __init__(self, intent_classifier, slot_classifier, interaction=None):
20
+ super().__init__()
21
+ self.intent_classifier = intent_classifier
22
+ self.slot_classifier = slot_classifier
23
+ self.interaction = interaction
24
+
25
+ def forward(self, hidden: HiddenData):
26
+ """forward
27
+
28
+ Args:
29
+ hidden (HiddenData): encoded data
30
+
31
+ Returns:
32
+ OutputData: prediction logits
33
+ """
34
+ if self.interaction is not None:
35
+ hidden = self.interaction(hidden)
36
+ return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden))
37
+
38
+ def decode(self, output: OutputData, target: InputData = None):
39
+ """decode output logits
40
+
41
+ Args:
42
+ output (OutputData): output logits data
43
+ target (InputData, optional): input data with attention mask. Defaults to None.
44
+
45
+ Returns:
46
+ List: decoded sequence ids
47
+ """
48
+ return OutputData(self.intent_classifier.decode(output, target), self.slot_classifier.decode(output, target))
49
+
50
+ def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True):
51
+ """compute loss.
52
+ Notice: can set intent and slot loss weight by adding 'weight' config item in corresponding classifier configuration.
53
+
54
+ Args:
55
+ pred (OutputData): output logits data
56
+ target (InputData): input golden data
57
+ compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True.
58
+ compute_slot_loss (bool, optional): whether to compute intent loss. Defaults to True.
59
+
60
+ Returns:
61
+ Tensor: loss result
62
+ """
63
+ intent_loss = self.intent_classifier.compute_loss(pred, target) if compute_intent_loss else None
64
+ slot_loss = self.slot_classifier.compute_loss(pred, target) if compute_slot_loss else None
65
+ slot_weight = self.slot_classifier.config.get("weight")
66
+ slot_weight = slot_weight if slot_weight is not None else 1.
67
+ intent_weight = self.intent_classifier.config.get("weight")
68
+ intent_weight = intent_weight if intent_weight is not None else 1.
69
+ loss = 0
70
+ if intent_loss is not None:
71
+ loss += intent_loss * intent_weight
72
+ if slot_loss is not None:
73
+ loss += slot_loss * slot_weight
74
+ return loss, intent_loss, slot_loss
75
+
76
+
77
+ class StackPropagationDecoder(BaseDecoder):
78
+
79
+ def forward(self, hidden: HiddenData):
80
+ # hidden = self.interaction(hidden)
81
+ pred_intent = self.intent_classifier(hidden)
82
+ # embedding = pred_intent.output_embedding if pred_intent.output_embedding is not None else pred_intent.classifier_output
83
+ # hidden.update_intent_hidden_state(torch.cat([hidden.get_slot_hidden_state(), embedding], dim=-1))
84
+ hidden = self.interaction(pred_intent, hidden)
85
+ pred_slot = self.slot_classifier(hidden)
86
+ return OutputData(pred_intent, pred_slot)
87
+
88
+ class DCANetDecoder(BaseDecoder):
89
+
90
+ def forward(self, hidden: HiddenData):
91
+ if self.interaction is not None:
92
+ hidden = self.interaction(hidden, intent_emb=self.intent_classifier, slot_emb=self.slot_classifier)
93
+ return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden))
94
+
model/decoder/classifier.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-31 20:07:00
6
+ Description:
7
+
8
+ '''
9
+ import random
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss
15
+
16
+ from model.decoder import decoder_utils
17
+
18
+ from torchcrf import CRF
19
+
20
+ from common.utils import HiddenData, OutputData, InputData, ClassifierOutputData, unpack_sequence, pack_sequence, \
21
+ instantiate
22
+
23
+
24
+ class BaseClassifier(nn.Module):
25
+ """Base class for all classifier module
26
+ """
27
+ def __init__(self, **config):
28
+ super().__init__()
29
+ self.config = config
30
+ if config.get("loss_fn"):
31
+ self.loss_fn = instantiate(config.get("loss_fn"))
32
+ else:
33
+ self.loss_fn = CrossEntropyLoss(ignore_index=self.config.get("ignore_index"))
34
+
35
+ def forward(self, *args, **kwargs):
36
+ raise NotImplementedError("No implemented classifier.")
37
+
38
+ def decode(self, output: OutputData,
39
+ target: InputData = None,
40
+ return_list=True,
41
+ return_sentence_level=None):
42
+ """decode output logits
43
+
44
+ Args:
45
+ output (OutputData): output logits data
46
+ target (InputData, optional): input data with attention mask. Defaults to None.
47
+ return_list (bool, optional): if True return list else return torch Tensor.. Defaults to True.
48
+ return_sentence_level (_type_, optional): if True decode sentence level intent else decode token level intent. Defaults to None.
49
+
50
+ Returns:
51
+ List or Tensor: decoded sequence ids
52
+ """
53
+ if self.config.get("return_sentence_level") is not None and return_sentence_level is None:
54
+ return_sentence_level = self.config.get("return_sentence_level")
55
+ elif self.config.get("return_sentence_level") is None and return_sentence_level is None:
56
+ return_sentence_level = False
57
+ return decoder_utils.decode(output, target,
58
+ return_list=return_list,
59
+ return_sentence_level=return_sentence_level,
60
+ pred_type=self.config.get("mode"),
61
+ use_multi=self.config.get("use_multi"),
62
+ multi_threshold=self.config.get("multi_threshold"))
63
+
64
+ def compute_loss(self, pred: OutputData, target: InputData):
65
+ """compute loss
66
+
67
+ Args:
68
+ pred (OutputData): output logits data
69
+ target (InputData): input golden data
70
+
71
+ Returns:
72
+ Tensor: loss result
73
+ """
74
+ _CRF = None
75
+ if self.config.get("use_crf"):
76
+ _CRF = self.CRF
77
+ return decoder_utils.compute_loss(pred, target, criterion_type=self.config["mode"],
78
+ use_crf=_CRF is not None,
79
+ ignore_index=self.config["ignore_index"],
80
+ use_multi=self.config.get("use_multi"),
81
+ loss_fn=self.loss_fn,
82
+ CRF=_CRF)
83
+
84
+
85
+ class LinearClassifier(BaseClassifier):
86
+ """
87
+ Decoder structure based on Linear.
88
+ """
89
+ def __init__(self, **config):
90
+ """Construction function for LinearClassifier
91
+
92
+ Args:
93
+ config (dict):
94
+ input_dim (int): hidden state dim.
95
+ use_slot (bool): whether to classify slot label.
96
+ slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True.
97
+ use_intent (bool): whether to classify intent label.
98
+ intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True.
99
+ use_crf (bool): whether to use crf for slot.
100
+ """
101
+ super().__init__(**config)
102
+ self.config = config
103
+ if config.get("use_slot"):
104
+ self.slot_classifier = nn.Linear(config["input_dim"], config["slot_label_num"])
105
+ if self.config.get("use_crf"):
106
+ self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True)
107
+ if config.get("use_intent"):
108
+ self.intent_classifier = nn.Linear(config["input_dim"], config["intent_label_num"])
109
+
110
+ def forward(self, hidden: HiddenData):
111
+ if self.config.get("use_intent"):
112
+ return ClassifierOutputData(self.intent_classifier(hidden.get_intent_hidden_state()))
113
+ if self.config.get("use_slot"):
114
+ return ClassifierOutputData(self.slot_classifier(hidden.get_slot_hidden_state()))
115
+
116
+
117
+
118
+ class AutoregressiveLSTMClassifier(BaseClassifier):
119
+ """
120
+ Decoder structure based on unidirectional LSTM.
121
+ """
122
+
123
+ def __init__(self, **config):
124
+ """ Construction function for Decoder.
125
+
126
+ Args:
127
+ config (dict):
128
+ input_dim (int): input dimension of Decoder. In fact, it's encoder hidden size.
129
+ use_slot (bool): whether to classify slot label.
130
+ slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True.
131
+ use_intent (bool): whether to classify intent label.
132
+ intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True.
133
+ use_crf (bool): whether to use crf for slot.
134
+ hidden_dim (int): hidden dimension of iterative LSTM.
135
+ embedding_dim (int): if it's not None, the input and output are relevant.
136
+ dropout_rate (float): dropout rate of network which is only useful for embedding.
137
+ """
138
+
139
+ super(AutoregressiveLSTMClassifier, self).__init__(**config)
140
+ if config.get("use_slot") and config.get("use_crf"):
141
+ self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True)
142
+ self.input_dim = config["input_dim"]
143
+ self.hidden_dim = config["hidden_dim"]
144
+ if config.get("use_intent"):
145
+ self.output_dim = config["intent_label_num"]
146
+ if config.get("use_slot"):
147
+ self.output_dim = config["slot_label_num"]
148
+ self.dropout_rate = config["dropout_rate"]
149
+ self.embedding_dim = config.get("embedding_dim")
150
+ self.force_ratio = config.get("force_ratio")
151
+ self.config = config
152
+ self.ignore_index = config.get("ignore_index") if config.get("ignore_index") is not None else -100
153
+ # If embedding_dim is not None, the output and input
154
+ # of this structure is relevant.
155
+ if self.embedding_dim is not None:
156
+ self.embedding_layer = nn.Embedding(self.output_dim, self.embedding_dim)
157
+ self.init_tensor = nn.Parameter(
158
+ torch.randn(1, self.embedding_dim),
159
+ requires_grad=True
160
+ )
161
+
162
+ # Make sure the input dimension of iterative LSTM.
163
+ if self.embedding_dim is not None:
164
+ lstm_input_dim = self.input_dim + self.embedding_dim
165
+ else:
166
+ lstm_input_dim = self.input_dim
167
+
168
+ # Network parameter definition.
169
+ self.dropout_layer = nn.Dropout(self.dropout_rate)
170
+ self.lstm_layer = nn.LSTM(
171
+ input_size=lstm_input_dim,
172
+ hidden_size=self.hidden_dim,
173
+ batch_first=True,
174
+ bidirectional=self.config["bidirectional"],
175
+ dropout=self.dropout_rate,
176
+ num_layers=self.config["layer_num"]
177
+ )
178
+ self.linear_layer = nn.Linear(
179
+ self.hidden_dim,
180
+ self.output_dim
181
+ )
182
+ # self.loss_fn = CrossEntropyLoss(ignore_index=self.ignore_index)
183
+
184
+ def forward(self, hidden: HiddenData, internal_interaction=None, **interaction_args):
185
+ """ Forward process for decoder.
186
+
187
+ :param internal_interaction:
188
+ :param hidden:
189
+ :return: is distribution of prediction labels.
190
+ """
191
+ input_tensor = hidden.slot_hidden
192
+ seq_lens = hidden.inputs.attention_mask.sum(-1).detach().cpu().tolist()
193
+ output_tensor_list, sent_start_pos = [], 0
194
+ input_tensor = pack_sequence(input_tensor, seq_lens)
195
+ forced_input = None
196
+ if self.training:
197
+ if random.random() < self.force_ratio:
198
+ if self.config["mode"]=="slot":
199
+
200
+ forced_slot = pack_sequence(hidden.inputs.slot, seq_lens)
201
+ temp_slot = []
202
+ for index, x in enumerate(forced_slot):
203
+ if index == 0:
204
+ temp_slot.append(x.reshape(1))
205
+ elif x == self.ignore_index:
206
+ temp_slot.append(temp_slot[-1])
207
+ else:
208
+ temp_slot.append(x.reshape(1))
209
+ forced_input = torch.cat(temp_slot, 0)
210
+ if self.config["mode"]=="token-level-intent":
211
+ forced_intent = hidden.inputs.intent.unsqueeze(1).repeat(1, hidden.inputs.slot.shape[1])
212
+ forced_input = pack_sequence(forced_intent, seq_lens)
213
+ if self.embedding_dim is None or forced_input is not None:
214
+
215
+ for sent_i in range(0, len(seq_lens)):
216
+ sent_end_pos = sent_start_pos + seq_lens[sent_i]
217
+
218
+ # Segment input hidden tensors.
219
+ seg_hiddens = input_tensor[sent_start_pos: sent_end_pos, :]
220
+
221
+ if self.embedding_dim is not None and forced_input is not None:
222
+ if seq_lens[sent_i] > 1:
223
+ seg_forced_input = forced_input[sent_start_pos: sent_end_pos]
224
+
225
+ seg_forced_tensor = self.embedding_layer(seg_forced_input)[:-1]
226
+ seg_prev_tensor = torch.cat([self.init_tensor, seg_forced_tensor], dim=0)
227
+ else:
228
+ seg_prev_tensor = self.init_tensor
229
+
230
+ # Concatenate forced target tensor.
231
+ combined_input = torch.cat([seg_hiddens, seg_prev_tensor], dim=1)
232
+ else:
233
+ combined_input = seg_hiddens
234
+ dropout_input = self.dropout_layer(combined_input)
235
+ lstm_out, _ = self.lstm_layer(dropout_input.view(1, seq_lens[sent_i], -1))
236
+ if internal_interaction is not None:
237
+ interaction_args["sent_id"] = sent_i
238
+ lstm_out = internal_interaction(torch.transpose(lstm_out, 0, 1), **interaction_args)[:, 0]
239
+ linear_out = self.linear_layer(lstm_out.view(seq_lens[sent_i], -1))
240
+
241
+ output_tensor_list.append(linear_out)
242
+ sent_start_pos = sent_end_pos
243
+ else:
244
+ for sent_i in range(0, len(seq_lens)):
245
+ prev_tensor = self.init_tensor
246
+
247
+ # It's necessary to remember h and c state
248
+ # when output prediction every single step.
249
+ last_h, last_c = None, None
250
+
251
+ sent_end_pos = sent_start_pos + seq_lens[sent_i]
252
+ for word_i in range(sent_start_pos, sent_end_pos):
253
+ seg_input = input_tensor[[word_i], :]
254
+ combined_input = torch.cat([seg_input, prev_tensor], dim=1)
255
+ dropout_input = self.dropout_layer(combined_input).view(1, 1, -1)
256
+ if last_h is None and last_c is None:
257
+ lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input)
258
+ else:
259
+ lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input, (last_h, last_c))
260
+
261
+ if internal_interaction is not None:
262
+ interaction_args["sent_id"] = sent_i
263
+ lstm_out = internal_interaction(lstm_out, **interaction_args)[:, 0]
264
+
265
+ lstm_out = self.linear_layer(lstm_out.view(1, -1))
266
+ output_tensor_list.append(lstm_out)
267
+
268
+ _, index = lstm_out.topk(1, dim=1)
269
+ prev_tensor = self.embedding_layer(index).view(1, -1)
270
+ sent_start_pos = sent_end_pos
271
+ seq_unpacked = unpack_sequence(torch.cat(output_tensor_list, dim=0), seq_lens)
272
+ # TODO: 都支持softmax
273
+ if self.config.get("use_multi"):
274
+ pred_output = ClassifierOutputData(seq_unpacked)
275
+ else:
276
+ pred_output = ClassifierOutputData(F.log_softmax(seq_unpacked, dim=-1))
277
+ return pred_output
278
+
279
+
280
+ class MLPClassifier(BaseClassifier):
281
+ """
282
+ Decoder structure based on MLP.
283
+ """
284
+ def __init__(self, **config):
285
+ """ Construction function for Decoder.
286
+
287
+ Args:
288
+ config (dict):
289
+ use_slot (bool): whether to classify slot label.
290
+ use_intent (bool): whether to classify intent label.
291
+ mlp (List):
292
+
293
+ - _model_target_: torch.nn.Linear
294
+
295
+ in_features (int): input feature dim
296
+
297
+ out_features (int): output feature dim
298
+
299
+ - _model_target_: torch.nn.LeakyReLU
300
+
301
+ negative_slope: 0.2
302
+
303
+ - ...
304
+ """
305
+ super(MLPClassifier, self).__init__(**config)
306
+ self.config = config
307
+ for i, x in enumerate(config["mlp"]):
308
+ if isinstance(x.get("in_features"), str):
309
+ config["mlp"][i]["in_features"] = self.config[x["in_features"][1:-1]]
310
+ if isinstance(x.get("out_features"), str):
311
+ config["mlp"][i]["out_features"] = self.config[x["out_features"][1:-1]]
312
+ mlp = [instantiate(x) for x in config["mlp"]]
313
+ self.seq = nn.Sequential(*mlp)
314
+
315
+
316
+ def forward(self, hidden: HiddenData):
317
+ if self.config.get("use_intent"):
318
+ res = self.seq(hidden.intent_hidden)
319
+ else:
320
+ res = self.seq(hidden.slot_hidden)
321
+ return ClassifierOutputData(res)
model/decoder/decoder_utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+
4
+ from common import utils
5
+ from common.utils import OutputData, InputData
6
+ from torch import Tensor
7
+
8
+ def argmax_for_seq_len(inputs, seq_lens, padding_value=-100):
9
+ packed_inputs = utils.pack_sequence(inputs, seq_lens)
10
+ outputs = torch.argmax(packed_inputs, dim=-1, keepdim=True)
11
+ return utils.unpack_sequence(outputs, seq_lens, padding_value).squeeze(-1)
12
+
13
+
14
+ def decode(output: OutputData,
15
+ target: InputData = None,
16
+ pred_type="slot",
17
+ multi_threshold=0.5,
18
+ ignore_index=-100,
19
+ return_list=True,
20
+ return_sentence_level=True,
21
+ use_multi=False,
22
+ use_crf=False,
23
+ CRF=None) -> List or Tensor:
24
+ """ decode output logits
25
+
26
+ Args:
27
+ output (OutputData): output logits data
28
+ target (InputData, optional): input data with attention mask. Defaults to None.
29
+ pred_type (str, optional): prediction type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
30
+ multi_threshold (float, optional): multi intent decode threshold. Defaults to 0.5.
31
+ ignore_index (int, optional): align and pad token with ignore index. Defaults to -100.
32
+ return_list (bool, optional): if True return list else return torch Tensor. Defaults to True.
33
+ return_sentence_level (bool, optional): if True decode sentence level intent else decode token level intent. Defaults to True.
34
+ use_multi (bool, optional): whether to decode to multi intent. Defaults to False.
35
+ use_crf (bool, optional): whether to use crf. Defaults to False.
36
+ CRF (CRF, optional): CRF function. Defaults to None.
37
+
38
+ Returns:
39
+ List or Tensor: decoded sequence ids
40
+ """
41
+ if pred_type == "slot":
42
+ inputs = output.slot_ids
43
+ else:
44
+ inputs = output.intent_ids
45
+
46
+ if pred_type == "slot":
47
+ if not use_multi:
48
+ if use_crf:
49
+ res = CRF.decode(inputs, mask=target.attention_mask)
50
+ else:
51
+ res = torch.argmax(inputs, dim=-1)
52
+ else:
53
+ raise NotImplementedError("Multi-slot prediction is not supported.")
54
+ elif pred_type == "intent":
55
+ if not use_multi:
56
+ res = torch.argmax(inputs, dim=-1)
57
+ else:
58
+ res = (torch.sigmoid(inputs) > multi_threshold).nonzero()
59
+ if return_list:
60
+ res_index = res.detach().cpu().tolist()
61
+ res_list = [[] for _ in range(len(target.seq_lens))]
62
+ for item in res_index:
63
+ res_list[item[0]].append(item[1])
64
+ return res_list
65
+ else:
66
+ return res
67
+ elif pred_type == "token-level-intent":
68
+ if not use_multi:
69
+ res = torch.argmax(inputs, dim=-1)
70
+ if not return_sentence_level:
71
+ return res
72
+ if return_list:
73
+ res = res.detach().cpu().tolist()
74
+ attention_mask = target.attention_mask
75
+ for i in range(attention_mask.shape[0]):
76
+ temp = []
77
+ for j in range(attention_mask.shape[1]):
78
+ if attention_mask[i][j] == 1:
79
+ temp.append(res[i][j])
80
+ else:
81
+ break
82
+ res[i] = temp
83
+ return [max(it, key=lambda v: it.count(v)) for it in res]
84
+ else:
85
+ seq_lens = target.seq_lens
86
+
87
+ if not return_sentence_level:
88
+ token_res = torch.cat([
89
+ torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold
90
+ for i in range(len(seq_lens))],
91
+ dim=0)
92
+ return utils.unpack_sequence(token_res, seq_lens, padding_value=ignore_index)
93
+
94
+ intent_index_sum = torch.cat([
95
+ torch.sum(torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold, dim=0).unsqueeze(0)
96
+ for i in range(len(seq_lens))],
97
+ dim=0)
98
+
99
+ res = (intent_index_sum > torch.div(seq_lens, 2, rounding_mode='floor').unsqueeze(1)).nonzero()
100
+ if return_list:
101
+ res_index = res.detach().cpu().tolist()
102
+ res_list = [[] for _ in range(len(seq_lens))]
103
+ for item in res_index:
104
+ res_list[item[0]].append(item[1])
105
+ return res_list
106
+ else:
107
+ return res
108
+ else:
109
+ raise NotImplementedError("Prediction mode except ['slot','intent','token-level-intent'] is not supported.")
110
+ if return_list:
111
+ res = res.detach().cpu().tolist()
112
+ return res
113
+
114
+
115
+ def compute_loss(pred: OutputData,
116
+ target: InputData,
117
+ criterion_type="slot",
118
+ use_crf=False,
119
+ ignore_index=-100,
120
+ loss_fn=None,
121
+ use_multi=False,
122
+ CRF=None):
123
+ """ compute loss
124
+
125
+ Args:
126
+ pred (OutputData): output logits data
127
+ target (InputData): input golden data
128
+ criterion_type (str, optional): criterion type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
129
+ ignore_index (int, optional): compute loss with ignore index. Defaults to -100.
130
+ loss_fn (_type_, optional): loss function. Defaults to None.
131
+ use_crf (bool, optional): whether to use crf. Defaults to False.
132
+ CRF (CRF, optional): CRF function. Defaults to None.
133
+
134
+ Returns:
135
+ Tensor: loss result
136
+ """
137
+ if criterion_type == "slot":
138
+ if use_crf:
139
+ return -1 * CRF(pred.slot_ids, target.slot, target.get_slot_mask(ignore_index).byte())
140
+ else:
141
+ pred_slot = utils.pack_sequence(pred.slot_ids, target.seq_lens)
142
+ target_slot = utils.pack_sequence(target.slot, target.seq_lens)
143
+ return loss_fn(pred_slot, target_slot)
144
+ elif criterion_type == "token-level-intent":
145
+ # TODO: Two decode function
146
+ intent_target = target.intent.unsqueeze(1)
147
+ if not use_multi:
148
+ intent_target = intent_target.repeat(1, pred.intent_ids.shape[1])
149
+ else:
150
+ intent_target = intent_target.repeat(1, pred.intent_ids.shape[1], 1)
151
+ intent_pred = utils.pack_sequence(pred.intent_ids, target.seq_lens)
152
+ intent_target = utils.pack_sequence(intent_target, target.seq_lens)
153
+ return loss_fn(intent_pred, intent_target)
154
+ else:
155
+ return loss_fn(pred.intent_ids, target.intent)
model/decoder/gl_gin_decoder.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from common.utils import HiddenData, OutputData, InputData
6
+ from model.decoder import BaseDecoder
7
+ from model.decoder.interaction.gl_gin_interaction import LSTMEncoder
8
+
9
+
10
+ class IntentEncoder(nn.Module):
11
+ def __init__(self,input_dim, dropout_rate):
12
+ super().__init__()
13
+ self.dropout_rate = dropout_rate
14
+ self.__intent_lstm = LSTMEncoder(
15
+ input_dim,
16
+ input_dim,
17
+ dropout_rate
18
+ )
19
+
20
+ def forward(self, g_hiddens, seq_lens):
21
+ intent_lstm_out = self.__intent_lstm(g_hiddens, seq_lens)
22
+ return F.dropout(intent_lstm_out, p=self.dropout_rate, training=self.training)
23
+
24
+
25
+ class GLGINDecoder(BaseDecoder):
26
+ def __init__(self, intent_classifier, slot_classifier, interaction=None, **config):
27
+ super().__init__(intent_classifier, slot_classifier, interaction)
28
+ self.config=config
29
+ self.intent_encoder = IntentEncoder(self.intent_classifier.config["input_dim"], self.config["dropout_rate"])
30
+
31
+ def forward(self, hidden: HiddenData, forced_slot=None, forced_intent=None, differentiable=None):
32
+ seq_lens = hidden.inputs.attention_mask.sum(-1)
33
+ intent_lstm_out = self.intent_encoder(hidden.slot_hidden, seq_lens)
34
+ hidden.update_intent_hidden_state(intent_lstm_out)
35
+ pred_intent = self.intent_classifier(hidden)
36
+ intent_index = self.intent_classifier.decode(OutputData(pred_intent, None),hidden.inputs,
37
+ return_list=False,
38
+ return_sentence_level=True)
39
+ slot_hidden = self.interaction(
40
+ hidden,
41
+ pred_intent=pred_intent,
42
+ intent_index=intent_index,
43
+ )
44
+ pred_slot = self.slot_classifier(slot_hidden)
45
+ num_intent = self.intent_classifier.config["intent_label_num"]
46
+ pred_slot = pred_slot.classifier_output[:, num_intent:]
47
+ return OutputData(pred_intent, F.log_softmax(pred_slot, dim=1))
model/decoder/interaction/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.decoder.interaction.agif_interaction import AGIFInteraction
2
+ from model.decoder.interaction.base_interaction import BaseInteraction
3
+ from model.decoder.interaction.bi_model_interaction import BiModelInteraction, BiModelWithoutDecoderInteraction
4
+ from model.decoder.interaction.dca_net_interaction import DCANetInteraction
5
+ from model.decoder.interaction.gl_gin_interaction import GLGINInteraction
6
+ from model.decoder.interaction.slot_gated_interaction import SlotGatedInteraction
7
+ from model.decoder.interaction.stack_interaction import StackInteraction
8
+
9
+ __all__ = ["BaseInteraction", "BiModelInteraction", "BiModelWithoutDecoderInteraction", "DCANetInteraction",
10
+ "StackInteraction", "SlotGatedInteraction", "AGIFInteraction", "GLGINInteraction"]
model/decoder/interaction/agif_interaction.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from model.decoder.interaction.base_interaction import BaseInteraction
6
+
7
+
8
+ class GraphAttentionLayer(nn.Module):
9
+ """
10
+ Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
11
+ """
12
+
13
+ def __init__(self, in_features, out_features, dropout, alpha, concat=True):
14
+ super(GraphAttentionLayer, self).__init__()
15
+ self.dropout = dropout
16
+ self.in_features = in_features
17
+ self.out_features = out_features
18
+ self.alpha = alpha
19
+ self.concat = concat
20
+
21
+ self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
22
+ nn.init.xavier_uniform_(self.W.data, gain=1.414)
23
+ self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
24
+ nn.init.xavier_uniform_(self.a.data, gain=1.414)
25
+
26
+ self.leakyrelu = nn.LeakyReLU(self.alpha)
27
+
28
+ def forward(self, input, adj):
29
+ h = torch.matmul(input, self.W)
30
+ B, N = h.size()[0], h.size()[1]
31
+
32
+ a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=2).view(B, N, -1,
33
+ 2 * self.out_features)
34
+ e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))
35
+
36
+ zero_vec = -9e15 * torch.ones_like(e)
37
+ attention = torch.where(adj > 0, e, zero_vec)
38
+ attention = F.softmax(attention, dim=2)
39
+ attention = F.dropout(attention, self.dropout, training=self.training)
40
+ h_prime = torch.matmul(attention, h)
41
+
42
+ if self.concat:
43
+ return F.elu(h_prime)
44
+ else:
45
+ return h_prime
46
+
47
+
48
+ class GAT(nn.Module):
49
+ def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers=2):
50
+ """Dense version of GAT."""
51
+ super(GAT, self).__init__()
52
+ self.dropout = dropout
53
+ self.nlayers = nlayers
54
+ self.nheads = nheads
55
+ self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in
56
+ range(nheads)]
57
+ for i, attention in enumerate(self.attentions):
58
+ self.add_module('attention_{}'.format(i), attention)
59
+ if self.nlayers > 2:
60
+ for i in range(self.nlayers - 2):
61
+ for j in range(self.nheads):
62
+ self.add_module('attention_{}_{}'.format(i + 1, j),
63
+ GraphAttentionLayer(nhid * nheads, nhid, dropout=dropout, alpha=alpha, concat=True))
64
+
65
+ self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
66
+
67
+ def forward(self, x, adj):
68
+ x = F.dropout(x, self.dropout, training=self.training)
69
+ input = x
70
+ x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
71
+ if self.nlayers > 2:
72
+ for i in range(self.nlayers - 2):
73
+ temp = []
74
+ x = F.dropout(x, self.dropout, training=self.training)
75
+ cur_input = x
76
+ for j in range(self.nheads):
77
+ temp.append(self.__getattr__('attention_{}_{}'.format(i + 1, j))(x, adj))
78
+ x = torch.cat(temp, dim=2) + cur_input
79
+ x = F.dropout(x, self.dropout, training=self.training)
80
+ x = F.elu(self.out_att(x, adj))
81
+ return x + input
82
+
83
+
84
+ def normalize_adj(mx):
85
+ """
86
+ Row-normalize matrix D^{-1}A
87
+ torch.diag_embed: https://github.com/pytorch/pytorch/pull/12447
88
+ """
89
+ mx = mx.float()
90
+ rowsum = mx.sum(2)
91
+ r_inv = torch.pow(rowsum, -1)
92
+ r_inv[torch.isinf(r_inv)] = 0.
93
+ r_mat_inv = torch.diag_embed(r_inv, 0)
94
+ mx = r_mat_inv.matmul(mx)
95
+ return mx
96
+
97
+
98
+ class AGIFInteraction(BaseInteraction):
99
+ def __init__(self, **config):
100
+ super().__init__(**config)
101
+ self.intent_embedding = nn.Parameter(
102
+ torch.FloatTensor(self.config["intent_label_num"], self.config["intent_embedding_dim"])) # 191, 32
103
+ nn.init.normal_(self.intent_embedding.data)
104
+ self.adj = None
105
+ self.graph = GAT(
106
+ config["output_dim"],
107
+ config["hidden_dim"],
108
+ config["output_dim"],
109
+ config["dropout_rate"],
110
+ config["alpha"],
111
+ config["num_heads"],
112
+ config["num_layers"])
113
+
114
+ def generate_adj_gat(self, index, batch, intent_label_num):
115
+ intent_idx_ = [[torch.tensor(0)] for i in range(batch)]
116
+ for item in index:
117
+ intent_idx_[item[0]].append(item[1] + 1)
118
+ intent_idx = intent_idx_
119
+ self.adj = torch.cat([torch.eye(intent_label_num + 1).unsqueeze(0) for i in range(batch)])
120
+ for i in range(batch):
121
+ for j in intent_idx[i]:
122
+ self.adj[i, j, intent_idx[i]] = 1.
123
+ if self.config["row_normalized"]:
124
+ self.adj = normalize_adj(self.adj)
125
+ self.adj = self.adj.to(self.intent_embedding.device)
126
+
127
+ def forward(self, encode_hidden, **interaction_args):
128
+ if self.adj is None or interaction_args["sent_id"] == 0:
129
+ self.generate_adj_gat(interaction_args["intent_index"], interaction_args["batch_size"], interaction_args["intent_label_num"])
130
+ lstm_out = torch.cat((encode_hidden,
131
+ self.intent_embedding.unsqueeze(0).repeat(encode_hidden.shape[0], 1, 1)), dim=1)
132
+ return self.graph(lstm_out, self.adj[interaction_args["sent_id"]])
model/decoder/interaction/base_interaction.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ class BaseInteraction(nn.Module):
4
+ def __init__(self, **config):
5
+ super().__init__()
6
+ self.config = config
7
+
8
+ def forward(self, hidden1, hidden2):
9
+ NotImplementedError("no implemented")
model/decoder/interaction/bi_model_interaction.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from common.utils import HiddenData
6
+ from model.decoder.interaction.base_interaction import BaseInteraction
7
+
8
+
9
+ class BiModelInteraction(BaseInteraction):
10
+ def __init__(self, **config):
11
+ super().__init__(**config)
12
+ self.intent_lstm = nn.LSTM(input_size=self.config["input_dim"], hidden_size=self.config["output_dim"],
13
+ batch_first=True,
14
+ num_layers=1)
15
+ self.slot_lstm = nn.LSTM(input_size=self.config["input_dim"] + self.config["output_dim"],
16
+ hidden_size=self.config["output_dim"], num_layers=1)
17
+
18
+ def forward(self, encode_hidden: HiddenData, **kwargs):
19
+ slot_hidden = encode_hidden.get_slot_hidden_state()
20
+ intent_hidden_detached = encode_hidden.get_intent_hidden_state().clone().detach()
21
+ seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
22
+ batch = slot_hidden.size(0)
23
+ length = slot_hidden.size(1)
24
+ dec_init_out = torch.zeros(batch, 1, self.config["output_dim"]).to(slot_hidden.device)
25
+ hidden_state = (torch.zeros(1, 1, self.config["output_dim"]).to(slot_hidden.device), torch.zeros(1, 1, self.config["output_dim"]).to(slot_hidden.device))
26
+ slot_hidden = torch.cat((slot_hidden, intent_hidden_detached), dim=-1).transpose(1,
27
+ 0) # 50 x batch x feature_size
28
+ slot_drop = F.dropout(slot_hidden, self.config["dropout_rate"])
29
+ all_out = []
30
+ for i in range(length):
31
+ if i == 0:
32
+ out, hidden_state = self.slot_lstm(torch.cat((slot_drop[i].unsqueeze(1), dec_init_out), dim=-1),
33
+ hidden_state)
34
+ else:
35
+ out, hidden_state = self.slot_lstm(torch.cat((slot_drop[i].unsqueeze(1), out), dim=-1), hidden_state)
36
+ all_out.append(out)
37
+ slot_output = torch.cat(all_out, dim=1) # batch x 50 x feature_size
38
+
39
+ intent_hidden = torch.cat((encode_hidden.get_intent_hidden_state(),
40
+ encode_hidden.get_slot_hidden_state().clone().detach()),
41
+ dim=-1)
42
+ intent_drop = F.dropout(intent_hidden, self.config["dropout_rate"])
43
+ intent_lstm_output, _ = self.intent_lstm(intent_drop)
44
+ intent_output = F.dropout(intent_lstm_output, self.config["dropout_rate"])
45
+ output_list = []
46
+ for index, slen in enumerate(seq_lens):
47
+ output_list.append(intent_output[index, slen - 1, :].unsqueeze(0))
48
+
49
+ encode_hidden.update_intent_hidden_state(torch.cat(output_list, dim=0))
50
+ encode_hidden.update_slot_hidden_state(slot_output)
51
+
52
+ return encode_hidden
53
+
54
+
55
+ class BiModelWithoutDecoderInteraction(BaseInteraction):
56
+ def forward(self, encode_hidden: HiddenData, **kwargs):
57
+ slot_hidden = encode_hidden.get_slot_hidden_state()
58
+ intent_hidden_detached = encode_hidden.get_intent_hidden_state().clone().detach()
59
+ seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
60
+ slot_hidden = torch.cat((slot_hidden, intent_hidden_detached), dim=-1) # 50 x batch x feature_size
61
+ slot_output = F.dropout(slot_hidden, self.config["dropout_rate"])
62
+
63
+ intent_hidden = torch.cat((encode_hidden.get_intent_hidden_state(),
64
+ encode_hidden.get_slot_hidden_state().clone().detach()),
65
+ dim=-1)
66
+ intent_output = F.dropout(intent_hidden, self.config["dropout_rate"])
67
+ output_list = []
68
+ for index, slen in enumerate(seq_lens):
69
+ output_list.append(intent_output[index, slen - 1, :].unsqueeze(0))
70
+
71
+ encode_hidden.update_intent_hidden_state(torch.cat(output_list, dim=0))
72
+ encode_hidden.update_slot_hidden_state(slot_output)
73
+
74
+ return encode_hidden
model/decoder/interaction/dca_net_interaction.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import LayerNorm
7
+
8
+ from common.utils import HiddenData
9
+ from model.decoder.interaction import BaseInteraction
10
+
11
+
12
+ class DCANetInteraction(BaseInteraction):
13
+ def __init__(self, **config):
14
+ super().__init__(**config)
15
+ self.I_S_Emb = Label_Attention()
16
+ self.T_block1 = I_S_Block(self.config["input_dim"], self.config["attention_dropout"], self.config["num_attention_heads"])
17
+ self.T_block2 = I_S_Block(self.config["input_dim"], self.config["attention_dropout"], self.config["num_attention_heads"])
18
+
19
+ def forward(self, encode_hidden: HiddenData, **kwargs):
20
+ mask = encode_hidden.inputs.attention_mask
21
+ H = encode_hidden.slot_hidden
22
+ H_I, H_S = self.I_S_Emb(H, H, kwargs["intent_emb"], kwargs["slot_emb"])
23
+ H_I, H_S = self.T_block1(H_I + H, H_S + H, mask)
24
+ H_I_1, H_S_1 = self.I_S_Emb(H_I, H_S, kwargs["intent_emb"], kwargs["slot_emb"])
25
+ H_I, H_S = self.T_block2(H_I + H_I_1, H_S + H_S_1, mask)
26
+ encode_hidden.update_intent_hidden_state(F.max_pool1d((H_I + H).transpose(1, 2), H_I.size(1)).squeeze(2))
27
+ encode_hidden.update_slot_hidden_state(H_S + H)
28
+ return encode_hidden
29
+
30
+
31
+ class Label_Attention(nn.Module):
32
+ def __init__(self):
33
+ super(Label_Attention, self).__init__()
34
+
35
+ def forward(self, input_intent, input_slot, intent_emb, slot_emb):
36
+ self.W_intent_emb = intent_emb.intent_classifier.weight
37
+ self.W_slot_emb = slot_emb.slot_classifier.weight
38
+ intent_score = torch.matmul(input_intent, self.W_intent_emb.t())
39
+ slot_score = torch.matmul(input_slot, self.W_slot_emb.t())
40
+ intent_probs = nn.Softmax(dim=-1)(intent_score)
41
+ slot_probs = nn.Softmax(dim=-1)(slot_score)
42
+ intent_res = torch.matmul(intent_probs, self.W_intent_emb)
43
+ slot_res = torch.matmul(slot_probs, self.W_slot_emb)
44
+
45
+ return intent_res, slot_res
46
+
47
+
48
+ class I_S_Block(nn.Module):
49
+ def __init__(self, hidden_size, attention_dropout, num_attention_heads):
50
+ super(I_S_Block, self).__init__()
51
+ self.I_S_Attention = I_S_SelfAttention(hidden_size, 2 * hidden_size, hidden_size, attention_dropout, num_attention_heads)
52
+ self.I_Out = SelfOutput(hidden_size, attention_dropout)
53
+ self.S_Out = SelfOutput(hidden_size, attention_dropout)
54
+ self.I_S_Feed_forward = Intermediate_I_S(hidden_size, hidden_size, attention_dropout)
55
+
56
+ def forward(self, H_intent_input, H_slot_input, mask):
57
+ H_slot, H_intent = self.I_S_Attention(H_intent_input, H_slot_input, mask)
58
+ H_slot = self.S_Out(H_slot, H_slot_input)
59
+ H_intent = self.I_Out(H_intent, H_intent_input)
60
+ H_intent, H_slot = self.I_S_Feed_forward(H_intent, H_slot)
61
+
62
+ return H_intent, H_slot
63
+
64
+
65
+ class I_S_SelfAttention(nn.Module):
66
+ def __init__(self, input_size, hidden_size, out_size, attention_dropout, num_attention_heads):
67
+ super(I_S_SelfAttention, self).__init__()
68
+
69
+ self.num_attention_heads = num_attention_heads
70
+ self.attention_head_size = int(hidden_size / self.num_attention_heads)
71
+
72
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
73
+ self.out_size = out_size
74
+ self.query = nn.Linear(input_size, self.all_head_size)
75
+ self.query_slot = nn.Linear(input_size, self.all_head_size)
76
+ self.key = nn.Linear(input_size, self.all_head_size)
77
+ self.key_slot = nn.Linear(input_size, self.all_head_size)
78
+ self.value = nn.Linear(input_size, self.out_size)
79
+ self.value_slot = nn.Linear(input_size, self.out_size)
80
+ self.dropout = nn.Dropout(attention_dropout)
81
+
82
+ def transpose_for_scores(self, x):
83
+ last_dim = int(x.size()[-1] / self.num_attention_heads)
84
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, last_dim)
85
+ x = x.view(*new_x_shape)
86
+ return x.permute(0, 2, 1, 3)
87
+
88
+ def forward(self, intent, slot, mask):
89
+ extended_attention_mask = mask.unsqueeze(1).unsqueeze(2)
90
+
91
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
92
+ attention_mask = (1.0 - extended_attention_mask) * -10000.0
93
+
94
+ mixed_query_layer = self.query(intent)
95
+ mixed_key_layer = self.key(slot)
96
+ mixed_value_layer = self.value(slot)
97
+
98
+ mixed_query_layer_slot = self.query_slot(slot)
99
+ mixed_key_layer_slot = self.key_slot(intent)
100
+ mixed_value_layer_slot = self.value_slot(intent)
101
+
102
+ query_layer = self.transpose_for_scores(mixed_query_layer)
103
+ query_layer_slot = self.transpose_for_scores(mixed_query_layer_slot)
104
+ key_layer = self.transpose_for_scores(mixed_key_layer)
105
+ key_layer_slot = self.transpose_for_scores(mixed_key_layer_slot)
106
+ value_layer = self.transpose_for_scores(mixed_value_layer)
107
+ value_layer_slot = self.transpose_for_scores(mixed_value_layer_slot)
108
+
109
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
110
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
111
+ # attention_scores_slot = torch.matmul(query_slot, key_slot.transpose(1,0))
112
+ attention_scores_slot = torch.matmul(query_layer_slot, key_layer_slot.transpose(-1, -2))
113
+ attention_scores_slot = attention_scores_slot / math.sqrt(self.attention_head_size)
114
+ attention_scores_intent = attention_scores + attention_mask
115
+
116
+ attention_scores_slot = attention_scores_slot + attention_mask
117
+
118
+ # Normalize the attention scores to probabilities.
119
+ attention_probs_slot = nn.Softmax(dim=-1)(attention_scores_slot)
120
+ attention_probs_intent = nn.Softmax(dim=-1)(attention_scores_intent)
121
+
122
+ attention_probs_slot = self.dropout(attention_probs_slot)
123
+ attention_probs_intent = self.dropout(attention_probs_intent)
124
+
125
+ context_layer_slot = torch.matmul(attention_probs_slot, value_layer_slot)
126
+ context_layer_intent = torch.matmul(attention_probs_intent, value_layer)
127
+
128
+ context_layer = context_layer_slot.permute(0, 2, 1, 3).contiguous()
129
+ context_layer_intent = context_layer_intent.permute(0, 2, 1, 3).contiguous()
130
+ new_context_layer_shape = context_layer.size()[:-2] + (self.out_size,)
131
+ new_context_layer_shape_intent = context_layer_intent.size()[:-2] + (self.out_size,)
132
+
133
+ context_layer = context_layer.view(*new_context_layer_shape)
134
+ context_layer_intent = context_layer_intent.view(*new_context_layer_shape_intent)
135
+ return context_layer, context_layer_intent
136
+
137
+
138
+ class SelfOutput(nn.Module):
139
+ def __init__(self, hidden_size, hidden_dropout_prob):
140
+ super(SelfOutput, self).__init__()
141
+ self.dense = nn.Linear(hidden_size, hidden_size)
142
+ self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
143
+ self.dropout = nn.Dropout(hidden_dropout_prob)
144
+
145
+ def forward(self, hidden_states, input_tensor):
146
+ hidden_states = self.dense(hidden_states)
147
+ hidden_states = self.dropout(hidden_states)
148
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
149
+ return hidden_states
150
+
151
+
152
+ class Intermediate_I_S(nn.Module):
153
+ def __init__(self, intermediate_size, hidden_size, attention_dropout):
154
+ super(Intermediate_I_S, self).__init__()
155
+ self.dense_in = nn.Linear(hidden_size * 6, intermediate_size)
156
+ self.intermediate_act_fn = nn.ReLU()
157
+ self.dense_out = nn.Linear(intermediate_size, hidden_size)
158
+ self.LayerNorm_I = LayerNorm(hidden_size, eps=1e-12)
159
+ self.LayerNorm_S = LayerNorm(hidden_size, eps=1e-12)
160
+ self.dropout = nn.Dropout(attention_dropout)
161
+
162
+ def forward(self, hidden_states_I, hidden_states_S):
163
+ hidden_states_in = torch.cat([hidden_states_I, hidden_states_S], dim=2)
164
+ batch_size, max_length, hidden_size = hidden_states_in.size()
165
+ h_pad = torch.zeros(batch_size, 1, hidden_size).to(hidden_states_I.device)
166
+ h_left = torch.cat([h_pad, hidden_states_in[:, :max_length - 1, :]], dim=1)
167
+ h_right = torch.cat([hidden_states_in[:, 1:, :], h_pad], dim=1)
168
+ hidden_states_in = torch.cat([hidden_states_in, h_left, h_right], dim=2)
169
+
170
+ hidden_states = self.dense_in(hidden_states_in)
171
+ hidden_states = self.intermediate_act_fn(hidden_states)
172
+ hidden_states = self.dense_out(hidden_states)
173
+ hidden_states = self.dropout(hidden_states)
174
+ hidden_states_I_NEW = self.LayerNorm_I(hidden_states + hidden_states_I)
175
+ hidden_states_S_NEW = self.LayerNorm_S(hidden_states + hidden_states_S)
176
+ return hidden_states_I_NEW, hidden_states_S_NEW
model/decoder/interaction/gl_gin_interaction.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5
+
6
+ from common.utils import HiddenData, ClassifierOutputData
7
+ from model.decoder.interaction import BaseInteraction
8
+
9
+
10
+ class LSTMEncoder(nn.Module):
11
+ """
12
+ Encoder structure based on bidirectional LSTM.
13
+ """
14
+
15
+ def __init__(self, embedding_dim, hidden_dim, dropout_rate):
16
+ super(LSTMEncoder, self).__init__()
17
+
18
+ # Parameter recording.
19
+ self.__embedding_dim = embedding_dim
20
+ self.__hidden_dim = hidden_dim // 2
21
+ self.__dropout_rate = dropout_rate
22
+
23
+ # Network attributes.
24
+ self.__dropout_layer = nn.Dropout(self.__dropout_rate)
25
+ self.__lstm_layer = nn.LSTM(
26
+ input_size=self.__embedding_dim,
27
+ hidden_size=self.__hidden_dim,
28
+ batch_first=True,
29
+ bidirectional=True,
30
+ dropout=self.__dropout_rate,
31
+ num_layers=1
32
+ )
33
+
34
+ def forward(self, embedded_text, seq_lens):
35
+ """ Forward process for LSTM Encoder.
36
+
37
+ (batch_size, max_sent_len)
38
+ -> (batch_size, max_sent_len, word_dim)
39
+ -> (batch_size, max_sent_len, hidden_dim)
40
+
41
+ :param embedded_text: padded and embedded input text.
42
+ :param seq_lens: is the length of original input text.
43
+ :return: is encoded word hidden vectors.
44
+ """
45
+
46
+ # Padded_text should be instance of LongTensor.
47
+ dropout_text = self.__dropout_layer(embedded_text)
48
+
49
+ # Pack and Pad process for input of variable length.
50
+ packed_text = pack_padded_sequence(dropout_text, seq_lens.cpu(), batch_first=True, enforce_sorted=False)
51
+ lstm_hiddens, (h_last, c_last) = self.__lstm_layer(packed_text)
52
+ padded_hiddens, _ = pad_packed_sequence(lstm_hiddens, batch_first=True)
53
+
54
+ return padded_hiddens
55
+
56
+
57
+ class GraphAttentionLayer(nn.Module):
58
+ """
59
+ Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
60
+ """
61
+
62
+ def __init__(self, in_features, out_features, dropout, alpha, concat=True):
63
+ super(GraphAttentionLayer, self).__init__()
64
+ self.dropout = dropout
65
+ self.in_features = in_features
66
+ self.out_features = out_features
67
+ self.alpha = alpha
68
+ self.concat = concat
69
+
70
+ self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
71
+ nn.init.xavier_uniform_(self.W.data, gain=1.414)
72
+ self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
73
+ nn.init.xavier_uniform_(self.a.data, gain=1.414)
74
+
75
+ self.leakyrelu = nn.LeakyReLU(self.alpha)
76
+
77
+ def forward(self, input, adj):
78
+ h = torch.matmul(input, self.W)
79
+ B, N = h.size()[0], h.size()[1]
80
+
81
+ a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=2).view(B, N, -1,
82
+ 2 * self.out_features)
83
+ e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))
84
+
85
+ zero_vec = -9e15 * torch.ones_like(e)
86
+ attention = torch.where(adj > 0, e, zero_vec)
87
+ attention = F.softmax(attention, dim=2)
88
+ attention = F.dropout(attention, self.dropout, training=self.training)
89
+ h_prime = torch.matmul(attention, h)
90
+
91
+ if self.concat:
92
+ return F.elu(h_prime)
93
+ else:
94
+ return h_prime
95
+
96
+
97
+ class GAT(nn.Module):
98
+ def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers=2):
99
+ """Dense version of GAT."""
100
+ super(GAT, self).__init__()
101
+ self.dropout = dropout
102
+ self.nlayers = nlayers
103
+ self.nheads = nheads
104
+ self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in
105
+ range(nheads)]
106
+ for i, attention in enumerate(self.attentions):
107
+ self.add_module('attention_{}'.format(i), attention)
108
+ if self.nlayers > 2:
109
+ for i in range(self.nlayers - 2):
110
+ for j in range(self.nheads):
111
+ self.add_module('attention_{}_{}'.format(i + 1, j),
112
+ GraphAttentionLayer(nhid * nheads, nhid, dropout=dropout, alpha=alpha, concat=True))
113
+
114
+ self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
115
+
116
+ def forward(self, x, adj):
117
+ x = F.dropout(x, self.dropout, training=self.training)
118
+ input = x
119
+ x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
120
+ if self.nlayers > 2:
121
+ for i in range(self.nlayers - 2):
122
+ temp = []
123
+ x = F.dropout(x, self.dropout, training=self.training)
124
+ cur_input = x
125
+ for j in range(self.nheads):
126
+ temp.append(self.__getattr__('attention_{}_{}'.format(i + 1, j))(x, adj))
127
+ x = torch.cat(temp, dim=2) + cur_input
128
+ x = F.dropout(x, self.dropout, training=self.training)
129
+ x = F.elu(self.out_att(x, adj))
130
+ return x + input
131
+
132
+
133
+ def normalize_adj(mx):
134
+ """
135
+ Row-normalize matrix D^{-1}A
136
+ torch.diag_embed: https://github.com/pytorch/pytorch/pull/12447
137
+ """
138
+ mx = mx.float()
139
+ rowsum = mx.sum(2)
140
+ r_inv = torch.pow(rowsum, -1)
141
+ r_inv[torch.isinf(r_inv)] = 0.
142
+ r_mat_inv = torch.diag_embed(r_inv, 0)
143
+ mx = r_mat_inv.matmul(mx)
144
+ return mx
145
+
146
+
147
+ class GLGINInteraction(BaseInteraction):
148
+ def __init__(self, **config):
149
+ super().__init__(**config)
150
+ self.intent_embedding = nn.Parameter(
151
+ torch.FloatTensor(self.config["intent_label_num"], self.config["intent_embedding_dim"])) # 191, 32
152
+ nn.init.normal_(self.intent_embedding.data)
153
+ self.adj = None
154
+ self.__slot_lstm = LSTMEncoder(
155
+ self.config["input_dim"] + self.config["intent_label_num"],
156
+ config["output_dim"],
157
+ config["dropout_rate"]
158
+ )
159
+ self.__slot_graph = GAT(
160
+ config["output_dim"],
161
+ config["hidden_dim"],
162
+ config["output_dim"],
163
+ config["dropout_rate"],
164
+ config["alpha"],
165
+ config["num_heads"],
166
+ config["num_layers"])
167
+
168
+ self.__global_graph = GAT(
169
+ config["output_dim"],
170
+ config["hidden_dim"],
171
+ config["output_dim"],
172
+ config["dropout_rate"],
173
+ config["alpha"],
174
+ config["num_heads"],
175
+ config["num_layers"])
176
+
177
+ def generate_global_adj_gat(self, seq_len, index, batch, window):
178
+ global_intent_idx = [[] for i in range(batch)]
179
+ global_slot_idx = [[] for i in range(batch)]
180
+ for item in index:
181
+ global_intent_idx[item[0]].append(item[1])
182
+
183
+ for i, len in enumerate(seq_len):
184
+ global_slot_idx[i].extend(
185
+ list(range(self.config["intent_label_num"], self.config["intent_label_num"] + len)))
186
+
187
+ adj = torch.cat([torch.eye(self.config["intent_label_num"] + max(seq_len)).unsqueeze(0) for i in range(batch)])
188
+ for i in range(batch):
189
+ for j in global_intent_idx[i]:
190
+ adj[i, j, global_slot_idx[i]] = 1.
191
+ adj[i, j, global_intent_idx[i]] = 1.
192
+ for j in global_slot_idx[i]:
193
+ adj[i, j, global_intent_idx[i]] = 1.
194
+
195
+ for i in range(batch):
196
+ for j in range(self.config["intent_label_num"], self.config["intent_label_num"] + seq_len[i]):
197
+ adj[i, j, max(self.config["intent_label_num"], j - window):min(seq_len[i] + self.config["intent_label_num"], j + window + 1)] = 1.
198
+
199
+ if self.config["row_normalized"]:
200
+ adj = normalize_adj(adj)
201
+ adj = adj.to(self.intent_embedding.device)
202
+ return adj
203
+
204
+ def generate_slot_adj_gat(self, seq_len, batch, window):
205
+ slot_idx_ = [[] for i in range(batch)]
206
+ adj = torch.cat([torch.eye(max(seq_len)).unsqueeze(0) for i in range(batch)])
207
+ for i in range(batch):
208
+ for j in range(seq_len[i]):
209
+ adj[i, j, max(0, j - window):min(seq_len[i], j + window + 1)] = 1.
210
+ if self.config["row_normalized"]:
211
+ adj = normalize_adj(adj)
212
+ adj = adj.to(self.intent_embedding.device)
213
+ return adj
214
+
215
+ def forward(self, encode_hidden: HiddenData, pred_intent: ClassifierOutputData = None, intent_index=None):
216
+ seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
217
+ slot_lstm_out = self.__slot_lstm(torch.cat([encode_hidden.slot_hidden, pred_intent.classifier_output], dim=-1),
218
+ seq_lens)
219
+ global_adj = self.generate_global_adj_gat(seq_lens, intent_index, len(seq_lens),
220
+ self.config["slot_graph_window"])
221
+ slot_adj = self.generate_slot_adj_gat(seq_lens, len(seq_lens), self.config["slot_graph_window"])
222
+ batch = len(seq_lens)
223
+ slot_graph_out = self.__slot_graph(slot_lstm_out, slot_adj)
224
+ intent_in = self.intent_embedding.unsqueeze(0).repeat(batch, 1, 1)
225
+ global_graph_in = torch.cat([intent_in, slot_graph_out], dim=1)
226
+ encode_hidden.update_slot_hidden_state(self.__global_graph(global_graph_in, global_adj))
227
+ return encode_hidden
model/decoder/interaction/slot_gated_interaction.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import einops
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import LayerNorm
8
+
9
+ from common.utils import HiddenData
10
+ from model.decoder.interaction import BaseInteraction
11
+
12
+
13
+ class SlotGatedInteraction(BaseInteraction):
14
+ def __init__(self, **config):
15
+ super().__init__(**config)
16
+ self.intent_linear = nn.Linear(self.config["input_dim"],1, bias=False)
17
+ self.slot_linear1 = nn.Linear(self.config["input_dim"],1, bias=False)
18
+ self.slot_linear2 = nn.Linear(self.config["input_dim"],1, bias=False)
19
+ self.remove_slot_attn = self.config["remove_slot_attn"]
20
+ self.slot_gate = SlotGate(**config)
21
+
22
+ def forward(self, encode_hidden: HiddenData, **kwargs):
23
+ input_hidden = encode_hidden.get_slot_hidden_state()
24
+
25
+ seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
26
+ output_list = []
27
+ for index, slen in enumerate(seq_lens):
28
+ output_list.append(input_hidden[index, slen - 1, :].unsqueeze(0))
29
+ intent_input = torch.cat(output_list, dim=0)
30
+ e_I = torch.tanh(self.intent_linear(intent_input)).squeeze(1)
31
+ alpha_I = einops.repeat(e_I, 'b -> b h', h=intent_input.shape[-1])
32
+ c_I = alpha_I * intent_input
33
+ intent_hidden = intent_input+c_I
34
+ if not self.remove_slot_attn:
35
+ # slot attention
36
+ h_k = einops.repeat(self.slot_linear1(input_hidden), 'b l h -> b l c h', c=input_hidden.shape[1])
37
+ h_i = einops.repeat(self.slot_linear2(input_hidden), 'b l h -> b l c h', c=input_hidden.shape[1]).transpose(1,2)
38
+ e_S = torch.tanh(h_k + h_i)
39
+ alpha_S = torch.softmax(e_S, dim=2).squeeze(3)
40
+ alpha_S = einops.repeat(alpha_S, 'b l1 l2 -> b l1 l2 h', h=input_hidden.shape[-1])
41
+ map_input_hidden = einops.repeat(input_hidden, 'b l h -> b l c h', c=input_hidden.shape[1])
42
+ c_S = torch.sum(alpha_S * map_input_hidden, dim=2)
43
+ else:
44
+ c_S = input_hidden
45
+ slot_hidden = input_hidden + c_S * self.slot_gate(c_S,c_I)
46
+ encode_hidden.update_intent_hidden_state(intent_hidden)
47
+ encode_hidden.update_slot_hidden_state(slot_hidden)
48
+ return encode_hidden
49
+
50
+ class SlotGate(nn.Module):
51
+ def __init__(self, **config):
52
+ super().__init__()
53
+ self.linear = nn.Linear(config["input_dim"], config["output_dim"],bias=False)
54
+ self.v = nn.Parameter(torch.rand(size=[1]))
55
+
56
+ def forward(self, slot_context, intent_context):
57
+ intent_gate = self.linear(intent_context)
58
+ intent_gate = einops.repeat(intent_gate, 'b h -> b l h', l=slot_context.shape[1])
59
+ return self.v * torch.tanh(slot_context + intent_gate)
model/decoder/interaction/stack_interaction.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+
5
+ from common import utils
6
+ from common.utils import ClassifierOutputData, HiddenData
7
+ from model.decoder.interaction.base_interaction import BaseInteraction
8
+
9
+
10
+ class StackInteraction(BaseInteraction):
11
+ def __init__(self, **config):
12
+ super().__init__(**config)
13
+ self.intent_embedding = nn.Embedding(
14
+ self.config["intent_label_num"], self.config["intent_label_num"]
15
+ )
16
+ self.differentiable = config.get("differentiable")
17
+ self.intent_embedding.weight.data = torch.eye(
18
+ self.config["intent_label_num"])
19
+ self.intent_embedding.weight.requires_grad = False
20
+
21
+ def forward(self, intent_output: ClassifierOutputData, encode_hidden: HiddenData):
22
+ if not self.differentiable:
23
+ _, idx_intent = intent_output.classifier_output.topk(1, dim=-1)
24
+ feed_intent = self.intent_embedding(idx_intent.squeeze(2))
25
+ else:
26
+ feed_intent = intent_output.classifier_output
27
+ encode_hidden.update_slot_hidden_state(
28
+ torch.cat([encode_hidden.get_slot_hidden_state(), feed_intent], dim=-1))
29
+ return encode_hidden
30
+
31
+ @staticmethod
32
+ def from_configured(configure_name_or_file="stack-interaction", **input_config):
33
+ return utils.from_configured(configure_name_or_file,
34
+ model_class=StackInteraction,
35
+ config_prefix="./config/decoder/interaction",
36
+ **input_config)
model/encoder/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from model.encoder.pretrained_encoder import PretrainedEncoder
2
+ from model.encoder.non_pretrained_encoder import NonPretrainedEncoder
3
+ from model.encoder.base_encoder import BiEncoder
4
+ from model.encoder.auto_encoder import AutoEncoder
5
+ __all__ = ["PretrainedEncoder", "NonPretrainedEncoder", "AutoEncoder","BiEncoder"]
model/encoder/auto_encoder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-26 17:46:10
6
+ Description:
7
+
8
+ '''
9
+ from common.utils import InputData
10
+ from model.encoder.base_encoder import BaseEncoder, BiEncoder
11
+ from model.encoder.pretrained_encoder import PretrainedEncoder
12
+ from model.encoder.non_pretrained_encoder import NonPretrainedEncoder
13
+
14
+ class AutoEncoder(BaseEncoder):
15
+
16
+ def __init__(self, **config):
17
+ """automatedly load encoder by 'encoder_name'
18
+ Args:
19
+ config (dict):
20
+ encoder_name (str): support ["lstm", "self-attention-lstm", "bi-encoder"] and other pretrained model in hugging face
21
+ **args (Any): other configuration items corresponding to each module.
22
+ """
23
+ super().__init__()
24
+ self.config = config
25
+ if config.get("encoder_name"):
26
+ encoder_name = config.get("encoder_name").lower()
27
+ if encoder_name in ["lstm", "self-attention-lstm"]:
28
+ self.__encoder = NonPretrainedEncoder(**config)
29
+ elif encoder_name == "bi-encoder":
30
+ self.__encoder= BiEncoder(self.__init__(**config["intent_encoder"]), self.__init__(**config["intent_encoder"]))
31
+ else:
32
+ self.__encoder = PretrainedEncoder(**config)
33
+ else:
34
+ raise ValueError("There is no Encoder Name in config.")
35
+
36
+ def forward(self, inputs: InputData):
37
+ return self.__encoder(inputs)
model/encoder/base_encoder.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-26 17:25:17
6
+ Description: Base encoder and bi encoder
7
+
8
+ '''
9
+ from torch import nn
10
+
11
+ from common.utils import InputData
12
+
13
+
14
+ class BaseEncoder(nn.Module):
15
+ """Base class for all encoder module
16
+ """
17
+ def __init__(self, **config):
18
+ super().__init__()
19
+ self.config = config
20
+ NotImplementedError("no implement")
21
+
22
+ def forward(self, inputs: InputData):
23
+ self.encoder(inputs.input_ids)
24
+
25
+
26
+ class BiEncoder(nn.Module):
27
+ """Bi Encoder for encode intent and slot separately
28
+ """
29
+ def __init__(self, intent_encoder: BaseEncoder, slot_encoder: BaseEncoder, **config):
30
+ super().__init__()
31
+ self.intent_encoder = intent_encoder
32
+ self.slot_encoder = slot_encoder
33
+
34
+ def forward(self, inputs: InputData):
35
+ hidden_slot = self.slot_encoder(inputs)
36
+ hidden_intent = self.intent_encoder(inputs)
37
+ if not self.intent_encoder.config["return_sentence_level_hidden"]:
38
+ hidden_slot.update_intent_hidden_state(hidden_intent.get_slot_hidden_state())
39
+ else:
40
+ hidden_slot.update_intent_hidden_state(hidden_intent.get_intent_hidden_state())
41
+ return hidden_slot
model/encoder/non_pretrained_encoder.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-30 15:00:29
6
+ Description: non-pretrained encoder model
7
+
8
+ '''
9
+ import math
10
+
11
+ import einops
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
17
+
18
+ from common.utils import HiddenData, InputData
19
+ from model.encoder.base_encoder import BaseEncoder
20
+
21
+ class NonPretrainedEncoder(BaseEncoder):
22
+ """
23
+ Encoder structure based on bidirectional LSTM and self-attention.
24
+ """
25
+
26
+ def __init__(self, **config):
27
+ """ init non-pretrained encoder
28
+
29
+ Args:
30
+ config (dict):
31
+ embedding (dict):
32
+ dropout_rate (float): dropout rate.
33
+ load_embedding_name (str): None if not use pretrained embedding or embedding name like "glove.6B.300d.txt".
34
+ embedding_matrix (Tensor, Optional): embedding matrix tensor. Enabled if load_embedding_name is not None.
35
+ vocab_size (str, Optional): vocabulary size. Enabled if load_embedding_name is None.
36
+ lstm (dict):
37
+ output_dim (int): lstm output dim.
38
+ bidirectional (bool): if use BiLSTM or LSTM.
39
+ layer_num (int): number of layers.
40
+ dropout_rate (float): dropout rate.
41
+ attention (dict, Optional):
42
+ dropout_rate (float): dropout rate.
43
+ hidden_dim (int): attention hidden dim.
44
+ output_dim (int): attention output dim.
45
+ unflat_attention (dict, optional): Enabled if attention is not None.
46
+ dropout_rate (float): dropout rate.
47
+ """
48
+ super(NonPretrainedEncoder, self).__init__()
49
+ self.config = config
50
+ # Embedding Initialization
51
+ embed_config = config["embedding"]
52
+ self.__embedding_dim = embed_config["embedding_dim"]
53
+ if embed_config.get("load_embedding_name"):
54
+ self.__embedding_layer = nn.Embedding.from_pretrained(embed_config["embedding_matrix"], padding_idx=0)
55
+ else:
56
+ self.__embedding_layer = nn.Embedding(
57
+ embed_config["vocab_size"], self.__embedding_dim
58
+ )
59
+ self.__embedding_dropout_layer = nn.Dropout(embed_config["dropout_rate"])
60
+
61
+ # LSTM Initialization
62
+ lstm_config = config["lstm"]
63
+ self.__hidden_size = lstm_config["output_dim"]
64
+ self.__lstm_layer = nn.LSTM(
65
+ input_size=self.__embedding_dim,
66
+ hidden_size=lstm_config["output_dim"] // 2,
67
+ batch_first=True,
68
+ bidirectional=lstm_config["bidirectional"],
69
+ dropout=lstm_config["dropout_rate"],
70
+ num_layers=lstm_config["layer_num"]
71
+ )
72
+ if self.config.get("attention"):
73
+ # Attention Initialization
74
+ att_config = config["attention"]
75
+ self.__attention_dropout_layer = nn.Dropout(att_config["dropout_rate"])
76
+ self.__attention_layer = QKVAttention(
77
+ self.__embedding_dim, self.__embedding_dim, self.__embedding_dim,
78
+ att_config["hidden_dim"], att_config["output_dim"], att_config["dropout_rate"]
79
+ )
80
+ if self.config.get("unflat_attention"):
81
+ unflat_att_config = config["unflat_attention"]
82
+ self.__sentattention = UnflatSelfAttention(
83
+ lstm_config["output_dim"] + att_config["output_dim"],
84
+ unflat_att_config["dropout_rate"]
85
+ )
86
+
87
+ def forward(self, inputs: InputData):
88
+ """ Forward process for Non-Pretrained Encoder.
89
+
90
+ Args:
91
+ inputs: padded input ids, masks.
92
+ Returns:
93
+ encoded hidden vectors.
94
+ """
95
+
96
+ # LSTM Encoder
97
+ # Padded_text should be instance of LongTensor.
98
+ embedded_text = self.__embedding_layer(inputs.input_ids)
99
+ dropout_text = self.__embedding_dropout_layer(embedded_text)
100
+ seq_lens = inputs.attention_mask.sum(-1).detach().cpu()
101
+ # Pack and Pad process for input of variable length.
102
+ packed_text = pack_padded_sequence(dropout_text, seq_lens, batch_first=True, enforce_sorted=False)
103
+ lstm_hiddens, (h_last, c_last) = self.__lstm_layer(packed_text)
104
+ padded_hiddens, _ = pad_packed_sequence(lstm_hiddens, batch_first=True)
105
+
106
+ if self.config.get("attention"):
107
+ # Attention Encoder
108
+ dropout_text = self.__attention_dropout_layer(embedded_text)
109
+ attention_hiddens = self.__attention_layer(
110
+ dropout_text, dropout_text, dropout_text, mask=inputs.attention_mask
111
+ )
112
+
113
+ # Attention + LSTM
114
+ hiddens = torch.cat([attention_hiddens, padded_hiddens], dim=-1)
115
+ hidden = HiddenData(None, hiddens)
116
+ if self.config.get("return_with_input"):
117
+ hidden.add_input(inputs)
118
+ if self.config.get("return_sentence_level_hidden"):
119
+ if self.config.get("unflat_attention"):
120
+ sentence = self.__sentattention(hiddens, seq_lens)
121
+ else:
122
+ sentence = hiddens[:, 0, :]
123
+ hidden.update_intent_hidden_state(sentence)
124
+ else:
125
+ sentence_hidden = None
126
+ if self.config.get("return_sentence_level_hidden"):
127
+ sentence_hidden = torch.cat((h_last[-1], h_last[-1], c_last[-1], c_last[-2]), dim=-1)
128
+ hidden = HiddenData(sentence_hidden, padded_hiddens)
129
+ if self.config.get("return_with_input"):
130
+ hidden.add_input(inputs)
131
+
132
+ return hidden
133
+
134
+
135
+ class QKVAttention(nn.Module):
136
+ """
137
+ Attention mechanism based on Query-Key-Value architecture. And
138
+ especially, when query == key == value, it's self-attention.
139
+ """
140
+
141
+ def __init__(self, query_dim, key_dim, value_dim, hidden_dim, output_dim, dropout_rate):
142
+ super(QKVAttention, self).__init__()
143
+
144
+ # Record hyper-parameters.
145
+ self.__query_dim = query_dim
146
+ self.__key_dim = key_dim
147
+ self.__value_dim = value_dim
148
+ self.__hidden_dim = hidden_dim
149
+ self.__output_dim = output_dim
150
+ self.__dropout_rate = dropout_rate
151
+
152
+ # Declare network structures.
153
+ self.__query_layer = nn.Linear(self.__query_dim, self.__hidden_dim)
154
+ self.__key_layer = nn.Linear(self.__key_dim, self.__hidden_dim)
155
+ self.__value_layer = nn.Linear(self.__value_dim, self.__output_dim)
156
+ self.__dropout_layer = nn.Dropout(p=self.__dropout_rate)
157
+
158
+ def forward(self, input_query, input_key, input_value, mask=None):
159
+ """ The forward propagation of attention.
160
+
161
+ Here we require the first dimension of input key
162
+ and value are equal.
163
+
164
+ Args:
165
+ input_query: is query tensor, (n, d_q)
166
+ input_key: is key tensor, (m, d_k)
167
+ input_value: is value tensor, (m, d_v)
168
+
169
+ Returns:
170
+ attention based tensor, (n, d_h)
171
+ """
172
+
173
+ # Linear transform to fine-tune dimension.
174
+ linear_query = self.__query_layer(input_query)
175
+ linear_key = self.__key_layer(input_key)
176
+ linear_value = self.__value_layer(input_value)
177
+
178
+ score_tensor = torch.matmul(
179
+ linear_query,
180
+ linear_key.transpose(-2, -1)
181
+ ) / math.sqrt(self.__hidden_dim)
182
+ if mask is not None:
183
+ attn_mask = einops.repeat((mask == 0), "b l -> b l h", h=score_tensor.shape[-1])
184
+ score_tensor = score_tensor.masked_fill_(attn_mask, -float(1e20))
185
+ score_tensor = F.softmax(score_tensor, dim=-1)
186
+ forced_tensor = torch.matmul(score_tensor, linear_value)
187
+ forced_tensor = self.__dropout_layer(forced_tensor)
188
+
189
+ return forced_tensor
190
+
191
+
192
+ class UnflatSelfAttention(nn.Module):
193
+ """
194
+ scores each element of the sequence with a linear layer and uses the normalized scores to compute a context over the sequence.
195
+ """
196
+
197
+ def __init__(self, d_hid, dropout=0.):
198
+ super().__init__()
199
+ self.scorer = nn.Linear(d_hid, 1)
200
+ self.dropout = nn.Dropout(dropout)
201
+
202
+ def forward(self, inp, lens):
203
+ batch_size, seq_len, d_feat = inp.size()
204
+ inp = self.dropout(inp)
205
+ scores = self.scorer(inp.contiguous().view(-1, d_feat)).view(batch_size, seq_len)
206
+ max_len = max(lens)
207
+ for i, l in enumerate(lens):
208
+ if l < max_len:
209
+ scores.data[i, l:] = -np.inf
210
+ scores = F.softmax(scores, dim=1)
211
+ context = scores.unsqueeze(2).expand_as(inp).mul(inp).sum(1)
212
+ return context
model/encoder/pretrained_encoder.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-26 17:18:01
6
+ Description: pretrained encoder model
7
+
8
+ '''
9
+ from transformers import AutoModel
10
+
11
+ from common.utils import InputData, HiddenData
12
+ from model.encoder.base_encoder import BaseEncoder
13
+
14
+
15
+ class PretrainedEncoder(BaseEncoder):
16
+ def __init__(self, **config):
17
+ """ init pretrained encoder
18
+
19
+ Args:
20
+ config (dict):
21
+ encoder_name (str): pretrained model name in hugging face.
22
+ """
23
+ super().__init__(**config)
24
+ self.encoder = AutoModel.from_pretrained(config["encoder_name"])
25
+
26
+ def forward(self, inputs: InputData):
27
+ output = self.encoder(**inputs.get_inputs())
28
+ hidden = HiddenData(None, output.last_hidden_state)
29
+ if self.config.get("return_with_input"):
30
+ hidden.add_input(inputs)
31
+ if self.config.get("return_sentence_level_hidden"):
32
+ padding_side = self.config.get("padding_side")
33
+ if hasattr(output, "pooler_output"):
34
+ hidden.update_intent_hidden_state(output.pooler_output)
35
+ elif padding_side is not None and padding_side == "left":
36
+ hidden.update_intent_hidden_state(output.last_hidden_state[:, -1])
37
+ else:
38
+ hidden.update_intent_hidden_state(output.last_hidden_state[:, 0])
39
+ return hidden
model/open_slu_model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Qiguang Chen
3
+ Date: 2023-01-11 10:39:26
4
+ LastEditors: Qiguang Chen
5
+ LastEditTime: 2023-01-26 17:18:22
6
+ Description: Root Model Module
7
+
8
+ '''
9
+ from torch import nn
10
+
11
+ from common.utils import OutputData, InputData
12
+ from model.decoder.base_decoder import BaseDecoder
13
+ from model.encoder.base_encoder import BaseEncoder
14
+
15
+
16
+ class OpenSLUModel(nn.Module):
17
+ def __init__(self, encoder: BaseEncoder, decoder:BaseDecoder, **config):
18
+ """Create model automatedly
19
+
20
+ Args:
21
+ encoder (BaseEncoder): encoder created by config
22
+ decoder (BaseDecoder): decoder created by config
23
+ config (dict): any other args
24
+ """
25
+ super().__init__()
26
+ self.encoder = encoder
27
+ self.decoder = decoder
28
+ self.config = config
29
+
30
+ def forward(self, inp: InputData) -> OutputData:
31
+ """ model forward
32
+
33
+ Args:
34
+ inp (InputData): input ids and other information
35
+
36
+ Returns:
37
+ OutputData: pred logits
38
+ """
39
+ return self.decoder(self.encoder(inp))
40
+
41
+ def decode(self, output: OutputData, target: InputData=None):
42
+ """ decode output
43
+
44
+ Args:
45
+ pred (OutputData): pred logits data
46
+ target (InputData): golden data
47
+
48
+ Returns: decoded ids
49
+ """
50
+ return self.decoder.decode(output, target)
51
+
52
+ def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True):
53
+ """ compute loss
54
+
55
+ Args:
56
+ pred (OutputData): pred logits data
57
+ target (InputData): golden data
58
+ compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True.
59
+ compute_slot_loss (bool, optional): whether to compute slot loss. Defaults to True.
60
+
61
+ Returns: loss value
62
+ """
63
+ return self.decoder.compute_loss(pred, target, compute_intent_loss=compute_intent_loss,
64
+ compute_slot_loss=compute_slot_loss)
save/stack/label.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"intent": ["atis_flight", "atis_airfare", "atis_airline", "atis_ground_service", "atis_quantity", "atis_city", "atis_flight#atis_airfare", "atis_abbreviation", "atis_aircraft", "atis_distance", "atis_ground_fare", "atis_capacity", "atis_flight_time", "atis_meal", "atis_aircraft#atis_flight#atis_flight_no", "atis_flight_no", "atis_restriction", "atis_airport", "atis_airline#atis_flight_no", "atis_cheapest", "atis_ground_service#atis_ground_fare"], "slot": ["O", "B-fromloc.city_name", "B-toloc.city_name", "B-round_trip", "I-round_trip", "B-cost_relative", "B-fare_amount", "I-fare_amount", "B-arrive_date.month_name", "B-arrive_date.day_number", "I-fromloc.city_name", "B-stoploc.city_name", "B-arrive_time.time_relative", "B-arrive_time.time", "I-arrive_time.time", "B-toloc.state_code", "I-toloc.city_name", "I-stoploc.city_name", "B-meal_description", "B-depart_date.month_name", "B-depart_date.day_number", "B-airline_name", "I-airline_name", "B-depart_time.period_of_day", "B-depart_date.day_name", "B-toloc.state_name", "B-depart_time.time_relative", "B-depart_time.time", "B-toloc.airport_name", "I-toloc.airport_name", "B-depart_date.date_relative", "B-or", "B-airline_code", "B-class_type", "I-class_type", "I-cost_relative", "I-depart_time.time", "B-fromloc.airport_name", "I-fromloc.airport_name", "B-city_name", "B-flight_mod", "B-meal", "B-economy", "B-fare_basis_code", "I-depart_date.day_number", "B-depart_date.today_relative", "B-flight_stop", "B-airport_code", "B-fromloc.state_name", "I-fromloc.state_name", "I-city_name", "B-connect", "B-arrive_date.day_name", "B-fromloc.state_code", "B-arrive_date.today_relative", "B-depart_date.year", "B-depart_time.start_time", "I-depart_time.start_time", "B-depart_time.end_time", "I-depart_time.end_time", "B-arrive_time.start_time", "B-arrive_time.end_time", "I-arrive_time.end_time", "I-flight_mod", "B-flight_days", "B-mod", "B-flight_number", "I-toloc.state_name", "B-meal_code", "I-meal_code", "B-airport_name", "I-airport_name", "I-flight_stop", "B-transport_type", "I-transport_type", "B-state_code", "B-aircraft_code", "B-toloc.country_name", "I-arrive_date.day_number", "B-toloc.airport_code", "B-return_date.date_relative", "I-return_date.date_relative", "B-flight_time", "I-economy", "B-fromloc.airport_code", "B-arrive_time.period_of_day", "B-depart_time.period_mod", "I-flight_time", "B-return_date.day_name", "B-arrive_date.date_relative", "B-restriction_code", "I-restriction_code", "B-arrive_time.period_mod", "I-arrive_time.period_of_day", "B-period_of_day", "B-stoploc.state_code", "I-depart_date.today_relative", "I-fare_basis_code", "I-arrive_time.start_time", "B-time", "B-today_relative", "I-today_relative", "B-state_name", "B-days_code", "I-depart_time.period_of_day", "I-arrive_time.time_relative", "B-time_relative", "I-time", "B-return_date.month_name", "B-return_date.day_number", "I-depart_time.time_relative", "B-stoploc.airport_name", "B-day_name", "B-month_name", "B-day_number", "B-return_time.period_mod", "B-return_time.period_of_day", "B-return_date.today_relative", "I-return_date.today_relative", "I-meal_description"]}
save/stack/model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9710de3d7d5c8a34fe55ef4dc36dc8a851863d1fb3bb14871d914a4e945c96ef
3
+ size 5793644
save/stack/outputs.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
save/stack/tokenizer.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"name": "word_tokenizer", "token_map": {"[PAD]": 0, "[UNK]": 1, "i": 2, "want": 3, "to": 4, "fly": 5, "from": 6, "baltimore": 7, "dallas": 8, "round": 9, "trip": 10, "fares": 11, "philadelphia": 12, "less": 13, "than": 14, "1000": 15, "dollars": 16, "denver": 17, "pittsburgh": 18, "show": 19, "me": 20, "the": 21, "flights": 22, "arriving": 23, "on": 24, "june": 25, "fourteenth": 26, "what": 27, "are": 28, "which": 29, "depart": 30, "san": 31, "francisco": 32, "washington": 33, "via": 34, "indianapolis": 35, "and": 36, "arrive": 37, "by": 38, "9": 39, "pm": 40, "airlines": 41, "boston": 42, "dc": 43, "other": 44, "cities": 45, "i'm": 46, "looking": 47, "for": 48, "a": 49, "flight": 50, "charlotte": 51, "las": 52, "vegas": 53, "that": 54, "stops": 55, "in": 56, "st.": 57, "louis": 58, "hopefully": 59, "dinner": 60, "how": 61, "can": 62, "find": 63, "out": 64, "okay": 65, "then": 66, "i'd": 67, "like": 68, "travel": 69, "atlanta": 70, "september": 71, "fourth": 72, "all": 73, "cincinnati": 74, "us": 75, "air": 76, "diego": 77, "afternoon": 78, "what's": 79, "available": 80, "tuesday": 81, "leave": 82, "phoenix": 83, "paul": 84, "minnesota": 85, "after": 86, "noon": 87, "american": 88, "chicago": 89, "los": 90, "angeles": 91, "morning": 92, "types": 93, "of": 94, "ground": 95, "transportation": 96, "there": 97, "airport": 98, "next": 99, "two": 100, "days": 101, "nashville": 102, "jose": 103, "or": 104, "tacoma": 105, "does": 106, "continental": 107, "milwaukee": 108, "many": 109, "twa": 110, "have": 111, "business": 112, "class": 113, "first": 114, "least": 115, "expensive": 116, "one": 117, "way": 118, "fare": 119, "booking": 120, "classes": 121, "wednesday": 122, "nineteenth": 123, "july": 124, "fifth": 125, "7": 126, "please": 127, "list": 128, "departing": 129, "general": 130, "mitchell": 131, "international": 132, "time": 133, "zone": 134, "is": 135, "serves": 136, "meal": 137, "seattle": 138, "salt": 139, "lake": 140, "city": 141, "you": 142, "with": 143, "economy": 144, "leaving": 145, "miami": 146, "cleveland": 147, "give": 148, "between": 149, "their": 150, "cost": 151, "code": 152, "y": 153, "mean": 154, "could": 155, "tell": 156, "leaves": 157, "united": 158, "airline": 159, "over": 160, "departures": 161, "arrivals": 162, "earliest": 163, "latest": 164, "return": 165, "within": 166, "same": 167, "day": 168, "orlando": 169, "either": 170, "evening": 171, "thursday": 172, "originating": 173, "going": 174, "order": 175, "snack": 176, "do": 177, "at": 178, "645": 179, "am": 180, "into": 181, "atlanta's": 182, "friday": 183, "qx": 184, "would": 185, "information": 186, "but": 187, "stopover": 188, "some": 189, "oakland": 190, "monday": 191, "know": 192, "type": 193, "aircraft": 194, "used": 195, "detroit": 196, "twenty": 197, "eighth": 198, "petersburg": 199, "take": 200, "begins": 201, "lands": 202, "fort": 203, "worth": 204, "stop": 205, "tomorrow": 206, "noontime": 207, "around": 208, "need": 209, "northwest": 210, "toronto": 211, "memphis": 212, "thirtieth": 213, "nonstop": 214, "houston": 215, "august": 216, "twentieth": 217, "ewr": 218, "seventh": 219, "newark": 220, "11": 221, "lowest": 222, "delta": 223, "has": 224, "go": 225, "any": 226, "jet": 227, "mco": 228, "new": 229, "jersey": 230, "ontario": 231, "saturday": 232, "york": 233, "long": 234, "it": 235, "get": 236, "prices": 237, "an": 238, "inexpensive": 239, "breakfast": 240, "direct": 241, "what're": 242, "sunday": 243, "colorado": 244, "see": 245, "serving": 246, "before": 247, "o'clock": 248, "january": 249, "1992": 250, "10": 251, "2": 252, "445": 253, "515": 254, "6": 255, "8": 256, "coach": 257, "only": 258, "weekdays": 259, "la": 260, "3": 261, "my": 262, "choices": 263, "early": 264, "minneapolis": 265, "cheapest": 266, "flying": 267, "possible": 268, "daily": 269, "beach": 270, "stopping": 271, "kansas": 272, "night": 273, "serve": 274, "meals": 275, "heading": 276, "kind": 277, "once": 278, "mornings": 279, "ff": 280, "arrangements": 281, "served": 282, "canadian": 283, "california": 284, "about": 285, "530": 286, "kinds": 287, "requesting": 288, "landing": 289, "distance": 290, "downtown": 291, "love": 292, "field": 293, "this": 294, "tampa": 295, "florida": 296, "5": 297, "must": 298, "be": 299, "advertises": 300, "having": 301, "more": 302, "november": 303, "eleventh": 304, "services": 305, "qo": 306, "american's": 307, "last": 308, "using": 309, "dl": 310, "217": 311, "montreal": 312, "service": 313, "ticket": 314, "should": 315, "near": 316, "also": 317, "missouri": 318, "utah": 319, "interested": 320, "those": 321, "when": 322, "north": 323, "carolina": 324, "415": 325, "200": 326, "explain": 327, "codes": 328, "sd": 329, "d": 330, "trying": 331, "plane": 332, "flies": 333, "weekday": 334, "columbus": 335, "1991": 336, "carries": 337, "smallest": 338, "number": 339, "passengers": 340, "takeoffs": 341, "landings": 342, "book": 343, "shortest": 344, "both": 345, "nationair": 346, "823": 347, "guardia": 348, "as": 349, "well": 350, "sixteenth": 351, "rental": 352, "car": 353, "rates": 354, "through": 355, "boeing": 356, "757": 357, "limousine": 358, "listing": 359, "canada": 360, "much": 361, "71": 362, "airfare": 363, "12": 364, "third": 365, "seating": 366, "capacity": 367, "arrives": 368, "bwi": 369, "ninth": 370, "late": 371, "nevada": 372, "4": 373, "price": 374, "fifteenth": 375, "eighteenth": 376, "returning": 377, "following": 378, "capacities": 379, "planes": 380, "1145": 381, "use": 382, "burbank": 383, "may": 384, "america": 385, "west": 386, "now": 387, "eastern": 388, "825": 389, "555": 390, "area": 391, "schedule": 392, "dfw": 393, "these": 394, "connecting": 395, "make": 396, "connection": 397, "lunch": 398, "f": 399, "belong": 400, "most": 401, "tickets": 402, "logan": 403, "vicinity": 404, "210": 405, "wednesdays": 406, "thursdays": 407, "yes": 408, "will": 409, "continuing": 410, "1039": 411, "southwest": 412, "times": 413, "400": 414, "week": 415, "if": 416, "813": 417, "enroute": 418, "another": 419, "twelfth": 420, "turboprop": 421, "420": 422, "today": 423, "1": 424, "we're": 425, "westchester": 426, "county": 427, "various": 428, "airplanes": 429, "uses": 430, "yn": 431, "852": 432, "transport": 433, "display": 434, "under": 435, "500": 436, "airfares": 437, "back": 438, "hours": 439, "fn": 440, "options": 441, "december": 442, "second": 443, "april": 444, "ohio": 445, "departs": 446, "2153": 447, "schedules": 448, "who": 449, "restriction": 450, "ap": 451, "57": 452, "layover": 453, "abbreviation": 454, "stands": 455, "1291": 456, "324": 457, "again": 458, "offer": 459, "dc10": 460, "currently": 461, "represented": 462, "database": 463, "arizona": 464, "1505": 465, "sixth": 466, "3724": 467, "three": 468, "including": 469, "connections": 470, "numbers": 471, "six": 472, "1100": 473, "destination": 474, "838": 475, "no": 476, "h": 477, "traveling": 478, "ap57": 479, "far": 480, "lufthansa": 481, "abbreviations": 482, "such": 483, "aa": 484, "459": 485, "where": 486, "ua": 487, "281": 488, "your": 489, "texas": 490, "1500": 491, "bound": 492, "includes": 493, "right": 494, "airports": 495, "eight": 496, "sixteen": 497, "trips": 498, "seventeenth": 499, "thrift": 500, "delta's": 501, "departure": 502, "listed": 503, "1055": 504, "405": 505, "midnight": 506, "hi": 507, "630": 508, "question": 509, "live": 510, "stand": 511, "ten": 512, "people": 513, "during": 514, "2100": 515, "gets": 516, "just": 517, "philly": 518, "21": 519, "airplane": 520, "1765": 521, "iah": 522, "737": 523, "midwest": 524, "express": 525, "s": 526, "designate": 527, "747": 528, "650": 529, "goes": 530, "reaches": 531, "seventeen": 532, "sorry": 533, "anywhere": 534, "provided": 535, "d10": 536, "toward": 537, "preferably": 538, "rate": 539, "difference": 540, "q": 541, "b": 542, "ac": 543, "tower": 544, "tenth": 545, "hp": 546, "4400": 547, "georgia": 548, "offers": 549, "fine": 550, "201": 551, "343": 552, "october": 553, "ea": 554, "jfk": 555, "name": 556, "arrange": 557, "largest": 558, "connect": 559, "operating": 560, "sundays": 561, "720": 562, "land": 563, "final": 564, "don't": 565, "stopovers": 566, "total": 567, "friday's": 568, "755": 569, "cheap": 570, "sfo": 571, "thirty": 572, "across": 573, "continent": 574, "makes": 575, "1220": 576, "co": 577, "1209": 578, "wanted": 579, "1850": 580, "without": 581, "listings": 582, "local": 583, "wish": 584, "bring": 585, "up": 586, "home": 587, "417": 588, "approximately": 589, "actually": 590, "1200": 591, "230": 592, "819": 593, "serviced": 594, "928": 595, "reservation": 596, "limousines": 597, "taxi": 598, "fit": 599, "72s": 600, "352": 601, "1133": 602, "43": 603, "define": 604, "directly": 605, "m80": 606, "close": 607, "restrictions": 608, "430": 609, "718": 610, "hou": 611, "costs": 612, "466": 613, "march": 614, "1026": 615, "1024": 616, "different": 617, "rentals": 618, "each": 619, "arrival": 620, "say": 621, "mealtime": 622, "932": 623, "1115": 624, "1245": 625, "include": 626, "whether": 627, "offered": 628, "130": 629, "alaska": 630, "296": 631, "they": 632, "106": 633, "york's": 634, "497766": 635, "itinerary": 636, "coming": 637, "month": 638, "bur": 639, "travels": 640, "pennsylvania": 641, "usa": 642, "1288": 643, "c": 644, "names": 645, "sure": 646, "meaning": 647, "ap80": 648, "269": 649, "reservations": 650, "d9s": 651, "sunday's": 652, "f28": 653, "934": 654, "earlier": 655, "1017": 656, "date": 657, "thank": 658, "oak": 659, "atl": 660, "cp": 661, "3357": 662, "1045": 663, "limo": 664, "845": 665, "sometime": 666, "1222": 667, "i'll": 668, "tennessee": 669, "0900": 670, "hello": 671, "let": 672, "repeat": 673, "provide": 674, "still": 675, "along": 676, "operation": 677, "year": 678, "one's": 679, "great": 680, "too": 681, "nighttime": 682, "1300": 683, "saturdays": 684, "416": 685, "four": 686, "257": 687, "minimum": 688, "intercontinental": 689, "february": 690, "spend": 691, "lastest": 692, "thing": 693, "originate": 694, "describe": 695, "concerning": 696, "sa": 697, "help": 698, "1700": 699, "225": 700, "1158": 701, "equipment": 702, "let's": 703, "wednesday's": 704, "quebec": 705, "highest": 706, "starting": 707, "taking": 708, "311": 709, "1230": 710, "able": 711, "put": 712, "later": 713, "takes": 714, "amount": 715, "qw": 716, "seven": 717, "maximum": 718, "yyz": 719, "it's": 720, "80": 721, "place": 722, "equal": 723, "while": 724, "train": 725, "look": 726, "815": 727, "takeoff": 728, "plan": 729, "2134": 730, "297": 731, "323": 732, "229": 733, "329": 734, "runs": 735, "730": 736, "closest": 737, "dulles": 738, "73s": 739, "so": 740, "economic": 741, "single": 742, "supper": 743, "110": 744, "calling": 745, "1205": 746, "55": 747, "michigan": 748, "proper": 749, "regarding": 750, "seats": 751, "19": 752, "m": 753, "midway": 754, "besides": 755, "reverse": 756, "1993": 757, "402": 758, "level": 759, "reaching": 760, "771": 761, "straight": 762, "located": 763, "305": 764, "repeating": 765, "indiana": 766, "connects": 767, "beginning": 768, "staying": 769, "town": 770, "cars": 771, "nonstops": 772, "300": 773, "345": 774, "dinnertime": 775, "sort": 776, "route": 777, "j31": 778, "tuesdays": 779, "212": 780, "705": 781, "red": 782, "eye": 783, "laying": 784, "friends": 785, "visit": 786, "here": 787, "them": 788, "lives": 789, "rent": 790, "279": 791, "137338": 792, "transcontinental": 793, "trans": 794, "world": 795, "1030": 796, "1130": 797, "come": 798, "727": 799, "1020": 800, "505": 801, "that's": 802, "163": 803, "ls": 804, "greatest": 805, "i've": 806, "got": 807, "somebody": 808, "else": 809, "wants": 810, "charges": 811, "734": 812, "carried": 813, "thirteenth": 814, "making": 815, "733": 816, "everywhere": 817, "prefer": 818, "run": 819, "non": 820, "315": 821, "746": 822, "companies": 823, "buy": 824, "very": 825, "270": 826, "locate": 827, "hartfield": 828, "start": 829, "98": 830, "inform": 831, "oh": 832, "82": 833, "139": 834, "1600": 835, "eleven": 836, "ord": 837, "mia": 838, "qualify": 839, "doesn't": 840, "mondays": 841, "catch": 842, "priced": 843, "bna": 844, "being": 845, "working": 846, "scenario": 847, "767": 848, "1940": 849, "150": 850, "100": 851, "afternoons": 852, "provides": 853, "723": 854, "1110": 855, "symbols": 856, "grounds": 857, "nw": 858, "539": 859, "soon": 860, "thereafter": 861, "scheduled": 862, "instead": 863, "810": 864, "lester": 865, "pearson": 866, "stapleton": 867, "615": 868, "twelve": 869, "bay": 870, "sounds": 871, "o'hare": 872, "ap68": 873, "fridays": 874, "try": 875, "fifteen": 876, "nights": 877, "determine": 878, "hold": 879, "lax": 880, "seat": 881, "k": 882, "planning": 883, "discount": 884, "summer": 885, "cover": 886, "271": 887, "tonight": 888, "off": 889, "124": 890, "thanks": 891, "longest": 892, "kindly": 893, "afterwards": 894, "overnight": 895, "1083": 896, "428": 897, "anything": 898, "1059": 899}}