damerajee commited on
Commit
3792fda
1 Parent(s): d5806c5

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +9 -10
modeling_Llamoe.py CHANGED
@@ -646,14 +646,14 @@ class LlamoeFlashAttention2(LlamoeAttention):
646
  )
647
 
648
 
649
- class LlamoeSdpaAttention(LlamoeAttention):
650
  """
651
- Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
652
- `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
653
  SDPA API.
654
  """
655
 
656
- # Adapted from LlamaAttention.forward
657
  def forward(
658
  self,
659
  hidden_states: torch.Tensor,
@@ -667,7 +667,7 @@ class LlamoeSdpaAttention(LlamoeAttention):
667
  if output_attentions:
668
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
669
  logger.warning_once(
670
- "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
671
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
672
  )
673
  return super().forward(
@@ -690,10 +690,9 @@ class LlamoeSdpaAttention(LlamoeAttention):
690
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
691
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
692
 
693
- cos, sin = self.rotary_emb(value_states, position_ids)
694
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
695
 
696
- # In case static cache is used, it is an instance attribute.
697
  past_key_value = getattr(self, "past_key_value", past_key_value)
698
 
699
  if past_key_value is not None:
@@ -724,12 +723,12 @@ class LlamoeSdpaAttention(LlamoeAttention):
724
  )
725
 
726
  attn_output = attn_output.transpose(1, 2).contiguous()
727
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
728
 
729
  attn_output = self.o_proj(attn_output)
730
 
731
  return attn_output, None, past_key_value
732
-
733
 
734
  LLAMOE_ATTENTION_CLASSES = {
735
  "eager": LlamoeAttention,
 
646
  )
647
 
648
 
649
+ class LlamoeSdpaAttention(GemmoeAttention):
650
  """
651
+ Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
652
+ `GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
653
  SDPA API.
654
  """
655
 
656
+ # Ignore copy
657
  def forward(
658
  self,
659
  hidden_states: torch.Tensor,
 
667
  if output_attentions:
668
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
669
  logger.warning_once(
670
+ "LlamoeModel is using LlamoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
671
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
672
  )
673
  return super().forward(
 
690
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
691
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
692
 
693
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
694
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
695
 
 
696
  past_key_value = getattr(self, "past_key_value", past_key_value)
697
 
698
  if past_key_value is not None:
 
723
  )
724
 
725
  attn_output = attn_output.transpose(1, 2).contiguous()
726
+ attn_output = attn_output.view(bsz, q_len, -1)
727
 
728
  attn_output = self.o_proj(attn_output)
729
 
730
  return attn_output, None, past_key_value
731
+
732
 
733
  LLAMOE_ATTENTION_CLASSES = {
734
  "eager": LlamoeAttention,