Pedro Cuenca commited on
Commit
32dc2d8
1 Parent(s): df3c7bd

* Only perform validation if requested

Browse files

* Disable rouge metric
* Add sanity check for tpus.
* Add training command.

Files changed (2) hide show
  1. seq2seq/do_run.sh +9 -0
  2. seq2seq/run_seq2seq_flax.py +38 -35
seq2seq/do_run.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ python run_seq2seq_flax.py \
2
+ --max_source_length 128 \
3
+ --train_file /data/CC12M/encoded-small-train.tsv \
4
+ --validation_file /data/CC12M/encoded-small-valid.tsv \
5
+ --output_dir output \
6
+ --per_device_train_batch_size 16 \
7
+ --per_device_eval_batch_size 16 \
8
+ --do_train \
9
+ --do_eval \
seq2seq/run_seq2seq_flax.py CHANGED
@@ -413,6 +413,8 @@ def main():
413
  #config.min_length = data_args.max_target_length # Set only in decoder?
414
  #config.max_length = data_args.max_target_length # Set only in decoder?
415
 
 
 
416
 
417
  # Create a custom model and initialize it randomly
418
  model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
@@ -534,7 +536,7 @@ def main():
534
  )
535
 
536
  # Metric
537
- metric = load_metric("rouge")
538
 
539
  def postprocess_text(preds, labels):
540
  preds = [pred.strip() for pred in preds]
@@ -740,40 +742,41 @@ def main():
740
 
741
  # ======================== Evaluating ==============================
742
  eval_metrics = []
743
- eval_preds = []
744
- eval_labels = []
745
-
746
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
747
- eval_steps = len(eval_dataset) // eval_batch_size
748
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
749
- # Model forward
750
- batch = next(eval_loader)
751
- labels = batch["labels"]
752
-
753
- metrics = p_eval_step(state.params, batch)
754
- eval_metrics.append(metrics)
755
-
756
- # generation
757
- if data_args.predict_with_generate:
758
- generated_ids = p_generate_step(state.params, batch)
759
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
760
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
761
-
762
- # normalize eval metrics
763
- eval_metrics = get_metrics(eval_metrics)
764
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
765
-
766
- # compute ROUGE metrics
767
- rouge_desc = ""
768
- if data_args.predict_with_generate:
769
- rouge_metrics = compute_metrics(eval_preds, eval_labels)
770
- eval_metrics.update(rouge_metrics)
771
- rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
772
-
773
- # Print metrics and update progress bar
774
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
775
- epochs.write(desc)
776
- epochs.desc = desc
 
777
 
778
  # Save metrics
779
  if has_tensorboard and jax.process_index() == 0:
 
413
  #config.min_length = data_args.max_target_length # Set only in decoder?
414
  #config.max_length = data_args.max_target_length # Set only in decoder?
415
 
416
+ print(f"TPUs: {jax.device_count()}")
417
+ assert jax.device_count() == 8, "TPUs in use, please check running processes"
418
 
419
  # Create a custom model and initialize it randomly
420
  model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
536
  )
537
 
538
  # Metric
539
+ #metric = load_metric("rouge")
540
 
541
  def postprocess_text(preds, labels):
542
  preds = [pred.strip() for pred in preds]
 
742
 
743
  # ======================== Evaluating ==============================
744
  eval_metrics = []
745
+ if training_args.do_eval:
746
+ eval_preds = []
747
+ eval_labels = []
748
+
749
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
750
+ eval_steps = len(eval_dataset) // eval_batch_size
751
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
752
+ # Model forward
753
+ batch = next(eval_loader)
754
+ labels = batch["labels"]
755
+
756
+ metrics = p_eval_step(state.params, batch)
757
+ eval_metrics.append(metrics)
758
+
759
+ # generation
760
+ if data_args.predict_with_generate:
761
+ generated_ids = p_generate_step(state.params, batch)
762
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
763
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
764
+
765
+ # normalize eval metrics
766
+ eval_metrics = get_metrics(eval_metrics)
767
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
768
+
769
+ # compute ROUGE metrics
770
+ rouge_desc = ""
771
+ # if data_args.predict_with_generate:
772
+ # rouge_metrics = compute_metrics(eval_preds, eval_labels)
773
+ # eval_metrics.update(rouge_metrics)
774
+ # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
775
+
776
+ # Print metrics and update progress bar
777
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
778
+ epochs.write(desc)
779
+ epochs.desc = desc
780
 
781
  # Save metrics
782
  if has_tensorboard and jax.process_index() == 0: