Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +84 -0
modeling_Llamoe.py
CHANGED
@@ -646,11 +646,95 @@ class LlamoeFlashAttention2(LlamoeAttention):
|
|
646 |
)
|
647 |
|
648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
|
650 |
|
651 |
LLAMOE_ATTENTION_CLASSES = {
|
652 |
"eager": LlamoeAttention,
|
653 |
"flash_attention_2": LlamoeFlashAttention2,
|
|
|
654 |
}
|
655 |
|
656 |
|
|
|
646 |
)
|
647 |
|
648 |
|
649 |
+
class LlamoeSdpaAttention(LlamoeAttention):
|
650 |
+
"""
|
651 |
+
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
652 |
+
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
653 |
+
SDPA API.
|
654 |
+
"""
|
655 |
+
|
656 |
+
# Adapted from LlamaAttention.forward
|
657 |
+
def forward(
|
658 |
+
self,
|
659 |
+
hidden_states: torch.Tensor,
|
660 |
+
attention_mask: Optional[torch.Tensor] = None,
|
661 |
+
position_ids: Optional[torch.LongTensor] = None,
|
662 |
+
past_key_value: Optional[Cache] = None,
|
663 |
+
output_attentions: bool = False,
|
664 |
+
use_cache: bool = False,
|
665 |
+
cache_position: Optional[torch.LongTensor] = None,
|
666 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
667 |
+
if output_attentions:
|
668 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
669 |
+
logger.warning_once(
|
670 |
+
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
671 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
672 |
+
)
|
673 |
+
return super().forward(
|
674 |
+
hidden_states=hidden_states,
|
675 |
+
attention_mask=attention_mask,
|
676 |
+
position_ids=position_ids,
|
677 |
+
past_key_value=past_key_value,
|
678 |
+
output_attentions=output_attentions,
|
679 |
+
use_cache=use_cache,
|
680 |
+
cache_position=cache_position,
|
681 |
+
)
|
682 |
+
|
683 |
+
bsz, q_len, _ = hidden_states.size()
|
684 |
+
|
685 |
+
query_states = self.q_proj(hidden_states)
|
686 |
+
key_states = self.k_proj(hidden_states)
|
687 |
+
value_states = self.v_proj(hidden_states)
|
688 |
+
|
689 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
690 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
691 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
692 |
+
|
693 |
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
694 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
695 |
+
|
696 |
+
# In case static cache is used, it is an instance attribute.
|
697 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
698 |
+
|
699 |
+
if past_key_value is not None:
|
700 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
701 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
702 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
703 |
+
|
704 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
705 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
706 |
+
|
707 |
+
causal_mask = attention_mask
|
708 |
+
if attention_mask is not None and cache_position is not None:
|
709 |
+
causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
|
710 |
+
|
711 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
712 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
713 |
+
if query_states.device.type == "cuda" and causal_mask is not None:
|
714 |
+
query_states = query_states.contiguous()
|
715 |
+
key_states = key_states.contiguous()
|
716 |
+
value_states = value_states.contiguous()
|
717 |
+
|
718 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
719 |
+
query_states,
|
720 |
+
key_states,
|
721 |
+
value_states,
|
722 |
+
attn_mask=causal_mask,
|
723 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
724 |
+
)
|
725 |
+
|
726 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
727 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
728 |
+
|
729 |
+
attn_output = self.o_proj(attn_output)
|
730 |
+
|
731 |
+
return attn_output, None, past_key_value
|
732 |
|
733 |
|
734 |
LLAMOE_ATTENTION_CLASSES = {
|
735 |
"eager": LlamoeAttention,
|
736 |
"flash_attention_2": LlamoeFlashAttention2,
|
737 |
+
"sdpa": LlamoeSdpaAttention
|
738 |
}
|
739 |
|
740 |
|