damerajee commited on
Commit
a9c0951
1 Parent(s): f9a1080

Update modeling_Llamoe.py

Browse files
Files changed (1) hide show
  1. modeling_Llamoe.py +0 -84
modeling_Llamoe.py CHANGED
@@ -646,95 +646,11 @@ class LlamoeFlashAttention2(LlamoeAttention):
646
  )
647
 
648
 
649
- class LlamoeSdpaAttention(LlamoeAttention):
650
- """
651
- Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
652
- `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
653
- SDPA API.
654
- """
655
-
656
- # Adapted from LlamaAttention.forward
657
- def forward(
658
- self,
659
- hidden_states: torch.Tensor,
660
- attention_mask: Optional[torch.Tensor] = None,
661
- position_ids: Optional[torch.LongTensor] = None,
662
- past_key_value: Optional[Cache] = None,
663
- output_attentions: bool = False,
664
- use_cache: bool = False,
665
- cache_position: Optional[torch.LongTensor] = None,
666
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
667
- if output_attentions:
668
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
669
- logger.warning_once(
670
- "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
671
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
672
- )
673
- return super().forward(
674
- hidden_states=hidden_states,
675
- attention_mask=attention_mask,
676
- position_ids=position_ids,
677
- past_key_value=past_key_value,
678
- output_attentions=output_attentions,
679
- use_cache=use_cache,
680
- cache_position=cache_position,
681
- )
682
-
683
- bsz, q_len, _ = hidden_states.size()
684
-
685
- query_states = self.q_proj(hidden_states)
686
- key_states = self.k_proj(hidden_states)
687
- value_states = self.v_proj(hidden_states)
688
-
689
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
690
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
691
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
692
-
693
- cos, sin = self.rotary_emb(value_states, position_ids)
694
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
695
-
696
- # In case static cache is used, it is an instance attribute.
697
- past_key_value = getattr(self, "past_key_value", past_key_value)
698
-
699
- if past_key_value is not None:
700
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
701
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
702
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
703
-
704
- key_states = repeat_kv(key_states, self.num_key_value_groups)
705
- value_states = repeat_kv(value_states, self.num_key_value_groups)
706
-
707
- causal_mask = attention_mask
708
- if attention_mask is not None and cache_position is not None:
709
- causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
710
-
711
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
712
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
713
- if query_states.device.type == "cuda" and causal_mask is not None:
714
- query_states = query_states.contiguous()
715
- key_states = key_states.contiguous()
716
- value_states = value_states.contiguous()
717
-
718
- attn_output = torch.nn.functional.scaled_dot_product_attention(
719
- query_states,
720
- key_states,
721
- value_states,
722
- attn_mask=causal_mask,
723
- dropout_p=self.attention_dropout if self.training else 0.0,
724
- )
725
-
726
- attn_output = attn_output.transpose(1, 2).contiguous()
727
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
728
-
729
- attn_output = self.o_proj(attn_output)
730
-
731
- return attn_output, None, past_key_value
732
 
733
 
734
  LLAMOE_ATTENTION_CLASSES = {
735
  "eager": LlamoeAttention,
736
  "flash_attention_2": LlamoeFlashAttention2,
737
- "sdpa": LlamoeSdpaAttention,
738
  }
739
 
740
 
 
646
  )
647
 
648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
 
651
  LLAMOE_ATTENTION_CLASSES = {
652
  "eager": LlamoeAttention,
653
  "flash_attention_2": LlamoeFlashAttention2,
 
654
  }
655
 
656