boris commited on
Commit
0d94b71
1 Parent(s): 6c5fc6a

feat: remove unused metrics

Browse files

Former-commit-id: 00a582c7b2dc2f5d8c86bc8818bf8968d4903a70

Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +3 -46
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -218,7 +218,7 @@ class DataTrainingArguments:
218
  default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
219
  )
220
  predict_with_generate: bool = field(
221
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
222
  )
223
  num_beams: Optional[int] = field(
224
  default=None,
@@ -605,35 +605,6 @@ def main():
605
  desc="Running tokenizer on prediction dataset",
606
  )
607
 
608
- # Metric
609
- #metric = load_metric("rouge")
610
-
611
- def postprocess_text(preds, labels):
612
- preds = [pred.strip() for pred in preds]
613
- labels = [label.strip() for label in labels]
614
-
615
- # rougeLSum expects newline after each sentence
616
- preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
617
- labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
618
-
619
- return preds, labels
620
-
621
- def compute_metrics(preds, labels):
622
- decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
623
- decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
624
-
625
- # Some simple post-processing
626
- decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
627
-
628
- result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
629
- # Extract a few results from ROUGE
630
- result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
631
-
632
- prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
633
- result["gen_len"] = np.mean(prediction_lens)
634
- result = {k: round(v, 4) for k, v in result.items()}
635
- return result
636
-
637
  # Initialize our training
638
  rng = jax.random.PRNGKey(training_args.seed)
639
  rng, dropout_rng = jax.random.split(rng)
@@ -819,15 +790,8 @@ def main():
819
  # log metrics
820
  wandb_log(eval_metrics, step=global_step, prefix='eval')
821
 
822
- # compute ROUGE metrics
823
- rouge_desc = ""
824
- # if data_args.predict_with_generate:
825
- # rouge_metrics = compute_metrics(eval_preds, eval_labels)
826
- # eval_metrics.update(rouge_metrics)
827
- # rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
828
-
829
  # Print metrics and update progress bar
830
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
831
  epochs.write(desc)
832
  epochs.desc = desc
833
 
@@ -952,15 +916,8 @@ def main():
952
  pred_metrics = get_metrics(pred_metrics)
953
  pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
954
 
955
- # compute ROUGE metrics
956
- rouge_desc = ""
957
- if data_args.predict_with_generate:
958
- rouge_metrics = compute_metrics(pred_generations, pred_labels)
959
- pred_metrics.update(rouge_metrics)
960
- rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
961
-
962
  # Print metrics
963
- desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
964
  logger.info(desc)
965
 
966
 
 
218
  default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
219
  )
220
  predict_with_generate: bool = field(
221
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics."}
222
  )
223
  num_beams: Optional[int] = field(
224
  default=None,
 
605
  desc="Running tokenizer on prediction dataset",
606
  )
607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
  # Initialize our training
609
  rng = jax.random.PRNGKey(training_args.seed)
610
  rng, dropout_rng = jax.random.split(rng)
 
790
  # log metrics
791
  wandb_log(eval_metrics, step=global_step, prefix='eval')
792
 
 
 
 
 
 
 
 
793
  # Print metrics and update progress bar
794
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
795
  epochs.write(desc)
796
  epochs.desc = desc
797
 
 
916
  pred_metrics = get_metrics(pred_metrics)
917
  pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
918
 
 
 
 
 
 
 
 
919
  # Print metrics
920
+ desc = f"Predict Loss: {pred_metrics['loss']})"
921
  logger.info(desc)
922
 
923