Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +4 -11
modeling_Llamoe.py
CHANGED
@@ -525,7 +525,7 @@ class LlamoeFlashAttention2(LlamoeAttention):
|
|
525 |
|
526 |
|
527 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
|
528 |
-
class LlamoeSdpaAttention(
|
529 |
"""
|
530 |
Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
531 |
`GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
@@ -564,17 +564,10 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
564 |
query_states = self.q_proj(hidden_states)
|
565 |
key_states = self.k_proj(hidden_states)
|
566 |
value_states = self.v_proj(hidden_states)
|
567 |
-
print("query :",query_states.shape)
|
568 |
-
print("key :",key_states.shape)
|
569 |
-
print("value :",value_states.shape)
|
570 |
-
|
571 |
-
query_states = query_states.view(bsz, self.num_heads, q_len, self.head_dim).transpose(1, 2)
|
572 |
-
key_states = key_states.view(bsz, self.num_key_value_heads, q_len, self.head_dim).transpose(1, 2)
|
573 |
-
value_states = value_states.view(bsz, self.num_key_value_heads, q_len, self.head_dim).transpose(1, 2)
|
574 |
-
print("queryafter :",query_states.shape)
|
575 |
-
print("ketafter :",key_states.shape)
|
576 |
-
print("valueafter :",value_states.shape)
|
577 |
|
|
|
|
|
|
|
578 |
|
579 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
580 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
|
|
525 |
|
526 |
|
527 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemmoe
|
528 |
+
class LlamoeSdpaAttention(GemmoeAttention):
|
529 |
"""
|
530 |
Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
531 |
`GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
|
564 |
query_states = self.q_proj(hidden_states)
|
565 |
key_states = self.k_proj(hidden_states)
|
566 |
value_states = self.v_proj(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
|
568 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
569 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
570 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
571 |
|
572 |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
573 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|