damerajee commited on
Commit
d5806c5
1 Parent(s): 6725561

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +84 -0
modeling_Llamoe.py CHANGED
@@ -646,11 +646,95 @@ class LlamoeFlashAttention2(LlamoeAttention):
646
  )
647
 
648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
 
651
  LLAMOE_ATTENTION_CLASSES = {
652
  "eager": LlamoeAttention,
653
  "flash_attention_2": LlamoeFlashAttention2,
 
654
  }
655
 
656
 
 
646
  )
647
 
648
 
649
+ class LlamoeSdpaAttention(LlamoeAttention):
650
+ """
651
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
652
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
653
+ SDPA API.
654
+ """
655
+
656
+ # Adapted from LlamaAttention.forward
657
+ def forward(
658
+ self,
659
+ hidden_states: torch.Tensor,
660
+ attention_mask: Optional[torch.Tensor] = None,
661
+ position_ids: Optional[torch.LongTensor] = None,
662
+ past_key_value: Optional[Cache] = None,
663
+ output_attentions: bool = False,
664
+ use_cache: bool = False,
665
+ cache_position: Optional[torch.LongTensor] = None,
666
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
667
+ if output_attentions:
668
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
669
+ logger.warning_once(
670
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
671
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
672
+ )
673
+ return super().forward(
674
+ hidden_states=hidden_states,
675
+ attention_mask=attention_mask,
676
+ position_ids=position_ids,
677
+ past_key_value=past_key_value,
678
+ output_attentions=output_attentions,
679
+ use_cache=use_cache,
680
+ cache_position=cache_position,
681
+ )
682
+
683
+ bsz, q_len, _ = hidden_states.size()
684
+
685
+ query_states = self.q_proj(hidden_states)
686
+ key_states = self.k_proj(hidden_states)
687
+ value_states = self.v_proj(hidden_states)
688
+
689
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
690
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
691
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
692
+
693
+ cos, sin = self.rotary_emb(value_states, position_ids)
694
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
695
+
696
+ # In case static cache is used, it is an instance attribute.
697
+ past_key_value = getattr(self, "past_key_value", past_key_value)
698
+
699
+ if past_key_value is not None:
700
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
701
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
702
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
703
+
704
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
705
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
706
+
707
+ causal_mask = attention_mask
708
+ if attention_mask is not None and cache_position is not None:
709
+ causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
710
+
711
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
712
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
713
+ if query_states.device.type == "cuda" and causal_mask is not None:
714
+ query_states = query_states.contiguous()
715
+ key_states = key_states.contiguous()
716
+ value_states = value_states.contiguous()
717
+
718
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
719
+ query_states,
720
+ key_states,
721
+ value_states,
722
+ attn_mask=causal_mask,
723
+ dropout_p=self.attention_dropout if self.training else 0.0,
724
+ )
725
+
726
+ attn_output = attn_output.transpose(1, 2).contiguous()
727
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
728
+
729
+ attn_output = self.o_proj(attn_output)
730
+
731
+ return attn_output, None, past_key_value
732
 
733
 
734
  LLAMOE_ATTENTION_CLASSES = {
735
  "eager": LlamoeAttention,
736
  "flash_attention_2": LlamoeFlashAttention2,
737
+ "sdpa": LlamoeSdpaAttention
738
  }
739
 
740