sbmaruf commited on
Commit
6d3bf36
1 Parent(s): 3d14a58
Files changed (4) hide show
  1. run.sh +36 -0
  2. run_t5_mlm_flax.py +790 -0
  3. tokenizer/config.json +56 -0
  4. tokenizer/tokenizer.json +0 -0
run.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EXP_FOLDER="dumped/bengali_t5_base"
2
+ CACHE_DIR=$EXP_FOLDER/
3
+ MODEL_CKPT=$EXP_FOLDER/
4
+ mkdir -p $EXP_FOLDER
5
+ mkdir -p $CACHE_DIR
6
+ mkdir -p $MODEL_CKPT
7
+
8
+ TOKENIZER_DIR="dumped/bengali_t5_base/tokenizer"
9
+ MODEL_CONFIG="t5-base"
10
+ MAX_SEQ_LEN=512
11
+
12
+ NUM_THREAD=50
13
+ DATASET_NAME="mc4"
14
+ DATASET_CONFIG_NAME="bn"
15
+
16
+ python -u run_t5_mlm_flax.py \
17
+ --output_dir ${MODEL_CKPT} \
18
+ --model_type "t5" \
19
+ --config_name $MODEL_CONFIG \
20
+ --tokenizer_name ${TOKENIZER_DIR} \
21
+ --dataset_name $DATASET_NAME \
22
+ --dataset_config_name $DATASET_CONFIG_NAME \
23
+ --max_seq_length $MAX_SEQ_LEN \
24
+ --per_device_train_batch_size 8 \
25
+ --per_device_eval_batch_size 8 \
26
+ --adafactor \
27
+ --learning_rate 1e-3 \
28
+ --weight_decay 0.001 \
29
+ --warmup_steps 5000 \
30
+ --overwrite_output_dir \
31
+ --num_train_epochs 10 \
32
+ --logging_steps 500 \
33
+ --save_steps 2500 \
34
+ --eval_steps 7500 \
35
+ --preprocessing_num_workers $NUM_THREAD \
36
+ --dtype bfloat16
run_t5_mlm_flax.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset.
18
+ Here is the full list of checkpoints on the hub that can be pretrained by this script:
19
+ https://huggingface.co/models?filter=t5
20
+ """
21
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
22
+ import logging
23
+ import os
24
+ import sys
25
+ import time
26
+ from dataclasses import dataclass, field
27
+ from pathlib import Path
28
+ from typing import Dict, List, Optional
29
+
30
+ import numpy as np
31
+ from datasets import load_dataset
32
+ from tqdm import tqdm
33
+
34
+ import flax
35
+ import jax
36
+ import jax.numpy as jnp
37
+ import optax
38
+ from flax import jax_utils, traverse_util
39
+ from flax.training import train_state
40
+ from flax.training.common_utils import get_metrics, onehot, shard
41
+ from transformers import (
42
+ CONFIG_MAPPING,
43
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
44
+ AutoTokenizer,
45
+ BatchEncoding,
46
+ FlaxT5ForConditionalGeneration,
47
+ HfArgumentParser,
48
+ PreTrainedTokenizerBase,
49
+ T5Config,
50
+ TrainingArguments,
51
+ is_tensorboard_available,
52
+ set_seed,
53
+ )
54
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
55
+
56
+
57
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
58
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
59
+
60
+
61
+ @dataclass
62
+ class ModelArguments:
63
+ """
64
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
65
+ """
66
+
67
+ model_name_or_path: Optional[str] = field(
68
+ default=None,
69
+ metadata={
70
+ "help": "The model checkpoint for weights initialization."
71
+ "Don't set if you want to train a model from scratch."
72
+ },
73
+ )
74
+ model_type: Optional[str] = field(
75
+ default=None,
76
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
77
+ )
78
+ config_name: Optional[str] = field(
79
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
80
+ )
81
+ tokenizer_name: Optional[str] = field(
82
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
83
+ )
84
+ cache_dir: Optional[str] = field(
85
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
86
+ )
87
+ use_fast_tokenizer: bool = field(
88
+ default=True,
89
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
90
+ )
91
+ dtype: Optional[str] = field(
92
+ default="bfloat16",
93
+ metadata={
94
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
95
+ },
96
+ )
97
+
98
+
99
+ @dataclass
100
+ class DataTrainingArguments:
101
+ """
102
+ Arguments pertaining to what data we are going to input our model for training and eval.
103
+ """
104
+
105
+ dataset_name: Optional[str] = field(
106
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
107
+ )
108
+ dataset_config_name: Optional[str] = field(
109
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
110
+ )
111
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
112
+ validation_file: Optional[str] = field(
113
+ default=None,
114
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
115
+ )
116
+ train_ref_file: Optional[str] = field(
117
+ default=None,
118
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
119
+ )
120
+ validation_ref_file: Optional[str] = field(
121
+ default=None,
122
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
123
+ )
124
+ overwrite_cache: bool = field(
125
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
126
+ )
127
+ validation_split_percentage: Optional[int] = field(
128
+ default=5,
129
+ metadata={
130
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
131
+ },
132
+ )
133
+ max_seq_length: Optional[int] = field(
134
+ default=None,
135
+ metadata={
136
+ "help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
137
+ },
138
+ )
139
+ preprocessing_num_workers: Optional[int] = field(
140
+ default=None,
141
+ metadata={"help": "The number of processes to use for the preprocessing."},
142
+ )
143
+ mlm_probability: float = field(
144
+ default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
145
+ )
146
+ mean_noise_span_length: float = field(
147
+ default=3.0,
148
+ metadata={"help": "Mean span length of masked tokens"},
149
+ )
150
+
151
+ def __post_init__(self):
152
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
153
+ raise ValueError("Need either a dataset name or a training/validation file.")
154
+ else:
155
+ if self.train_file is not None:
156
+ extension = self.train_file.split(".")[-1]
157
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
158
+ if self.validation_file is not None:
159
+ extension = self.validation_file.split(".")[-1]
160
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
161
+
162
+
163
+ def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
164
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
165
+ Training parameters to avoid padding with random_spans_noise_mask.
166
+ When training a model with random_spans_noise_mask, we would like to set the other
167
+ training hyperparmeters in a way that avoids padding.
168
+ This function helps us compute these hyperparameters.
169
+ We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
170
+ and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
171
+ This function tells us the required number of tokens in the raw example (for split_tokens())
172
+ as well as the length of the encoded targets. Note that this function assumes
173
+ the inputs and targets will have EOS appended and includes that in the reported length.
174
+ Args:
175
+ inputs_length: an integer - desired length of the tokenized inputs sequence
176
+ noise_density: a float
177
+ mean_noise_span_length: a float
178
+ Returns:
179
+ tokens_length: length of original text in tokens
180
+ targets_length: an integer - length in tokens of encoded targets sequence
181
+ """
182
+
183
+ def _tokens_length_to_inputs_length_targets_length(tokens_length):
184
+ num_noise_tokens = int(round(tokens_length * noise_density))
185
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
186
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
187
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
188
+ # and one EOS token.
189
+ _input_length = num_nonnoise_tokens + num_noise_spans + 1
190
+ _output_length = num_noise_tokens + num_noise_spans + 1
191
+ return _input_length, _output_length
192
+
193
+ tokens_length = inputs_length
194
+
195
+ while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
196
+ tokens_length += 1
197
+
198
+ inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
199
+
200
+ # minor hack to get the targets length to be equal to inputs length
201
+ # which is more likely to have been set to a nice round number.
202
+ if noise_density == 0.5 and targets_length > inputs_length:
203
+ tokens_length -= 1
204
+ targets_length -= 1
205
+ return tokens_length, targets_length
206
+
207
+
208
+ @flax.struct.dataclass
209
+ class FlaxDataCollatorForT5MLM:
210
+ """
211
+ Data collator used for T5 span-masked language modeling.
212
+ It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
213
+ For more information on how T5 span-masked language modeling works, one can take a look
214
+ at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
215
+ or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
216
+ Args:
217
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
218
+ The tokenizer used for encoding the data.
219
+ noise_density (:obj:`float`):
220
+ The probability with which to (randomly) mask tokens in the input.
221
+ mean_noise_span_length (:obj:`float`):
222
+ The average span length of the masked tokens.
223
+ input_length (:obj:`int`):
224
+ The expected input length after masking.
225
+ target_length (:obj:`int`):
226
+ The expected target length after masking.
227
+ pad_token_id: (:obj:`int`):
228
+ The pad token id of the model
229
+ decoder_start_token_id: (:obj:`int):
230
+ The decoder start token id of the model
231
+ """
232
+
233
+ tokenizer: PreTrainedTokenizerBase
234
+ noise_density: float
235
+ mean_noise_span_length: float
236
+ input_length: int
237
+ target_length: int
238
+ pad_token_id: int
239
+ decoder_start_token_id: int
240
+
241
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
242
+
243
+ # convert list to dict and tensorize input
244
+ batch = BatchEncoding(
245
+ {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
246
+ )
247
+
248
+ input_ids = batch["input_ids"]
249
+ batch_size, expandend_input_length = input_ids.shape
250
+
251
+ mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
252
+ labels_mask = ~mask_indices
253
+
254
+ input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
255
+ labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
256
+
257
+ batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
258
+ batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
259
+
260
+ if batch["input_ids"].shape[-1] != self.input_length:
261
+ raise ValueError(
262
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
263
+ )
264
+
265
+ if batch["labels"].shape[-1] != self.target_length:
266
+ raise ValueError(
267
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
268
+ )
269
+
270
+ # to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
271
+ batch["decoder_input_ids"] = shift_tokens_right(
272
+ batch["labels"], self.pad_token_id, self.decoder_start_token_id
273
+ )
274
+
275
+ return batch
276
+
277
+ def create_sentinel_ids(self, mask_indices):
278
+ """
279
+ Sentinel ids creation given the indices that should be masked.
280
+ The start indices of each mask are replaced by the sentinel ids in increasing
281
+ order. Consecutive mask indices to be deleted are replaced with `-1`.
282
+ """
283
+ start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
284
+ start_indices[:, 0] = mask_indices[:, 0]
285
+
286
+ sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
287
+ sentinel_ids = np.where(sentinel_ids != 0, (sentinel_ids + self.tokenizer.vocab_size - 1), 0)
288
+ sentinel_ids -= mask_indices - start_indices
289
+
290
+ return sentinel_ids
291
+
292
+ def filter_input_ids(self, input_ids, sentinel_ids):
293
+ """
294
+ Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
295
+ This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
296
+ """
297
+ batch_size = input_ids.shape[0]
298
+
299
+ input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
300
+ input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
301
+ input_ids = np.concatenate(
302
+ [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
303
+ )
304
+ return input_ids
305
+
306
+ def random_spans_noise_mask(self, length):
307
+
308
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
309
+ Noise mask consisting of random spans of noise tokens.
310
+ The number of noise tokens and the number of noise spans and non-noise spans
311
+ are determined deterministically as follows:
312
+ num_noise_tokens = round(length * noise_density)
313
+ num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
314
+ Spans alternate between non-noise and noise, beginning with non-noise.
315
+ Subject to the above restrictions, all masks are equally likely.
316
+ Args:
317
+ length: an int32 scalar (length of the incoming token sequence)
318
+ noise_density: a float - approximate density of output mask
319
+ mean_noise_span_length: a number
320
+ Returns:
321
+ a boolean tensor with shape [length]
322
+ """
323
+
324
+ orig_length = length
325
+
326
+ num_noise_tokens = int(np.round(length * self.noise_density))
327
+ # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
328
+ num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
329
+ num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
330
+
331
+ # avoid degeneracy by ensuring positive number of noise spans
332
+ num_noise_spans = max(num_noise_spans, 1)
333
+ num_nonnoise_tokens = length - num_noise_tokens
334
+
335
+ # pick the lengths of the noise spans and the non-noise spans
336
+ def _random_segmentation(num_items, num_segments):
337
+ """Partition a sequence of items randomly into non-empty segments.
338
+ Args:
339
+ num_items: an integer scalar > 0
340
+ num_segments: an integer scalar in [1, num_items]
341
+ Returns:
342
+ a Tensor with shape [num_segments] containing positive integers that add
343
+ up to num_items
344
+ """
345
+ mask_indices = np.arange(num_items - 1) < (num_segments - 1)
346
+ np.random.shuffle(mask_indices)
347
+ first_in_segment = np.pad(mask_indices, [[1, 0]])
348
+ segment_id = np.cumsum(first_in_segment)
349
+ segment_length = np.asarray(jax.ops.segment_sum(np.ones_like(segment_id), segment_id))
350
+ return segment_length
351
+
352
+ noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
353
+ nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
354
+
355
+ interleaved_span_lengths = np.reshape(
356
+ np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
357
+ )
358
+ span_starts = np.cumsum(interleaved_span_lengths)[:-1]
359
+ span_start_indicator = np.zeros((length,), dtype=np.int8)
360
+ span_start_indicator[span_starts] = True
361
+ span_num = np.cumsum(span_start_indicator)
362
+ is_noise = np.equal(span_num % 2, 1)
363
+
364
+ return is_noise[:orig_length]
365
+
366
+
367
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
368
+ num_samples = len(samples_idx)
369
+ samples_to_remove = num_samples % batch_size
370
+
371
+ if samples_to_remove != 0:
372
+ samples_idx = samples_idx[:-samples_to_remove]
373
+ sections_split = num_samples // batch_size
374
+ batch_idx = np.split(samples_idx, sections_split)
375
+ return batch_idx
376
+
377
+
378
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
379
+ summary_writer.scalar("train_time", train_time, step)
380
+
381
+ train_metrics = get_metrics(train_metrics)
382
+ for key, vals in train_metrics.items():
383
+ tag = f"train_{key}"
384
+ for i, val in enumerate(vals):
385
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
386
+
387
+
388
+ def write_eval_metric(summary_writer, eval_metrics, step):
389
+ for metric_name, value in eval_metrics.items():
390
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
391
+
392
+
393
+ if __name__ == "__main__":
394
+ # See all possible arguments in src/transformers/training_args.py
395
+ # or by passing the --help flag to this script.
396
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
397
+
398
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
399
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
400
+ # If we pass only one argument to the script and it's the path to a json file,
401
+ # let's parse it to get our arguments.
402
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
403
+ else:
404
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
405
+
406
+ if (
407
+ os.path.exists(training_args.output_dir)
408
+ and os.listdir(training_args.output_dir)
409
+ and training_args.do_train
410
+ and not training_args.overwrite_output_dir
411
+ ):
412
+ raise ValueError(
413
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
414
+ "Use --overwrite_output_dir to overcome."
415
+ )
416
+
417
+ # Setup logging
418
+ logging.basicConfig(
419
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
420
+ level="NOTSET",
421
+ datefmt="[%X]",
422
+ )
423
+
424
+ # Log on each process the small summary:
425
+ logger = logging.getLogger(__name__)
426
+
427
+ # Set the verbosity to info of the Transformers logger (on main process only):
428
+ logger.info(f"Training/evaluation parameters {training_args}")
429
+
430
+ # Set seed before initializing model.
431
+ set_seed(training_args.seed)
432
+
433
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
434
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
435
+ # (the dataset will be downloaded automatically from the datasets Hub).
436
+ #
437
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
438
+ # 'text' is found. You can easily tweak this behavior (see below).
439
+ if data_args.dataset_name is not None:
440
+ # Downloading and loading a dataset from the hub.
441
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
442
+
443
+ if "validation" not in datasets.keys():
444
+ datasets["validation"] = load_dataset(
445
+ data_args.dataset_name,
446
+ data_args.dataset_config_name,
447
+ split=f"train[:{data_args.validation_split_percentage}%]",
448
+ cache_dir=model_args.cache_dir,
449
+ )
450
+ datasets["train"] = load_dataset(
451
+ data_args.dataset_name,
452
+ data_args.dataset_config_name,
453
+ split=f"train[{data_args.validation_split_percentage}%:]",
454
+ cache_dir=model_args.cache_dir,
455
+ )
456
+ else:
457
+ data_files = {}
458
+ if data_args.train_file is not None:
459
+ data_files["train"] = data_args.train_file
460
+ if data_args.validation_file is not None:
461
+ data_files["validation"] = data_args.validation_file
462
+ extension = data_args.train_file.split(".")[-1]
463
+ if extension == "txt":
464
+ extension = "text"
465
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
466
+
467
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
468
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
469
+
470
+ # Load pretrained model and tokenizer
471
+
472
+ if model_args.tokenizer_name:
473
+ tokenizer = AutoTokenizer.from_pretrained(
474
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
475
+ )
476
+ elif model_args.model_name_or_path:
477
+ tokenizer = AutoTokenizer.from_pretrained(
478
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
479
+ )
480
+ else:
481
+ raise ValueError(
482
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
483
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
484
+ )
485
+
486
+ if model_args.config_name:
487
+ config = T5Config.from_pretrained(
488
+ model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
489
+ )
490
+ elif model_args.model_name_or_path:
491
+ config = T5Config.from_pretrained(
492
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
493
+ )
494
+ else:
495
+ config = CONFIG_MAPPING[model_args.model_type]()
496
+ logger.warning("You are instantiating a new config instance from scratch.")
497
+
498
+ # Preprocessing the datasets.
499
+ # First we tokenize all the texts.
500
+ if training_args.do_train:
501
+ column_names = datasets["train"].column_names
502
+ else:
503
+ column_names = datasets["validation"].column_names
504
+ text_column_name = "text" if "text" in column_names else column_names[0]
505
+
506
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
507
+
508
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
509
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
510
+ def tokenize_function(examples):
511
+ return tokenizer(examples[text_column_name], return_attention_mask=False)
512
+
513
+ tokenized_datasets = datasets.map(
514
+ tokenize_function,
515
+ batched=True,
516
+ num_proc=data_args.preprocessing_num_workers,
517
+ remove_columns=column_names,
518
+ load_from_cache_file=not data_args.overwrite_cache,
519
+ )
520
+
521
+ # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
522
+ # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
523
+ # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
524
+ expanded_inputs_length, targets_length = compute_input_and_target_lengths(
525
+ inputs_length=max_seq_length,
526
+ noise_density=data_args.mlm_probability,
527
+ mean_noise_span_length=data_args.mean_noise_span_length,
528
+ )
529
+
530
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
531
+ def group_texts(examples):
532
+ # Concatenate all texts.
533
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
534
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
535
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
536
+ # customize this part to your needs.
537
+ if total_length >= expanded_inputs_length:
538
+ total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
539
+ # Split by chunks of max_len.
540
+ result = {
541
+ k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
542
+ for k, t in concatenated_examples.items()
543
+ }
544
+ return result
545
+
546
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
547
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
548
+ # might be slower to preprocess.
549
+ #
550
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
551
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
552
+ tokenized_datasets = tokenized_datasets.map(
553
+ group_texts,
554
+ batched=True,
555
+ num_proc=data_args.preprocessing_num_workers,
556
+ load_from_cache_file=not data_args.overwrite_cache,
557
+ )
558
+
559
+ # Enable tensorboard only on the master node
560
+ has_tensorboard = is_tensorboard_available()
561
+ if has_tensorboard and jax.process_index() == 0:
562
+ try:
563
+ from flax.metrics.tensorboard import SummaryWriter
564
+
565
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
566
+ except ImportError as ie:
567
+ has_tensorboard = False
568
+ logger.warning(
569
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
570
+ )
571
+ else:
572
+ logger.warning(
573
+ "Unable to display metrics through TensorBoard because the package is not installed: "
574
+ "Please run pip install tensorboard to enable."
575
+ )
576
+
577
+ # Initialize our training
578
+ rng = jax.random.PRNGKey(training_args.seed)
579
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
580
+
581
+ if model_args.model_name_or_path:
582
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
583
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
584
+ )
585
+ else:
586
+ model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
587
+
588
+ # Data collator
589
+ # This one will take care of randomly masking the tokens.
590
+ data_collator = FlaxDataCollatorForT5MLM(
591
+ tokenizer=tokenizer,
592
+ noise_density=data_args.mlm_probability,
593
+ mean_noise_span_length=data_args.mean_noise_span_length,
594
+ input_length=max_seq_length,
595
+ target_length=targets_length,
596
+ pad_token_id=model.config.pad_token_id,
597
+ decoder_start_token_id=model.config.decoder_start_token_id,
598
+ )
599
+
600
+ # Store some constant
601
+ num_epochs = int(training_args.num_train_epochs)
602
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
603
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
604
+
605
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
606
+
607
+ # Create learning rate schedule
608
+ warmup_fn = optax.linear_schedule(
609
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
610
+ )
611
+ decay_fn = optax.linear_schedule(
612
+ init_value=training_args.learning_rate,
613
+ end_value=0,
614
+ transition_steps=num_train_steps - training_args.warmup_steps,
615
+ )
616
+ linear_decay_lr_schedule_fn = optax.join_schedules(
617
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
618
+ )
619
+
620
+ # We use Optax's "masking" functionality to not apply weight decay
621
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
622
+ # mask boolean with the same structure as the parameters.
623
+ # The mask is True for parameters that should be decayed.
624
+ def decay_mask_fn(params):
625
+ flat_params = traverse_util.flatten_dict(params)
626
+ flat_mask = {
627
+ path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
628
+ for path in flat_params
629
+ }
630
+ return traverse_util.unflatten_dict(flat_mask)
631
+
632
+ # create adam optimizer
633
+ if training_args.adafactor:
634
+ # We use the default parameters here to initialize adafactor,
635
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
636
+ optimizer = optax.adafactor(
637
+ learning_rate=linear_decay_lr_schedule_fn,
638
+ )
639
+ else:
640
+ optimizer = optax.adamw(
641
+ learning_rate=linear_decay_lr_schedule_fn,
642
+ b1=training_args.adam_beta1,
643
+ b2=training_args.adam_beta2,
644
+ weight_decay=training_args.weight_decay,
645
+ mask=decay_mask_fn,
646
+ )
647
+
648
+ # Setup train state
649
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
650
+
651
+ # Define gradient update step fn
652
+ def train_step(state, batch, dropout_rng):
653
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
654
+
655
+ def loss_fn(params):
656
+ labels = batch.pop("labels")
657
+
658
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
659
+
660
+ # compute loss
661
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
662
+
663
+ return loss
664
+
665
+ grad_fn = jax.value_and_grad(loss_fn)
666
+ loss, grad = grad_fn(state.params)
667
+ grad = jax.lax.pmean(grad, "batch")
668
+ new_state = state.apply_gradients(grads=grad)
669
+
670
+ metrics = jax.lax.pmean(
671
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
672
+ )
673
+
674
+ return new_state, metrics, new_dropout_rng
675
+
676
+ # Create parallel version of the train step
677
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
678
+
679
+ # Define eval fn
680
+ def eval_step(params, batch):
681
+ labels = batch.pop("labels")
682
+
683
+ logits = model(**batch, params=params, train=False)[0]
684
+
685
+ # compute loss
686
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
687
+
688
+ # compute accuracy
689
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)
690
+
691
+ # summarize metrics
692
+ metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
693
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
694
+
695
+ return metrics
696
+
697
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
698
+
699
+ # Replicate the train state on each device
700
+ state = jax_utils.replicate(state)
701
+
702
+ train_time = 0
703
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
704
+ for epoch in epochs:
705
+ # ======================== Training ================================
706
+ train_start = time.time()
707
+ train_metrics = []
708
+
709
+ # Create sampling rng
710
+ rng, input_rng = jax.random.split(rng)
711
+
712
+ # Generate an epoch by shuffling sampling indices from the train dataset
713
+ num_train_samples = len(tokenized_datasets["train"])
714
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
715
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
716
+ err_cnt = 0
717
+ # Gather the indexes for creating the batch and do a training step
718
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
719
+ try:
720
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
721
+ except:
722
+ err_cnt += 1
723
+ continue
724
+ model_inputs = data_collator(samples)
725
+
726
+ # Model forward
727
+ model_inputs = shard(model_inputs.data)
728
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
729
+ train_metrics.append(train_metric)
730
+
731
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
732
+
733
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
734
+ # Save metrics
735
+ train_metric = jax_utils.unreplicate(train_metric)
736
+ train_time += time.time() - train_start
737
+ if has_tensorboard and jax.process_index() == 0:
738
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
739
+
740
+ epochs.write(
741
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
742
+ )
743
+ epochs.write("total error sample {}".format(err_cnt))
744
+ train_metrics = []
745
+
746
+ if cur_step % training_args.save_steps == 0:
747
+ logger.info("Model saved")
748
+ # save checkpoint after each epoch and push checkpoint to the hub
749
+ if jax.process_index() == 0:
750
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
751
+ model.save_pretrained(
752
+ training_args.output_dir,
753
+ params=params,
754
+ push_to_hub=training_args.push_to_hub,
755
+ commit_message=f"Saving weights and logs of step {cur_step}",
756
+ )
757
+ if cur_step % training_args.eval_steps == 0:
758
+ # ======================== Evaluating ==============================
759
+ num_eval_samples = len(tokenized_datasets["validation"])
760
+ eval_samples_idx = jnp.arange(num_eval_samples)
761
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
762
+
763
+ eval_metrics = []
764
+ eval_err_cnt = 0
765
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
766
+ try:
767
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
768
+ except:
769
+ eval_err_cnt += 1
770
+ continue
771
+ model_inputs = data_collator(samples)
772
+
773
+ # Model forward
774
+ model_inputs = shard(model_inputs.data)
775
+ metrics = p_eval_step(state.params, model_inputs)
776
+ eval_metrics.append(metrics)
777
+
778
+ # get eval metrics
779
+ eval_metrics = get_metrics(eval_metrics)
780
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
781
+
782
+ # Update progress bar
783
+ epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
784
+ epochs.write("Eval error sample {}".format(eval_err_cnt))
785
+
786
+ # Save metrics
787
+ if has_tensorboard and jax.process_index() == 0:
788
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
789
+
790
+
tokenizer/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5WithLMHeadModel"
4
+ ],
5
+ "d_ff": 2048,
6
+ "d_kv": 64,
7
+ "d_model": 512,
8
+ "decoder_start_token_id": 0,
9
+ "dropout_rate": 0.1,
10
+ "eos_token_id": 1,
11
+ "feed_forward_proj": "relu",
12
+ "gradient_checkpointing": false,
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "layer_norm_epsilon": 1e-06,
16
+ "model_type": "t5",
17
+ "n_positions": 512,
18
+ "num_decoder_layers": 6,
19
+ "num_heads": 8,
20
+ "num_layers": 6,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "relative_attention_num_buckets": 32,
24
+ "task_specific_params": {
25
+ "summarization": {
26
+ "early_stopping": true,
27
+ "length_penalty": 2.0,
28
+ "max_length": 200,
29
+ "min_length": 30,
30
+ "no_repeat_ngram_size": 3,
31
+ "num_beams": 4,
32
+ "prefix": "summarize: "
33
+ },
34
+ "translation_en_to_de": {
35
+ "early_stopping": true,
36
+ "max_length": 300,
37
+ "num_beams": 4,
38
+ "prefix": "translate English to German: "
39
+ },
40
+ "translation_en_to_fr": {
41
+ "early_stopping": true,
42
+ "max_length": 300,
43
+ "num_beams": 4,
44
+ "prefix": "translate English to French: "
45
+ },
46
+ "translation_en_to_ro": {
47
+ "early_stopping": true,
48
+ "max_length": 300,
49
+ "num_beams": 4,
50
+ "prefix": "translate English to Romanian: "
51
+ }
52
+ },
53
+ "transformers_version": "4.9.0.dev0",
54
+ "use_cache": true,
55
+ "vocab_size": 32128
56
+ }
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff