Upload prompt_tune_phi3.ipynb with huggingface_hub
Browse files- prompt_tune_phi3.ipynb +338 -73
prompt_tune_phi3.ipynb
CHANGED
@@ -90,7 +90,7 @@
|
|
90 |
},
|
91 |
{
|
92 |
"cell_type": "code",
|
93 |
-
"execution_count":
|
94 |
"id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
|
95 |
"metadata": {},
|
96 |
"outputs": [],
|
@@ -117,23 +117,23 @@
|
|
117 |
"label_col = 'text_label'\n",
|
118 |
"max_len = 64\n",
|
119 |
"lr = 3e-2\n",
|
120 |
-
"epochs =
|
121 |
"batch_size = 8"
|
122 |
]
|
123 |
},
|
124 |
{
|
125 |
"cell_type": "code",
|
126 |
-
"execution_count":
|
127 |
"id": "6f677839-ef23-428a-bcfe-f596590804ca",
|
128 |
"metadata": {},
|
129 |
"outputs": [],
|
130 |
"source": [
|
131 |
-
"dataset = load_dataset('ought/raft', dataset_name
|
132 |
]
|
133 |
},
|
134 |
{
|
135 |
"cell_type": "code",
|
136 |
-
"execution_count":
|
137 |
"id": "c0c05613-7941-4959-ada9-49ed1093bec4",
|
138 |
"metadata": {},
|
139 |
"outputs": [
|
@@ -143,22 +143,36 @@
|
|
143 |
"['Unlabeled', 'complaint', 'no complaint']"
|
144 |
]
|
145 |
},
|
146 |
-
"execution_count":
|
147 |
"metadata": {},
|
148 |
"output_type": "execute_result"
|
149 |
}
|
150 |
],
|
151 |
"source": [
|
152 |
-
"dataset.features['Label'].names\n",
|
153 |
"#>>> ['Unlabeled', 'complaint', 'no complaint']"
|
154 |
]
|
155 |
},
|
156 |
{
|
157 |
"cell_type": "code",
|
158 |
-
"execution_count":
|
159 |
"id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
|
160 |
"metadata": {},
|
161 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
{
|
163 |
"data": {
|
164 |
"text/plain": [
|
@@ -168,26 +182,26 @@
|
|
168 |
" 'text_label': 'no complaint'}"
|
169 |
]
|
170 |
},
|
171 |
-
"execution_count":
|
172 |
"metadata": {},
|
173 |
"output_type": "execute_result"
|
174 |
}
|
175 |
],
|
176 |
"source": [
|
177 |
"# Create lambda function\n",
|
178 |
-
"classes = [k.replace('_', ' ') for k in dataset.features['Label'].names]\n",
|
179 |
"dataset = dataset.map(\n",
|
180 |
" lambda x: {'text_label': [classes[label] for label in x['Label']]},\n",
|
181 |
" batched=True,\n",
|
182 |
" num_proc=10,\n",
|
183 |
")\n",
|
184 |
"\n",
|
185 |
-
"dataset[0]"
|
186 |
]
|
187 |
},
|
188 |
{
|
189 |
"cell_type": "code",
|
190 |
-
"execution_count":
|
191 |
"id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
|
192 |
"metadata": {},
|
193 |
"outputs": [
|
@@ -204,7 +218,7 @@
|
|
204 |
"[1, 853, 29880, 24025]"
|
205 |
]
|
206 |
},
|
207 |
-
"execution_count":
|
208 |
"metadata": {},
|
209 |
"output_type": "execute_result"
|
210 |
}
|
@@ -236,7 +250,7 @@
|
|
236 |
},
|
237 |
{
|
238 |
"cell_type": "code",
|
239 |
-
"execution_count":
|
240 |
"id": "03f05467-dce3-4e42-ab3b-c39ba620e164",
|
241 |
"metadata": {},
|
242 |
"outputs": [],
|
@@ -261,19 +275,30 @@
|
|
261 |
" #>>> -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000\n",
|
262 |
" # Pad the beginning of the sequence with n -100s (ignore tokens)\n",
|
263 |
" model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
|
264 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
]
|
266 |
},
|
267 |
{
|
268 |
"cell_type": "code",
|
269 |
-
"execution_count":
|
270 |
"id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec",
|
271 |
"metadata": {},
|
272 |
"outputs": [
|
273 |
{
|
274 |
"data": {
|
275 |
"application/vnd.jupyter.widget-view+json": {
|
276 |
-
"model_id": "
|
277 |
"version_major": 2,
|
278 |
"version_minor": 0
|
279 |
},
|
@@ -285,60 +310,18 @@
|
|
285 |
"output_type": "display_data"
|
286 |
},
|
287 |
{
|
288 |
-
"
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
"
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
301 |
-
"\n",
|
302 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
|
303 |
-
"\n",
|
304 |
-
"\n",
|
305 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
|
306 |
-
"\n",
|
307 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
308 |
-
"\n",
|
309 |
-
"\n",
|
310 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
311 |
-
"\n",
|
312 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
313 |
-
"\n",
|
314 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
315 |
-
"\n",
|
316 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
317 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
318 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
319 |
-
"\n",
|
320 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
321 |
-
"\n",
|
322 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
|
323 |
-
"\n",
|
324 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}\n",
|
325 |
-
"\n",
|
326 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
|
327 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}\n",
|
328 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
|
329 |
-
"\n",
|
330 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
|
331 |
-
"\n",
|
332 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]]}\n",
|
333 |
-
"\n",
|
334 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}\n",
|
335 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]}{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 694, 15313, 524], [1, 15313, 524], [1, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n",
|
336 |
-
"\n",
|
337 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n",
|
338 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n",
|
339 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [1, 15313, 524]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n",
|
340 |
-
"{'input_ids': [[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000], [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1, 15313, 524, 32000]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]}\n"
|
341 |
-
]
|
342 |
}
|
343 |
],
|
344 |
"source": [
|
@@ -346,7 +329,7 @@
|
|
346 |
" preproc,\n",
|
347 |
" batched=True, # uses default batch size\n",
|
348 |
" num_proc=10,\n",
|
349 |
-
" remove_columns=dataset.column_names, # All columns from the original dataset will be removed in the new dataset\n",
|
350 |
" load_from_cache_file=False,\n",
|
351 |
" desc=\"Preprocessing dataset\"\n",
|
352 |
")"
|
@@ -354,10 +337,292 @@
|
|
354 |
},
|
355 |
{
|
356 |
"cell_type": "code",
|
357 |
-
"execution_count":
|
358 |
"id": "40cea6bc-e898-4d86-a6bf-5afc3a647e07",
|
359 |
"metadata": {},
|
360 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
"source": []
|
362 |
}
|
363 |
],
|
|
|
90 |
},
|
91 |
{
|
92 |
"cell_type": "code",
|
93 |
+
"execution_count": 54,
|
94 |
"id": "6cad1e5c-038f-4e75-8c3f-8ce0a43713a4",
|
95 |
"metadata": {},
|
96 |
"outputs": [],
|
|
|
117 |
"label_col = 'text_label'\n",
|
118 |
"max_len = 64\n",
|
119 |
"lr = 3e-2\n",
|
120 |
+
"epochs = 5\n",
|
121 |
"batch_size = 8"
|
122 |
]
|
123 |
},
|
124 |
{
|
125 |
"cell_type": "code",
|
126 |
+
"execution_count": 28,
|
127 |
"id": "6f677839-ef23-428a-bcfe-f596590804ca",
|
128 |
"metadata": {},
|
129 |
"outputs": [],
|
130 |
"source": [
|
131 |
+
"dataset = load_dataset('ought/raft', dataset_name)"
|
132 |
]
|
133 |
},
|
134 |
{
|
135 |
"cell_type": "code",
|
136 |
+
"execution_count": 30,
|
137 |
"id": "c0c05613-7941-4959-ada9-49ed1093bec4",
|
138 |
"metadata": {},
|
139 |
"outputs": [
|
|
|
143 |
"['Unlabeled', 'complaint', 'no complaint']"
|
144 |
]
|
145 |
},
|
146 |
+
"execution_count": 30,
|
147 |
"metadata": {},
|
148 |
"output_type": "execute_result"
|
149 |
}
|
150 |
],
|
151 |
"source": [
|
152 |
+
"dataset['train'].features['Label'].names\n",
|
153 |
"#>>> ['Unlabeled', 'complaint', 'no complaint']"
|
154 |
]
|
155 |
},
|
156 |
{
|
157 |
"cell_type": "code",
|
158 |
+
"execution_count": 32,
|
159 |
"id": "14e2bc8b-b4e3-49c9-ae2b-5946e412caa5",
|
160 |
"metadata": {},
|
161 |
"outputs": [
|
162 |
+
{
|
163 |
+
"data": {
|
164 |
+
"application/vnd.jupyter.widget-view+json": {
|
165 |
+
"model_id": "11da1eb81527428a95c41816f5bf459f",
|
166 |
+
"version_major": 2,
|
167 |
+
"version_minor": 0
|
168 |
+
},
|
169 |
+
"text/plain": [
|
170 |
+
"Map (num_proc=10): 0%| | 0/3399 [00:00<?, ? examples/s]"
|
171 |
+
]
|
172 |
+
},
|
173 |
+
"metadata": {},
|
174 |
+
"output_type": "display_data"
|
175 |
+
},
|
176 |
{
|
177 |
"data": {
|
178 |
"text/plain": [
|
|
|
182 |
" 'text_label': 'no complaint'}"
|
183 |
]
|
184 |
},
|
185 |
+
"execution_count": 32,
|
186 |
"metadata": {},
|
187 |
"output_type": "execute_result"
|
188 |
}
|
189 |
],
|
190 |
"source": [
|
191 |
"# Create lambda function\n",
|
192 |
+
"classes = [k.replace('_', ' ') for k in dataset['train'].features['Label'].names]\n",
|
193 |
"dataset = dataset.map(\n",
|
194 |
" lambda x: {'text_label': [classes[label] for label in x['Label']]},\n",
|
195 |
" batched=True,\n",
|
196 |
" num_proc=10,\n",
|
197 |
")\n",
|
198 |
"\n",
|
199 |
+
"dataset['train'][0]"
|
200 |
]
|
201 |
},
|
202 |
{
|
203 |
"cell_type": "code",
|
204 |
+
"execution_count": 41,
|
205 |
"id": "19f0865d-e490-4c9f-a5f4-e781ed270f47",
|
206 |
"metadata": {},
|
207 |
"outputs": [
|
|
|
218 |
"[1, 853, 29880, 24025]"
|
219 |
]
|
220 |
},
|
221 |
+
"execution_count": 41,
|
222 |
"metadata": {},
|
223 |
"output_type": "execute_result"
|
224 |
}
|
|
|
250 |
},
|
251 |
{
|
252 |
"cell_type": "code",
|
253 |
+
"execution_count": 26,
|
254 |
"id": "03f05467-dce3-4e42-ab3b-c39ba620e164",
|
255 |
"metadata": {},
|
256 |
"outputs": [],
|
|
|
275 |
" #>>> -100, -100, -100, -100, -100, -100, -100, -100, 1, 694, 15313, 524, 32000\n",
|
276 |
" # Pad the beginning of the sequence with n -100s (ignore tokens)\n",
|
277 |
" model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
|
278 |
+
"\n",
|
279 |
+
" for i in range(batch_size):\n",
|
280 |
+
" sample_input_ids = model_inputs[\"input_ids\"][i]\n",
|
281 |
+
" label_input_ids = labels[\"input_ids\"][i]\n",
|
282 |
+
" model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (target_max_len - len(sample_input_ids)) + sample_input_ids\n",
|
283 |
+
" model_inputs[\"attention_mask\"][i] = [0] * (target_max_len - len(sample_input_ids)) + model_inputs[\"attention_mask\"][i]\n",
|
284 |
+
" labels[\"input_ids\"][i] = [-100] * (target_max_len - len(sample_input_ids)) + label_input_ids\n",
|
285 |
+
" model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:target_max_len])\n",
|
286 |
+
" model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:target_max_len])\n",
|
287 |
+
" labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:target_max_len])\n",
|
288 |
+
" model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
|
289 |
+
" return model_inputs"
|
290 |
]
|
291 |
},
|
292 |
{
|
293 |
"cell_type": "code",
|
294 |
+
"execution_count": 33,
|
295 |
"id": "72ddca5f-7bce-4342-9414-9dd9d41d9dec",
|
296 |
"metadata": {},
|
297 |
"outputs": [
|
298 |
{
|
299 |
"data": {
|
300 |
"application/vnd.jupyter.widget-view+json": {
|
301 |
+
"model_id": "05958c1cf67d413b9085622ace0cb799",
|
302 |
"version_major": 2,
|
303 |
"version_minor": 0
|
304 |
},
|
|
|
310 |
"output_type": "display_data"
|
311 |
},
|
312 |
{
|
313 |
+
"data": {
|
314 |
+
"application/vnd.jupyter.widget-view+json": {
|
315 |
+
"model_id": "05e7c3181c20464492f2ec4ced190fd4",
|
316 |
+
"version_major": 2,
|
317 |
+
"version_minor": 0
|
318 |
+
},
|
319 |
+
"text/plain": [
|
320 |
+
"Preprocessing dataset (num_proc=10): 0%| | 0/3399 [00:00<?, ? examples/s]"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
"metadata": {},
|
324 |
+
"output_type": "display_data"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
}
|
326 |
],
|
327 |
"source": [
|
|
|
329 |
" preproc,\n",
|
330 |
" batched=True, # uses default batch size\n",
|
331 |
" num_proc=10,\n",
|
332 |
+
" remove_columns=dataset[\"train\"].column_names, # All columns from the original dataset will be removed in the new dataset\n",
|
333 |
" load_from_cache_file=False,\n",
|
334 |
" desc=\"Preprocessing dataset\"\n",
|
335 |
")"
|
|
|
337 |
},
|
338 |
{
|
339 |
"cell_type": "code",
|
340 |
+
"execution_count": 43,
|
341 |
"id": "40cea6bc-e898-4d86-a6bf-5afc3a647e07",
|
342 |
"metadata": {},
|
343 |
"outputs": [],
|
344 |
+
"source": [
|
345 |
+
"train_dataset = processed_datasets[\"train\"]\n",
|
346 |
+
"eval_dataset = processed_datasets[\"test\"]\n",
|
347 |
+
"\n",
|
348 |
+
"train_dataloader = DataLoader(train_dataset,\n",
|
349 |
+
" shuffle=True, # shuffling is unneccasary since we are not training\n",
|
350 |
+
" collate_fn=default_data_collator,\n",
|
351 |
+
" batch_size=batch_size,\n",
|
352 |
+
" pin_memory=True # pin memory when using a GPU, makes loading data faster\n",
|
353 |
+
" )\n",
|
354 |
+
"\n",
|
355 |
+
"eval_dataloader = DataLoader(eval_dataset,\n",
|
356 |
+
" shuffle=False,\n",
|
357 |
+
" collate_fn=default_data_collator,\n",
|
358 |
+
" batch_size=batch_size,\n",
|
359 |
+
" pin_memory=True\n",
|
360 |
+
" )"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"cell_type": "code",
|
365 |
+
"execution_count": 51,
|
366 |
+
"id": "a4c529e4-d8ae-42b2-a658-f76d183bb264",
|
367 |
+
"metadata": {},
|
368 |
+
"outputs": [
|
369 |
+
{
|
370 |
+
"data": {
|
371 |
+
"application/vnd.jupyter.widget-view+json": {
|
372 |
+
"model_id": "58f2ef57b8ea49c2a26d4361ce4a5983",
|
373 |
+
"version_major": 2,
|
374 |
+
"version_minor": 0
|
375 |
+
},
|
376 |
+
"text/plain": [
|
377 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
378 |
+
]
|
379 |
+
},
|
380 |
+
"metadata": {},
|
381 |
+
"output_type": "display_data"
|
382 |
+
},
|
383 |
+
{
|
384 |
+
"name": "stderr",
|
385 |
+
"output_type": "stream",
|
386 |
+
"text": [
|
387 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
388 |
+
]
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"name": "stdout",
|
392 |
+
"output_type": "stream",
|
393 |
+
"text": [
|
394 |
+
"trainable params: 24,576 || all params: 3,821,104,128 || trainable%: 0.0006\n",
|
395 |
+
"None\n"
|
396 |
+
]
|
397 |
+
}
|
398 |
+
],
|
399 |
+
"source": [
|
400 |
+
"model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=\"flash_attention_2\", torch_dtype=torch.bfloat16)\n",
|
401 |
+
"model = get_peft_model(model, peft_conf)\n",
|
402 |
+
"\n",
|
403 |
+
"# the rest of the model is frozen\n",
|
404 |
+
"print(model.print_trainable_parameters())"
|
405 |
+
]
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"cell_type": "code",
|
409 |
+
"execution_count": 52,
|
410 |
+
"id": "3289e4e3-9b9a-4256-921b-5df21d18344e",
|
411 |
+
"metadata": {},
|
412 |
+
"outputs": [],
|
413 |
+
"source": [
|
414 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
|
415 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
416 |
+
" optimizer=optimizer,\n",
|
417 |
+
" num_warmup_steps=0,\n",
|
418 |
+
" num_training_steps=(len(train_dataloader) * epochs),\n",
|
419 |
+
")"
|
420 |
+
]
|
421 |
+
},
|
422 |
+
{
|
423 |
+
"cell_type": "code",
|
424 |
+
"execution_count": 55,
|
425 |
+
"id": "e7939d75-c6b9-47a8-b1a3-88f7c33ff121",
|
426 |
+
"metadata": {},
|
427 |
+
"outputs": [
|
428 |
+
{
|
429 |
+
"name": "stderr",
|
430 |
+
"output_type": "stream",
|
431 |
+
"text": [
|
432 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 10.97it/s]\n",
|
433 |
+
"100%|ββββββββββ| 425/425 [00:13<00:00, 31.61it/s]\n"
|
434 |
+
]
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"name": "stdout",
|
438 |
+
"output_type": "stream",
|
439 |
+
"text": [
|
440 |
+
"epoch=0: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
|
441 |
+
]
|
442 |
+
},
|
443 |
+
{
|
444 |
+
"name": "stderr",
|
445 |
+
"output_type": "stream",
|
446 |
+
"text": [
|
447 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 12.02it/s]\n",
|
448 |
+
"100%|ββββββββββ| 425/425 [00:13<00:00, 31.35it/s]\n"
|
449 |
+
]
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"name": "stdout",
|
453 |
+
"output_type": "stream",
|
454 |
+
"text": [
|
455 |
+
"epoch=1: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"name": "stderr",
|
460 |
+
"output_type": "stream",
|
461 |
+
"text": [
|
462 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 12.70it/s]\n",
|
463 |
+
"100%|ββββββββββ| 425/425 [00:13<00:00, 31.66it/s]\n"
|
464 |
+
]
|
465 |
+
},
|
466 |
+
{
|
467 |
+
"name": "stdout",
|
468 |
+
"output_type": "stream",
|
469 |
+
"text": [
|
470 |
+
"epoch=2: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
|
471 |
+
]
|
472 |
+
},
|
473 |
+
{
|
474 |
+
"name": "stderr",
|
475 |
+
"output_type": "stream",
|
476 |
+
"text": [
|
477 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 11.85it/s]\n",
|
478 |
+
"100%|ββββββββββ| 425/425 [00:13<00:00, 32.45it/s]\n"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
{
|
482 |
+
"name": "stdout",
|
483 |
+
"output_type": "stream",
|
484 |
+
"text": [
|
485 |
+
"epoch=3: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
|
486 |
+
]
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"name": "stderr",
|
490 |
+
"output_type": "stream",
|
491 |
+
"text": [
|
492 |
+
"100%|ββββββββββ| 7/7 [00:00<00:00, 12.53it/s]\n",
|
493 |
+
"100%|ββββββββββ| 425/425 [00:13<00:00, 32.38it/s]"
|
494 |
+
]
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"name": "stdout",
|
498 |
+
"output_type": "stream",
|
499 |
+
"text": [
|
500 |
+
"epoch=4: train_ppl=tensor(nan, device='cuda:0') train_epoch_loss=tensor(nan, device='cuda:0') eval_ppl=tensor(nan, device='cuda:0') eval_epoch_loss=tensor(nan, device='cuda:0')\n"
|
501 |
+
]
|
502 |
+
},
|
503 |
+
{
|
504 |
+
"name": "stderr",
|
505 |
+
"output_type": "stream",
|
506 |
+
"text": [
|
507 |
+
"\n"
|
508 |
+
]
|
509 |
+
}
|
510 |
+
],
|
511 |
+
"source": [
|
512 |
+
"model = model.to(device)\n",
|
513 |
+
"\n",
|
514 |
+
"for epoch in range(epochs):\n",
|
515 |
+
" model.train()\n",
|
516 |
+
" total_loss = 0\n",
|
517 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
518 |
+
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
519 |
+
" outputs = model(**batch)\n",
|
520 |
+
" loss = outputs.loss\n",
|
521 |
+
" total_loss += loss.detach().float()\n",
|
522 |
+
" loss.backward()\n",
|
523 |
+
" optimizer.step()\n",
|
524 |
+
" lr_scheduler.step()\n",
|
525 |
+
" optimizer.zero_grad()\n",
|
526 |
+
"\n",
|
527 |
+
" model.eval()\n",
|
528 |
+
" eval_loss = 0\n",
|
529 |
+
" eval_preds = []\n",
|
530 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
531 |
+
" batch = {k: v.to(device) for k, v in batch.items()}\n",
|
532 |
+
" with torch.no_grad():\n",
|
533 |
+
" outputs = model(**batch)\n",
|
534 |
+
" loss = outputs.loss\n",
|
535 |
+
" eval_loss += loss.detach().float()\n",
|
536 |
+
" eval_preds.extend(\n",
|
537 |
+
" tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
|
538 |
+
" )\n",
|
539 |
+
"\n",
|
540 |
+
" eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
|
541 |
+
" eval_ppl = torch.exp(eval_epoch_loss)\n",
|
542 |
+
" train_epoch_loss = total_loss / len(train_dataloader)\n",
|
543 |
+
" train_ppl = torch.exp(train_epoch_loss)\n",
|
544 |
+
" print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
|
545 |
+
]
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"cell_type": "code",
|
549 |
+
"execution_count": 59,
|
550 |
+
"id": "806d36f8-499e-4af8-b717-68e5d849866d",
|
551 |
+
"metadata": {},
|
552 |
+
"outputs": [],
|
553 |
+
"source": [
|
554 |
+
"model.save_pretrained('model')"
|
555 |
+
]
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"cell_type": "code",
|
559 |
+
"execution_count": 1,
|
560 |
+
"id": "13db780a-fe20-4b23-b6cb-17118f7b695e",
|
561 |
+
"metadata": {},
|
562 |
+
"outputs": [
|
563 |
+
{
|
564 |
+
"data": {
|
565 |
+
"application/vnd.jupyter.widget-view+json": {
|
566 |
+
"model_id": "d8f94426025f4ad89847ac7e983cec42",
|
567 |
+
"version_major": 2,
|
568 |
+
"version_minor": 0
|
569 |
+
},
|
570 |
+
"text/plain": [
|
571 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
572 |
+
]
|
573 |
+
},
|
574 |
+
"metadata": {},
|
575 |
+
"output_type": "display_data"
|
576 |
+
},
|
577 |
+
{
|
578 |
+
"name": "stderr",
|
579 |
+
"output_type": "stream",
|
580 |
+
"text": [
|
581 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
582 |
+
]
|
583 |
+
}
|
584 |
+
],
|
585 |
+
"source": [
|
586 |
+
"from transformers import pipeline\n",
|
587 |
+
"device = 'cuda'\n",
|
588 |
+
"pipe = pipeline(model='model', device=device, max_length=100)"
|
589 |
+
]
|
590 |
+
},
|
591 |
+
{
|
592 |
+
"cell_type": "code",
|
593 |
+
"execution_count": 2,
|
594 |
+
"id": "26438301-3601-44f4-bbe4-3c573a1c28be",
|
595 |
+
"metadata": {},
|
596 |
+
"outputs": [
|
597 |
+
{
|
598 |
+
"name": "stderr",
|
599 |
+
"output_type": "stream",
|
600 |
+
"text": [
|
601 |
+
"Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
|
602 |
+
"You are not running the flash-attention implementation, expect numerical differences.\n"
|
603 |
+
]
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"data": {
|
607 |
+
"text/plain": [
|
608 |
+
"[{'generated_text': \"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\\n\\n### response\\nI understand your situation and I'm here to help. First, it's important to clarify that as an AI developed by Microsoft, I don't have the authority to directly intervene with your utility bills or the National Grid. However, I can guide you through the steps you should take to address this issue.\\n\\n1\"}]"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
"execution_count": 2,
|
612 |
+
"metadata": {},
|
613 |
+
"output_type": "execute_result"
|
614 |
+
}
|
615 |
+
],
|
616 |
+
"source": [
|
617 |
+
"pipe(\"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\")"
|
618 |
+
]
|
619 |
+
},
|
620 |
+
{
|
621 |
+
"cell_type": "code",
|
622 |
+
"execution_count": null,
|
623 |
+
"id": "f83e960d-ab80-406e-9ba9-e9533fe9d033",
|
624 |
+
"metadata": {},
|
625 |
+
"outputs": [],
|
626 |
"source": []
|
627 |
}
|
628 |
],
|