seonil commited on
Commit
f469daa
1 Parent(s): 3f218d0

return details

Browse files
Files changed (2) hide show
  1. harim_plus.py +5 -1
  2. harim_scorer.py +4 -1
harim_plus.py CHANGED
@@ -128,5 +128,9 @@ class Harimplus(evaluate.Metric):
128
  return_details=False):
129
  summaries = predictions
130
  articles = references
131
- scores = self.scorer.compute(predictions=summaries, references=articles, use_aggregator=use_aggregator, bsz=bsz, tokenwise_score=tokenwise_score, return_details=return_details)
 
 
 
 
132
  return scores
 
128
  return_details=False):
129
  summaries = predictions
130
  articles = references
131
+ scores = self.scorer.compute(predictions=summaries,
132
+ references=articles,
133
+ use_aggregator=use_aggregator,
134
+ bsz=bsz, tokenwise_score=tokenwise_score,
135
+ return_details=return_details)
136
  return scores
harim_scorer.py CHANGED
@@ -102,6 +102,7 @@ class Harimplus_Scorer:
102
  references:List[str],
103
  bsz:int=32,
104
  use_aggregator:bool=False,
 
105
  tokenwise_score:bool=False,
106
  ):
107
  '''
@@ -190,6 +191,8 @@ class Harimplus_Scorer:
190
  if not k.startswith('tok_'):
191
  scores[k] = sum(v)/len(v) # aggregate (mean)
192
  scores['lambda'] = self._lambda
 
 
193
  return scores
194
 
195
 
@@ -221,7 +224,7 @@ def test(bsz = 16, pretrained_name='facebook/bart-large-cnn', tokenizer=None):
221
  ]
222
  articles = [ art1 ]*5 + [art2 ]*4
223
  # set_trace()
224
- hp_score = scorer.compute(predictions=summaries, references=articles, use_aggregator=False, bsz=bsz)
225
  # pprint(f"{articles=}")
226
  # pprint(f"{summaries=}")
227
  pprint(hp_score)
 
102
  references:List[str],
103
  bsz:int=32,
104
  use_aggregator:bool=False,
105
+ return_details:bool=False,
106
  tokenwise_score:bool=False,
107
  ):
108
  '''
 
191
  if not k.startswith('tok_'):
192
  scores[k] = sum(v)/len(v) # aggregate (mean)
193
  scores['lambda'] = self._lambda
194
+ if not return_details:
195
+ scores = scores['harim+']
196
  return scores
197
 
198
 
 
224
  ]
225
  articles = [ art1 ]*5 + [art2 ]*4
226
  # set_trace()
227
+ hp_score = scorer.compute(predictions=summaries, references=articles, use_aggregator=False, bsz=bsz, return_details=False, tokenwise_score=False)
228
  # pprint(f"{articles=}")
229
  # pprint(f"{summaries=}")
230
  pprint(hp_score)