damerajee commited on
Commit
0c54f9c
1 Parent(s): af76b52

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +4 -11
modeling_Llamoe.py CHANGED
@@ -525,7 +525,7 @@ class LlamoeFlashAttention2(LlamoeAttention):
525
 
526
 
527
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
528
- class LlamoeSdpaAttention(LlamoeAttention):
529
  """
530
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
531
  `GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
@@ -564,17 +564,10 @@ class LlamoeSdpaAttention(LlamoeAttention):
564
  query_states = self.q_proj(hidden_states)
565
  key_states = self.k_proj(hidden_states)
566
  value_states = self.v_proj(hidden_states)
567
- print("query :",query_states.shape)
568
- print("key :",key_states.shape)
569
- print("value :",value_states.shape)
570
-
571
- query_states = query_states.view(bsz, self.num_heads, q_len, self.head_dim).transpose(1, 2)
572
- key_states = key_states.view(bsz, self.num_key_value_heads, q_len, self.head_dim).transpose(1, 2)
573
- value_states = value_states.view(bsz, self.num_key_value_heads, q_len, self.head_dim).transpose(1, 2)
574
- print("queryafter :",query_states.shape)
575
- print("ketafter :",key_states.shape)
576
- print("valueafter :",value_states.shape)
577
 
 
 
 
578
 
579
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
580
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
525
 
526
 
527
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
528
+ class LlamoeSdpaAttention(GemmoeAttention):
529
  """
530
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
531
  `GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
 
564
  query_states = self.q_proj(hidden_states)
565
  key_states = self.k_proj(hidden_states)
566
  value_states = self.v_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
567
 
568
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
569
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
570
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
571
 
572
  cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
573
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)