# Copyright 2024 **AUTHORS_TODO** # License: Apache-2.0 # Copyright 2022 MosaicML Examples authors # SPDX-License-Identifier: Apache-2.0 # Copyright 2023 MosaicML Examples authors # SPDX-License-Identifier: Apache-2.0 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, Tri Dao. import torch import torch.nn as nn import torch.nn.functional as F import warnings from typing import Optional import importlib.metadata import logging import math from .bert_padding import pad_input, unpad_input_only, index_first_axis from .configuration_bert import FlexBertConfig, maybe_add_padding from .normalization import get_norm_layer from .initialization import ModuleType, init_weights IMPL_USE_FLASH3 = False IMPL_USE_FLASH2 = False try: from flash_attn_interface import flash_attn_varlen_func IMPL_USE_FLASH3 = True except ImportError: pass # Import Flash Attention 2, which supports ALiBi https://github.com/Dao-AILab/flash-attention try: from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func # type: ignore installed_version = importlib.metadata.version("flash_attn") # type: ignore if installed_version < "2.5.7": raise ImportError("newer version of flash_attn required (>= 2.5.7)") IMPL_USE_FLASH2 = True except ImportError: pass try: from flash_attn.layers.rotary import RotaryEmbedding # type: ignore from .rotary import UnpaddedRotaryEmbedding # type: ignore except ImportError: RotaryEmbedding = None UnpaddedRotaryEmbedding = None logger = logging.getLogger(__name__) class BertAlibiUnpadSelfAttention(nn.Module): """Performs multi-headed self attention on a batch of unpadded sequences. If Flash Attention 2 is installed, this module uses Flash Attention to greatly improve throughput. The Flash Attention implementation used in MosaicBERT supports arbitrary attention biases (which we use to implement ALiBi). If either Flash Attention 2 is not installed the implementation will default to a math-equivalent pytorch version, which is much slower. See `forward` method for additional details. """ def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.p_dropout = config.attention_probs_dropout_prob self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size) self.deterministic_fa2 = getattr(config, "deterministic_fa2", False) # Warn if defaulting to pytorch because of import issues if not IMPL_USE_FLASH2: warnings.warn( "Unable to import flash_attn; defaulting MosaicBERT attention implementation to " "vanilla PyTorch (this will reduce throughput when using this model)." ) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, indices: torch.Tensor, attn_mask: torch.Tensor, bias: torch.Tensor, slopes: torch.Tensor, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations: vanilla attention with ALiBi, and Flash Attention 2 with ALiBi The arguments are unpadded. The vanilla implementation of attention requires padded arguments while the Flash Attention implementation does not. If using vanilla we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute. Args: hidden_states: (total_nnz, dim) cu_seqlens: (batch + 1,) max_seqlen: int indices: (total_nnz,) attn_mask: (batch, max_seqlen) bias: (batch, heads, max_seqlen, max_seqlen) slopes: (heads) or (batch, heads) Returns: attention: (total_nnz, dim) """ bs, dim = hidden_states.shape qkv = self.Wqkv(hidden_states) # Option 1: Flash Attention with ALiBi if IMPL_USE_FLASH2: qkv = qkv.view(-1, 3, self.num_attention_heads, self.attention_head_size) assert 1 <= len(slopes.shape) <= 2, f"{slopes=}" assert slopes.shape[-1] == self.num_attention_heads, f"{slopes=}" convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16 # If FA2 is supported, bfloat16 must be supported # as of FA2 2.4.2. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attention = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, alibi_slopes=slopes, causal=self.is_causal ) attention = attention.to(orig_dtype) # type: ignore else: attention = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, alibi_slopes=slopes, causal = self.is_causal ) else: assert not self.is_causal, f"causal mask not implemented here yet" assert False qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd unpad_bs, *_ = qkv.shape qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size) # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size) attention_scores = attention_scores + bias attention_probs = nn.functional.softmax(attention_scores, dim=-1) attention_probs = self.dropout(attention_probs) attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d attention = bert_padding.unpad_input_only(attention, torch.squeeze(attn_mask) == 1) return attention.view(bs, dim) # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules. class BertSelfOutput(nn.Module): """Computes the output of the attention layer. This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertSelfOutput`. The implementation is identical. Rather than use the original module directly, we re-implement it here so that Mosaic BERT's modules will not be affected by any Composer surgery algorithm that modifies Hugging Face BERT modules. """ def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = get_norm_layer(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAlibiUnpadAttention(nn.Module): """Chains attention, Dropout, and LayerNorm for Mosaic BERT.""" def __init__(self, config): super().__init__() self.self = BertAlibiUnpadSelfAttention(config) self.output = BertSelfOutput(config) def forward( self, input_tensor: torch.Tensor, cu_seqlens: torch.Tensor, max_s: int, subset_idx: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, slopes: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass for scaled self-attention without padding. Arguments: input_tensor: (total_nnz, dim) cu_seqlens: (batch + 1,) max_s: int subset_idx: () set of indices whose values we care about at the end of the layer (e.g., the masked tokens, if this is the final layer). indices: None or (total_nnz,) attn_mask: None or (batch, max_seqlen) bias: None or (batch, heads, max_seqlen, max_seqlen) slopes: None or (batch, heads) or (heads,) """ assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}" assert False self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes) if subset_idx is not None: return self.output( bert_padding.index_first_axis(self_output, subset_idx), bert_padding.index_first_axis(input_tensor, subset_idx), ) else: return self.output(self_output, input_tensor) class FlexBertAttentionBase(nn.Module): """A FlexBERT attention base class for type hints.""" def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config self.layer_id = layer_id def _init_weights(self, reset_params: bool = False): raise NotImplementedError("This is a base class and should not be used directly.") def forward(self, hidden_states: torch.Tensor, attn_mask: torch.Tensor, **kwargs) -> torch.Tensor: raise NotImplementedError("This is a base class and should not be used directly.") def extra_repr(self) -> str: repr = "" if hasattr(self, "num_attention_heads"): repr += f"num_attention_heads={self.num_attention_heads}" if hasattr(self, "attn_head_size"): repr += f", attn_head_size={self.attn_head_size}" if hasattr(self, "sliding_window"): repr += f", sliding_window={self.sliding_window if self.sliding_window != (-1, -1) else 'False'}" if hasattr(self, "use_fa2"): repr += f", use_fa2={self.use_fa2}" if hasattr(self, "deterministic_fa2"): repr += f", deterministic_fa2={self.deterministic_fa2}" return repr class FlexBertUnpadAttention(FlexBertAttentionBase): """Performs multi-headed self attention on a batch of unpadded sequences. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, which requires padding and unpadding inputs, adding some overhead. See `forward` method for additional detail. """ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__(config=config, layer_id=layer_id) if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attn_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attn_head_size self.p_dropout = config.attention_probs_dropout_prob self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) self.out_drop = ( nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() ) self.use_fa2 = config.use_fa2 self.deterministic_fa2 = config.deterministic_fa2 self.use_sdpa_attn_mask = config.use_sdpa_attn_mask if config.global_attn_every_n_layers > 0: if config.sliding_window == -1: raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") if layer_id % config.global_attn_every_n_layers != 0: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) else: self.sliding_window = (-1, -1) else: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) # Warn if defaulting to pytorch because of import issues if not IMPL_USE_FLASH2 and self.use_fa2: logger.warn_once( "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's" " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead." ) self.use_fa2 = False if not self.use_fa2: if not self.use_sdpa_attn_mask: logger.warn_once( "SDPA attention is being used without an attention mask. Including padding in the " " attention calculation may cause differences from the Flash Attention implementation." ) else: logger.warn_once( "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will" " use more memory during the backward pass. Use the FA2 backend for linear memory scaling" " with sequence length." ) if self.sliding_window[0] > 0: raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") def _init_weights(self, reset_params: bool = False): init_weights( self.config, self.Wqkv, layer_dim=self.config.hidden_size, layer_id=None, type_of_module=ModuleType.in_module, ) init_weights( self.config, self.Wo, layer_dim=self.config.hidden_size, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, indices: torch.Tensor, attn_mask: torch.Tensor, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute. Args: hidden_states: (total_nnz, dim) cu_seqlens: (batch + 1,) max_seqlen: int indices: (total_nnz,) attn_mask: (batch, max_seqlen) Returns: attention: (total_nnz, dim) """ bs, dim = hidden_states.shape qkv = self.Wqkv(hidden_states) if self.use_fa2: qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) attn = attn.view(bs, dim) else: assert not self.is_causal, f"causal mask not implemented here yet" assert False qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd unpad_bs, seqlen, _ = qkv.shape qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size) q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.p_dropout, attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen) if self.use_sdpa_attn_mask else None, ) attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1) return self.out_drop(self.Wo(attn)) class FlexBertUnpadParallelAttention(FlexBertAttentionBase): """Computes the output of the multi-headed self parallel attention on a batch of unpadded sequences If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, which requires padding and unpadding inputs, adding some overhead. See `forward` method for additional detail. """ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__(config=config, layer_id=layer_id) if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attn_head_size = int(config.hidden_size / config.num_attention_heads) self.hidden_size = config.hidden_size self.p_dropout = config.attention_probs_dropout_prob self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) self.out_drop = ( nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() ) self.use_fa2 = config.use_fa2 self.deterministic_fa2 = config.deterministic_fa2 self.use_sdpa_attn_mask = config.use_sdpa_attn_mask if config.global_attn_every_n_layers > 0: if config.sliding_window == -1: raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") if layer_id % config.global_attn_every_n_layers != 0: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) else: self.sliding_window = (-1, -1) else: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) # Warn if defaulting to pytorch because of import issues if not IMPL_USE_FLASH2 and self.use_fa2: logger.warn_once( "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's" " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead." ) self.use_fa2 = False if not self.use_fa2: if not self.use_sdpa_attn_mask: logger.warn_once( "SDPA attention is being used without an attention mask. Including padding in the " " attention calculation may cause differences from the Flash Attention implementation." ) else: logger.warn_once( "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will" " use more memory during the backward pass. Use the FA2 backend for linear memory scaling" " with sequence length." ) if self.sliding_window[0] > 0: raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") def _init_weights(self, reset_params: bool = False): init_weights( self.config, self.Wo, layer_dim=self.config.hidden_size, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) def forward( self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, indices: torch.Tensor, attn_mask: torch.Tensor, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute. Args: qkv: (total_nnz, 3 * dim) cu_seqlens: (batch + 1,) max_seqlen: int indices: (total_nnz,) attn_mask: (batch, max_seqlen) Returns: attention: (total_nnz, dim) """ bs = qkv.shape[0] dim = self.hidden_size if self.use_fa2: qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) attn = attn.view(bs, dim) else: assert not self.is_causal, f"causal mask not implemented here yet" assert False qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd unpad_bs, seqlen, _ = qkv.shape qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size) q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.p_dropout, attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen) if self.use_sdpa_attn_mask else None, ) attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1) return self.out_drop(self.Wo(attn.view(bs, dim))) class FlexBertPaddedAttention(FlexBertAttentionBase): """Performs multi-headed self attention on a batch of padded sequences. This module supports two attention implementations: 1. Flash Attention 2 (if installed), which improves throughput. 2. PyTorch's scaled_dot_product_attention. See `forward` method for additional detail. """ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__(config=config, layer_id=layer_id) if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attn_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attn_head_size self.p_dropout = config.attention_probs_dropout_prob self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) self.out_drop = ( nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() ) self.use_fa2 = config.use_fa2 self.deterministic_fa2 = config.deterministic_fa2 self.use_sdpa_attn_mask = config.use_sdpa_attn_mask if config.global_attn_every_n_layers > 0: if config.sliding_window == -1: raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") if layer_id % config.global_attn_every_n_layers != 0: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) else: self.sliding_window = (-1, -1) else: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) if not IMPL_USE_FLASH2 and self.use_fa2: self.use_fa2 = False if self.use_fa2 and self.use_sdpa_attn_mask: logger.warn_once( "Flash Attention 2 does not support attention masks. Use unpadded attention " "the equivalent functionality of masking out padding tokens." ) if not self.use_fa2 and self.sliding_window[0] > 0: raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") def _init_weights(self, reset_params: bool = False): init_weights( self.config, self.Wqkv, layer_dim=self.config.hidden_size, layer_id=None, type_of_module=ModuleType.in_module, ) init_weights( self.config, self.Wo, layer_dim=self.config.hidden_size, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) def forward( self, hidden_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations supported: Flash Attention 2 and PyTorch's scaled_dot_product_attention. Args: hidden_states: (batch, seqlen, dim) attn_mask: (batch, seqlen) Returns: attention: (batch, seqlen, dim) """ bs, seqlen, dim = hidden_states.shape qkv = self.Wqkv(hidden_states) if self.use_fa2: qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attn = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) else: assert not self.is_causal, f"causal mask not implemented here yet" assert False qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) q, k, v = qkv.transpose(3, 1).unbind(dim=2) attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.p_dropout, attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen) if self.use_sdpa_attn_mask else None, ).transpose(1, 2) attn = attn.view(bs, seqlen, dim) return self.out_drop(self.Wo(attn)) class FlexBertUnpadRopeAttention(FlexBertAttentionBase): """Performs multi-headed self attention on a batch of unpadded sequences. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, which requires padding and unpadding inputs, adding some overhead. See `forward` method for additional details. """ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__(config=config, layer_id=layer_id) if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attn_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attn_head_size self.p_dropout = config.attention_probs_dropout_prob self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) self.out_drop = ( nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() ) if config.global_attn_every_n_layers > 0: if config.sliding_window == -1: raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") if layer_id % config.global_attn_every_n_layers != 0: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) else: self.sliding_window = (-1, -1) else: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) if config.rotary_emb_dim is None: config.rotary_emb_dim = self.attn_head_size rotary_base = config.rotary_emb_base rotary_dim = config.rotary_emb_dim if self.sliding_window != (-1, -1): if config.local_attn_rotary_emb_base != -1: rotary_base = config.local_attn_rotary_emb_base if config.local_attn_rotary_emb_dim is not None: rotary_dim = config.local_attn_rotary_emb_dim assert UnpaddedRotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = UnpaddedRotaryEmbedding( dim=rotary_dim, base=rotary_base, scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). interleaved=config.rotary_emb_interleaved, ) self.use_fa2 = config.use_fa2 # flash attention 3 only supports global attention self.use_fa3 = config.use_fa2 and self.sliding_window == (-1, -1) and IMPL_USE_FLASH3 self.deterministic_fa2 = config.deterministic_fa2 self.use_sdpa_attn_mask = config.use_sdpa_attn_mask # Warn if defaulting to pytorch because of import issues if not IMPL_USE_FLASH2 and self.use_fa2: logger.warn_once( "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's" " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead." ) self.use_fa2 = False if not self.use_fa2: if not self.use_sdpa_attn_mask: logger.warn_once( "SDPA attention is being used without an attention mask. Including padding in the " " attention calculation may cause differences from the Flash Attention implementation." ) else: logger.warn_once( "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will" " use more memory during the backward pass. Use the FA2 backend for linear memory scaling" " with sequence length." ) if self.sliding_window[0] > 0: raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") def _init_weights(self, reset_params: bool = False): init_weights( self.config, self.Wqkv, layer_dim=self.config.hidden_size, layer_id=None, type_of_module=ModuleType.in_module, ) init_weights( self.config, self.Wo, layer_dim=self.config.hidden_size, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, indices: torch.Tensor, attn_mask: torch.Tensor, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute. Args: hidden_states: (total_nnz, dim) cu_seqlens: (batch + 1,) max_seqlen: int indices: (total_nnz,) attn_mask: (batch, max_seqlen) Returns: attention: (total_nnz, dim) """ bs, dim = hidden_states.shape qkv = self.Wqkv(hidden_states) # only needed for inference when we have KV cache seqlen_offset = 0 # (total_seqlen, 3, nheads, headdim) qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size) qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset) if self.use_fa3: convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) q, k, v = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size).unbind(dim=1) attn, _ = flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, deterministic=self.deterministic_fa2, causal=self.is_causal, ) attn = attn.to(orig_dtype) # type: ignore else: q, k, v = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size).unbind(dim=1) attn, _ = flash_attn_varlen_func( q=q, k=k, v=v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, deterministic=self.deterministic_fa2, causal=self.is_causal, ) attn = attn.view(bs, dim) elif self.use_fa2: convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal, ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal, ) attn = attn.view(bs, dim) else: assert not self.is_causal, f"causal mask not implemented here yet" assert False qkv = bert_padding.pad_input( qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1] ) # batch, max_seqlen, thd unpad_bs, seqlen, *_ = qkv.shape q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.p_dropout, attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen) if self.use_sdpa_attn_mask else None, ) attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1) return self.out_drop(self.Wo(attn)) class FlexBertPaddedRopeAttention(FlexBertAttentionBase): """Performs multi-headed self attention on a batch of padded sequences. This module supports two attention implementations: 1. Flash Attention 2 (if installed), which improves throughput. 2. PyTorch's scaled_dot_product_attention. See `forward` method for additional details. """ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__(config=config, layer_id=layer_id) if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attn_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attn_head_size self.p_dropout = config.attention_probs_dropout_prob self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) self.out_drop = ( nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() ) self.use_fa2 = config.use_fa2 self.deterministic_fa2 = config.deterministic_fa2 self.use_sdpa_attn_mask = config.use_sdpa_attn_mask if config.global_attn_every_n_layers > 0: if config.sliding_window == -1: raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") if layer_id % config.global_attn_every_n_layers != 0: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) else: self.sliding_window = (-1, -1) else: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) if config.rotary_emb_dim is None: config.rotary_emb_dim = self.attn_head_size rotary_base = config.rotary_emb_base rotary_dim = config.rotary_emb_dim if self.sliding_window != (-1, -1): if config.local_attn_rotary_emb_base != -1: rotary_base = config.local_attn_rotary_emb_base if config.local_attn_rotary_emb_dim is not None: rotary_dim = config.local_attn_rotary_emb_dim assert RotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = RotaryEmbedding( dim=rotary_dim, base=rotary_base, scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). interleaved=config.rotary_emb_interleaved, ) if not IMPL_USE_FLASH2 and self.use_fa2: self.use_fa2 = False if self.use_fa2 and self.use_sdpa_attn_mask: logger.warn_once( "Flash Attention 2 does not support attention masks. Use unpadded attention " "the equivalent functionality of masking out padding tokens." ) if not self.use_fa2 and self.sliding_window[0] > 0: raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") def _init_weights(self, reset_params: bool = False): init_weights( self.config, self.Wqkv, layer_dim=self.config.hidden_size, layer_id=None, type_of_module=ModuleType.in_module, ) init_weights( self.config, self.Wo, layer_dim=self.config.hidden_size, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) def forward( self, hidden_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations supported: Flash Attention 2 and PyTorch's scaled_dot_product_attention. Args: hidden_states: (batch, seqlen, dim) attn_mask: (batch, seqlen) Returns: attention: (batch, seqlen, dim) """ bs, seqlen, dim = hidden_states.shape qkv = self.Wqkv(hidden_states) seqlen_offset = 0 # Reshape to (batch, seqlen, 3, nheads, headdim) qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) if IMPL_USE_FLASH2: # Apply RoPE qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attn = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal, ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) else: assert not self.is_causal, f"causal mask not implemented here yet" assert False qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None) q, k, v = qkv.transpose(3, 1).unbind(dim=2) attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.p_dropout, attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen) if self.use_sdpa_attn_mask else None, ).transpose(1, 2) attn = attn.view(bs, seqlen, dim) return self.out_drop(self.Wo(attn)) class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase): """Performs multi-headed self attention on a batch of unpadded sequences. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel, which requires padding and unpadding inputs, adding some overhead. See `forward` method for additional details. """ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__(config=config, layer_id=layer_id) if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attn_head_size = int(config.hidden_size / config.num_attention_heads) self.hidden_size = config.hidden_size self.p_dropout = config.attention_probs_dropout_prob self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) self.out_drop = ( nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() ) if config.global_attn_every_n_layers > 0: if config.sliding_window == -1: raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") if layer_id % config.global_attn_every_n_layers != 0: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) else: self.sliding_window = (-1, -1) else: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) if config.rotary_emb_dim is None: config.rotary_emb_dim = self.attn_head_size rotary_base = config.rotary_emb_base rotary_dim = config.rotary_emb_dim if self.sliding_window != (-1, -1): if config.local_attn_rotary_emb_base != -1: rotary_base = config.local_attn_rotary_emb_base if config.local_attn_rotary_emb_dim is not None: rotary_dim = config.local_attn_rotary_emb_dim assert UnpaddedRotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = UnpaddedRotaryEmbedding( dim=rotary_dim, base=rotary_base, scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). interleaved=config.rotary_emb_interleaved, ) self.use_fa2 = config.use_fa2 self.deterministic_fa2 = config.deterministic_fa2 self.use_sdpa_attn_mask = config.use_sdpa_attn_mask # Warn if defaulting to pytorch because of import issues if not IMPL_USE_FLASH2 and self.use_fa2: logger.warn_once( "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's" " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead." ) self.use_fa2 = False if not self.use_fa2: if not self.use_sdpa_attn_mask: logger.warn_once( "SDPA attention is being used without an attention mask. Including padding in the " " attention calculation may cause differences from the Flash Attention implementation." ) else: logger.warn_once( "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will" " use more memory during the backward pass. Use the FA2 backend for linear memory scaling" " with sequence length." ) if self.sliding_window[0] > 0: raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") def _init_weights(self, reset_params: bool = False): init_weights( self.config, self.Wo, layer_dim=self.config.hidden_size, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) def forward( self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, indices: torch.Tensor, attn_mask: torch.Tensor, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2. The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute. Args: qkv: (total_nnz, 3 * dim) cu_seqlens: (batch + 1,) max_seqlen: int indices: (total_nnz,) attn_mask: (batch, max_seqlen) Returns: attention: (total_nnz, dim) """ bs = qkv.shape[0] dim = self.hidden_size # only needed for inference when we have KV cache seqlen_offset = 0 # (total_seqlen, 3, nheads, headdim) qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size) qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset) if self.use_fa2: convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal, ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal, ) attn = attn.view(bs, dim) else: assert not self.is_causal, f"causal mask not implemented here yet" assert False qkv = bert_padding.pad_input( qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1] ) # batch, max_seqlen, thd unpad_bs, seqlen, *_ = qkv.shape q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.p_dropout, attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen) if self.use_sdpa_attn_mask else None, ) attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1) return self.out_drop(self.Wo(attn)) class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase): """Performs multi-headed self attention on a batch of padded sequences. This module supports two attention implementations: 1. Flash Attention 2 (if installed), which improves throughput. 2. PyTorch's scaled_dot_product_attention. See `forward` method for additional details. """ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__(config=config, layer_id=layer_id) if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attn_head_size = int(config.hidden_size / config.num_attention_heads) self.hidden_size = config.hidden_size self.p_dropout = config.attention_probs_dropout_prob self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) self.out_drop = ( nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() ) self.use_fa2 = config.use_fa2 self.deterministic_fa2 = config.deterministic_fa2 self.use_sdpa_attn_mask = config.use_sdpa_attn_mask if not IMPL_USE_FLASH2 and self.use_fa2: self.use_fa2 = False if config.global_attn_every_n_layers > 0: if config.sliding_window == -1: raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") if layer_id % config.global_attn_every_n_layers != 0: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) else: self.sliding_window = (-1, -1) else: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) if config.rotary_emb_dim is None: config.rotary_emb_dim = self.attn_head_size rotary_base = config.rotary_emb_base rotary_dim = config.rotary_emb_dim if self.sliding_window != (-1, -1): if config.local_attn_rotary_emb_base != -1: rotary_base = config.local_attn_rotary_emb_base if config.local_attn_rotary_emb_dim is not None: rotary_dim = config.local_attn_rotary_emb_dim assert RotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = RotaryEmbedding( dim=rotary_dim, base=rotary_base, scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). interleaved=config.rotary_emb_interleaved, ) if not IMPL_USE_FLASH2 and self.use_fa2: self.use_fa2 = False if self.use_fa2 and self.use_sdpa_attn_mask: logger.warn_once( "Flash Attention 2 does not support attention masks. Use unpadded attention " "the equivalent functionality of masking out padding tokens." ) if not self.use_fa2 and self.sliding_window[0] > 0: raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") def _init_weights(self, reset_params: bool = False): init_weights( self.config, self.Wo, layer_dim=self.config.hidden_size, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) def forward( self, qkv: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations supported: Flash Attention 2 and PyTorch's scaled_dot_product_attention. Args: qkv: (batch, seqlen, 3 * dim) attn_mask: (batch, seqlen) Returns: attention: (batch, seqlen, dim) """ bs, seqlen, _ = qkv.shape dim = self.hidden_size seqlen_offset = 0 # Reshape to (batch, seqlen, 3, nheads, headdim) qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) if self.use_fa2: # Apply RoPE qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attn = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) else: assert not self.is_causal, f"causal mask not implemented here yet" assert False qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None) q, k, v = qkv.transpose(3, 1).unbind(dim=2) attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.p_dropout, attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen) if self.use_sdpa_attn_mask else None, ).transpose(1, 2) attn = attn.view(bs, seqlen, dim) return self.out_drop(self.Wo(attn)) class FlexBertPaddedParallelAttention(FlexBertAttentionBase): """Performs multi-headed self attention on a batch of padded sequences. This module supports two attention implementations: 1. Flash Attention 2 (if installed), which improves throughput. 2. PyTorch's scaled_dot_product_attention. See `forward` method for additional detail. """ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): super().__init__(config=config, layer_id=layer_id) if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})" ) self.is_causal = config.causal_mask self.num_attention_heads = config.num_attention_heads self.attn_head_size = int(config.hidden_size / config.num_attention_heads) self.hidden_size = config.hidden_size self.p_dropout = config.attention_probs_dropout_prob self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias) self.out_drop = ( nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity() ) self.use_fa2 = config.use_fa2 self.deterministic_fa2 = config.deterministic_fa2 self.use_sdpa_attn_mask = config.use_sdpa_attn_mask if config.global_attn_every_n_layers > 0: if config.sliding_window == -1: raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set") if layer_id % config.global_attn_every_n_layers != 0: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) else: self.sliding_window = (-1, -1) else: self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2) if not IMPL_USE_FLASH2 and self.use_fa2: self.use_fa2 = False if self.use_fa2 and self.use_sdpa_attn_mask: logger.warn_once( "Flash Attention 2 does not support attention masks. Use unpadded attention " "the equivalent functionality of masking out padding tokens." ) if not self.use_fa2 and self.sliding_window[0] > 0: raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.") def _init_weights(self, reset_params: bool = False): init_weights( self.config, self.Wo, layer_dim=self.config.hidden_size, layer_id=self.layer_id, type_of_module=ModuleType.out_module, ) def forward( self, qkv: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Perform self-attention. There are two attention implementations supported: Flash Attention 2 and PyTorch's scaled_dot_product_attention. Args: qkv: (batch, seqlen, 3 * dim) attn_mask: (batch, seqlen) Returns: attention: (batch, seqlen, dim) """ bs, seqlen, _ = qkv.shape dim = self.hidden_size if self.use_fa2: qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: # FA2 implementation only supports fp16 and bf16. If FA2 is supported, # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported) orig_dtype = qkv.dtype qkv = qkv.to(torch.bfloat16) attn = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) attn = attn.to(orig_dtype) # type: ignore else: attn = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, deterministic=self.deterministic_fa2, window_size=self.sliding_window, causal=self.is_causal ) else: assert not self.is_causal, f"causal attention mask not yet implemented here" assert False qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size) q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.p_dropout, attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen) if self.use_sdpa_attn_mask else None, ).transpose(1, 2) attn = attn.view(bs, seqlen, dim) return self.out_drop(self.Wo(attn)) ATTN2CLS = { "unpadded_base": FlexBertUnpadAttention, "padded_base": FlexBertPaddedAttention, "unpadded_parallel": FlexBertUnpadParallelAttention, "padded_parallel": FlexBertPaddedParallelAttention, "unpadded_rope": FlexBertUnpadRopeAttention, "padded_rope": FlexBertPaddedRopeAttention, "unpadded_rope_parallel": FlexBertUnpadRopeParallelAttention, "padded_rope_parallel": FlexBertPaddedRopeParallelAttention, } def get_attention_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertAttentionBase: try: attention_layer = ( config.initial_attention_layer if layer_id < config.num_initial_layers and getattr(config, "initial_attention_layer", None) is not None else config.attention_layer ) return ATTN2CLS[maybe_add_padding(config, attention_layer)](config, layer_id=layer_id) except KeyError: if layer_id < config.num_initial_layers and getattr(config, "initial_attention_layer", None) is not None: raise ValueError( f"Invalid attention layer type: {config.initial_attention_layer=}, must be one of {ATTN2CLS.keys()}." f"{config.padding=} will be automatically prepended to `config.attention_layer` if unspecified." ) else: raise ValueError( f"Invalid attention layer type: {config.attention_layer=}, must be one of {ATTN2CLS.keys()}. " f"{config.padding=} will be automatically prepended to `config.attention_layer` if unspecified." )