Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +9 -10
modeling_Llamoe.py
CHANGED
@@ -646,14 +646,14 @@ class LlamoeFlashAttention2(LlamoeAttention):
|
|
646 |
)
|
647 |
|
648 |
|
649 |
-
class LlamoeSdpaAttention(
|
650 |
"""
|
651 |
-
|
652 |
-
`
|
653 |
SDPA API.
|
654 |
"""
|
655 |
|
656 |
-
#
|
657 |
def forward(
|
658 |
self,
|
659 |
hidden_states: torch.Tensor,
|
@@ -667,7 +667,7 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
667 |
if output_attentions:
|
668 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
669 |
logger.warning_once(
|
670 |
-
"
|
671 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
672 |
)
|
673 |
return super().forward(
|
@@ -690,10 +690,9 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
690 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
691 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
692 |
|
693 |
-
cos, sin = self.rotary_emb(value_states, position_ids)
|
694 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
695 |
|
696 |
-
# In case static cache is used, it is an instance attribute.
|
697 |
past_key_value = getattr(self, "past_key_value", past_key_value)
|
698 |
|
699 |
if past_key_value is not None:
|
@@ -724,12 +723,12 @@ class LlamoeSdpaAttention(LlamoeAttention):
|
|
724 |
)
|
725 |
|
726 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
727 |
-
attn_output = attn_output.view(bsz, q_len,
|
728 |
|
729 |
attn_output = self.o_proj(attn_output)
|
730 |
|
731 |
return attn_output, None, past_key_value
|
732 |
-
|
733 |
|
734 |
LLAMOE_ATTENTION_CLASSES = {
|
735 |
"eager": LlamoeAttention,
|
|
|
646 |
)
|
647 |
|
648 |
|
649 |
+
class LlamoeSdpaAttention(GemmoeAttention):
|
650 |
"""
|
651 |
+
Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
652 |
+
`GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
653 |
SDPA API.
|
654 |
"""
|
655 |
|
656 |
+
# Ignore copy
|
657 |
def forward(
|
658 |
self,
|
659 |
hidden_states: torch.Tensor,
|
|
|
667 |
if output_attentions:
|
668 |
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
669 |
logger.warning_once(
|
670 |
+
"LlamoeModel is using LlamoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
671 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
672 |
)
|
673 |
return super().forward(
|
|
|
690 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
691 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
692 |
|
693 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
694 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
695 |
|
|
|
696 |
past_key_value = getattr(self, "past_key_value", past_key_value)
|
697 |
|
698 |
if past_key_value is not None:
|
|
|
723 |
)
|
724 |
|
725 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
726 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
727 |
|
728 |
attn_output = self.o_proj(attn_output)
|
729 |
|
730 |
return attn_output, None, past_key_value
|
731 |
+
|
732 |
|
733 |
LLAMOE_ATTENTION_CLASSES = {
|
734 |
"eager": LlamoeAttention,
|