Spaces:
Running
Running
feat: remove unused metrics
Browse filesFormer-commit-id: 00a582c7b2dc2f5d8c86bc8818bf8968d4903a70
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -218,7 +218,7 @@ class DataTrainingArguments:
|
|
218 |
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
219 |
)
|
220 |
predict_with_generate: bool = field(
|
221 |
-
default=False, metadata={"help": "Whether to use generate to calculate generative metrics
|
222 |
)
|
223 |
num_beams: Optional[int] = field(
|
224 |
default=None,
|
@@ -605,35 +605,6 @@ def main():
|
|
605 |
desc="Running tokenizer on prediction dataset",
|
606 |
)
|
607 |
|
608 |
-
# Metric
|
609 |
-
#metric = load_metric("rouge")
|
610 |
-
|
611 |
-
def postprocess_text(preds, labels):
|
612 |
-
preds = [pred.strip() for pred in preds]
|
613 |
-
labels = [label.strip() for label in labels]
|
614 |
-
|
615 |
-
# rougeLSum expects newline after each sentence
|
616 |
-
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
|
617 |
-
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
|
618 |
-
|
619 |
-
return preds, labels
|
620 |
-
|
621 |
-
def compute_metrics(preds, labels):
|
622 |
-
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
623 |
-
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
624 |
-
|
625 |
-
# Some simple post-processing
|
626 |
-
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
627 |
-
|
628 |
-
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
629 |
-
# Extract a few results from ROUGE
|
630 |
-
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
631 |
-
|
632 |
-
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
633 |
-
result["gen_len"] = np.mean(prediction_lens)
|
634 |
-
result = {k: round(v, 4) for k, v in result.items()}
|
635 |
-
return result
|
636 |
-
|
637 |
# Initialize our training
|
638 |
rng = jax.random.PRNGKey(training_args.seed)
|
639 |
rng, dropout_rng = jax.random.split(rng)
|
@@ -819,15 +790,8 @@ def main():
|
|
819 |
# log metrics
|
820 |
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
821 |
|
822 |
-
# compute ROUGE metrics
|
823 |
-
rouge_desc = ""
|
824 |
-
# if data_args.predict_with_generate:
|
825 |
-
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
826 |
-
# eval_metrics.update(rouge_metrics)
|
827 |
-
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
828 |
-
|
829 |
# Print metrics and update progress bar
|
830 |
-
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']}
|
831 |
epochs.write(desc)
|
832 |
epochs.desc = desc
|
833 |
|
@@ -952,15 +916,8 @@ def main():
|
|
952 |
pred_metrics = get_metrics(pred_metrics)
|
953 |
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
954 |
|
955 |
-
# compute ROUGE metrics
|
956 |
-
rouge_desc = ""
|
957 |
-
if data_args.predict_with_generate:
|
958 |
-
rouge_metrics = compute_metrics(pred_generations, pred_labels)
|
959 |
-
pred_metrics.update(rouge_metrics)
|
960 |
-
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
|
961 |
-
|
962 |
# Print metrics
|
963 |
-
desc = f"Predict Loss: {pred_metrics['loss']}
|
964 |
logger.info(desc)
|
965 |
|
966 |
|
|
|
218 |
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
219 |
)
|
220 |
predict_with_generate: bool = field(
|
221 |
+
default=False, metadata={"help": "Whether to use generate to calculate generative metrics."}
|
222 |
)
|
223 |
num_beams: Optional[int] = field(
|
224 |
default=None,
|
|
|
605 |
desc="Running tokenizer on prediction dataset",
|
606 |
)
|
607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
608 |
# Initialize our training
|
609 |
rng = jax.random.PRNGKey(training_args.seed)
|
610 |
rng, dropout_rng = jax.random.split(rng)
|
|
|
790 |
# log metrics
|
791 |
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
792 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
# Print metrics and update progress bar
|
794 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
795 |
epochs.write(desc)
|
796 |
epochs.desc = desc
|
797 |
|
|
|
916 |
pred_metrics = get_metrics(pred_metrics)
|
917 |
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
918 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
919 |
# Print metrics
|
920 |
+
desc = f"Predict Loss: {pred_metrics['loss']})"
|
921 |
logger.info(desc)
|
922 |
|
923 |
|