Saving weights and logs of epoch 1
Browse files- .gitattributes +1 -0
- events.out.tfevents.1626286603.t1v-n-b95d739e-w-0.590614.3.v2 → events.out.tfevents.1626318482.t1v-n-b95d739e-w-0.622701.3.v2 +2 -2
- flax_model.msgpack +3 -0
- flax_to_torch.py +7 -0
- nohup.out +2 -2
- run.sh +1 -0
- run_mlm_flax.py +86 -110
.gitattributes
CHANGED
@@ -16,3 +16,4 @@
|
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
18 |
nohup.out filter=lfs diff=lfs merge=lfs -text
|
|
|
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
18 |
nohup.out filter=lfs diff=lfs merge=lfs -text
|
19 |
+
flax_model.msgpack filter=lfs diff=lfs merge=lfs -text
|
events.out.tfevents.1626286603.t1v-n-b95d739e-w-0.590614.3.v2 → events.out.tfevents.1626318482.t1v-n-b95d739e-w-0.622701.3.v2
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0c8833b4f4649f58ab0d01d47c772f8c05080f371d5f9d57e7134d997e944a1
|
3 |
+
size 157187
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91922c046bb159618da797c8c5076e9684aa45c6ded263ff8c60dab3cb008059
|
3 |
+
size 498796983
|
flax_to_torch.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import RobertaForMaskedLM, AutoTokenizer
|
2 |
+
|
3 |
+
model = RobertaForMaskedLM.from_pretrained("./", from_flax=True)
|
4 |
+
model.save_pretrained("./")
|
5 |
+
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("./")
|
7 |
+
tokenizer.save_pretrained("./")
|
nohup.out
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99c46650710b372548e97ab2d4a123983e2b495c3ceb094847f500b4ac3a64f7
|
3 |
+
size 193918
|
run.sh
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
#!/usr/bin/env bash
|
2 |
python3 run_mlm_flax.py \
|
|
|
3 |
--output_dir="./" \
|
4 |
--model_type="roberta" \
|
5 |
--config_name="./" \
|
|
|
1 |
#!/usr/bin/env bash
|
2 |
python3 run_mlm_flax.py \
|
3 |
+
--model_name_or_path="flax_model.msgpack" \
|
4 |
--output_dir="./" \
|
5 |
--model_type="roberta" \
|
6 |
--config_name="./" \
|
run_mlm_flax.py
CHANGED
@@ -56,6 +56,24 @@ from transformers import (
|
|
56 |
)
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
60 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
61 |
|
@@ -156,7 +174,7 @@ class DataTrainingArguments:
|
|
156 |
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
157 |
)
|
158 |
validation_split_percentage: Optional[int] = field(
|
159 |
-
default=
|
160 |
metadata={
|
161 |
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
162 |
},
|
@@ -314,7 +332,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
|
|
314 |
return batch_idx
|
315 |
|
316 |
|
317 |
-
def
|
318 |
summary_writer.scalar("train_time", train_time, step)
|
319 |
|
320 |
train_metrics = get_metrics(train_metrics)
|
@@ -323,8 +341,6 @@ def write_train_metric(summary_writer, train_metrics, train_time, step):
|
|
323 |
for i, val in enumerate(vals):
|
324 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
325 |
|
326 |
-
|
327 |
-
def write_eval_metric(summary_writer, eval_metrics, step):
|
328 |
for metric_name, value in eval_metrics.items():
|
329 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
330 |
|
@@ -366,6 +382,10 @@ if __name__ == "__main__":
|
|
366 |
|
367 |
# Log on each process the small summary:
|
368 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
369 |
|
370 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
371 |
logger.info(f"Training/evaluation parameters {training_args}")
|
@@ -557,22 +577,8 @@ if __name__ == "__main__":
|
|
557 |
)
|
558 |
|
559 |
# Enable tensorboard only on the master node
|
560 |
-
has_tensorboard = is_tensorboard_available()
|
561 |
if has_tensorboard and jax.process_index() == 0:
|
562 |
-
|
563 |
-
from flax.metrics.tensorboard import SummaryWriter
|
564 |
-
|
565 |
-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
566 |
-
except ImportError as ie:
|
567 |
-
has_tensorboard = False
|
568 |
-
logger.warning(
|
569 |
-
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
570 |
-
)
|
571 |
-
else:
|
572 |
-
logger.warning(
|
573 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
574 |
-
"Please run pip install tensorboard to enable."
|
575 |
-
)
|
576 |
|
577 |
# Data collator
|
578 |
# This one will take care of randomly masking the tokens.
|
@@ -584,17 +590,9 @@ if __name__ == "__main__":
|
|
584 |
rng = jax.random.PRNGKey(training_args.seed)
|
585 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
586 |
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
config=config,
|
591 |
-
seed=training_args.seed,
|
592 |
-
dtype=getattr(jnp, model_args.dtype),
|
593 |
-
)
|
594 |
-
else:
|
595 |
-
model = FlaxAutoModelForMaskedLM.from_config(
|
596 |
-
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
597 |
-
)
|
598 |
|
599 |
# Store some constant
|
600 |
num_epochs = int(training_args.num_train_epochs)
|
@@ -636,23 +634,18 @@ if __name__ == "__main__":
|
|
636 |
return traverse_util.unflatten_dict(flat_mask)
|
637 |
|
638 |
# create adam optimizer
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
b2=training_args.adam_beta2,
|
648 |
-
eps=training_args.adam_epsilon,
|
649 |
-
weight_decay=training_args.weight_decay,
|
650 |
-
mask=decay_mask_fn,
|
651 |
-
)
|
652 |
|
653 |
# Setup train state
|
654 |
state = train_state.TrainState.create(
|
655 |
-
apply_fn=model.__call__, params=model.params, tx=
|
656 |
)
|
657 |
|
658 |
# Define gradient update step fn
|
@@ -742,7 +735,7 @@ if __name__ == "__main__":
|
|
742 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
743 |
|
744 |
# Gather the indexes for creating the batch and do a training step
|
745 |
-
for
|
746 |
tqdm(train_batch_idx, desc="Training...", position=1)
|
747 |
):
|
748 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
@@ -755,69 +748,52 @@ if __name__ == "__main__":
|
|
755 |
)
|
756 |
train_metrics.append(train_metric)
|
757 |
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
# Save metrics
|
807 |
-
if has_tensorboard and jax.process_index() == 0:
|
808 |
-
cur_step = epoch * (
|
809 |
-
len(tokenized_datasets["train"]) // train_batch_size
|
810 |
-
)
|
811 |
-
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
812 |
-
|
813 |
-
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
814 |
-
# save checkpoint after each epoch and push checkpoint to the hub
|
815 |
-
if jax.process_index() == 0:
|
816 |
-
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
817 |
-
model.save_pretrained(
|
818 |
-
training_args.output_dir,
|
819 |
-
params=params,
|
820 |
-
push_to_hub=training_args.push_to_hub,
|
821 |
-
commit_message=f"Saving weights and logs of step {cur_step}",
|
822 |
-
)
|
823 |
|
|
|
56 |
)
|
57 |
|
58 |
|
59 |
+
# Cache the result
|
60 |
+
has_tensorboard = is_tensorboard_available()
|
61 |
+
if has_tensorboard:
|
62 |
+
try:
|
63 |
+
from flax.metrics.tensorboard import SummaryWriter
|
64 |
+
except ImportError as ie:
|
65 |
+
has_tensorboard = False
|
66 |
+
print(
|
67 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
68 |
+
)
|
69 |
+
|
70 |
+
else:
|
71 |
+
print(
|
72 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
73 |
+
"Please run pip install tensorboard to enable."
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
78 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
79 |
|
|
|
174 |
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
175 |
)
|
176 |
validation_split_percentage: Optional[int] = field(
|
177 |
+
default=5,
|
178 |
metadata={
|
179 |
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
180 |
},
|
|
|
332 |
return batch_idx
|
333 |
|
334 |
|
335 |
+
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
336 |
summary_writer.scalar("train_time", train_time, step)
|
337 |
|
338 |
train_metrics = get_metrics(train_metrics)
|
|
|
341 |
for i, val in enumerate(vals):
|
342 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
343 |
|
|
|
|
|
344 |
for metric_name, value in eval_metrics.items():
|
345 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
346 |
|
|
|
382 |
|
383 |
# Log on each process the small summary:
|
384 |
logger = logging.getLogger(__name__)
|
385 |
+
logger.warning(
|
386 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
387 |
+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
388 |
+
)
|
389 |
|
390 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
391 |
logger.info(f"Training/evaluation parameters {training_args}")
|
|
|
577 |
)
|
578 |
|
579 |
# Enable tensorboard only on the master node
|
|
|
580 |
if has_tensorboard and jax.process_index() == 0:
|
581 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
582 |
|
583 |
# Data collator
|
584 |
# This one will take care of randomly masking the tokens.
|
|
|
590 |
rng = jax.random.PRNGKey(training_args.seed)
|
591 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
592 |
|
593 |
+
model = FlaxAutoModelForMaskedLM.from_config(
|
594 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
595 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
|
597 |
# Store some constant
|
598 |
num_epochs = int(training_args.num_train_epochs)
|
|
|
634 |
return traverse_util.unflatten_dict(flat_mask)
|
635 |
|
636 |
# create adam optimizer
|
637 |
+
adamw = optax.adamw(
|
638 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
639 |
+
b1=training_args.adam_beta1,
|
640 |
+
b2=training_args.adam_beta2,
|
641 |
+
eps=1e-8,
|
642 |
+
weight_decay=training_args.weight_decay,
|
643 |
+
mask=decay_mask_fn,
|
644 |
+
)
|
|
|
|
|
|
|
|
|
|
|
645 |
|
646 |
# Setup train state
|
647 |
state = train_state.TrainState.create(
|
648 |
+
apply_fn=model.__call__, params=model.params, tx=adamw
|
649 |
)
|
650 |
|
651 |
# Define gradient update step fn
|
|
|
735 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
736 |
|
737 |
# Gather the indexes for creating the batch and do a training step
|
738 |
+
for i, batch_idx in enumerate(
|
739 |
tqdm(train_batch_idx, desc="Training...", position=1)
|
740 |
):
|
741 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
|
|
748 |
)
|
749 |
train_metrics.append(train_metric)
|
750 |
|
751 |
+
train_time += time.time() - train_start
|
752 |
+
|
753 |
+
epochs.write(
|
754 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
755 |
+
)
|
756 |
+
|
757 |
+
# ======================== Evaluating ==============================
|
758 |
+
num_eval_samples = len(tokenized_datasets["test"])
|
759 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
760 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
761 |
+
|
762 |
+
eval_metrics = []
|
763 |
+
for i, batch_idx in enumerate(
|
764 |
+
tqdm(eval_batch_idx, desc="Evaluating ...", position=2)
|
765 |
+
):
|
766 |
+
samples = [tokenized_datasets["test"][int(idx)] for idx in batch_idx]
|
767 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
768 |
+
|
769 |
+
# Model forward
|
770 |
+
model_inputs = shard(model_inputs.data)
|
771 |
+
metrics = p_eval_step(state.params, model_inputs)
|
772 |
+
eval_metrics.append(metrics)
|
773 |
+
|
774 |
+
# normalize eval metrics
|
775 |
+
eval_metrics = get_metrics(eval_metrics)
|
776 |
+
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
777 |
+
eval_normalizer = eval_metrics.pop("normalizer")
|
778 |
+
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
779 |
+
|
780 |
+
# Update progress bar
|
781 |
+
epochs.desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
782 |
+
|
783 |
+
# Save metrics
|
784 |
+
if has_tensorboard and jax.process_index() == 0:
|
785 |
+
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
786 |
+
write_metric(
|
787 |
+
summary_writer, train_metrics, eval_metrics, train_time, cur_step
|
788 |
+
)
|
789 |
+
|
790 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
791 |
+
if jax.process_index() == 0:
|
792 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
793 |
+
model.save_pretrained(
|
794 |
+
training_args.output_dir,
|
795 |
+
params=params,
|
796 |
+
push_to_hub=training_args.push_to_hub,
|
797 |
+
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
798 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
|