seonil commited on
Commit
a1a10ca
1 Parent(s): 37d452a
__pycache__/harim_scorer.cpython-39.pyc CHANGED
Binary files a/__pycache__/harim_scorer.cpython-39.pyc and b/__pycache__/harim_scorer.cpython-39.pyc differ
 
harim_plus.py CHANGED
@@ -207,18 +207,19 @@ class Harimplus_Scorer:
207
  emp_in = emp_in.to(self._device)
208
  tgt_in = tgt_in.to(self._device)
209
  tgt_mask = tgt_mask.to(self._device)
 
210
 
211
  with torch.no_grad():
212
  # token_type_ids attribute causes error
213
  s2s_logits = self._encdec_model.forward(
214
  input_ids = src_in.input_ids,
215
  attention_mask = src_in.attention_mask,
216
- labels = tgt_in.input_ids,
217
  return_dict=True).logits
218
  lm_logits = self._encdec_model.forward(
219
  input_ids = emp_in.input_ids,
220
  attention_mask = emp_in.attention_mask,
221
- labels = tgt_in.input_ids,
222
  return_dict=True).logits
223
  sent_lengths = tgt_mask.sum(-1)
224
  ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
 
207
  emp_in = emp_in.to(self._device)
208
  tgt_in = tgt_in.to(self._device)
209
  tgt_mask = tgt_mask.to(self._device)
210
+ fill_ignore_mask = ~(tgt_mask.bool())
211
 
212
  with torch.no_grad():
213
  # token_type_ids attribute causes error
214
  s2s_logits = self._encdec_model.forward(
215
  input_ids = src_in.input_ids,
216
  attention_mask = src_in.attention_mask,
217
+ labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
218
  return_dict=True).logits
219
  lm_logits = self._encdec_model.forward(
220
  input_ids = emp_in.input_ids,
221
  attention_mask = emp_in.attention_mask,
222
+ labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
223
  return_dict=True).logits
224
  sent_lengths = tgt_mask.sum(-1)
225
  ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
harim_scorer.py CHANGED
@@ -141,18 +141,19 @@ class Harimplus_Scorer:
141
  emp_in = emp_in.to(self._device)
142
  tgt_in = tgt_in.to(self._device)
143
  tgt_mask = tgt_mask.to(self._device)
 
144
 
145
  with torch.no_grad():
146
  # token_type_ids attribute causes error
147
  s2s_logits = self._encdec_model.forward(
148
  input_ids = src_in.input_ids,
149
  attention_mask = src_in.attention_mask,
150
- labels = tgt_in.input_ids,
151
  return_dict=True).logits
152
  lm_logits = self._encdec_model.forward(
153
  input_ids = emp_in.input_ids,
154
  attention_mask = emp_in.attention_mask,
155
- labels = tgt_in.input_ids,
156
  return_dict=True).logits
157
  sent_lengths = tgt_mask.sum(-1)
158
  ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)
 
141
  emp_in = emp_in.to(self._device)
142
  tgt_in = tgt_in.to(self._device)
143
  tgt_mask = tgt_mask.to(self._device)
144
+ fill_ignore_mask = ~(tgt_mask.bool())
145
 
146
  with torch.no_grad():
147
  # token_type_ids attribute causes error
148
  s2s_logits = self._encdec_model.forward(
149
  input_ids = src_in.input_ids,
150
  attention_mask = src_in.attention_mask,
151
+ labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
152
  return_dict=True).logits
153
  lm_logits = self._encdec_model.forward(
154
  input_ids = emp_in.input_ids,
155
  attention_mask = emp_in.attention_mask,
156
+ labels = tgt_in.input_ids.masked_fill(fill_ignore_mask, -100),
157
  return_dict=True).logits
158
  sent_lengths = tgt_mask.sum(-1)
159
  ll_tok = self.log_likelihoods(s2s_logits, tgt_in.input_ids, tgt_mask)