sonsus commited on
Commit
82e28d8
1 Parent(s): d623514

Update harim_plus.py

Browse files
Files changed (1) hide show
  1. harim_plus.py +2 -2
harim_plus.py CHANGED
@@ -232,10 +232,10 @@ class Harimplus_Scorer:
232
  labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
233
  return_dict=True).logits
234
  sent_lengths = tgt_mask.sum(-1)
235
- ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, 1)#tgt_mask)
236
  ll = ll_tok.sum(-1) / sent_lengths
237
 
238
- harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, 1)#tgt_mask)
239
  harim = harim_tok.sum(-1) / sent_lengths
240
 
241
  harim_plus_normalized = (ll + self._lambda * harim) # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
 
232
  labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
233
  return_dict=True).logits
234
  sent_lengths = tgt_mask.sum(-1)
235
+ ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
236
  ll = ll_tok.sum(-1) / sent_lengths
237
 
238
+ harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
239
  harim = harim_tok.sum(-1) / sent_lengths
240
 
241
  harim_plus_normalized = (ll + self._lambda * harim) # loglikelihood + lambda * negative_harim (negative harim=-1* risk)