jaygala24 commited on
Commit
7d5c6bf
1 Parent(s): 7d773d1

Upload modeling_indictrans.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +270 -82
modeling_indictrans.py CHANGED
@@ -34,7 +34,7 @@ from transformers.modeling_outputs import (
34
  from transformers.utils import logging
35
  from transformers.modeling_utils import PreTrainedModel
36
 
37
- from .configuration_indictrans import IndicTransConfig
38
 
39
 
40
  logger = logging.get_logger(__name__)
@@ -45,7 +45,9 @@ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
45
 
46
 
47
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
48
- def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
 
 
49
  """
50
  Shift input ids one token to the right.
51
  """
@@ -63,7 +65,10 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
63
 
64
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
65
  def _make_causal_mask(
66
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
 
 
67
  ):
68
  """
69
  Make causal mask used for bi-directional self-attention.
@@ -75,8 +80,18 @@ def _make_causal_mask(
75
  mask = mask.to(dtype)
76
 
77
  if past_key_values_length > 0:
78
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
79
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  # Copied from transformers.models.bart.modeling_bart._expand_mask
@@ -91,17 +106,23 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
91
 
92
  inverted_mask = 1.0 - expanded_mask
93
 
94
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
 
 
95
 
96
 
97
- def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
 
 
98
  """
99
  Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
100
  are ignored. This is modified from fairseq's `utils.make_positions`.
101
  """
102
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
103
  mask = input_ids.ne(padding_idx).int()
104
- incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
 
 
105
  return incremental_indices.long() + padding_idx
106
 
107
 
@@ -109,23 +130,31 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
109
  class IndicTransSinusoidalPositionalEmbedding(nn.Module):
110
  """This module produces sinusoidal positional embeddings of any length."""
111
 
112
- def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
 
 
113
  super().__init__()
114
  self.offset = 2
115
  self.embedding_dim = embedding_dim
116
  self.padding_idx = padding_idx
117
  self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
118
 
119
- def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
 
 
120
  emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
121
  if hasattr(self, "weights"):
122
  # in forward put the weights on the correct dtype and device of the param
123
- emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
 
 
124
 
125
  self.register_buffer("weights", emb_weights, persistent=False)
126
 
127
  @staticmethod
128
- def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
 
 
129
  """
130
  Build sinusoidal embeddings.
131
 
@@ -135,8 +164,12 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
135
  half_dim = embedding_dim // 2
136
  emb = math.log(10000) / (half_dim - 1)
137
  emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
138
- emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
139
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
 
 
 
 
140
  if embedding_dim % 2 == 1:
141
  # zero pad
142
  emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
@@ -147,26 +180,39 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
147
 
148
  @torch.no_grad()
149
  def forward(
150
- self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
 
 
 
151
  ):
152
  if input_ids is not None:
153
  bsz, seq_len = input_ids.size()
154
  # Create the position ids from the input token ids. Any padded tokens remain padded.
155
- position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
156
- input_ids.device
157
- )
158
  else:
159
  bsz, seq_len = inputs_embeds.size()[:-1]
160
- position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
 
 
161
 
162
  # expand embeddings if needed
163
  max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
164
  if max_pos > self.weights.size(0):
165
- self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
 
 
166
 
167
- return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
 
 
 
 
168
 
169
- def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
 
 
170
  """
171
  We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
172
 
@@ -179,9 +225,15 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
179
  sequence_length = input_shape[1]
180
 
181
  position_ids = torch.arange(
182
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
 
 
 
 
 
 
 
183
  )
184
- return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
185
 
186
 
187
  # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
@@ -216,7 +268,11 @@ class IndicTransAttention(nn.Module):
216
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
217
 
218
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
219
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
220
 
221
  def forward(
222
  self,
@@ -293,7 +349,10 @@ class IndicTransAttention(nn.Module):
293
  raise ValueError(
294
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
295
  )
296
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
 
 
 
297
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
298
 
299
  attn_weights = F.softmax(attn_weights, dim=-1)
@@ -304,7 +363,9 @@ class IndicTransAttention(nn.Module):
304
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
305
  f" {layer_head_mask.size()}"
306
  )
307
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
 
 
308
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
309
 
310
  if output_attentions:
@@ -312,8 +373,12 @@ class IndicTransAttention(nn.Module):
312
  # make sure that attn_weights keeps its gradient.
313
  # In order to do so, attn_weights have to be reshaped
314
  # twice and have to be reused in the following
315
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
316
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
 
 
 
 
317
  else:
318
  attn_weights_reshaped = None
319
 
@@ -394,7 +459,9 @@ class IndicTransEncoderLayer(nn.Module):
394
  if self.normalize_before:
395
  hidden_states = self.final_layer_norm(hidden_states)
396
  hidden_states = self.activation_fn(self.fc1(hidden_states))
397
- hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
 
 
398
  hidden_states = self.fc2(hidden_states)
399
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
400
  hidden_states = residual + hidden_states
@@ -405,7 +472,9 @@ class IndicTransEncoderLayer(nn.Module):
405
  torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
406
  ):
407
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
408
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
 
 
409
 
410
  outputs = (hidden_states,)
411
 
@@ -480,7 +549,9 @@ class IndicTransDecoderLayer(nn.Module):
480
 
481
  # Self Attention
482
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
483
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
 
 
484
  # add present self-attn cache to positions 1,2 of present_key_value tuple
485
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
486
  hidden_states=hidden_states,
@@ -503,8 +574,14 @@ class IndicTransDecoderLayer(nn.Module):
503
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
504
 
505
  # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
506
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
507
- hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
 
 
 
 
 
 
508
  hidden_states=hidden_states,
509
  key_value_states=encoder_hidden_states,
510
  attention_mask=encoder_attention_mask,
@@ -512,7 +589,9 @@ class IndicTransDecoderLayer(nn.Module):
512
  past_key_value=cross_attn_past_key_value,
513
  output_attentions=output_attentions,
514
  )
515
- hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
 
 
516
  hidden_states = residual + hidden_states
517
  if not self.normalize_before:
518
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
@@ -525,7 +604,9 @@ class IndicTransDecoderLayer(nn.Module):
525
  if self.normalize_before:
526
  hidden_states = self.final_layer_norm(hidden_states)
527
  hidden_states = self.activation_fn(self.fc1(hidden_states))
528
- hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
 
 
529
  hidden_states = self.fc2(hidden_states)
530
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
531
  hidden_states = residual + hidden_states
@@ -577,7 +658,9 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
577
  embed_tokens (nn.Embedding): output embedding
578
  """
579
 
580
- def __init__(self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None):
 
 
581
  super().__init__(config)
582
 
583
  self.dropout = config.dropout
@@ -588,7 +671,9 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
588
  self.max_source_positions = config.max_source_positions
589
  self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
590
 
591
- self.embed_tokens = nn.Embedding(config.encoder_vocab_size, embed_dim, self.padding_idx)
 
 
592
 
593
  if embed_tokens is not None:
594
  self.embed_tokens.weight = embed_tokens.weight
@@ -598,9 +683,15 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
598
  embed_dim,
599
  self.padding_idx,
600
  )
601
- self.layers = nn.ModuleList([IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)])
602
- self.layer_norm = nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
603
- self.layernorm_embedding = nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
 
 
 
 
 
 
604
 
605
  self.gradient_checkpointing = False
606
  # Initialize weights and apply final processing
@@ -652,15 +743,25 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
652
  return_dict (`bool`, *optional*):
653
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
654
  """
655
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
656
  output_hidden_states = (
657
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
658
  )
659
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
 
661
  # retrieve input_ids and inputs_embeds
662
  if input_ids is not None and inputs_embeds is not None:
663
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
 
664
  elif input_ids is not None:
665
  self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
666
  input_shape = input_ids.size()
@@ -705,7 +806,11 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
705
  # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
706
  dropout_probability = torch.rand([])
707
 
708
- skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
 
 
 
 
709
  if not skip_the_layer or deepspeed_zero3_is_enabled:
710
  # under deepspeed zero3 all gpus must run in sync
711
 
@@ -727,7 +832,9 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
727
  layer_outputs = encoder_layer(
728
  hidden_states,
729
  attention_mask,
730
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
 
 
731
  output_attentions=output_attentions,
732
  )
733
 
@@ -746,9 +853,15 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
746
  encoder_states = encoder_states + (hidden_states,)
747
 
748
  if not return_dict:
749
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
 
 
 
 
750
  return BaseModelOutput(
751
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
 
 
752
  )
753
 
754
 
@@ -762,7 +875,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
762
  embed_tokens (nn.Embedding): output embedding
763
  """
764
 
765
- def __init__(self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None):
 
 
766
  super().__init__(config)
767
  self.dropout = config.dropout
768
  self.layerdrop = config.decoder_layerdrop
@@ -772,7 +887,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
772
  self.max_target_positions = config.max_target_positions
773
  self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
774
 
775
- self.embed_tokens = nn.Embedding(config.decoder_vocab_size, embed_dim, self.padding_idx)
 
 
776
 
777
  if embed_tokens is not None:
778
  self.embed_tokens.weight = embed_tokens.weight
@@ -782,9 +899,15 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
782
  embed_dim,
783
  self.padding_idx,
784
  )
785
- self.layers = nn.ModuleList([IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)])
786
- self.layer_norm = nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
787
- self.layernorm_embedding = nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
 
 
 
 
 
 
788
 
789
  self.gradient_checkpointing = False
790
  # Initialize weights and apply final processing
@@ -870,26 +993,40 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
870
  return_dict (`bool`, *optional*):
871
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
872
  """
873
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
874
  output_hidden_states = (
875
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
876
  )
877
  use_cache = use_cache if use_cache is not None else self.config.use_cache
878
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
879
 
880
  # retrieve input_ids and inputs_embeds
881
  if input_ids is not None and inputs_embeds is not None:
882
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 
 
883
  elif input_ids is not None:
884
  input_shape = input_ids.size()
885
  input_ids = input_ids.view(-1, input_shape[-1])
886
  elif inputs_embeds is not None:
887
  input_shape = inputs_embeds.size()[:-1]
888
  else:
889
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 
 
890
 
891
  # past_key_values_length
892
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
 
 
893
 
894
  if inputs_embeds is None:
895
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
@@ -914,10 +1051,14 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
914
  # expand encoder attention mask
915
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
916
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
917
- encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
 
 
918
 
919
  # embed positions
920
- positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
 
 
921
  positions = positions.to(inputs_embeds.device)
922
 
923
  hidden_states = inputs_embeds + positions
@@ -929,7 +1070,8 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
929
  if self.gradient_checkpointing and self.training:
930
  if use_cache:
931
  logger.warning_once(
932
- "`use_cache=True` is incompatible with gradient checkpointing. Setting" " `use_cache=False`..."
 
933
  )
934
  use_cache = False
935
 
@@ -940,7 +1082,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
940
  next_decoder_cache = () if use_cache else None
941
 
942
  # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
943
- for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
 
 
944
  if attn_mask is not None:
945
  if attn_mask.size()[0] != len(self.layers):
946
  raise ValueError(
@@ -956,11 +1100,17 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
956
  # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
957
  dropout_probability = torch.rand([])
958
 
959
- skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
 
 
 
 
960
  if not skip_the_layer or deepspeed_zero3_is_enabled:
961
  # under deepspeed zero3 all gpus must run in sync
962
 
963
- past_key_value = past_key_values[idx] if past_key_values is not None else None
 
 
964
 
965
  if self.gradient_checkpointing and self.training:
966
 
@@ -978,7 +1128,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
978
  encoder_hidden_states,
979
  encoder_attention_mask,
980
  head_mask[idx] if head_mask is not None else None,
981
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
 
 
982
  None,
983
  )
984
  else:
@@ -987,9 +1139,13 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
987
  attention_mask=combined_attention_mask,
988
  encoder_hidden_states=encoder_hidden_states,
989
  encoder_attention_mask=encoder_attention_mask,
990
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
 
 
991
  cross_attn_layer_head_mask=(
992
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
 
 
993
  ),
994
  past_key_value=past_key_value,
995
  output_attentions=output_attentions,
@@ -1019,7 +1175,13 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1019
  if not return_dict:
1020
  return tuple(
1021
  v
1022
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
 
 
 
 
 
 
1023
  if v is not None
1024
  )
1025
  return BaseModelOutputWithPastAndCrossAttentions(
@@ -1037,7 +1199,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
1037
 
1038
  def __init__(self, config: IndicTransConfig):
1039
  super().__init__(config)
1040
-
1041
  self.encoder = IndicTransEncoder(config)
1042
  self.decoder = IndicTransDecoder(config)
1043
 
@@ -1068,12 +1230,20 @@ class IndicTransModel(IndicTransPreTrainedModel):
1068
  output_hidden_states: Optional[bool] = None,
1069
  return_dict: Optional[bool] = None,
1070
  ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1071
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1072
  output_hidden_states = (
1073
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
1074
  )
1075
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1076
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1077
 
1078
  if encoder_outputs is None:
1079
  encoder_outputs = self.encoder(
@@ -1128,17 +1298,20 @@ class IndicTransModel(IndicTransPreTrainedModel):
1128
  class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1129
  base_model_prefix = "model"
1130
  _tied_weights_keys = None
 
1131
 
1132
  def __init__(self, config: IndicTransConfig):
1133
  super().__init__(config)
1134
  self.model = IndicTransModel(config)
1135
- self.lm_head = nn.Linear(config.decoder_embed_dim, config.decoder_vocab_size, bias=False)
 
 
1136
 
1137
  if config.share_decoder_input_output_embed:
1138
  self.lm_head.weight = self.model.decoder.embed_tokens.weight
1139
-
1140
  self.post_init()
1141
-
1142
  def tie_weights(self):
1143
  pass
1144
 
@@ -1153,6 +1326,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1153
 
1154
  def set_output_embeddings(self, new_embeddings):
1155
  self.lm_head = new_embeddings
 
 
 
1156
 
1157
  def forward(
1158
  self,
@@ -1181,7 +1357,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1181
 
1182
  Returns:
1183
  """
1184
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1185
 
1186
  if labels is not None:
1187
  if decoder_input_ids is None:
@@ -1212,12 +1390,18 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1212
  if labels is not None:
1213
  # move labels to the correct device to enable PP
1214
  labels = labels.to(lm_logits.device)
1215
- loss_fct = nn.CrossEntropyLoss()
1216
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
 
 
 
 
1217
 
1218
  if not return_dict:
1219
  output = (lm_logits,) + outputs[1:]
1220
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
 
 
1221
 
1222
  return Seq2SeqLMOutput(
1223
  loss=masked_lm_loss,
@@ -1263,5 +1447,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1263
  def _reorder_cache(past_key_values, beam_idx):
1264
  reordered_past = ()
1265
  for layer_past in past_key_values:
1266
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
 
 
 
 
1267
  return reordered_past
 
34
  from transformers.utils import logging
35
  from transformers.modeling_utils import PreTrainedModel
36
 
37
+ from configuration_indictrans import IndicTransConfig
38
 
39
 
40
  logger = logging.get_logger(__name__)
 
45
 
46
 
47
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
48
+ def shift_tokens_right(
49
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
50
+ ):
51
  """
52
  Shift input ids one token to the right.
53
  """
 
65
 
66
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
67
  def _make_causal_mask(
68
+ input_ids_shape: torch.Size,
69
+ dtype: torch.dtype,
70
+ device: torch.device,
71
+ past_key_values_length: int = 0,
72
  ):
73
  """
74
  Make causal mask used for bi-directional self-attention.
 
80
  mask = mask.to(dtype)
81
 
82
  if past_key_values_length > 0:
83
+ mask = torch.cat(
84
+ [
85
+ torch.zeros(
86
+ tgt_len, past_key_values_length, dtype=dtype, device=device
87
+ ),
88
+ mask,
89
+ ],
90
+ dim=-1,
91
+ )
92
+ return mask[None, None, :, :].expand(
93
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
94
+ )
95
 
96
 
97
  # Copied from transformers.models.bart.modeling_bart._expand_mask
 
106
 
107
  inverted_mask = 1.0 - expanded_mask
108
 
109
+ return inverted_mask.masked_fill(
110
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
111
+ )
112
 
113
 
114
+ def create_position_ids_from_input_ids(
115
+ input_ids, padding_idx, past_key_values_length=0
116
+ ):
117
  """
118
  Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
119
  are ignored. This is modified from fairseq's `utils.make_positions`.
120
  """
121
  # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
122
  mask = input_ids.ne(padding_idx).int()
123
+ incremental_indices = (
124
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
125
+ ) * mask
126
  return incremental_indices.long() + padding_idx
127
 
128
 
 
130
  class IndicTransSinusoidalPositionalEmbedding(nn.Module):
131
  """This module produces sinusoidal positional embeddings of any length."""
132
 
133
+ def __init__(
134
+ self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
135
+ ):
136
  super().__init__()
137
  self.offset = 2
138
  self.embedding_dim = embedding_dim
139
  self.padding_idx = padding_idx
140
  self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
141
 
142
+ def make_weights(
143
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
144
+ ):
145
  emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
146
  if hasattr(self, "weights"):
147
  # in forward put the weights on the correct dtype and device of the param
148
+ emb_weights = emb_weights.to(
149
+ dtype=self.weights.dtype, device=self.weights.device
150
+ )
151
 
152
  self.register_buffer("weights", emb_weights, persistent=False)
153
 
154
  @staticmethod
155
+ def get_embedding(
156
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
157
+ ):
158
  """
159
  Build sinusoidal embeddings.
160
 
 
164
  half_dim = embedding_dim // 2
165
  emb = math.log(10000) / (half_dim - 1)
166
  emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
167
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
168
+ 1
169
+ ) * emb.unsqueeze(0)
170
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
171
+ num_embeddings, -1
172
+ )
173
  if embedding_dim % 2 == 1:
174
  # zero pad
175
  emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
 
180
 
181
  @torch.no_grad()
182
  def forward(
183
+ self,
184
+ input_ids: torch.Tensor = None,
185
+ inputs_embeds: torch.Tensor = None,
186
+ past_key_values_length: int = 0,
187
  ):
188
  if input_ids is not None:
189
  bsz, seq_len = input_ids.size()
190
  # Create the position ids from the input token ids. Any padded tokens remain padded.
191
+ position_ids = create_position_ids_from_input_ids(
192
+ input_ids, self.padding_idx, past_key_values_length
193
+ ).to(input_ids.device)
194
  else:
195
  bsz, seq_len = inputs_embeds.size()[:-1]
196
+ position_ids = self.create_position_ids_from_inputs_embeds(
197
+ inputs_embeds, past_key_values_length
198
+ )
199
 
200
  # expand embeddings if needed
201
  max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
202
  if max_pos > self.weights.size(0):
203
+ self.make_weights(
204
+ max_pos + self.offset, self.embedding_dim, self.padding_idx
205
+ )
206
 
207
+ return (
208
+ self.weights.index_select(0, position_ids.view(-1))
209
+ .view(bsz, seq_len, self.weights.shape[-1])
210
+ .detach()
211
+ )
212
 
213
+ def create_position_ids_from_inputs_embeds(
214
+ self, inputs_embeds, past_key_values_length
215
+ ):
216
  """
217
  We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
218
 
 
225
  sequence_length = input_shape[1]
226
 
227
  position_ids = torch.arange(
228
+ self.padding_idx + 1,
229
+ sequence_length + self.padding_idx + 1,
230
+ dtype=torch.long,
231
+ device=inputs_embeds.device,
232
+ )
233
+ return (
234
+ position_ids.unsqueeze(0).expand(input_shape).contiguous()
235
+ + past_key_values_length
236
  )
 
237
 
238
 
239
  # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
 
268
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
269
 
270
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
271
+ return (
272
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
273
+ .transpose(1, 2)
274
+ .contiguous()
275
+ )
276
 
277
  def forward(
278
  self,
 
349
  raise ValueError(
350
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
351
  )
352
+ attn_weights = (
353
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
354
+ + attention_mask
355
+ )
356
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
357
 
358
  attn_weights = F.softmax(attn_weights, dim=-1)
 
363
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
364
  f" {layer_head_mask.size()}"
365
  )
366
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
367
+ bsz, self.num_heads, tgt_len, src_len
368
+ )
369
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
370
 
371
  if output_attentions:
 
373
  # make sure that attn_weights keeps its gradient.
374
  # In order to do so, attn_weights have to be reshaped
375
  # twice and have to be reused in the following
376
+ attn_weights_reshaped = attn_weights.view(
377
+ bsz, self.num_heads, tgt_len, src_len
378
+ )
379
+ attn_weights = attn_weights_reshaped.view(
380
+ bsz * self.num_heads, tgt_len, src_len
381
+ )
382
  else:
383
  attn_weights_reshaped = None
384
 
 
459
  if self.normalize_before:
460
  hidden_states = self.final_layer_norm(hidden_states)
461
  hidden_states = self.activation_fn(self.fc1(hidden_states))
462
+ hidden_states = F.dropout(
463
+ hidden_states, p=self.activation_dropout, training=self.training
464
+ )
465
  hidden_states = self.fc2(hidden_states)
466
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
467
  hidden_states = residual + hidden_states
 
472
  torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
473
  ):
474
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
475
+ hidden_states = torch.clamp(
476
+ hidden_states, min=-clamp_value, max=clamp_value
477
+ )
478
 
479
  outputs = (hidden_states,)
480
 
 
549
 
550
  # Self Attention
551
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
552
+ self_attn_past_key_value = (
553
+ past_key_value[:2] if past_key_value is not None else None
554
+ )
555
  # add present self-attn cache to positions 1,2 of present_key_value tuple
556
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
557
  hidden_states=hidden_states,
 
574
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
575
 
576
  # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
577
+ cross_attn_past_key_value = (
578
+ past_key_value[-2:] if past_key_value is not None else None
579
+ )
580
+ (
581
+ hidden_states,
582
+ cross_attn_weights,
583
+ cross_attn_present_key_value,
584
+ ) = self.encoder_attn(
585
  hidden_states=hidden_states,
586
  key_value_states=encoder_hidden_states,
587
  attention_mask=encoder_attention_mask,
 
589
  past_key_value=cross_attn_past_key_value,
590
  output_attentions=output_attentions,
591
  )
592
+ hidden_states = F.dropout(
593
+ hidden_states, p=self.dropout, training=self.training
594
+ )
595
  hidden_states = residual + hidden_states
596
  if not self.normalize_before:
597
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
 
604
  if self.normalize_before:
605
  hidden_states = self.final_layer_norm(hidden_states)
606
  hidden_states = self.activation_fn(self.fc1(hidden_states))
607
+ hidden_states = F.dropout(
608
+ hidden_states, p=self.activation_dropout, training=self.training
609
+ )
610
  hidden_states = self.fc2(hidden_states)
611
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
612
  hidden_states = residual + hidden_states
 
658
  embed_tokens (nn.Embedding): output embedding
659
  """
660
 
661
+ def __init__(
662
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
663
+ ):
664
  super().__init__(config)
665
 
666
  self.dropout = config.dropout
 
671
  self.max_source_positions = config.max_source_positions
672
  self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
673
 
674
+ self.embed_tokens = nn.Embedding(
675
+ config.encoder_vocab_size, embed_dim, self.padding_idx
676
+ )
677
 
678
  if embed_tokens is not None:
679
  self.embed_tokens.weight = embed_tokens.weight
 
683
  embed_dim,
684
  self.padding_idx,
685
  )
686
+ self.layers = nn.ModuleList(
687
+ [IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
688
+ )
689
+ self.layer_norm = (
690
+ nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
691
+ )
692
+ self.layernorm_embedding = (
693
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
694
+ )
695
 
696
  self.gradient_checkpointing = False
697
  # Initialize weights and apply final processing
 
743
  return_dict (`bool`, *optional*):
744
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
745
  """
746
+ output_attentions = (
747
+ output_attentions
748
+ if output_attentions is not None
749
+ else self.config.output_attentions
750
+ )
751
  output_hidden_states = (
752
+ output_hidden_states
753
+ if output_hidden_states is not None
754
+ else self.config.output_hidden_states
755
+ )
756
+ return_dict = (
757
+ return_dict if return_dict is not None else self.config.use_return_dict
758
  )
 
759
 
760
  # retrieve input_ids and inputs_embeds
761
  if input_ids is not None and inputs_embeds is not None:
762
+ raise ValueError(
763
+ "You cannot specify both input_ids and inputs_embeds at the same time"
764
+ )
765
  elif input_ids is not None:
766
  self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
767
  input_shape = input_ids.size()
 
806
  # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
807
  dropout_probability = torch.rand([])
808
 
809
+ skip_the_layer = (
810
+ True
811
+ if self.training and (dropout_probability < self.layerdrop)
812
+ else False
813
+ )
814
  if not skip_the_layer or deepspeed_zero3_is_enabled:
815
  # under deepspeed zero3 all gpus must run in sync
816
 
 
832
  layer_outputs = encoder_layer(
833
  hidden_states,
834
  attention_mask,
835
+ layer_head_mask=(
836
+ head_mask[idx] if head_mask is not None else None
837
+ ),
838
  output_attentions=output_attentions,
839
  )
840
 
 
853
  encoder_states = encoder_states + (hidden_states,)
854
 
855
  if not return_dict:
856
+ return tuple(
857
+ v
858
+ for v in [hidden_states, encoder_states, all_attentions]
859
+ if v is not None
860
+ )
861
  return BaseModelOutput(
862
+ last_hidden_state=hidden_states,
863
+ hidden_states=encoder_states,
864
+ attentions=all_attentions,
865
  )
866
 
867
 
 
875
  embed_tokens (nn.Embedding): output embedding
876
  """
877
 
878
+ def __init__(
879
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
880
+ ):
881
  super().__init__(config)
882
  self.dropout = config.dropout
883
  self.layerdrop = config.decoder_layerdrop
 
887
  self.max_target_positions = config.max_target_positions
888
  self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
889
 
890
+ self.embed_tokens = nn.Embedding(
891
+ config.decoder_vocab_size, embed_dim, self.padding_idx
892
+ )
893
 
894
  if embed_tokens is not None:
895
  self.embed_tokens.weight = embed_tokens.weight
 
899
  embed_dim,
900
  self.padding_idx,
901
  )
902
+ self.layers = nn.ModuleList(
903
+ [IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
904
+ )
905
+ self.layer_norm = (
906
+ nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
907
+ )
908
+ self.layernorm_embedding = (
909
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
910
+ )
911
 
912
  self.gradient_checkpointing = False
913
  # Initialize weights and apply final processing
 
993
  return_dict (`bool`, *optional*):
994
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
995
  """
996
+ output_attentions = (
997
+ output_attentions
998
+ if output_attentions is not None
999
+ else self.config.output_attentions
1000
+ )
1001
  output_hidden_states = (
1002
+ output_hidden_states
1003
+ if output_hidden_states is not None
1004
+ else self.config.output_hidden_states
1005
  )
1006
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1007
+ return_dict = (
1008
+ return_dict if return_dict is not None else self.config.use_return_dict
1009
+ )
1010
 
1011
  # retrieve input_ids and inputs_embeds
1012
  if input_ids is not None and inputs_embeds is not None:
1013
+ raise ValueError(
1014
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1015
+ )
1016
  elif input_ids is not None:
1017
  input_shape = input_ids.size()
1018
  input_ids = input_ids.view(-1, input_shape[-1])
1019
  elif inputs_embeds is not None:
1020
  input_shape = inputs_embeds.size()[:-1]
1021
  else:
1022
+ raise ValueError(
1023
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1024
+ )
1025
 
1026
  # past_key_values_length
1027
+ past_key_values_length = (
1028
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1029
+ )
1030
 
1031
  if inputs_embeds is None:
1032
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
 
1051
  # expand encoder attention mask
1052
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1053
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1054
+ encoder_attention_mask = _expand_mask(
1055
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1056
+ )
1057
 
1058
  # embed positions
1059
+ positions = self.embed_positions(
1060
+ input_ids, inputs_embeds, past_key_values_length
1061
+ )
1062
  positions = positions.to(inputs_embeds.device)
1063
 
1064
  hidden_states = inputs_embeds + positions
 
1070
  if self.gradient_checkpointing and self.training:
1071
  if use_cache:
1072
  logger.warning_once(
1073
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
1074
+ " `use_cache=False`..."
1075
  )
1076
  use_cache = False
1077
 
 
1082
  next_decoder_cache = () if use_cache else None
1083
 
1084
  # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1085
+ for attn_mask, mask_name in zip(
1086
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1087
+ ):
1088
  if attn_mask is not None:
1089
  if attn_mask.size()[0] != len(self.layers):
1090
  raise ValueError(
 
1100
  # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1101
  dropout_probability = torch.rand([])
1102
 
1103
+ skip_the_layer = (
1104
+ True
1105
+ if self.training and (dropout_probability < self.layerdrop)
1106
+ else False
1107
+ )
1108
  if not skip_the_layer or deepspeed_zero3_is_enabled:
1109
  # under deepspeed zero3 all gpus must run in sync
1110
 
1111
+ past_key_value = (
1112
+ past_key_values[idx] if past_key_values is not None else None
1113
+ )
1114
 
1115
  if self.gradient_checkpointing and self.training:
1116
 
 
1128
  encoder_hidden_states,
1129
  encoder_attention_mask,
1130
  head_mask[idx] if head_mask is not None else None,
1131
+ cross_attn_head_mask[idx]
1132
+ if cross_attn_head_mask is not None
1133
+ else None,
1134
  None,
1135
  )
1136
  else:
 
1139
  attention_mask=combined_attention_mask,
1140
  encoder_hidden_states=encoder_hidden_states,
1141
  encoder_attention_mask=encoder_attention_mask,
1142
+ layer_head_mask=(
1143
+ head_mask[idx] if head_mask is not None else None
1144
+ ),
1145
  cross_attn_layer_head_mask=(
1146
+ cross_attn_head_mask[idx]
1147
+ if cross_attn_head_mask is not None
1148
+ else None
1149
  ),
1150
  past_key_value=past_key_value,
1151
  output_attentions=output_attentions,
 
1175
  if not return_dict:
1176
  return tuple(
1177
  v
1178
+ for v in [
1179
+ hidden_states,
1180
+ next_cache,
1181
+ all_hidden_states,
1182
+ all_self_attns,
1183
+ all_cross_attentions,
1184
+ ]
1185
  if v is not None
1186
  )
1187
  return BaseModelOutputWithPastAndCrossAttentions(
 
1199
 
1200
  def __init__(self, config: IndicTransConfig):
1201
  super().__init__(config)
1202
+
1203
  self.encoder = IndicTransEncoder(config)
1204
  self.decoder = IndicTransDecoder(config)
1205
 
 
1230
  output_hidden_states: Optional[bool] = None,
1231
  return_dict: Optional[bool] = None,
1232
  ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1233
+ output_attentions = (
1234
+ output_attentions
1235
+ if output_attentions is not None
1236
+ else self.config.output_attentions
1237
+ )
1238
  output_hidden_states = (
1239
+ output_hidden_states
1240
+ if output_hidden_states is not None
1241
+ else self.config.output_hidden_states
1242
  )
1243
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1244
+ return_dict = (
1245
+ return_dict if return_dict is not None else self.config.use_return_dict
1246
+ )
1247
 
1248
  if encoder_outputs is None:
1249
  encoder_outputs = self.encoder(
 
1298
  class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1299
  base_model_prefix = "model"
1300
  _tied_weights_keys = None
1301
+ _label_smoothing = 0.0
1302
 
1303
  def __init__(self, config: IndicTransConfig):
1304
  super().__init__(config)
1305
  self.model = IndicTransModel(config)
1306
+ self.lm_head = nn.Linear(
1307
+ config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1308
+ )
1309
 
1310
  if config.share_decoder_input_output_embed:
1311
  self.lm_head.weight = self.model.decoder.embed_tokens.weight
1312
+
1313
  self.post_init()
1314
+
1315
  def tie_weights(self):
1316
  pass
1317
 
 
1326
 
1327
  def set_output_embeddings(self, new_embeddings):
1328
  self.lm_head = new_embeddings
1329
+
1330
+ def set_label_smoothing(self, label_smoothing):
1331
+ self._label_smoothing = label_smoothing
1332
 
1333
  def forward(
1334
  self,
 
1357
 
1358
  Returns:
1359
  """
1360
+ return_dict = (
1361
+ return_dict if return_dict is not None else self.config.use_return_dict
1362
+ )
1363
 
1364
  if labels is not None:
1365
  if decoder_input_ids is None:
 
1390
  if labels is not None:
1391
  # move labels to the correct device to enable PP
1392
  labels = labels.to(lm_logits.device)
1393
+ masked_lm_loss = F.cross_entropy(
1394
+ input=lm_logits.view(-1, self.config.decoder_vocab_size),
1395
+ target=labels.view(-1),
1396
+ ignore_index=self.config.pad_token_id,
1397
+ label_smoothing=self._label_smoothing,
1398
+ )
1399
 
1400
  if not return_dict:
1401
  output = (lm_logits,) + outputs[1:]
1402
+ return (
1403
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1404
+ )
1405
 
1406
  return Seq2SeqLMOutput(
1407
  loss=masked_lm_loss,
 
1447
  def _reorder_cache(past_key_values, beam_idx):
1448
  reordered_past = ()
1449
  for layer_past in past_key_values:
1450
+ reordered_past += (
1451
+ tuple(
1452
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1453
+ ),
1454
+ )
1455
  return reordered_past