Aman K commited on
Commit
2e5979b
1 Parent(s): 0b86536

Updated code to have different seed and reduced lr

Browse files
Files changed (2) hide show
  1. run.sh +1 -1
  2. run_mlm_flax.py +10 -0
run.sh CHANGED
@@ -11,7 +11,7 @@
11
  --preprocessing_num_workers="64" \
12
  --per_device_train_batch_size="64" \
13
  --per_device_eval_batch_size="64" \
14
- --learning_rate="3e-4" \
15
  --warmup_steps="1000" \
16
  --overwrite_output_dir \
17
  --num_train_epochs="8" \
 
11
  --preprocessing_num_workers="64" \
12
  --per_device_train_batch_size="64" \
13
  --per_device_eval_batch_size="64" \
14
+ --learning_rate="2e-4" \
15
  --warmup_steps="1000" \
16
  --overwrite_output_dir \
17
  --num_train_epochs="8" \
run_mlm_flax.py CHANGED
@@ -324,6 +324,7 @@ if __name__ == "__main__":
324
  logger.info(f"Training/evaluation parameters {training_args}")
325
 
326
  # Set seed before initializing model.
 
327
  set_seed(training_args.seed)
328
 
329
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
@@ -587,6 +588,7 @@ if __name__ == "__main__":
587
 
588
  train_time = 0
589
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
 
590
  for epoch in epochs:
591
  # ======================== Training ================================
592
  train_start = time.time()
@@ -609,6 +611,14 @@ if __name__ == "__main__":
609
  model_inputs = shard(model_inputs.data)
610
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
611
  train_metrics.append(train_metric)
 
 
 
 
 
 
 
 
612
 
613
  train_time += time.time() - train_start
614
 
 
324
  logger.info(f"Training/evaluation parameters {training_args}")
325
 
326
  # Set seed before initializing model.
327
+ training_args.seed = 42
328
  set_seed(training_args.seed)
329
 
330
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
 
588
 
589
  train_time = 0
590
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
591
+ save_checkpoint=True
592
  for epoch in epochs:
593
  # ======================== Training ================================
594
  train_start = time.time()
 
611
  model_inputs = shard(model_inputs.data)
612
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
613
  train_metrics.append(train_metric)
614
+ if save_checkpoint and (train_metric['loss'] < 5.).all():
615
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
616
+ model.save_pretrained(
617
+ '/home/khandelia1000/checkpoints/',
618
+ params=params,
619
+ push_to_hub=False
620
+ )
621
+ save_checkpoint = False
622
 
623
  train_time += time.time() - train_start
624