w11wo commited on
Commit
60d5285
1 Parent(s): 2e61b14
.gitattributes CHANGED
@@ -15,3 +15,4 @@
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
18
+ nohup.out filter=lfs diff=lfs merge=lfs -text
events.out.tfevents.1626286603.t1v-n-b95d739e-w-0.590614.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69d0b0c11510581415e3ad84919fcc5857dd72e276dfa98d90a601a31995e9d7
3
+ size 4897718
nohup.out ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2645f6739234c77a54e6320ab13a6dcdd86e2decd09fe90e309506975ad0b0b
3
+ size 4470375
run.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python3 run_mlm_flax.py \
3
+ --output_dir="./" \
4
+ --model_type="roberta" \
5
+ --config_name="./" \
6
+ --tokenizer_name="./" \
7
+ --dataset_language="su" \
8
+ --max_seq_length="128" \
9
+ --preprocessing_num_workers="64" \
10
+ --weight_decay="0.0" \
11
+ --per_device_train_batch_size="128" \
12
+ --per_device_eval_batch_size="128" \
13
+ --learning_rate="2e-4" \
14
+ --warmup_steps="1000" \
15
+ --overwrite_output_dir \
16
+ --pad_to_max_length \
17
+ --num_train_epochs="50" \
18
+ --adam_beta1="0.9" \
19
+ --adam_beta2="0.999" \
20
+ --adam_epsilon="1e-8" \
21
+ --push_to_hub
run_mlm_flax.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import os
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+
29
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Tuple
32
+
33
+ import numpy as np
34
+ from datasets import load_dataset, concatenate_datasets
35
+ from tqdm import tqdm
36
+
37
+ import flax
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ from flax import jax_utils, traverse_util
42
+ from flax.training import train_state
43
+ from flax.training.common_utils import get_metrics, onehot, shard
44
+ from transformers import (
45
+ CONFIG_MAPPING,
46
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
47
+ AutoConfig,
48
+ AutoTokenizer,
49
+ FlaxAutoModelForMaskedLM,
50
+ HfArgumentParser,
51
+ PreTrainedTokenizerBase,
52
+ TensorType,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+
58
+
59
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
60
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
61
+
62
+
63
+ @dataclass
64
+ class ModelArguments:
65
+ """
66
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
67
+ """
68
+
69
+ model_name_or_path: Optional[str] = field(
70
+ default=None,
71
+ metadata={
72
+ "help": "The model checkpoint for weights initialization."
73
+ "Don't set if you want to train a model from scratch."
74
+ },
75
+ )
76
+ model_type: Optional[str] = field(
77
+ default=None,
78
+ metadata={
79
+ "help": "If training from scratch, pass a model type from the list: "
80
+ + ", ".join(MODEL_TYPES)
81
+ },
82
+ )
83
+ config_name: Optional[str] = field(
84
+ default=None,
85
+ metadata={
86
+ "help": "Pretrained config name or path if not the same as model_name"
87
+ },
88
+ )
89
+ tokenizer_name: Optional[str] = field(
90
+ default=None,
91
+ metadata={
92
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
93
+ },
94
+ )
95
+ cache_dir: Optional[str] = field(
96
+ default=None,
97
+ metadata={
98
+ "help": "Where do you want to store the pretrained models downloaded from s3"
99
+ },
100
+ )
101
+ use_fast_tokenizer: bool = field(
102
+ default=True,
103
+ metadata={
104
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
105
+ },
106
+ )
107
+ dtype: Optional[str] = field(
108
+ default="float32",
109
+ metadata={
110
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
111
+ },
112
+ )
113
+
114
+
115
+ @dataclass
116
+ class DataTrainingArguments:
117
+ """
118
+ Arguments pertaining to what data we are going to input our model for training and eval.
119
+ """
120
+
121
+ dataset_language: Optional[str] = field(
122
+ default=None,
123
+ metadata={
124
+ "help": "The language of the OSCAR, MC4, CC100 dataset to use (via the datasets library)."
125
+ },
126
+ )
127
+ # dataset_name: Optional[str] = field(
128
+ # default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
129
+ # )
130
+ # dataset_config_name: Optional[str] = field(
131
+ # default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
132
+ # )
133
+ train_file: Optional[str] = field(
134
+ default=None, metadata={"help": "The input training data file (a text file)."}
135
+ )
136
+ validation_file: Optional[str] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
140
+ },
141
+ )
142
+ train_ref_file: Optional[str] = field(
143
+ default=None,
144
+ metadata={
145
+ "help": "An optional input train ref data file for whole word masking in Chinese."
146
+ },
147
+ )
148
+ validation_ref_file: Optional[str] = field(
149
+ default=None,
150
+ metadata={
151
+ "help": "An optional input validation ref data file for whole word masking in Chinese."
152
+ },
153
+ )
154
+ overwrite_cache: bool = field(
155
+ default=False,
156
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
157
+ )
158
+ validation_split_percentage: Optional[int] = field(
159
+ default=10,
160
+ metadata={
161
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
162
+ },
163
+ )
164
+ max_seq_length: Optional[int] = field(
165
+ default=None,
166
+ metadata={
167
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
168
+ "than this will be truncated. Default to the max input length of the model."
169
+ },
170
+ )
171
+ preprocessing_num_workers: Optional[int] = field(
172
+ default=None,
173
+ metadata={"help": "The number of processes to use for the preprocessing."},
174
+ )
175
+ mlm_probability: float = field(
176
+ default=0.15,
177
+ metadata={"help": "Ratio of tokens to mask for masked language modeling loss"},
178
+ )
179
+ pad_to_max_length: bool = field(
180
+ default=False,
181
+ metadata={
182
+ "help": "Whether to pad all samples to `max_seq_length`. "
183
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
184
+ },
185
+ )
186
+ line_by_line: bool = field(
187
+ default=False,
188
+ metadata={
189
+ "help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."
190
+ },
191
+ )
192
+
193
+ def __post_init__(self):
194
+ if (
195
+ self.dataset_language is None
196
+ and self.train_file is None
197
+ and self.validation_file is None
198
+ ):
199
+ raise ValueError(
200
+ "Need either a dataset name or a training/validation file."
201
+ )
202
+ else:
203
+ if self.train_file is not None:
204
+ extension = self.train_file.split(".")[-1]
205
+ assert extension in [
206
+ "csv",
207
+ "json",
208
+ "txt",
209
+ ], "`train_file` should be a csv, a json or a txt file."
210
+ if self.validation_file is not None:
211
+ extension = self.validation_file.split(".")[-1]
212
+ assert extension in [
213
+ "csv",
214
+ "json",
215
+ "txt",
216
+ ], "`validation_file` should be a csv, a json or a txt file."
217
+
218
+
219
+ @flax.struct.dataclass
220
+ class FlaxDataCollatorForLanguageModeling:
221
+ """
222
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
223
+ are not all of the same length.
224
+
225
+ Args:
226
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
227
+ The tokenizer used for encoding the data.
228
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
229
+ The probability with which to (randomly) mask tokens in the input.
230
+
231
+ .. note::
232
+
233
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
234
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
235
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
236
+ argument :obj:`return_special_tokens_mask=True`.
237
+ """
238
+
239
+ tokenizer: PreTrainedTokenizerBase
240
+ mlm_probability: float = 0.15
241
+
242
+ def __post_init__(self):
243
+ if self.tokenizer.mask_token is None:
244
+ raise ValueError(
245
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
246
+ "You should pass `mlm=False` to train on causal language modeling instead."
247
+ )
248
+
249
+ def __call__(
250
+ self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int
251
+ ) -> Dict[str, np.ndarray]:
252
+ # Handle dict or lists with proper padding and conversion to tensor.
253
+ batch = self.tokenizer.pad(
254
+ examples,
255
+ pad_to_multiple_of=pad_to_multiple_of,
256
+ return_tensors=TensorType.NUMPY,
257
+ )
258
+
259
+ # If special token mask has been preprocessed, pop it from the dict.
260
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
261
+
262
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
263
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
264
+ )
265
+ return batch
266
+
267
+ def mask_tokens(
268
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
269
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
270
+ """
271
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
272
+ """
273
+ labels = inputs.copy()
274
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
275
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
276
+ special_tokens_mask = special_tokens_mask.astype("bool")
277
+
278
+ probability_matrix[special_tokens_mask] = 0.0
279
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
280
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
281
+
282
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
283
+ indices_replaced = (
284
+ np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool")
285
+ & masked_indices
286
+ )
287
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
288
+ self.tokenizer.mask_token
289
+ )
290
+
291
+ # 10% of the time, we replace masked input tokens with random word
292
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype(
293
+ "bool"
294
+ )
295
+ indices_random &= masked_indices & ~indices_replaced
296
+
297
+ random_words = np.random.randint(
298
+ self.tokenizer.vocab_size, size=labels.shape, dtype="i4"
299
+ )
300
+ inputs[indices_random] = random_words[indices_random]
301
+
302
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
303
+ return inputs, labels
304
+
305
+
306
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
307
+ num_samples = len(samples_idx)
308
+ samples_to_remove = num_samples % batch_size
309
+
310
+ if samples_to_remove != 0:
311
+ samples_idx = samples_idx[:-samples_to_remove]
312
+ sections_split = num_samples // batch_size
313
+ batch_idx = np.split(samples_idx, sections_split)
314
+ return batch_idx
315
+
316
+
317
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
318
+ summary_writer.scalar("train_time", train_time, step)
319
+
320
+ train_metrics = get_metrics(train_metrics)
321
+ for key, vals in train_metrics.items():
322
+ tag = f"train_{key}"
323
+ for i, val in enumerate(vals):
324
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
325
+
326
+
327
+ def write_eval_metric(summary_writer, eval_metrics, step):
328
+ for metric_name, value in eval_metrics.items():
329
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
330
+
331
+
332
+ if __name__ == "__main__":
333
+ # See all possible arguments in src/transformers/training_args.py
334
+ # or by passing the --help flag to this script.
335
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
336
+
337
+ parser = HfArgumentParser(
338
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
339
+ )
340
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
341
+ # If we pass only one argument to the script and it's the path to a json file,
342
+ # let's parse it to get our arguments.
343
+ model_args, data_args, training_args = parser.parse_json_file(
344
+ json_file=os.path.abspath(sys.argv[1])
345
+ )
346
+ else:
347
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
348
+
349
+ if (
350
+ os.path.exists(training_args.output_dir)
351
+ and os.listdir(training_args.output_dir)
352
+ and training_args.do_train
353
+ and not training_args.overwrite_output_dir
354
+ ):
355
+ raise ValueError(
356
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
357
+ "Use --overwrite_output_dir to overcome."
358
+ )
359
+
360
+ # Setup logging
361
+ logging.basicConfig(
362
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
363
+ level="NOTSET",
364
+ datefmt="[%X]",
365
+ )
366
+
367
+ # Log on each process the small summary:
368
+ logger = logging.getLogger(__name__)
369
+
370
+ # Set the verbosity to info of the Transformers logger (on main process only):
371
+ logger.info(f"Training/evaluation parameters {training_args}")
372
+
373
+ # Set seed before initializing model.
374
+ set_seed(training_args.seed)
375
+
376
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
377
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
378
+ # (the dataset will be downloaded automatically from the datasets Hub).
379
+ #
380
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
381
+ # 'text' is found. You can easily tweak this behavior (see below).
382
+ #
383
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
384
+ # download the dataset.
385
+ if data_args.dataset_language is not None:
386
+ # Downloading and loading a dataset from the hub.
387
+ oscar = load_dataset(
388
+ "oscar",
389
+ f"unshuffled_deduplicated_{data_args.dataset_language}",
390
+ split="train",
391
+ cache_dir=model_args.cache_dir,
392
+ )
393
+
394
+ cc100 = load_dataset(
395
+ "cc100",
396
+ lang=data_args.dataset_language,
397
+ split="train",
398
+ cache_dir=model_args.cache_dir,
399
+ )
400
+
401
+ mc4 = load_dataset(
402
+ "mc4",
403
+ data_args.dataset_language,
404
+ split="train",
405
+ cache_dir=model_args.cache_dir,
406
+ )
407
+
408
+ wiki_files = [str(x) for x in Path("../docs").glob("*.txt")]
409
+ wiki = load_dataset("text", data_files=wiki_files)
410
+
411
+ # want: text column only!
412
+ oscar = oscar.remove_columns("id")
413
+ mc4 = mc4.remove_columns(["url", "timestamp"])
414
+ cc100 = cc100.remove_columns("id")
415
+
416
+ # combine datasets
417
+ datasets = concatenate_datasets([oscar, mc4, cc100, wiki["train"]])
418
+ # split train and validation
419
+ # note: renamed `validation` key to `test` everywhere else in the script
420
+ datasets = datasets.train_test_split(
421
+ test_size=data_args.validation_split_percentage / 100, seed=42
422
+ )
423
+
424
+ else:
425
+ data_files = {}
426
+ if data_args.train_file is not None:
427
+ data_files["train"] = data_args.train_file
428
+ if data_args.validation_file is not None:
429
+ data_files["test"] = data_args.validation_file
430
+ extension = data_args.train_file.split(".")[-1]
431
+ if extension == "txt":
432
+ extension = "text"
433
+ datasets = load_dataset(
434
+ extension, data_files=data_files, cache_dir=model_args.cache_dir
435
+ )
436
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
437
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
438
+
439
+ # Load pretrained model and tokenizer
440
+
441
+ # Distributed training:
442
+ # The .from_pretrained methods guarantee that only one local process can concurrently
443
+ # download model & vocab.
444
+ if model_args.config_name:
445
+ config = AutoConfig.from_pretrained(
446
+ model_args.config_name, cache_dir=model_args.cache_dir
447
+ )
448
+ elif model_args.model_name_or_path:
449
+ config = AutoConfig.from_pretrained(
450
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir
451
+ )
452
+ else:
453
+ config = CONFIG_MAPPING[model_args.model_type]()
454
+ logger.warning("You are instantiating a new config instance from scratch.")
455
+
456
+ if model_args.tokenizer_name:
457
+ tokenizer = AutoTokenizer.from_pretrained(
458
+ model_args.tokenizer_name,
459
+ cache_dir=model_args.cache_dir,
460
+ use_fast=model_args.use_fast_tokenizer,
461
+ )
462
+ elif model_args.model_name_or_path:
463
+ tokenizer = AutoTokenizer.from_pretrained(
464
+ model_args.model_name_or_path,
465
+ cache_dir=model_args.cache_dir,
466
+ use_fast=model_args.use_fast_tokenizer,
467
+ )
468
+ else:
469
+ raise ValueError(
470
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
471
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
472
+ )
473
+
474
+ # Preprocessing the datasets.
475
+ # First we tokenize all the texts.
476
+ if training_args.do_train:
477
+ column_names = datasets["train"].column_names
478
+ else:
479
+ column_names = datasets["test"].column_names
480
+ text_column_name = "text" if "text" in column_names else column_names[0]
481
+
482
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
483
+
484
+ if data_args.line_by_line:
485
+ # When using line_by_line, we just tokenize each nonempty line.
486
+ padding = "max_length" if data_args.pad_to_max_length else False
487
+
488
+ def tokenize_function(examples):
489
+ # Remove empty lines
490
+ examples = [
491
+ line for line in examples if len(line) > 0 and not line.isspace()
492
+ ]
493
+ return tokenizer(
494
+ examples,
495
+ return_special_tokens_mask=True,
496
+ padding=padding,
497
+ truncation=True,
498
+ max_length=max_seq_length,
499
+ )
500
+
501
+ tokenized_datasets = datasets.map(
502
+ tokenize_function,
503
+ input_columns=[text_column_name],
504
+ batched=True,
505
+ num_proc=data_args.preprocessing_num_workers,
506
+ remove_columns=column_names,
507
+ load_from_cache_file=not data_args.overwrite_cache,
508
+ )
509
+
510
+ else:
511
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
512
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
513
+ # efficient when it receives the `special_tokens_mask`.
514
+ def tokenize_function(examples):
515
+ return tokenizer(
516
+ examples[text_column_name], return_special_tokens_mask=True
517
+ )
518
+
519
+ tokenized_datasets = datasets.map(
520
+ tokenize_function,
521
+ batched=True,
522
+ num_proc=data_args.preprocessing_num_workers,
523
+ remove_columns=column_names,
524
+ load_from_cache_file=not data_args.overwrite_cache,
525
+ )
526
+
527
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
528
+ # max_seq_length.
529
+ def group_texts(examples):
530
+ # Concatenate all texts.
531
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
532
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
533
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
534
+ # customize this part to your needs.
535
+ total_length = (total_length // max_seq_length) * max_seq_length
536
+ # Split by chunks of max_len.
537
+ result = {
538
+ k: [
539
+ t[i : i + max_seq_length]
540
+ for i in range(0, total_length, max_seq_length)
541
+ ]
542
+ for k, t in concatenated_examples.items()
543
+ }
544
+ return result
545
+
546
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
547
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
548
+ # might be slower to preprocess.
549
+ #
550
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
551
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
552
+ tokenized_datasets = tokenized_datasets.map(
553
+ group_texts,
554
+ batched=True,
555
+ num_proc=data_args.preprocessing_num_workers,
556
+ load_from_cache_file=not data_args.overwrite_cache,
557
+ )
558
+
559
+ # Enable tensorboard only on the master node
560
+ has_tensorboard = is_tensorboard_available()
561
+ if has_tensorboard and jax.process_index() == 0:
562
+ try:
563
+ from flax.metrics.tensorboard import SummaryWriter
564
+
565
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
566
+ except ImportError as ie:
567
+ has_tensorboard = False
568
+ logger.warning(
569
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
570
+ )
571
+ else:
572
+ logger.warning(
573
+ "Unable to display metrics through TensorBoard because the package is not installed: "
574
+ "Please run pip install tensorboard to enable."
575
+ )
576
+
577
+ # Data collator
578
+ # This one will take care of randomly masking the tokens.
579
+ data_collator = FlaxDataCollatorForLanguageModeling(
580
+ tokenizer=tokenizer, mlm_probability=data_args.mlm_probability
581
+ )
582
+
583
+ # Initialize our training
584
+ rng = jax.random.PRNGKey(training_args.seed)
585
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
586
+
587
+ if model_args.model_name_or_path:
588
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
589
+ model_args.model_name_or_path,
590
+ config=config,
591
+ seed=training_args.seed,
592
+ dtype=getattr(jnp, model_args.dtype),
593
+ )
594
+ else:
595
+ model = FlaxAutoModelForMaskedLM.from_config(
596
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
597
+ )
598
+
599
+ # Store some constant
600
+ num_epochs = int(training_args.num_train_epochs)
601
+ train_batch_size = (
602
+ int(training_args.per_device_train_batch_size) * jax.device_count()
603
+ )
604
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
605
+
606
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
607
+
608
+ # Create learning rate schedule
609
+ warmup_fn = optax.linear_schedule(
610
+ init_value=0.0,
611
+ end_value=training_args.learning_rate,
612
+ transition_steps=training_args.warmup_steps,
613
+ )
614
+ decay_fn = optax.linear_schedule(
615
+ init_value=training_args.learning_rate,
616
+ end_value=0,
617
+ transition_steps=num_train_steps - training_args.warmup_steps,
618
+ )
619
+ linear_decay_lr_schedule_fn = optax.join_schedules(
620
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
621
+ )
622
+
623
+ # We use Optax's "masking" functionality to not apply weight decay
624
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
625
+ # mask boolean with the same structure as the parameters.
626
+ # The mask is True for parameters that should be decayed.
627
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
628
+ # For other models, one should correct the layer norm parameter naming
629
+ # accordingly.
630
+ def decay_mask_fn(params):
631
+ flat_params = traverse_util.flatten_dict(params)
632
+ flat_mask = {
633
+ path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))
634
+ for path in flat_params
635
+ }
636
+ return traverse_util.unflatten_dict(flat_mask)
637
+
638
+ # create adam optimizer
639
+ if training_args.adafactor:
640
+ # We use the default parameters here to initialize adafactor,
641
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
642
+ optimizer = optax.adafactor(learning_rate=linear_decay_lr_schedule_fn,)
643
+ else:
644
+ optimizer = optax.adamw(
645
+ learning_rate=linear_decay_lr_schedule_fn,
646
+ b1=training_args.adam_beta1,
647
+ b2=training_args.adam_beta2,
648
+ eps=training_args.adam_epsilon,
649
+ weight_decay=training_args.weight_decay,
650
+ mask=decay_mask_fn,
651
+ )
652
+
653
+ # Setup train state
654
+ state = train_state.TrainState.create(
655
+ apply_fn=model.__call__, params=model.params, tx=optimizer
656
+ )
657
+
658
+ # Define gradient update step fn
659
+ def train_step(state, batch, dropout_rng):
660
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
661
+
662
+ def loss_fn(params):
663
+ labels = batch.pop("labels")
664
+
665
+ logits = state.apply_fn(
666
+ **batch, params=params, dropout_rng=dropout_rng, train=True
667
+ )[0]
668
+
669
+ # compute loss, ignore padded input tokens
670
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
671
+ loss = (
672
+ optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
673
+ * label_mask
674
+ )
675
+
676
+ # take average
677
+ loss = loss.sum() / label_mask.sum()
678
+
679
+ return loss
680
+
681
+ grad_fn = jax.value_and_grad(loss_fn)
682
+ loss, grad = grad_fn(state.params)
683
+ grad = jax.lax.pmean(grad, "batch")
684
+ new_state = state.apply_gradients(grads=grad)
685
+
686
+ metrics = jax.lax.pmean(
687
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)},
688
+ axis_name="batch",
689
+ )
690
+
691
+ return new_state, metrics, new_dropout_rng
692
+
693
+ # Create parallel version of the train step
694
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
695
+
696
+ # Define eval fn
697
+ def eval_step(params, batch):
698
+ labels = batch.pop("labels")
699
+
700
+ logits = model(**batch, params=params, train=False)[0]
701
+
702
+ # compute loss, ignore padded input tokens
703
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
704
+ loss = (
705
+ optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
706
+ * label_mask
707
+ )
708
+
709
+ # compute accuracy
710
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
711
+
712
+ # summarize metrics
713
+ metrics = {
714
+ "loss": loss.sum(),
715
+ "accuracy": accuracy.sum(),
716
+ "normalizer": label_mask.sum(),
717
+ }
718
+ metrics = jax.lax.psum(metrics, axis_name="batch")
719
+
720
+ return metrics
721
+
722
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
723
+
724
+ # Replicate the train state on each device
725
+ state = jax_utils.replicate(state)
726
+
727
+ train_time = 0
728
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
729
+ for epoch in epochs:
730
+ # ======================== Training ================================
731
+ train_start = time.time()
732
+ train_metrics = []
733
+
734
+ # Create sampling rng
735
+ rng, input_rng = jax.random.split(rng)
736
+
737
+ # Generate an epoch by shuffling sampling indices from the train dataset
738
+ num_train_samples = len(tokenized_datasets["train"])
739
+ train_samples_idx = jax.random.permutation(
740
+ input_rng, jnp.arange(num_train_samples)
741
+ )
742
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
743
+
744
+ # Gather the indexes for creating the batch and do a training step
745
+ for step, batch_idx in enumerate(
746
+ tqdm(train_batch_idx, desc="Training...", position=1)
747
+ ):
748
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
749
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
750
+
751
+ # Model forward
752
+ model_inputs = shard(model_inputs.data)
753
+ state, train_metric, dropout_rngs = p_train_step(
754
+ state, model_inputs, dropout_rngs
755
+ )
756
+ train_metrics.append(train_metric)
757
+
758
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
759
+
760
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
761
+ # Save metrics
762
+ train_metric = jax_utils.unreplicate(train_metric)
763
+ train_time += time.time() - train_start
764
+ if has_tensorboard and jax.process_index() == 0:
765
+ write_train_metric(
766
+ summary_writer, train_metrics, train_time, cur_step
767
+ )
768
+
769
+ epochs.write(
770
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
771
+ )
772
+
773
+ train_metrics = []
774
+
775
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
776
+ # ======================== Evaluating ==============================
777
+ num_eval_samples = len(tokenized_datasets["test"])
778
+ eval_samples_idx = jnp.arange(num_eval_samples)
779
+ eval_batch_idx = generate_batch_splits(
780
+ eval_samples_idx, eval_batch_size
781
+ )
782
+
783
+ eval_metrics = []
784
+ for i, batch_idx in enumerate(
785
+ tqdm(eval_batch_idx, desc="Evaluating ...", position=2)
786
+ ):
787
+ samples = [
788
+ tokenized_datasets["test"][int(idx)] for idx in batch_idx
789
+ ]
790
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
791
+
792
+ # Model forward
793
+ model_inputs = shard(model_inputs.data)
794
+ metrics = p_eval_step(state.params, model_inputs)
795
+ eval_metrics.append(metrics)
796
+
797
+ # normalize eval metrics
798
+ eval_metrics = get_metrics(eval_metrics)
799
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
800
+ eval_normalizer = eval_metrics.pop("normalizer")
801
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
802
+
803
+ # Update progress bar
804
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
805
+
806
+ # Save metrics
807
+ if has_tensorboard and jax.process_index() == 0:
808
+ cur_step = epoch * (
809
+ len(tokenized_datasets["train"]) // train_batch_size
810
+ )
811
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
812
+
813
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
814
+ # save checkpoint after each epoch and push checkpoint to the hub
815
+ if jax.process_index() == 0:
816
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
817
+ model.save_pretrained(
818
+ training_args.output_dir,
819
+ params=params,
820
+ push_to_hub=training_args.push_to_hub,
821
+ commit_message=f"Saving weights and logs of step {cur_step}",
822
+ )
823
+