Aman K
commited on
Commit
•
2e5979b
1
Parent(s):
0b86536
Updated code to have different seed and reduced lr
Browse files- run.sh +1 -1
- run_mlm_flax.py +10 -0
run.sh
CHANGED
@@ -11,7 +11,7 @@
|
|
11 |
--preprocessing_num_workers="64" \
|
12 |
--per_device_train_batch_size="64" \
|
13 |
--per_device_eval_batch_size="64" \
|
14 |
-
--learning_rate="
|
15 |
--warmup_steps="1000" \
|
16 |
--overwrite_output_dir \
|
17 |
--num_train_epochs="8" \
|
|
|
11 |
--preprocessing_num_workers="64" \
|
12 |
--per_device_train_batch_size="64" \
|
13 |
--per_device_eval_batch_size="64" \
|
14 |
+
--learning_rate="2e-4" \
|
15 |
--warmup_steps="1000" \
|
16 |
--overwrite_output_dir \
|
17 |
--num_train_epochs="8" \
|
run_mlm_flax.py
CHANGED
@@ -324,6 +324,7 @@ if __name__ == "__main__":
|
|
324 |
logger.info(f"Training/evaluation parameters {training_args}")
|
325 |
|
326 |
# Set seed before initializing model.
|
|
|
327 |
set_seed(training_args.seed)
|
328 |
|
329 |
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
@@ -587,6 +588,7 @@ if __name__ == "__main__":
|
|
587 |
|
588 |
train_time = 0
|
589 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
|
590 |
for epoch in epochs:
|
591 |
# ======================== Training ================================
|
592 |
train_start = time.time()
|
@@ -609,6 +611,14 @@ if __name__ == "__main__":
|
|
609 |
model_inputs = shard(model_inputs.data)
|
610 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
611 |
train_metrics.append(train_metric)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
|
613 |
train_time += time.time() - train_start
|
614 |
|
|
|
324 |
logger.info(f"Training/evaluation parameters {training_args}")
|
325 |
|
326 |
# Set seed before initializing model.
|
327 |
+
training_args.seed = 42
|
328 |
set_seed(training_args.seed)
|
329 |
|
330 |
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
|
|
588 |
|
589 |
train_time = 0
|
590 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
591 |
+
save_checkpoint=True
|
592 |
for epoch in epochs:
|
593 |
# ======================== Training ================================
|
594 |
train_start = time.time()
|
|
|
611 |
model_inputs = shard(model_inputs.data)
|
612 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
613 |
train_metrics.append(train_metric)
|
614 |
+
if save_checkpoint and (train_metric['loss'] < 5.).all():
|
615 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
616 |
+
model.save_pretrained(
|
617 |
+
'/home/khandelia1000/checkpoints/',
|
618 |
+
params=params,
|
619 |
+
push_to_hub=False
|
620 |
+
)
|
621 |
+
save_checkpoint = False
|
622 |
|
623 |
train_time += time.time() - train_start
|
624 |
|