jaygala24 commited on
Commit
c90c65f
1 Parent(s): 728e02c

Update modeling_indictrans.py

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +1 -1
modeling_indictrans.py CHANGED
@@ -1213,7 +1213,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1213
  # move labels to the correct device to enable PP
1214
  labels = labels.to(lm_logits.device)
1215
  loss_fct = nn.CrossEntropyLoss()
1216
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1217
 
1218
  if not return_dict:
1219
  output = (lm_logits,) + outputs[1:]
 
1213
  # move labels to the correct device to enable PP
1214
  labels = labels.to(lm_logits.device)
1215
  loss_fct = nn.CrossEntropyLoss()
1216
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
1217
 
1218
  if not return_dict:
1219
  output = (lm_logits,) + outputs[1:]