Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
32dc2d8
1
Parent(s):
df3c7bd
* Only perform validation if requested
Browse files* Disable rouge metric
* Add sanity check for tpus.
* Add training command.
- seq2seq/do_run.sh +9 -0
- seq2seq/run_seq2seq_flax.py +38 -35
seq2seq/do_run.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python run_seq2seq_flax.py \
|
2 |
+
--max_source_length 128 \
|
3 |
+
--train_file /data/CC12M/encoded-small-train.tsv \
|
4 |
+
--validation_file /data/CC12M/encoded-small-valid.tsv \
|
5 |
+
--output_dir output \
|
6 |
+
--per_device_train_batch_size 16 \
|
7 |
+
--per_device_eval_batch_size 16 \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -413,6 +413,8 @@ def main():
|
|
413 |
#config.min_length = data_args.max_target_length # Set only in decoder?
|
414 |
#config.max_length = data_args.max_target_length # Set only in decoder?
|
415 |
|
|
|
|
|
416 |
|
417 |
# Create a custom model and initialize it randomly
|
418 |
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
@@ -534,7 +536,7 @@ def main():
|
|
534 |
)
|
535 |
|
536 |
# Metric
|
537 |
-
metric = load_metric("rouge")
|
538 |
|
539 |
def postprocess_text(preds, labels):
|
540 |
preds = [pred.strip() for pred in preds]
|
@@ -740,40 +742,41 @@ def main():
|
|
740 |
|
741 |
# ======================== Evaluating ==============================
|
742 |
eval_metrics = []
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
|
|
777 |
|
778 |
# Save metrics
|
779 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
413 |
#config.min_length = data_args.max_target_length # Set only in decoder?
|
414 |
#config.max_length = data_args.max_target_length # Set only in decoder?
|
415 |
|
416 |
+
print(f"TPUs: {jax.device_count()}")
|
417 |
+
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
418 |
|
419 |
# Create a custom model and initialize it randomly
|
420 |
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
|
|
536 |
)
|
537 |
|
538 |
# Metric
|
539 |
+
#metric = load_metric("rouge")
|
540 |
|
541 |
def postprocess_text(preds, labels):
|
542 |
preds = [pred.strip() for pred in preds]
|
|
|
742 |
|
743 |
# ======================== Evaluating ==============================
|
744 |
eval_metrics = []
|
745 |
+
if training_args.do_eval:
|
746 |
+
eval_preds = []
|
747 |
+
eval_labels = []
|
748 |
+
|
749 |
+
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
750 |
+
eval_steps = len(eval_dataset) // eval_batch_size
|
751 |
+
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
752 |
+
# Model forward
|
753 |
+
batch = next(eval_loader)
|
754 |
+
labels = batch["labels"]
|
755 |
+
|
756 |
+
metrics = p_eval_step(state.params, batch)
|
757 |
+
eval_metrics.append(metrics)
|
758 |
+
|
759 |
+
# generation
|
760 |
+
if data_args.predict_with_generate:
|
761 |
+
generated_ids = p_generate_step(state.params, batch)
|
762 |
+
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
763 |
+
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
764 |
+
|
765 |
+
# normalize eval metrics
|
766 |
+
eval_metrics = get_metrics(eval_metrics)
|
767 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
768 |
+
|
769 |
+
# compute ROUGE metrics
|
770 |
+
rouge_desc = ""
|
771 |
+
# if data_args.predict_with_generate:
|
772 |
+
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
773 |
+
# eval_metrics.update(rouge_metrics)
|
774 |
+
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
775 |
+
|
776 |
+
# Print metrics and update progress bar
|
777 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
778 |
+
epochs.write(desc)
|
779 |
+
epochs.desc = desc
|
780 |
|
781 |
# Save metrics
|
782 |
if has_tensorboard and jax.process_index() == 0:
|