surdan commited on
Commit
055be6c
1 Parent(s): 13f3280

Upload Train_model.ipynb

Browse files
Files changed (1) hide show
  1. Train_model.ipynb +327 -0
Train_model.ipynb ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3ca08817",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# !pip install seqeval"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "c5958200",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "# import torch\n",
21
+ "# torch.cuda.is_available(), torch.cuda.device_count()"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "590c3f48",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "import warnings\n",
32
+ "warnings.filterwarnings('ignore')\n",
33
+ "\n",
34
+ "import pickle\n",
35
+ "import numpy as np\n",
36
+ "import transformers\n",
37
+ "from transformers import Trainer\n",
38
+ "from datasets import load_metric\n",
39
+ "from datasets import load_dataset\n",
40
+ "from transformers import AutoTokenizer\n",
41
+ "from transformers import TrainingArguments\n",
42
+ "from transformers import AutoModelForTokenClassification\n",
43
+ "from transformers import DataCollatorForTokenClassification"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "id": "44d7c35c",
49
+ "metadata": {},
50
+ "source": [
51
+ "## Helpful funcs "
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "id": "5c9e36d9",
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "def align_labels_with_tokens(labels, word_ids):\n",
62
+ " return [-100 if i is None else labels[i] for i in word_ids]\n",
63
+ "\n",
64
+ "def tokenize_and_align_labels(examples):\n",
65
+ " tokenized_inputs = tokenizer(\n",
66
+ " examples[\"sequences\"], truncation=True, is_split_into_words=True\n",
67
+ " )\n",
68
+ " all_labels = examples[\"ids\"]\n",
69
+ " new_labels = []\n",
70
+ " for i, labels in enumerate(all_labels):\n",
71
+ " word_ids = tokenized_inputs.word_ids(i)\n",
72
+ " new_labels.append(align_labels_with_tokens(labels, word_ids))\n",
73
+ "\n",
74
+ " tokenized_inputs[\"labels\"] = new_labels\n",
75
+ " return tokenized_inputs\n",
76
+ "\n",
77
+ "def compute_metrics(eval_preds):\n",
78
+ " logits, labels = eval_preds\n",
79
+ " predictions = np.argmax(logits, axis=-1)\n",
80
+ "\n",
81
+ " # Remove ignored index (special tokens) and convert to labels\n",
82
+ " true_labels = [[label_names[l] for l in label if l != -100] for label in labels]\n",
83
+ " true_predictions = [\n",
84
+ " [label_names[p] for (p, l) in zip(prediction, label) if l != -100]\n",
85
+ " for prediction, label in zip(predictions, labels)\n",
86
+ " ]\n",
87
+ " all_metrics = metric.compute(predictions=true_predictions, references=true_labels)\n",
88
+ " return {\n",
89
+ " \"precision\": all_metrics[\"overall_precision\"],\n",
90
+ " \"recall\": all_metrics[\"overall_recall\"],\n",
91
+ " \"f1\": all_metrics[\"overall_f1\"],\n",
92
+ " \"accuracy\": all_metrics[\"overall_accuracy\"],\n",
93
+ " }"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "markdown",
98
+ "id": "8760e709",
99
+ "metadata": {},
100
+ "source": [
101
+ "## Load Data"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "e8c723f7",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "raw_datasets = load_dataset(\"surdan/nerel_short\")"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "e540a898",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "raw_datasets"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "id": "5a4947d1",
127
+ "metadata": {},
128
+ "source": [
129
+ "## Preprocess data"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "id": "8829557e",
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "model_checkpoint = \"cointegrated/LaBSE-en-ru\""
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "id": "b6c13ad1",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "id": "ea2c1a9e",
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "tokenized_datasets = raw_datasets.map(\n",
160
+ " tokenize_and_align_labels,\n",
161
+ " batched=True,\n",
162
+ " remove_columns=raw_datasets[\"train\"].column_names,\n",
163
+ ")"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "id": "e9b5b9b1",
169
+ "metadata": {},
170
+ "source": [
171
+ "## Init Training pipeline"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "id": "b24d86e3",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "with open('id_to_label_map.pickle', 'rb') as f:\n",
182
+ " map_id_to_label = pickle.load(f)"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "id": "1d90a6d9",
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "id": "3d890df2",
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "id2label = {str(k): v for k, v in map_id_to_label.items()}\n",
203
+ "label2id = {v: k for k, v in id2label.items()}\n",
204
+ "label_names = list(id2label.values())"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "id": "31bcfd6c",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "model = AutoModelForTokenClassification.from_pretrained(\n",
215
+ " model_checkpoint,\n",
216
+ " id2label=id2label,\n",
217
+ " label2id=label2id,\n",
218
+ ")"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "id": "84497580",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "model.config.num_labels"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "id": "1ccfbf74",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "args = TrainingArguments(\n",
239
+ " \"LaBSE_ner_nerel\",\n",
240
+ " evaluation_strategy=\"epoch\",\n",
241
+ " save_strategy=\"no\",\n",
242
+ " learning_rate=2e-5,\n",
243
+ " num_train_epochs=25,\n",
244
+ " weight_decay=0.01,\n",
245
+ " push_to_hub=False,\n",
246
+ " per_device_train_batch_size = 4 ## depending on the total volume of memory of your GPU\n",
247
+ ")"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "id": "c798d567",
253
+ "metadata": {},
254
+ "source": [
255
+ "## Train model"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "id": "1348d188",
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "metric = load_metric(\"seqeval\")"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "id": "5cff0367",
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "trainer = Trainer(\n",
276
+ " model=model,\n",
277
+ " args=args,\n",
278
+ " train_dataset=tokenized_datasets[\"train\"],\n",
279
+ " eval_dataset=tokenized_datasets[\"dev\"],\n",
280
+ " data_collator=data_collator,\n",
281
+ " compute_metrics=compute_metrics,\n",
282
+ " tokenizer=tokenizer,\n",
283
+ ")\n",
284
+ "trainer.train()"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "id": "576a10f4",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "trainer.save_model(\"LaBSE_nerel_last_checkpoint\")"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": null,
300
+ "id": "451d6db1",
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": []
304
+ }
305
+ ],
306
+ "metadata": {
307
+ "kernelspec": {
308
+ "display_name": "hf_env",
309
+ "language": "python",
310
+ "name": "hf_env"
311
+ },
312
+ "language_info": {
313
+ "codemirror_mode": {
314
+ "name": "ipython",
315
+ "version": 3
316
+ },
317
+ "file_extension": ".py",
318
+ "mimetype": "text/x-python",
319
+ "name": "python",
320
+ "nbconvert_exporter": "python",
321
+ "pygments_lexer": "ipython3",
322
+ "version": "3.8.10"
323
+ }
324
+ },
325
+ "nbformat": 4,
326
+ "nbformat_minor": 5
327
+ }