Spaces:
Running
Running
feat: handle streaming
Browse files- dev/seq2seq/do_big_run.sh +6 -2
- dev/seq2seq/do_small_run.sh +6 -3
- dev/seq2seq/run_seq2seq_flax.py +402 -268
dev/seq2seq/do_big_run.sh
CHANGED
@@ -1,7 +1,11 @@
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
--max_source_length 128 \
|
3 |
-
--
|
4 |
-
--
|
|
|
|
|
|
|
|
|
5 |
--output_dir output \
|
6 |
--per_device_train_batch_size 56 \
|
7 |
--per_device_eval_batch_size 56 \
|
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
--max_source_length 128 \
|
3 |
+
--dataset_repo_or_path dalle-mini/encoded \
|
4 |
+
--train_file **/train/*/*.jsonl \
|
5 |
+
--validation_file **/valid/*/*.jsonl \
|
6 |
+
--streaming \
|
7 |
+
--len_train 1000000 \
|
8 |
+
--len_eval 100 \
|
9 |
--output_dir output \
|
10 |
--per_device_train_batch_size 56 \
|
11 |
--per_device_eval_batch_size 56 \
|
dev/seq2seq/do_small_run.sh
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
-
--
|
3 |
-
--train_file
|
4 |
-
--validation_file
|
|
|
|
|
|
|
5 |
--output_dir output \
|
6 |
--per_device_train_batch_size 56 \
|
7 |
--per_device_eval_batch_size 56 \
|
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
+
--dataset_repo_or_path dalle-mini/encoded \
|
3 |
+
--train_file **/train/*/*.jsonl \
|
4 |
+
--validation_file **/valid/*/*.jsonl \
|
5 |
+
--streaming \
|
6 |
+
--len_train 1000000 \
|
7 |
+
--len_eval 1000 \
|
8 |
--output_dir output \
|
9 |
--per_device_train_batch_size 56 \
|
10 |
--per_device_eval_batch_size 56 \
|
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -20,9 +20,8 @@ Script adapted from run_summarization_flax.py
|
|
20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
|
22 |
import os
|
23 |
-
import logging as pylogging
|
24 |
import sys
|
25 |
-
import time
|
26 |
from dataclasses import dataclass, field
|
27 |
from functools import partial
|
28 |
from pathlib import Path
|
@@ -30,7 +29,6 @@ from typing import Callable, Optional
|
|
30 |
import json
|
31 |
|
32 |
import datasets
|
33 |
-
import nltk # Here to have a nice missing dependency error message early on
|
34 |
import numpy as np
|
35 |
from datasets import Dataset, load_dataset, load_metric
|
36 |
from tqdm import tqdm
|
@@ -47,9 +45,7 @@ from flax.jax_utils import unreplicate
|
|
47 |
from flax.training import train_state
|
48 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
49 |
from transformers import (
|
50 |
-
CONFIG_MAPPING,
|
51 |
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
52 |
-
AutoConfig,
|
53 |
AutoTokenizer,
|
54 |
FlaxAutoModelForSeq2SeqLM,
|
55 |
FlaxBartForConditionalGeneration,
|
@@ -61,17 +57,9 @@ from transformers.file_utils import is_offline_mode
|
|
61 |
|
62 |
import wandb
|
63 |
|
64 |
-
|
65 |
|
66 |
-
|
67 |
-
nltk.data.find("tokenizers/punkt")
|
68 |
-
except (LookupError, OSError):
|
69 |
-
if is_offline_mode():
|
70 |
-
raise LookupError(
|
71 |
-
"Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
|
72 |
-
)
|
73 |
-
with FileLock(".lock") as lock:
|
74 |
-
nltk.download("punkt", quiet=True)
|
75 |
|
76 |
|
77 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
@@ -83,7 +71,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
83 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
84 |
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
|
85 |
BOS_TOKEN_ID = 16384
|
86 |
-
BASE_MODEL =
|
87 |
|
88 |
|
89 |
@dataclass
|
@@ -101,20 +89,34 @@ class ModelArguments:
|
|
101 |
)
|
102 |
model_type: Optional[str] = field(
|
103 |
default=None,
|
104 |
-
metadata={
|
|
|
|
|
|
|
105 |
)
|
106 |
config_name: Optional[str] = field(
|
107 |
-
default=None,
|
|
|
|
|
|
|
108 |
)
|
109 |
tokenizer_name: Optional[str] = field(
|
110 |
-
default=None,
|
|
|
|
|
|
|
111 |
)
|
112 |
cache_dir: Optional[str] = field(
|
113 |
-
default=None,
|
|
|
|
|
|
|
114 |
)
|
115 |
use_fast_tokenizer: bool = field(
|
116 |
default=True,
|
117 |
-
metadata={
|
|
|
|
|
118 |
)
|
119 |
dtype: Optional[str] = field(
|
120 |
default="float32",
|
@@ -137,27 +139,51 @@ class DataTrainingArguments:
|
|
137 |
"""
|
138 |
|
139 |
dataset_name: Optional[str] = field(
|
140 |
-
default=None,
|
|
|
141 |
)
|
142 |
dataset_config_name: Optional[str] = field(
|
143 |
-
default=None,
|
|
|
|
|
|
|
144 |
)
|
145 |
text_column: Optional[str] = field(
|
146 |
-
default=
|
147 |
-
metadata={
|
|
|
|
|
148 |
)
|
149 |
encoding_column: Optional[str] = field(
|
150 |
-
default=
|
151 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
)
|
153 |
-
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
154 |
validation_file: Optional[str] = field(
|
155 |
default=None,
|
156 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
)
|
158 |
-
|
159 |
default=None,
|
160 |
-
metadata={"help": "
|
|
|
|
|
|
|
|
|
161 |
)
|
162 |
max_source_length: Optional[int] = field(
|
163 |
default=128,
|
@@ -167,7 +193,8 @@ class DataTrainingArguments:
|
|
167 |
},
|
168 |
)
|
169 |
no_decay: bool = field(
|
170 |
-
default=False,
|
|
|
171 |
)
|
172 |
max_target_length: Optional[int] = field(
|
173 |
default=OUTPUT_LENGTH,
|
@@ -199,60 +226,65 @@ class DataTrainingArguments:
|
|
199 |
"value if set."
|
200 |
},
|
201 |
)
|
202 |
-
|
203 |
-
default=
|
204 |
-
metadata={
|
205 |
-
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
206 |
-
"value if set."
|
207 |
-
},
|
208 |
)
|
209 |
preprocessing_num_workers: Optional[int] = field(
|
210 |
-
default=80,
|
211 |
metadata={"help": "The number of processes to use for the preprocessing."},
|
212 |
)
|
213 |
source_prefix: Optional[str] = field(
|
214 |
-
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
215 |
-
)
|
216 |
-
predict_with_generate: bool = field(
|
217 |
-
default=False, metadata={"help": "Whether to use generate to calculate generative metrics."}
|
218 |
-
)
|
219 |
-
num_beams: Optional[int] = field(
|
220 |
default=None,
|
221 |
metadata={
|
222 |
-
"help": "
|
223 |
-
"which is used during evaluation."
|
224 |
},
|
225 |
)
|
226 |
overwrite_cache: bool = field(
|
227 |
-
default=False,
|
|
|
228 |
)
|
229 |
log_interval: Optional[int] = field(
|
230 |
default=40,
|
231 |
-
metadata={
|
232 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
233 |
-
"value if set."
|
234 |
-
},
|
235 |
)
|
236 |
log_model: bool = field(
|
237 |
-
default=False,
|
|
|
238 |
)
|
239 |
save_model_steps: Optional[int] = field(
|
240 |
-
default=3000,
|
241 |
metadata={
|
242 |
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
243 |
},
|
244 |
)
|
245 |
|
246 |
def __post_init__(self):
|
247 |
-
if
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
else:
|
250 |
if self.train_file is not None:
|
251 |
extension = self.train_file.split(".")[-1]
|
252 |
-
assert extension in [
|
|
|
|
|
|
|
|
|
|
|
253 |
if self.validation_file is not None:
|
254 |
extension = self.validation_file.split(".")[-1]
|
255 |
-
assert extension in [
|
|
|
|
|
|
|
|
|
|
|
256 |
if self.val_max_target_length is None:
|
257 |
self.val_max_target_length = self.max_target_length
|
258 |
|
@@ -263,14 +295,20 @@ class TrainState(train_state.TrainState):
|
|
263 |
optimizer_step: int
|
264 |
|
265 |
def replicate(self):
|
266 |
-
return jax_utils.replicate(self).replace(
|
|
|
|
|
267 |
|
268 |
|
269 |
class CustomFlaxBartModule(FlaxBartModule):
|
270 |
def setup(self):
|
271 |
# check config is valid, otherwise set default values
|
272 |
-
self.config.vocab_size_output = getattr(
|
273 |
-
|
|
|
|
|
|
|
|
|
274 |
|
275 |
# we keep shared to easily load pre-trained weights
|
276 |
self.shared = nn.Embed(
|
@@ -286,18 +324,29 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
286 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
287 |
dtype=self.dtype,
|
288 |
)
|
289 |
-
self.encoder = FlaxBartEncoder(
|
|
|
|
|
290 |
|
291 |
# the decoder has a different config
|
292 |
decoder_config = BartConfig(self.config.to_dict())
|
293 |
-
decoder_config.max_position_embeddings =
|
|
|
|
|
294 |
decoder_config.vocab_size = self.config.vocab_size_output
|
295 |
-
self.decoder = FlaxBartDecoder(
|
|
|
|
|
|
|
296 |
|
297 |
-
class CustomFlaxBartForConditionalGenerationModule(
|
|
|
|
|
298 |
def setup(self):
|
299 |
# check config is valid, otherwise set default values
|
300 |
-
self.config.vocab_size_output = getattr(
|
|
|
|
|
301 |
|
302 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
303 |
self.lm_head = nn.Dense(
|
@@ -306,13 +355,18 @@ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerat
|
|
306 |
dtype=self.dtype,
|
307 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
308 |
)
|
309 |
-
self.final_logits_bias = self.param(
|
|
|
|
|
|
|
310 |
|
311 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
312 |
module_class = CustomFlaxBartForConditionalGenerationModule
|
313 |
-
|
314 |
|
315 |
-
|
|
|
|
|
|
|
316 |
"""
|
317 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
318 |
Shuffle batches if `shuffle` is `True`.
|
@@ -330,33 +384,58 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
330 |
for idx in batch_idx:
|
331 |
batch = dataset[idx]
|
332 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
333 |
-
|
334 |
batch = shard(batch)
|
335 |
-
|
336 |
yield batch
|
337 |
|
338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
def create_learning_rate_fn(
|
340 |
-
train_ds_size: int,
|
|
|
|
|
|
|
|
|
|
|
341 |
) -> Callable[[int], jnp.array]:
|
342 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
343 |
steps_per_epoch = train_ds_size // train_batch_size
|
344 |
num_train_steps = steps_per_epoch * num_train_epochs
|
345 |
-
warmup_fn = optax.linear_schedule(
|
|
|
|
|
346 |
if no_decay:
|
347 |
return warmup_fn
|
348 |
decay_fn = optax.linear_schedule(
|
349 |
-
init_value=learning_rate,
|
|
|
|
|
|
|
|
|
|
|
350 |
)
|
351 |
-
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
352 |
return schedule_fn
|
353 |
|
354 |
|
355 |
def wandb_log(metrics, step=None, prefix=None):
|
356 |
if jax.process_index() == 0:
|
357 |
-
log_metrics = {
|
|
|
|
|
|
|
358 |
if step is not None:
|
359 |
-
log_metrics[
|
360 |
wandb.log(log_metrics)
|
361 |
|
362 |
|
@@ -365,11 +444,15 @@ def main():
|
|
365 |
# or by passing the --help flag to this script.
|
366 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
367 |
|
368 |
-
parser = HfArgumentParser(
|
|
|
|
|
369 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
370 |
# If we pass only one argument to the script and it's the path to a json file,
|
371 |
# let's parse it to get our arguments.
|
372 |
-
model_args, data_args, training_args = parser.parse_json_file(
|
|
|
|
|
373 |
else:
|
374 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
375 |
|
@@ -383,18 +466,18 @@ def main():
|
|
383 |
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
384 |
"Use --overwrite_output_dir to overcome."
|
385 |
)
|
386 |
-
|
387 |
# Set up wandb run
|
388 |
wandb.init(
|
389 |
-
entity=
|
390 |
-
project=
|
391 |
-
job_type=
|
392 |
-
config=parser.parse_args()
|
393 |
)
|
394 |
|
395 |
# set default x-axis as 'train/step'
|
396 |
-
wandb.define_metric(
|
397 |
-
wandb.define_metric(
|
398 |
|
399 |
# Make one log on every process with the configuration for debugging.
|
400 |
pylogging.basicConfig(
|
@@ -418,16 +501,13 @@ def main():
|
|
418 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
419 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
420 |
#
|
421 |
-
data_files = {
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
|
429 |
-
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
430 |
-
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
431 |
|
432 |
# Set up items to load or create
|
433 |
tokenizer = None
|
@@ -435,17 +515,17 @@ def main():
|
|
435 |
|
436 |
def restore_state(state, artifact_dir):
|
437 |
# restore optimizer state
|
438 |
-
with (Path(artifact_dir) /
|
439 |
opt_state = from_bytes(state.opt_state, f.read())
|
440 |
-
|
441 |
# restore steps
|
442 |
-
with (Path(artifact_dir) /
|
443 |
training_state = json.load(f)
|
444 |
-
step = training_state[
|
445 |
optimizer_step = step // training_args.gradient_accumulation_steps
|
446 |
|
447 |
return step, optimizer_step, opt_state
|
448 |
-
|
449 |
if model_args.from_checkpoint is not None:
|
450 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
451 |
artifact_dir = artifact.download()
|
@@ -461,40 +541,54 @@ def main():
|
|
461 |
config = model.config
|
462 |
|
463 |
# load tokenizer if present
|
464 |
-
if (Path(artifact_dir) /
|
465 |
tokenizer = AutoTokenizer.from_pretrained(
|
466 |
-
|
467 |
-
|
|
|
|
|
468 |
|
469 |
else:
|
470 |
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
471 |
-
model_args.model_name_or_path,
|
|
|
|
|
472 |
)
|
473 |
# Set up our new model config
|
474 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
475 |
config.tie_word_embeddings = False
|
476 |
config.decoder_start_token_id = BOS_TOKEN_ID # for first token
|
477 |
-
config.bos_token_id =
|
478 |
-
|
|
|
|
|
|
|
|
|
479 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
480 |
config.forced_bos_token_id = None # we don't need this token
|
481 |
config.forced_eos_token_id = None # we don't need this token
|
482 |
-
config.force_bos_token_to_be_generated =
|
|
|
|
|
483 |
config.min_length = data_args.max_target_length
|
484 |
config.max_length = data_args.max_target_length
|
485 |
|
486 |
# Create a custom model and initialize it randomly
|
487 |
-
model = CustomFlaxBartForConditionalGeneration(
|
|
|
|
|
488 |
|
489 |
# Use pre-trained weights for encoder
|
490 |
-
model.params[
|
491 |
-
model.params[
|
492 |
del base_model
|
493 |
|
494 |
# Load tokenizer if it has not been set
|
495 |
if tokenizer is None:
|
496 |
tokenizer = AutoTokenizer.from_pretrained(
|
497 |
-
model_args.model_name_or_path,
|
|
|
|
|
498 |
)
|
499 |
|
500 |
print(f"TPUs: {jax.device_count()}")
|
@@ -504,23 +598,11 @@ def main():
|
|
504 |
|
505 |
# Preprocessing the datasets.
|
506 |
# We need to tokenize inputs and targets.
|
507 |
-
if training_args.do_train:
|
508 |
-
column_names = dataset["train"].column_names
|
509 |
-
elif training_args.do_eval:
|
510 |
-
column_names = dataset["validation"].column_names
|
511 |
-
elif training_args.do_predict:
|
512 |
-
column_names = dataset["test"].column_names
|
513 |
-
else:
|
514 |
-
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
515 |
-
return
|
516 |
|
517 |
# Get the column names for input/target.
|
518 |
text_column = data_args.text_column
|
519 |
encoding_column = data_args.encoding_column
|
520 |
|
521 |
-
# Temporarily set max_target_length for training.
|
522 |
-
max_target_length = data_args.max_target_length
|
523 |
-
|
524 |
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
525 |
"""
|
526 |
Shift input ids one token to the right.
|
@@ -530,18 +612,28 @@ def main():
|
|
530 |
shifted_input_ids[:, 0] = decoder_start_token_id
|
531 |
return shifted_input_ids
|
532 |
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
def preprocess_function(examples):
|
534 |
inputs = examples[text_column]
|
535 |
-
inputs = [prefix + inp for inp in inputs]
|
536 |
-
|
537 |
model_inputs = tokenizer(
|
538 |
-
inputs,
|
|
|
|
|
|
|
|
|
539 |
)
|
540 |
|
541 |
# set up targets
|
542 |
# Note: labels correspond to our target indices
|
543 |
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
544 |
-
labels =
|
545 |
labels = np.asarray(labels)
|
546 |
|
547 |
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
@@ -558,46 +650,75 @@ def main():
|
|
558 |
raise ValueError("--do_train requires a train dataset")
|
559 |
train_dataset = dataset["train"]
|
560 |
if data_args.max_train_samples is not None:
|
561 |
-
train_dataset =
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
)
|
570 |
|
571 |
if training_args.do_eval:
|
572 |
-
max_target_length = data_args.val_max_target_length
|
573 |
if "validation" not in dataset:
|
574 |
raise ValueError("--do_eval requires a validation dataset")
|
575 |
eval_dataset = dataset["validation"]
|
576 |
if data_args.max_eval_samples is not None:
|
577 |
-
eval_dataset =
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
)
|
602 |
|
603 |
# Initialize our training
|
@@ -606,21 +727,40 @@ def main():
|
|
606 |
|
607 |
# Store some constant
|
608 |
num_epochs = int(training_args.num_train_epochs)
|
609 |
-
train_batch_size =
|
|
|
|
|
610 |
total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
|
611 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
613 |
total_steps = steps_per_epoch * num_epochs
|
614 |
-
total_optimization_steps = (
|
615 |
|
616 |
# Create learning rate schedule
|
617 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
618 |
-
|
619 |
total_batch_size,
|
620 |
training_args.num_train_epochs,
|
621 |
training_args.warmup_steps,
|
622 |
training_args.learning_rate,
|
623 |
-
data_args.no_decay
|
624 |
)
|
625 |
|
626 |
# We use Optax's "masking" functionality to not apply weight decay
|
@@ -633,9 +773,17 @@ def main():
|
|
633 |
def decay_mask_fn(params):
|
634 |
flat_params = traverse_util.flatten_dict(params)
|
635 |
layer_norm_params = [
|
636 |
-
(name, "scale")
|
|
|
|
|
|
|
|
|
|
|
637 |
]
|
638 |
-
flat_mask = {
|
|
|
|
|
|
|
639 |
return traverse_util.unflatten_dict(flat_mask)
|
640 |
|
641 |
# create adam optimizer
|
@@ -667,7 +815,9 @@ def main():
|
|
667 |
if model_args.from_checkpoint is not None:
|
668 |
# restore optimizer state, step and optimizer_step
|
669 |
step, optimizer_step, opt_state = restore_state(state, artifact_dir)
|
670 |
-
state = state.replace(
|
|
|
|
|
671 |
|
672 |
# label smoothed cross entropy
|
673 |
def loss_fn(logits, labels):
|
@@ -681,7 +831,9 @@ def main():
|
|
681 |
|
682 |
def compute_loss(params):
|
683 |
labels = batch.pop("labels")
|
684 |
-
logits = state.apply_fn(
|
|
|
|
|
685 |
loss = loss_fn(logits, labels)
|
686 |
return loss
|
687 |
|
@@ -690,10 +842,14 @@ def main():
|
|
690 |
grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
|
691 |
|
692 |
def update_fn():
|
693 |
-
grads = jax.tree_map(
|
|
|
|
|
694 |
grads = jax.lax.pmean(grads, "batch")
|
695 |
new_state = state.apply_gradients(
|
696 |
-
grads=grads,
|
|
|
|
|
697 |
)
|
698 |
return new_state
|
699 |
|
@@ -704,7 +860,10 @@ def main():
|
|
704 |
None,
|
705 |
)
|
706 |
|
707 |
-
metrics = {
|
|
|
|
|
|
|
708 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
709 |
|
710 |
return new_state.replace(dropout_rng=new_dropout_rng), metrics
|
@@ -720,39 +879,25 @@ def main():
|
|
720 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
721 |
return metrics
|
722 |
|
723 |
-
# Define generation function
|
724 |
-
max_length = (
|
725 |
-
data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
|
726 |
-
)
|
727 |
-
num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
|
728 |
-
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
729 |
-
|
730 |
-
def generate_step(params, batch):
|
731 |
-
model.params = params
|
732 |
-
output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
|
733 |
-
return output_ids.sequences
|
734 |
-
|
735 |
# Create parallel version of the train and eval step
|
736 |
-
p_train_step = jax.pmap(
|
737 |
-
train_step, "batch", donate_argnums=(0,)
|
738 |
-
)
|
739 |
p_eval_step = jax.pmap(eval_step, "batch")
|
740 |
-
p_generate_step = jax.pmap(generate_step, "batch")
|
741 |
|
742 |
# Replicate the train state on each device
|
743 |
state = state.replicate()
|
744 |
|
745 |
logger.info("***** Running training *****")
|
746 |
-
logger.info(f" Num examples = {
|
747 |
logger.info(f" Num Epochs = {num_epochs}")
|
748 |
-
logger.info(
|
|
|
|
|
749 |
logger.info(
|
750 |
f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
|
751 |
)
|
752 |
logger.info(f" Total global steps = {total_steps}")
|
753 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
754 |
|
755 |
-
train_time = 0
|
756 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
757 |
global_step = 0
|
758 |
|
@@ -760,31 +905,31 @@ def main():
|
|
760 |
# ======================== Evaluating ==============================
|
761 |
eval_metrics = []
|
762 |
if training_args.do_eval:
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
eval_steps =
|
768 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
# Model forward
|
770 |
-
batch = next(eval_loader)
|
771 |
-
labels = batch["labels"]
|
772 |
-
|
773 |
metrics = p_eval_step(state.params, batch)
|
774 |
eval_metrics.append(metrics)
|
775 |
|
776 |
-
# generation
|
777 |
-
if data_args.predict_with_generate:
|
778 |
-
generated_ids = p_generate_step(state.params, batch)
|
779 |
-
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
780 |
-
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
781 |
-
|
782 |
# normalize eval metrics
|
|
|
783 |
eval_metrics = get_metrics(eval_metrics)
|
|
|
784 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
|
|
785 |
|
786 |
# log metrics
|
787 |
-
wandb_log(eval_metrics, step=global_step, prefix=
|
788 |
|
789 |
# Print metrics and update progress bar
|
790 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -808,28 +953,42 @@ def main():
|
|
808 |
|
809 |
# save state
|
810 |
state = unreplicate(state)
|
811 |
-
with (Path(training_args.output_dir) /
|
812 |
f.write(to_bytes(state.opt_state))
|
813 |
-
with (Path(training_args.output_dir) /
|
814 |
-
|
|
|
|
|
815 |
|
816 |
# save to W&B
|
817 |
if data_args.log_model:
|
818 |
-
metadata = {
|
819 |
if eval_metrics is not None:
|
820 |
-
metadata[
|
821 |
artifact = wandb.Artifact(
|
822 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
823 |
)
|
824 |
-
artifact.add_file(
|
825 |
-
|
826 |
-
|
827 |
-
artifact.add_file(str(Path(training_args.output_dir) /
|
828 |
-
artifact.add_file(
|
829 |
-
|
830 |
-
|
831 |
-
artifact.add_file(
|
832 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
wandb.run.log_artifact(artifact)
|
834 |
|
835 |
# save some space
|
@@ -843,39 +1002,47 @@ def main():
|
|
843 |
params=params,
|
844 |
push_to_hub=training_args.push_to_hub,
|
845 |
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
846 |
-
temp_dir=True # avoid issues with being in a repository
|
847 |
)
|
848 |
-
|
849 |
for epoch in epochs:
|
850 |
# ======================== Training ================================
|
851 |
-
train_start = time.time()
|
852 |
|
853 |
# Create sampling rng
|
854 |
rng, input_rng = jax.random.split(rng)
|
855 |
|
856 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
857 |
-
|
858 |
-
|
|
|
|
|
|
|
|
|
|
|
859 |
# train
|
860 |
-
for
|
861 |
-
|
862 |
-
|
|
|
|
|
|
|
|
|
|
|
863 |
state, train_metric = p_train_step(state, batch)
|
864 |
|
865 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
866 |
# log metrics
|
867 |
-
wandb_log(unreplicate(train_metric), step=global_step, prefix=
|
868 |
|
869 |
if training_args.eval_steps and global_step % training_args.eval_steps == 0:
|
870 |
run_evaluation()
|
871 |
-
|
872 |
if global_step % data_args.save_model_steps == 0:
|
873 |
run_save_model(state, global_step, epoch)
|
874 |
-
|
875 |
# log final train metrics
|
876 |
-
wandb_log(unreplicate(train_metric), step=global_step, prefix=
|
877 |
|
878 |
-
train_time += time.time() - train_start
|
879 |
train_metric = unreplicate(train_metric)
|
880 |
epochs.write(
|
881 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
@@ -888,38 +1055,5 @@ def main():
|
|
888 |
run_save_model(state, global_step, epoch, eval_metrics)
|
889 |
|
890 |
|
891 |
-
# ======================== Prediction loop ==============================
|
892 |
-
if training_args.do_predict:
|
893 |
-
logger.info("*** Predict ***")
|
894 |
-
|
895 |
-
pred_metrics = []
|
896 |
-
pred_generations = []
|
897 |
-
pred_labels = []
|
898 |
-
|
899 |
-
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
|
900 |
-
pred_steps = len(predict_dataset) // eval_batch_size
|
901 |
-
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
|
902 |
-
# Model forward
|
903 |
-
batch = next(pred_loader)
|
904 |
-
labels = batch["labels"]
|
905 |
-
|
906 |
-
metrics = p_eval_step(state.params, batch)
|
907 |
-
pred_metrics.append(metrics)
|
908 |
-
|
909 |
-
# generation
|
910 |
-
if data_args.predict_with_generate:
|
911 |
-
generated_ids = p_generate_step(state.params, batch)
|
912 |
-
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
913 |
-
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
914 |
-
|
915 |
-
# normalize prediction metrics
|
916 |
-
pred_metrics = get_metrics(pred_metrics)
|
917 |
-
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
918 |
-
|
919 |
-
# Print metrics
|
920 |
-
desc = f"Predict Loss: {pred_metrics['loss']})"
|
921 |
-
logger.info(desc)
|
922 |
-
|
923 |
-
|
924 |
if __name__ == "__main__":
|
925 |
main()
|
|
|
20 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
21 |
|
22 |
import os
|
23 |
+
import logging as pylogging # To avoid collision with transformers.utils.logging
|
24 |
import sys
|
|
|
25 |
from dataclasses import dataclass, field
|
26 |
from functools import partial
|
27 |
from pathlib import Path
|
|
|
29 |
import json
|
30 |
|
31 |
import datasets
|
|
|
32 |
import numpy as np
|
33 |
from datasets import Dataset, load_dataset, load_metric
|
34 |
from tqdm import tqdm
|
|
|
45 |
from flax.training import train_state
|
46 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
47 |
from transformers import (
|
|
|
48 |
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
|
|
49 |
AutoTokenizer,
|
50 |
FlaxAutoModelForSeq2SeqLM,
|
51 |
FlaxBartForConditionalGeneration,
|
|
|
57 |
|
58 |
import wandb
|
59 |
|
60 |
+
from dalle_mini.text import TextNormalizer
|
61 |
|
62 |
+
logger = pylogging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
|
65 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
|
|
71 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
72 |
OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
|
73 |
BOS_TOKEN_ID = 16384
|
74 |
+
BASE_MODEL = "facebook/bart-large-cnn" # we currently have issues with bart-large
|
75 |
|
76 |
|
77 |
@dataclass
|
|
|
89 |
)
|
90 |
model_type: Optional[str] = field(
|
91 |
default=None,
|
92 |
+
metadata={
|
93 |
+
"help": "If training from scratch, pass a model type from the list: "
|
94 |
+
+ ", ".join(MODEL_TYPES)
|
95 |
+
},
|
96 |
)
|
97 |
config_name: Optional[str] = field(
|
98 |
+
default=None,
|
99 |
+
metadata={
|
100 |
+
"help": "Pretrained config name or path if not the same as model_name"
|
101 |
+
},
|
102 |
)
|
103 |
tokenizer_name: Optional[str] = field(
|
104 |
+
default=None,
|
105 |
+
metadata={
|
106 |
+
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
107 |
+
},
|
108 |
)
|
109 |
cache_dir: Optional[str] = field(
|
110 |
+
default=None,
|
111 |
+
metadata={
|
112 |
+
"help": "Where do you want to store the pretrained models downloaded from s3"
|
113 |
+
},
|
114 |
)
|
115 |
use_fast_tokenizer: bool = field(
|
116 |
default=True,
|
117 |
+
metadata={
|
118 |
+
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
|
119 |
+
},
|
120 |
)
|
121 |
dtype: Optional[str] = field(
|
122 |
default="float32",
|
|
|
139 |
"""
|
140 |
|
141 |
dataset_name: Optional[str] = field(
|
142 |
+
default=None,
|
143 |
+
metadata={"help": "The name of the dataset to use (via the datasets library)."},
|
144 |
)
|
145 |
dataset_config_name: Optional[str] = field(
|
146 |
+
default=None,
|
147 |
+
metadata={
|
148 |
+
"help": "The configuration name of the dataset to use (via the datasets library)."
|
149 |
+
},
|
150 |
)
|
151 |
text_column: Optional[str] = field(
|
152 |
+
default="caption",
|
153 |
+
metadata={
|
154 |
+
"help": "The name of the column in the datasets containing the full texts (for summarization)."
|
155 |
+
},
|
156 |
)
|
157 |
encoding_column: Optional[str] = field(
|
158 |
+
default="encoding",
|
159 |
+
metadata={
|
160 |
+
"help": "The name of the column in the datasets containing the image encodings."
|
161 |
+
},
|
162 |
+
)
|
163 |
+
dataset_repo_or_path: Optional[str] = field(
|
164 |
+
default=None,
|
165 |
+
metadata={"help": "The dataset repository containing encoded files."},
|
166 |
+
)
|
167 |
+
train_file: Optional[str] = field(
|
168 |
+
default=None, metadata={"help": "The input training data file (a text file)."}
|
169 |
)
|
|
|
170 |
validation_file: Optional[str] = field(
|
171 |
default=None,
|
172 |
+
metadata={
|
173 |
+
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
|
174 |
+
},
|
175 |
+
)
|
176 |
+
streaming: bool = field(
|
177 |
+
default=False,
|
178 |
+
metadata={"help": "Whether to stream the dataset."},
|
179 |
)
|
180 |
+
len_train: Optional[int] = field(
|
181 |
default=None,
|
182 |
+
metadata={"help": "Length of training dataset, required for streaming"},
|
183 |
+
)
|
184 |
+
len_eval: Optional[int] = field(
|
185 |
+
default=None,
|
186 |
+
metadata={"help": "Length of validation dataset, required for streaming"},
|
187 |
)
|
188 |
max_source_length: Optional[int] = field(
|
189 |
default=128,
|
|
|
193 |
},
|
194 |
)
|
195 |
no_decay: bool = field(
|
196 |
+
default=False,
|
197 |
+
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
198 |
)
|
199 |
max_target_length: Optional[int] = field(
|
200 |
default=OUTPUT_LENGTH,
|
|
|
226 |
"value if set."
|
227 |
},
|
228 |
)
|
229 |
+
normalize_text: bool = field(
|
230 |
+
default=False,
|
231 |
+
metadata={"help": "Normalize/Simplify text"},
|
|
|
|
|
|
|
232 |
)
|
233 |
preprocessing_num_workers: Optional[int] = field(
|
234 |
+
default=80, # ensure we have the same datasets cached data and avoid using too much space
|
235 |
metadata={"help": "The number of processes to use for the preprocessing."},
|
236 |
)
|
237 |
source_prefix: Optional[str] = field(
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
default=None,
|
239 |
metadata={
|
240 |
+
"help": "A prefix to add before every source text (useful for T5 models)."
|
|
|
241 |
},
|
242 |
)
|
243 |
overwrite_cache: bool = field(
|
244 |
+
default=False,
|
245 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
246 |
)
|
247 |
log_interval: Optional[int] = field(
|
248 |
default=40,
|
249 |
+
metadata={"help": "Log frequency for metrics"},
|
|
|
|
|
|
|
250 |
)
|
251 |
log_model: bool = field(
|
252 |
+
default=False,
|
253 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
254 |
)
|
255 |
save_model_steps: Optional[int] = field(
|
256 |
+
default=3000, # about once every hour in our experiments
|
257 |
metadata={
|
258 |
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
259 |
},
|
260 |
)
|
261 |
|
262 |
def __post_init__(self):
|
263 |
+
if (
|
264 |
+
self.dataset_name is None
|
265 |
+
and self.train_file is None
|
266 |
+
and self.validation_file is None
|
267 |
+
):
|
268 |
+
raise ValueError(
|
269 |
+
"Need either a dataset name or a training/validation file."
|
270 |
+
)
|
271 |
else:
|
272 |
if self.train_file is not None:
|
273 |
extension = self.train_file.split(".")[-1]
|
274 |
+
assert extension in [
|
275 |
+
"tsv",
|
276 |
+
"csv",
|
277 |
+
"json",
|
278 |
+
"jsonl",
|
279 |
+
], "`train_file` should be a tsv, csv or json file."
|
280 |
if self.validation_file is not None:
|
281 |
extension = self.validation_file.split(".")[-1]
|
282 |
+
assert extension in [
|
283 |
+
"tsv",
|
284 |
+
"csv",
|
285 |
+
"json",
|
286 |
+
"jsonl",
|
287 |
+
], "`validation_file` should be a tsv, csv or json file."
|
288 |
if self.val_max_target_length is None:
|
289 |
self.val_max_target_length = self.max_target_length
|
290 |
|
|
|
295 |
optimizer_step: int
|
296 |
|
297 |
def replicate(self):
|
298 |
+
return jax_utils.replicate(self).replace(
|
299 |
+
dropout_rng=shard_prng_key(self.dropout_rng)
|
300 |
+
)
|
301 |
|
302 |
|
303 |
class CustomFlaxBartModule(FlaxBartModule):
|
304 |
def setup(self):
|
305 |
# check config is valid, otherwise set default values
|
306 |
+
self.config.vocab_size_output = getattr(
|
307 |
+
self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
|
308 |
+
)
|
309 |
+
self.config.max_position_embeddings_decoder = getattr(
|
310 |
+
self.config, "max_position_embeddings_decoder", OUTPUT_LENGTH
|
311 |
+
)
|
312 |
|
313 |
# we keep shared to easily load pre-trained weights
|
314 |
self.shared = nn.Embed(
|
|
|
324 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
325 |
dtype=self.dtype,
|
326 |
)
|
327 |
+
self.encoder = FlaxBartEncoder(
|
328 |
+
self.config, dtype=self.dtype, embed_tokens=self.shared
|
329 |
+
)
|
330 |
|
331 |
# the decoder has a different config
|
332 |
decoder_config = BartConfig(self.config.to_dict())
|
333 |
+
decoder_config.max_position_embeddings = (
|
334 |
+
self.config.max_position_embeddings_decoder
|
335 |
+
)
|
336 |
decoder_config.vocab_size = self.config.vocab_size_output
|
337 |
+
self.decoder = FlaxBartDecoder(
|
338 |
+
decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
|
339 |
+
)
|
340 |
+
|
341 |
|
342 |
+
class CustomFlaxBartForConditionalGenerationModule(
|
343 |
+
FlaxBartForConditionalGenerationModule
|
344 |
+
):
|
345 |
def setup(self):
|
346 |
# check config is valid, otherwise set default values
|
347 |
+
self.config.vocab_size_output = getattr(
|
348 |
+
self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
|
349 |
+
)
|
350 |
|
351 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
352 |
self.lm_head = nn.Dense(
|
|
|
355 |
dtype=self.dtype,
|
356 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
357 |
)
|
358 |
+
self.final_logits_bias = self.param(
|
359 |
+
"final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)
|
360 |
+
)
|
361 |
+
|
362 |
|
363 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
364 |
module_class = CustomFlaxBartForConditionalGenerationModule
|
|
|
365 |
|
366 |
+
|
367 |
+
def data_loader(
|
368 |
+
rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
|
369 |
+
):
|
370 |
"""
|
371 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
372 |
Shuffle batches if `shuffle` is `True`.
|
|
|
384 |
for idx in batch_idx:
|
385 |
batch = dataset[idx]
|
386 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
387 |
batch = shard(batch)
|
|
|
388 |
yield batch
|
389 |
|
390 |
|
391 |
+
def data_loader_streaming(dataset: Dataset, batch_size: int):
|
392 |
+
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
393 |
+
batch = {k: [] for k in keys}
|
394 |
+
for item in dataset:
|
395 |
+
for k, v in item.items():
|
396 |
+
batch[k].append(v)
|
397 |
+
if len(batch[keys[0]]) == batch_size:
|
398 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
399 |
+
batch = shard(batch)
|
400 |
+
yield batch
|
401 |
+
batch = {k: [] for k in keys}
|
402 |
+
|
403 |
+
|
404 |
def create_learning_rate_fn(
|
405 |
+
train_ds_size: int,
|
406 |
+
train_batch_size: int,
|
407 |
+
num_train_epochs: int,
|
408 |
+
num_warmup_steps: int,
|
409 |
+
learning_rate: float,
|
410 |
+
no_decay: bool,
|
411 |
) -> Callable[[int], jnp.array]:
|
412 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
413 |
steps_per_epoch = train_ds_size // train_batch_size
|
414 |
num_train_steps = steps_per_epoch * num_train_epochs
|
415 |
+
warmup_fn = optax.linear_schedule(
|
416 |
+
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
417 |
+
)
|
418 |
if no_decay:
|
419 |
return warmup_fn
|
420 |
decay_fn = optax.linear_schedule(
|
421 |
+
init_value=learning_rate,
|
422 |
+
end_value=0,
|
423 |
+
transition_steps=num_train_steps - num_warmup_steps,
|
424 |
+
)
|
425 |
+
schedule_fn = optax.join_schedules(
|
426 |
+
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
427 |
)
|
|
|
428 |
return schedule_fn
|
429 |
|
430 |
|
431 |
def wandb_log(metrics, step=None, prefix=None):
|
432 |
if jax.process_index() == 0:
|
433 |
+
log_metrics = {
|
434 |
+
f"{prefix}/{k}" if prefix is not None else k: jax.device_get(v)
|
435 |
+
for k, v in metrics.items()
|
436 |
+
}
|
437 |
if step is not None:
|
438 |
+
log_metrics["train/step"] = step
|
439 |
wandb.log(log_metrics)
|
440 |
|
441 |
|
|
|
444 |
# or by passing the --help flag to this script.
|
445 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
446 |
|
447 |
+
parser = HfArgumentParser(
|
448 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments)
|
449 |
+
)
|
450 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
451 |
# If we pass only one argument to the script and it's the path to a json file,
|
452 |
# let's parse it to get our arguments.
|
453 |
+
model_args, data_args, training_args = parser.parse_json_file(
|
454 |
+
json_file=os.path.abspath(sys.argv[1])
|
455 |
+
)
|
456 |
else:
|
457 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
458 |
|
|
|
466 |
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
467 |
"Use --overwrite_output_dir to overcome."
|
468 |
)
|
469 |
+
|
470 |
# Set up wandb run
|
471 |
wandb.init(
|
472 |
+
entity="dalle-mini",
|
473 |
+
project="dalle-mini",
|
474 |
+
job_type="Seq2Seq",
|
475 |
+
config=parser.parse_args(),
|
476 |
)
|
477 |
|
478 |
# set default x-axis as 'train/step'
|
479 |
+
wandb.define_metric("train/step")
|
480 |
+
wandb.define_metric("*", step_metric="train/step")
|
481 |
|
482 |
# Make one log on every process with the configuration for debugging.
|
483 |
pylogging.basicConfig(
|
|
|
501 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
502 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
503 |
#
|
504 |
+
data_files = {
|
505 |
+
"train": data_args.train_file,
|
506 |
+
"validation": data_args.validation_file,
|
507 |
+
}
|
508 |
+
dataset = load_dataset(
|
509 |
+
data_args.dataset_repo_or_path, data_files=data_files, streaming=True
|
510 |
+
)
|
|
|
|
|
|
|
511 |
|
512 |
# Set up items to load or create
|
513 |
tokenizer = None
|
|
|
515 |
|
516 |
def restore_state(state, artifact_dir):
|
517 |
# restore optimizer state
|
518 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
519 |
opt_state = from_bytes(state.opt_state, f.read())
|
520 |
+
|
521 |
# restore steps
|
522 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
523 |
training_state = json.load(f)
|
524 |
+
step = training_state["step"]
|
525 |
optimizer_step = step // training_args.gradient_accumulation_steps
|
526 |
|
527 |
return step, optimizer_step, opt_state
|
528 |
+
|
529 |
if model_args.from_checkpoint is not None:
|
530 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
531 |
artifact_dir = artifact.download()
|
|
|
541 |
config = model.config
|
542 |
|
543 |
# load tokenizer if present
|
544 |
+
if (Path(artifact_dir) / "tokenizer_config.json").exists():
|
545 |
tokenizer = AutoTokenizer.from_pretrained(
|
546 |
+
model_args.model_name_or_path,
|
547 |
+
cache_dir=model_args.cache_dir,
|
548 |
+
use_fast=model_args.use_fast_tokenizer,
|
549 |
+
)
|
550 |
|
551 |
else:
|
552 |
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
553 |
+
model_args.model_name_or_path,
|
554 |
+
seed=training_args.seed,
|
555 |
+
dtype=getattr(jnp, model_args.dtype),
|
556 |
)
|
557 |
# Set up our new model config
|
558 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
559 |
config.tie_word_embeddings = False
|
560 |
config.decoder_start_token_id = BOS_TOKEN_ID # for first token
|
561 |
+
config.bos_token_id = (
|
562 |
+
BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
|
563 |
+
)
|
564 |
+
config.pos_token_id = (
|
565 |
+
BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
566 |
+
)
|
567 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
568 |
config.forced_bos_token_id = None # we don't need this token
|
569 |
config.forced_eos_token_id = None # we don't need this token
|
570 |
+
config.force_bos_token_to_be_generated = (
|
571 |
+
False # otherwise it sets bos_token_id at loading
|
572 |
+
)
|
573 |
config.min_length = data_args.max_target_length
|
574 |
config.max_length = data_args.max_target_length
|
575 |
|
576 |
# Create a custom model and initialize it randomly
|
577 |
+
model = CustomFlaxBartForConditionalGeneration(
|
578 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
579 |
+
)
|
580 |
|
581 |
# Use pre-trained weights for encoder
|
582 |
+
model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
|
583 |
+
model.params["model"]["shared"] = base_model.params["model"]["shared"]
|
584 |
del base_model
|
585 |
|
586 |
# Load tokenizer if it has not been set
|
587 |
if tokenizer is None:
|
588 |
tokenizer = AutoTokenizer.from_pretrained(
|
589 |
+
model_args.model_name_or_path,
|
590 |
+
cache_dir=model_args.cache_dir,
|
591 |
+
use_fast=model_args.use_fast_tokenizer,
|
592 |
)
|
593 |
|
594 |
print(f"TPUs: {jax.device_count()}")
|
|
|
598 |
|
599 |
# Preprocessing the datasets.
|
600 |
# We need to tokenize inputs and targets.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
|
602 |
# Get the column names for input/target.
|
603 |
text_column = data_args.text_column
|
604 |
encoding_column = data_args.encoding_column
|
605 |
|
|
|
|
|
|
|
606 |
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
607 |
"""
|
608 |
Shift input ids one token to the right.
|
|
|
612 |
shifted_input_ids[:, 0] = decoder_start_token_id
|
613 |
return shifted_input_ids
|
614 |
|
615 |
+
text_normalizer = TextNormalizer() if data_args.normalize_text else None
|
616 |
+
|
617 |
+
def normalize_text(example):
|
618 |
+
example[text_column] = text_normalizer(example[text_column])
|
619 |
+
return example
|
620 |
+
|
621 |
def preprocess_function(examples):
|
622 |
inputs = examples[text_column]
|
623 |
+
inputs = [prefix + inp for inp in inputs] if prefix else inputs
|
624 |
+
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
625 |
model_inputs = tokenizer(
|
626 |
+
inputs,
|
627 |
+
max_length=data_args.max_source_length,
|
628 |
+
padding="max_length",
|
629 |
+
truncation=True,
|
630 |
+
return_tensors="np",
|
631 |
)
|
632 |
|
633 |
# set up targets
|
634 |
# Note: labels correspond to our target indices
|
635 |
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
636 |
+
labels = examples[encoding_column]
|
637 |
labels = np.asarray(labels)
|
638 |
|
639 |
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
|
|
650 |
raise ValueError("--do_train requires a train dataset")
|
651 |
train_dataset = dataset["train"]
|
652 |
if data_args.max_train_samples is not None:
|
653 |
+
train_dataset = (
|
654 |
+
train_dataset.take(data_args.max_train_samples)
|
655 |
+
if data_args.streaming
|
656 |
+
else train_dataset.select(range(data_args.max_train_samples))
|
657 |
+
)
|
658 |
+
if data_args.streaming:
|
659 |
+
train_dataset = train_dataset.shuffle(1000, training_args.seed)
|
660 |
+
if data_args.normalize_text:
|
661 |
+
train_dataset = (
|
662 |
+
train_dataset.map(text_normalizer)
|
663 |
+
if data_args.streaming
|
664 |
+
else train_dataset.map(
|
665 |
+
normalize_text,
|
666 |
+
num_proc=data_args.preprocessing_num_workers,
|
667 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
668 |
+
desc="Normalizing the validation dataset",
|
669 |
+
)
|
670 |
+
)
|
671 |
+
train_dataset = (
|
672 |
+
train_dataset.map(
|
673 |
+
preprocess_function,
|
674 |
+
batched=True,
|
675 |
+
)
|
676 |
+
if data_args.streaming
|
677 |
+
else train_dataset.map(
|
678 |
+
preprocess_function,
|
679 |
+
batched=True,
|
680 |
+
num_proc=data_args.preprocessing_num_workers,
|
681 |
+
remove_columns=train_dataset.column_names,
|
682 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
683 |
+
desc="Running tokenizer on validation dataset",
|
684 |
+
)
|
685 |
)
|
686 |
|
687 |
if training_args.do_eval:
|
|
|
688 |
if "validation" not in dataset:
|
689 |
raise ValueError("--do_eval requires a validation dataset")
|
690 |
eval_dataset = dataset["validation"]
|
691 |
if data_args.max_eval_samples is not None:
|
692 |
+
eval_dataset = (
|
693 |
+
eval_dataset.take(data_args.max_train_samples)
|
694 |
+
if data_args.streaming
|
695 |
+
else eval_dataset.select(range(data_args.max_train_samples))
|
696 |
+
)
|
697 |
+
if data_args.normalize_text:
|
698 |
+
eval_dataset = (
|
699 |
+
eval_dataset.map(text_normalizer)
|
700 |
+
if data_args.streaming
|
701 |
+
else eval_dataset.map(
|
702 |
+
normalize_text,
|
703 |
+
num_proc=data_args.preprocessing_num_workers,
|
704 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
705 |
+
desc="Normalizing the validation dataset",
|
706 |
+
)
|
707 |
+
)
|
708 |
+
eval_dataset = (
|
709 |
+
eval_dataset.map(
|
710 |
+
preprocess_function,
|
711 |
+
batched=True,
|
712 |
+
)
|
713 |
+
if data_args.streaming
|
714 |
+
else eval_dataset.map(
|
715 |
+
preprocess_function,
|
716 |
+
batched=True,
|
717 |
+
num_proc=data_args.preprocessing_num_workers,
|
718 |
+
remove_columns=eval_dataset.column_names,
|
719 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
720 |
+
desc="Running tokenizer on validation dataset",
|
721 |
+
)
|
722 |
)
|
723 |
|
724 |
# Initialize our training
|
|
|
727 |
|
728 |
# Store some constant
|
729 |
num_epochs = int(training_args.num_train_epochs)
|
730 |
+
train_batch_size = (
|
731 |
+
int(training_args.per_device_train_batch_size) * jax.device_count()
|
732 |
+
)
|
733 |
total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
|
734 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
735 |
+
if data_args.streaming:
|
736 |
+
len_train_dataset = data_args.len_train
|
737 |
+
if (
|
738 |
+
data_args.max_train_samples is not None
|
739 |
+
and data_args.max_train_samples < len_train_dataset
|
740 |
+
):
|
741 |
+
len_train_dataset = data_args.max_train_samples
|
742 |
+
|
743 |
+
len_eval_dataset = data_args.len_eval
|
744 |
+
if (
|
745 |
+
data_args.max_eval_samples is not None
|
746 |
+
and data_args.max_eval_samples < len_eval_dataset
|
747 |
+
):
|
748 |
+
len_eval_dataset = data_args.max_eval_samples
|
749 |
+
else:
|
750 |
+
len_train_dataset = len(train_dataset)
|
751 |
+
len_eval_dataset = len(eval_dataset)
|
752 |
+
steps_per_epoch = len_train_dataset // train_batch_size
|
753 |
total_steps = steps_per_epoch * num_epochs
|
754 |
+
total_optimization_steps = (len_train_dataset // total_batch_size) * num_epochs
|
755 |
|
756 |
# Create learning rate schedule
|
757 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
758 |
+
len_train_dataset,
|
759 |
total_batch_size,
|
760 |
training_args.num_train_epochs,
|
761 |
training_args.warmup_steps,
|
762 |
training_args.learning_rate,
|
763 |
+
data_args.no_decay,
|
764 |
)
|
765 |
|
766 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
|
773 |
def decay_mask_fn(params):
|
774 |
flat_params = traverse_util.flatten_dict(params)
|
775 |
layer_norm_params = [
|
776 |
+
(name, "scale")
|
777 |
+
for name in [
|
778 |
+
"self_attn_layer_norm",
|
779 |
+
"layernorm_embedding",
|
780 |
+
"final_layer_norm",
|
781 |
+
]
|
782 |
]
|
783 |
+
flat_mask = {
|
784 |
+
path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
|
785 |
+
for path in flat_params
|
786 |
+
}
|
787 |
return traverse_util.unflatten_dict(flat_mask)
|
788 |
|
789 |
# create adam optimizer
|
|
|
815 |
if model_args.from_checkpoint is not None:
|
816 |
# restore optimizer state, step and optimizer_step
|
817 |
step, optimizer_step, opt_state = restore_state(state, artifact_dir)
|
818 |
+
state = state.replace(
|
819 |
+
step=step, optimizer_step=optimizer_step, opt_state=opt_state
|
820 |
+
)
|
821 |
|
822 |
# label smoothed cross entropy
|
823 |
def loss_fn(logits, labels):
|
|
|
831 |
|
832 |
def compute_loss(params):
|
833 |
labels = batch.pop("labels")
|
834 |
+
logits = state.apply_fn(
|
835 |
+
**batch, params=params, dropout_rng=dropout_rng, train=True
|
836 |
+
)[0]
|
837 |
loss = loss_fn(logits, labels)
|
838 |
return loss
|
839 |
|
|
|
842 |
grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
|
843 |
|
844 |
def update_fn():
|
845 |
+
grads = jax.tree_map(
|
846 |
+
lambda x: x / training_args.gradient_accumulation_steps, grad_accum
|
847 |
+
)
|
848 |
grads = jax.lax.pmean(grads, "batch")
|
849 |
new_state = state.apply_gradients(
|
850 |
+
grads=grads,
|
851 |
+
grad_accum=jax.tree_map(jnp.zeros_like, grads),
|
852 |
+
optimizer_step=state.optimizer_step + 1,
|
853 |
)
|
854 |
return new_state
|
855 |
|
|
|
860 |
None,
|
861 |
)
|
862 |
|
863 |
+
metrics = {
|
864 |
+
"loss": loss,
|
865 |
+
"learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step),
|
866 |
+
}
|
867 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
868 |
|
869 |
return new_state.replace(dropout_rng=new_dropout_rng), metrics
|
|
|
879 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
880 |
return metrics
|
881 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
882 |
# Create parallel version of the train and eval step
|
883 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
|
|
|
|
884 |
p_eval_step = jax.pmap(eval_step, "batch")
|
|
|
885 |
|
886 |
# Replicate the train state on each device
|
887 |
state = state.replicate()
|
888 |
|
889 |
logger.info("***** Running training *****")
|
890 |
+
logger.info(f" Num examples = {len_train_dataset}")
|
891 |
logger.info(f" Num Epochs = {num_epochs}")
|
892 |
+
logger.info(
|
893 |
+
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
894 |
+
)
|
895 |
logger.info(
|
896 |
f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
|
897 |
)
|
898 |
logger.info(f" Total global steps = {total_steps}")
|
899 |
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
900 |
|
|
|
901 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
902 |
global_step = 0
|
903 |
|
|
|
905 |
# ======================== Evaluating ==============================
|
906 |
eval_metrics = []
|
907 |
if training_args.do_eval:
|
908 |
+
if data_args.streaming:
|
909 |
+
eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
|
910 |
+
else:
|
911 |
+
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
912 |
+
eval_steps = len_eval_dataset // eval_batch_size
|
913 |
+
for batch in tqdm(
|
914 |
+
eval_loader,
|
915 |
+
desc="Evaluating...",
|
916 |
+
position=2,
|
917 |
+
leave=False,
|
918 |
+
total=eval_steps,
|
919 |
+
):
|
920 |
# Model forward
|
|
|
|
|
|
|
921 |
metrics = p_eval_step(state.params, batch)
|
922 |
eval_metrics.append(metrics)
|
923 |
|
|
|
|
|
|
|
|
|
|
|
|
|
924 |
# normalize eval metrics
|
925 |
+
breakpoint()
|
926 |
eval_metrics = get_metrics(eval_metrics)
|
927 |
+
breakpoint()
|
928 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
929 |
+
breakpoint()
|
930 |
|
931 |
# log metrics
|
932 |
+
wandb_log(eval_metrics, step=global_step, prefix="eval")
|
933 |
|
934 |
# Print metrics and update progress bar
|
935 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
953 |
|
954 |
# save state
|
955 |
state = unreplicate(state)
|
956 |
+
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
957 |
f.write(to_bytes(state.opt_state))
|
958 |
+
with (Path(training_args.output_dir) / "training_state.json").open(
|
959 |
+
"w"
|
960 |
+
) as f:
|
961 |
+
json.dump({"step": state.step.item()}, f)
|
962 |
|
963 |
# save to W&B
|
964 |
if data_args.log_model:
|
965 |
+
metadata = {"step": step, "epoch": epoch}
|
966 |
if eval_metrics is not None:
|
967 |
+
metadata["eval/loss"] = eval_metrics["loss"]
|
968 |
artifact = wandb.Artifact(
|
969 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
970 |
)
|
971 |
+
artifact.add_file(
|
972 |
+
str(Path(training_args.output_dir) / "flax_model.msgpack")
|
973 |
+
)
|
974 |
+
artifact.add_file(str(Path(training_args.output_dir) / "config.json"))
|
975 |
+
artifact.add_file(
|
976 |
+
str(Path(training_args.output_dir) / "tokenizer.json")
|
977 |
+
)
|
978 |
+
artifact.add_file(
|
979 |
+
str(Path(training_args.output_dir) / "tokenizer_config.json")
|
980 |
+
)
|
981 |
+
artifact.add_file(str(Path(training_args.output_dir) / "vocab.json"))
|
982 |
+
artifact.add_file(str(Path(training_args.output_dir) / "merges.txt"))
|
983 |
+
artifact.add_file(
|
984 |
+
str(Path(training_args.output_dir) / "special_tokens_map.json")
|
985 |
+
)
|
986 |
+
artifact.add_file(
|
987 |
+
str(Path(training_args.output_dir) / "opt_state.msgpack")
|
988 |
+
)
|
989 |
+
artifact.add_file(
|
990 |
+
str(Path(training_args.output_dir) / "training_state.json")
|
991 |
+
)
|
992 |
wandb.run.log_artifact(artifact)
|
993 |
|
994 |
# save some space
|
|
|
1002 |
params=params,
|
1003 |
push_to_hub=training_args.push_to_hub,
|
1004 |
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
1005 |
+
temp_dir=True, # avoid issues with being in a repository
|
1006 |
)
|
1007 |
+
|
1008 |
for epoch in epochs:
|
1009 |
# ======================== Training ================================
|
|
|
1010 |
|
1011 |
# Create sampling rng
|
1012 |
rng, input_rng = jax.random.split(rng)
|
1013 |
|
1014 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
1015 |
+
if data_args.streaming:
|
1016 |
+
train_dataset.set_epoch(epoch)
|
1017 |
+
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
1018 |
+
else:
|
1019 |
+
train_loader = data_loader(
|
1020 |
+
input_rng, train_dataset, train_batch_size, shuffle=True
|
1021 |
+
)
|
1022 |
# train
|
1023 |
+
for batch in tqdm(
|
1024 |
+
train_loader,
|
1025 |
+
desc="Training...",
|
1026 |
+
position=1,
|
1027 |
+
leave=False,
|
1028 |
+
total=steps_per_epoch,
|
1029 |
+
):
|
1030 |
+
global_step += 1
|
1031 |
state, train_metric = p_train_step(state, batch)
|
1032 |
|
1033 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
1034 |
# log metrics
|
1035 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
|
1036 |
|
1037 |
if training_args.eval_steps and global_step % training_args.eval_steps == 0:
|
1038 |
run_evaluation()
|
1039 |
+
|
1040 |
if global_step % data_args.save_model_steps == 0:
|
1041 |
run_save_model(state, global_step, epoch)
|
1042 |
+
|
1043 |
# log final train metrics
|
1044 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
|
1045 |
|
|
|
1046 |
train_metric = unreplicate(train_metric)
|
1047 |
epochs.write(
|
1048 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
|
|
1055 |
run_save_model(state, global_step, epoch, eval_metrics)
|
1056 |
|
1057 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1058 |
if __name__ == "__main__":
|
1059 |
main()
|