oweller2 commited on
Commit
204da06
1 Parent(s): c46937d

added in file

Browse files
Files changed (17) hide show
  1. README.md +2 -32
  2. __init__.py +74 -0
  3. activation.py +60 -0
  4. attention.py +1601 -0
  5. bert_padding.py +141 -0
  6. config.json +3 -1
  7. configuration_bert.py +272 -0
  8. embeddings.py +218 -0
  9. initialization.py +551 -0
  10. layers.py +700 -0
  11. mlp.py +214 -0
  12. modeling_flexbert.py +1920 -0
  13. normalization.py +116 -0
  14. options.py +32 -0
  15. padding.py +87 -0
  16. rotary.py +297 -0
  17. utils.py +38 -0
README.md CHANGED
@@ -1,33 +1,3 @@
1
  ---
2
- language:
3
- - en
4
- pipeline_tag: fill-mask
5
- ---
6
-
7
- ## How to run:
8
- Install these requirements
9
- ```
10
- pip install flash_attn
11
- pip install transformers==4.45.2 # (probably works with newer/older but tested with >=4.45.2)
12
- ```
13
- Then you can load with
14
- ```
15
- from transformers import AutoModel, AutoTokenizer
16
- import torch
17
-
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- print(f"Using device: {device}")
20
-
21
- model = AutoModel.from_pretrained("ModernBERT/bert24-base-v2-2ep-decay_100B-0.08-lr", trust_remote_code=True)
22
- model = model.to(device)
23
- tokenizer = AutoTokenizer.from_pretrained("ModernBERT/bert24-base-v2-2ep-decay_100B-0.08-lr", trust_remote_code=True)
24
-
25
- # test it out and encode some text
26
- text = "This is a test sentence"
27
- inputs = tokenizer(text, return_tensors="pt")
28
- inputs = {k: v.to(device) for k, v in inputs.items()}
29
-
30
- outputs = model(**inputs)
31
- last_hidden_states = outputs.last_hidden_state
32
- print(last_hidden_states.shape)
33
- ```
 
1
  ---
2
+ license: apache-2.0
3
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__init__.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import (
2
+ BertAlibiUnpadAttention,
3
+ BertAlibiUnpadSelfAttention,
4
+ BertSelfOutput,
5
+ FlexBertPaddedAttention,
6
+ FlexBertUnpadAttention,
7
+ )
8
+ from .embeddings import (
9
+ BertAlibiEmbeddings,
10
+ FlexBertAbsoluteEmbeddings,
11
+ FlexBertSansPositionEmbeddings,
12
+ )
13
+ from .layers import (
14
+ BertAlibiEncoder,
15
+ BertAlibiLayer,
16
+ BertResidualGLU,
17
+ FlexBertPaddedPreNormLayer,
18
+ FlexBertPaddedPostNormLayer,
19
+ FlexBertUnpadPostNormLayer,
20
+ FlexBertUnpadPreNormLayer,
21
+ )
22
+ from .modeling_flexbert import (
23
+ BertLMPredictionHead,
24
+ BertModel,
25
+ BertForMaskedLM,
26
+ BertForSequenceClassification,
27
+ BertForMultipleChoice,
28
+ BertOnlyMLMHead,
29
+ BertOnlyNSPHead,
30
+ BertPooler,
31
+ BertPredictionHeadTransform,
32
+ FlexBertModel,
33
+ FlexBertForMaskedLM,
34
+ FlexBertForSequenceClassification,
35
+ FlexBertForMultipleChoice,
36
+ FlexBertForCasualLM,
37
+ )
38
+ from .bert_padding import(
39
+ IndexFirstAxis,
40
+ IndexPutFirstAxis
41
+ )
42
+
43
+ __all__ = [
44
+ "BertAlibiEmbeddings",
45
+ "BertAlibiEncoder",
46
+ "BertForMaskedLM",
47
+ "BertForSequenceClassification",
48
+ "BertForMultipleChoice",
49
+ "BertResidualGLU",
50
+ "BertAlibiLayer",
51
+ "BertLMPredictionHead",
52
+ "BertModel",
53
+ "BertOnlyMLMHead",
54
+ "BertOnlyNSPHead",
55
+ "BertPooler",
56
+ "BertPredictionHeadTransform",
57
+ "BertSelfOutput",
58
+ "BertAlibiUnpadAttention",
59
+ "BertAlibiUnpadSelfAttention",
60
+ "FlexBertPaddedAttention",
61
+ "FlexBertUnpadAttention",
62
+ "FlexBertAbsoluteEmbeddings",
63
+ "FlexBertSansPositionEmbeddings",
64
+ "FlexBertPaddedPreNormLayer",
65
+ "FlexBertPaddedPostNormLayer",
66
+ "FlexBertUnpadPostNormLayer",
67
+ "FlexBertUnpadPreNormLayer",
68
+ "FlexBertModel",
69
+ "FlexBertForMaskedLM",
70
+ "FlexBertForSequenceClassification",
71
+ "FlexBertForMultipleChoice",
72
+ "IndexFirstAxis",
73
+ "IndexPutFirstAxis"
74
+ ]
activation.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2020 The HuggingFace Team.
5
+ # License: Apache-2.0
6
+
7
+ from collections import OrderedDict
8
+ from typing import Union
9
+ import torch.nn as nn
10
+ from .configuration_bert import FlexBertConfig
11
+
12
+
13
+ class ClassInstantier(OrderedDict):
14
+ def __getitem__(self, key):
15
+ content = super().__getitem__(key)
16
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
17
+ return cls(**kwargs)
18
+
19
+
20
+ ACT2CLS = {
21
+ "celu": nn.CELU,
22
+ "elu": nn.ELU,
23
+ "gelu": nn.GELU,
24
+ "gelu_tanh": (nn.GELU, {"approximate": "tanh"}),
25
+ "hardtanh": nn.Hardtanh,
26
+ "hardsigmoid": nn.Hardsigmoid,
27
+ "hardshrink": nn.Hardshrink,
28
+ "hardswish": nn.Hardswish,
29
+ "leaky_relu": nn.LeakyReLU,
30
+ "logsigmoid": nn.LogSigmoid,
31
+ "mish": nn.Mish,
32
+ "prelu": nn.PReLU,
33
+ "relu": nn.ReLU,
34
+ "relu6": nn.ReLU6,
35
+ "rrelu": nn.RReLU,
36
+ "selu": nn.SELU,
37
+ "sigmoid": nn.Sigmoid,
38
+ "silu": nn.SiLU,
39
+ "softmin": nn.Softmin,
40
+ "softplus": nn.Softplus,
41
+ "softshrink": nn.Softshrink,
42
+ "softsign": nn.Softsign,
43
+ "swish": nn.SiLU,
44
+ "tanh": nn.Tanh,
45
+ "tanhshrink": nn.Tanhshrink,
46
+ "threshold": nn.Threshold,
47
+ }
48
+ ACT2FN = ClassInstantier(ACT2CLS)
49
+
50
+
51
+ def get_act_fn(config: Union[FlexBertConfig, str]) -> nn.Module:
52
+ try:
53
+ if isinstance(config, str):
54
+ return ACT2FN[config]
55
+ return ACT2FN[config.hidden_act]
56
+ except KeyError:
57
+ if isinstance(config, str):
58
+ raise ValueError(f"Invalid activation function type: {config}, must be one of {ACT2FN.keys()}.")
59
+ else:
60
+ raise ValueError(f"Invalid activation function type: {config.hidden_act=}, must be one of {ACT2FN.keys()}.")
attention.py ADDED
@@ -0,0 +1,1601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2022 MosaicML Examples authors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # Copyright 2023 MosaicML Examples authors
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
11
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import warnings
19
+ from typing import Optional
20
+ import importlib.metadata
21
+ import logging
22
+ import math
23
+
24
+ import bert_padding
25
+ from .configuration_bert import FlexBertConfig, maybe_add_padding
26
+ from .normalization import get_norm_layer
27
+ from .initialization import ModuleType, init_weights
28
+ import src.utils # noqa: F401
29
+
30
+ IMPL_USE_FLASH3 = False
31
+ IMPL_USE_FLASH2 = False
32
+ try:
33
+ from flash_attn_interface import flash_attn_varlen_func
34
+
35
+ IMPL_USE_FLASH3 = True
36
+ except ImportError:
37
+ pass
38
+ # Import Flash Attention 2, which supports ALiBi https://github.com/Dao-AILab/flash-attention
39
+ try:
40
+ from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func # type: ignore
41
+
42
+ installed_version = importlib.metadata.version("flash_attn") # type: ignore
43
+ if installed_version < "2.5.7":
44
+ raise ImportError("newer version of flash_attn required (>= 2.5.7)")
45
+ IMPL_USE_FLASH2 = True
46
+ except ImportError:
47
+ pass
48
+
49
+ try:
50
+ from flash_attn.layers.rotary import RotaryEmbedding # type: ignore
51
+ from .rotary import UnpaddedRotaryEmbedding # type: ignore
52
+
53
+ except ImportError:
54
+ RotaryEmbedding = None
55
+ UnpaddedRotaryEmbedding = None
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ class BertAlibiUnpadSelfAttention(nn.Module):
61
+ """Performs multi-headed self attention on a batch of unpadded sequences.
62
+
63
+ If Flash Attention 2 is installed, this module uses Flash Attention to greatly improve throughput.
64
+ The Flash Attention implementation used in MosaicBERT supports arbitrary attention biases (which
65
+ we use to implement ALiBi). If either Flash Attention 2 is not installed the implementation will
66
+ default to a math-equivalent pytorch version, which is much slower.
67
+
68
+ See `forward` method for additional details.
69
+ """
70
+
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
74
+ raise ValueError(
75
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
76
+ f"heads ({config.num_attention_heads})"
77
+ )
78
+
79
+ self.is_casual = config.casual_mask
80
+ self.num_attention_heads = config.num_attention_heads
81
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
82
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
83
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
84
+ self.p_dropout = config.attention_probs_dropout_prob
85
+ self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
86
+ self.deterministic_fa2 = getattr(config, "deterministic_fa2", False)
87
+
88
+ # Warn if defaulting to pytorch because of import issues
89
+ if not IMPL_USE_FLASH2:
90
+ warnings.warn(
91
+ "Unable to import flash_attn; defaulting MosaicBERT attention implementation to "
92
+ "vanilla PyTorch (this will reduce throughput when using this model)."
93
+ )
94
+
95
+ def forward(
96
+ self,
97
+ hidden_states: torch.Tensor,
98
+ cu_seqlens: torch.Tensor,
99
+ max_seqlen: int,
100
+ indices: torch.Tensor,
101
+ attn_mask: torch.Tensor,
102
+ bias: torch.Tensor,
103
+ slopes: torch.Tensor,
104
+ ) -> torch.Tensor:
105
+ """Perform self-attention.
106
+
107
+ There are two attention implementations: vanilla attention with ALiBi, and Flash Attention 2 with ALiBi
108
+
109
+ The arguments are unpadded. The vanilla implementation of attention requires padded arguments while the
110
+ Flash Attention implementation does not. If using vanilla we first call `pad_input`. Once we compute
111
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
112
+ sending pad tokens through ffs saves compute.
113
+
114
+ Args:
115
+ hidden_states: (total_nnz, dim)
116
+ cu_seqlens: (batch + 1,)
117
+ max_seqlen: int
118
+ indices: (total_nnz,)
119
+ attn_mask: (batch, max_seqlen)
120
+ bias: (batch, heads, max_seqlen, max_seqlen)
121
+ slopes: (heads) or (batch, heads)
122
+
123
+ Returns:
124
+ attention: (total_nnz, dim)
125
+ """
126
+ bs, dim = hidden_states.shape
127
+ qkv = self.Wqkv(hidden_states)
128
+
129
+ # Option 1: Flash Attention with ALiBi
130
+ if IMPL_USE_FLASH2:
131
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attention_head_size)
132
+ assert 1 <= len(slopes.shape) <= 2, f"{slopes=}"
133
+ assert slopes.shape[-1] == self.num_attention_heads, f"{slopes=}"
134
+
135
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
136
+ if convert_dtype:
137
+ # FA2 implementation only supports fp16 and bf16
138
+ # If FA2 is supported, bfloat16 must be supported
139
+ # as of FA2 2.4.2. (Turing GPUs not supported)
140
+ orig_dtype = qkv.dtype
141
+ qkv = qkv.to(torch.bfloat16)
142
+
143
+ attention = flash_attn_varlen_qkvpacked_func(
144
+ qkv,
145
+ cu_seqlens=cu_seqlens,
146
+ max_seqlen=max_seqlen,
147
+ dropout_p=self.p_dropout,
148
+ deterministic=self.deterministic_fa2,
149
+ alibi_slopes=slopes,
150
+ casual=self.is_casual
151
+ )
152
+ attention = attention.to(orig_dtype) # type: ignore
153
+ else:
154
+ attention = flash_attn_varlen_qkvpacked_func(
155
+ qkv,
156
+ cu_seqlens=cu_seqlens,
157
+ max_seqlen=max_seqlen,
158
+ dropout_p=self.p_dropout,
159
+ deterministic=self.deterministic_fa2,
160
+ alibi_slopes=slopes,
161
+ casual = self.is_casual
162
+ )
163
+ else:
164
+ assert not self.is_casual, f"Casual mask not implemented here yet"
165
+ qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
166
+ unpad_bs, *_ = qkv.shape
167
+ qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
168
+ # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
169
+ q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
170
+ k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
171
+ v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
172
+ attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size)
173
+ attention_scores = attention_scores + bias
174
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
175
+ attention_probs = self.dropout(attention_probs)
176
+ attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
177
+
178
+ attention = bert_padding.unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
179
+
180
+ return attention.view(bs, dim)
181
+
182
+
183
+ # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
184
+ class BertSelfOutput(nn.Module):
185
+ """Computes the output of the attention layer.
186
+
187
+ This module is modeled after the Hugging Face BERT's
188
+ :class:`~transformers.model.bert.modeling_bert.BertSelfOutput`.
189
+ The implementation is identical. Rather than use the original module
190
+ directly, we re-implement it here so that Mosaic BERT's modules will not
191
+ be affected by any Composer surgery algorithm that modifies Hugging Face
192
+ BERT modules.
193
+ """
194
+
195
+ def __init__(self, config):
196
+ super().__init__()
197
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
198
+ self.LayerNorm = get_norm_layer(config)
199
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
200
+
201
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
202
+ hidden_states = self.dense(hidden_states)
203
+ hidden_states = self.dropout(hidden_states)
204
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
205
+ return hidden_states
206
+
207
+
208
+ class BertAlibiUnpadAttention(nn.Module):
209
+ """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
210
+
211
+ def __init__(self, config):
212
+ super().__init__()
213
+ self.self = BertAlibiUnpadSelfAttention(config)
214
+ self.output = BertSelfOutput(config)
215
+
216
+ def forward(
217
+ self,
218
+ input_tensor: torch.Tensor,
219
+ cu_seqlens: torch.Tensor,
220
+ max_s: int,
221
+ subset_idx: Optional[torch.Tensor] = None,
222
+ indices: Optional[torch.Tensor] = None,
223
+ attn_mask: Optional[torch.Tensor] = None,
224
+ bias: Optional[torch.Tensor] = None,
225
+ slopes: Optional[torch.Tensor] = None,
226
+ ) -> torch.Tensor:
227
+ """Forward pass for scaled self-attention without padding.
228
+
229
+ Arguments:
230
+ input_tensor: (total_nnz, dim)
231
+ cu_seqlens: (batch + 1,)
232
+ max_s: int
233
+ subset_idx: () set of indices whose values we care about at the end of the layer
234
+ (e.g., the masked tokens, if this is the final layer).
235
+ indices: None or (total_nnz,)
236
+ attn_mask: None or (batch, max_seqlen)
237
+ bias: None or (batch, heads, max_seqlen, max_seqlen)
238
+ slopes: None or (batch, heads) or (heads,)
239
+ """
240
+ assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
241
+ self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
242
+ if subset_idx is not None:
243
+ return self.output(
244
+ bert_padding.index_first_axis(self_output, subset_idx),
245
+ bert_padding.index_first_axis(input_tensor, subset_idx),
246
+ )
247
+ else:
248
+ return self.output(self_output, input_tensor)
249
+
250
+
251
+ class FlexBertAttentionBase(nn.Module):
252
+ """A FlexBERT attention base class for type hints."""
253
+
254
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
255
+ super().__init__()
256
+ self.config = config
257
+ self.layer_id = layer_id
258
+
259
+ def _init_weights(self, reset_params: bool = False):
260
+ raise NotImplementedError("This is a base class and should not be used directly.")
261
+
262
+ def forward(self, hidden_states: torch.Tensor, attn_mask: torch.Tensor, **kwargs) -> torch.Tensor:
263
+ raise NotImplementedError("This is a base class and should not be used directly.")
264
+
265
+ def extra_repr(self) -> str:
266
+ repr = ""
267
+ if hasattr(self, "num_attention_heads"):
268
+ repr += f"num_attention_heads={self.num_attention_heads}"
269
+ if hasattr(self, "attn_head_size"):
270
+ repr += f", attn_head_size={self.attn_head_size}"
271
+ if hasattr(self, "sliding_window"):
272
+ repr += f", sliding_window={self.sliding_window if self.sliding_window != (-1, -1) else 'False'}"
273
+ if hasattr(self, "use_fa2"):
274
+ repr += f", use_fa2={self.use_fa2}"
275
+ if hasattr(self, "deterministic_fa2"):
276
+ repr += f", deterministic_fa2={self.deterministic_fa2}"
277
+ return repr
278
+
279
+
280
+ class FlexBertUnpadAttention(FlexBertAttentionBase):
281
+ """Performs multi-headed self attention on a batch of unpadded sequences.
282
+
283
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
284
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
285
+ which requires padding and unpadding inputs, adding some overhead.
286
+
287
+ See `forward` method for additional detail.
288
+ """
289
+
290
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
291
+ super().__init__(config=config, layer_id=layer_id)
292
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
293
+ raise ValueError(
294
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
295
+ f"heads ({config.num_attention_heads})"
296
+ )
297
+
298
+ self.is_casual = config.casual_mask
299
+ self.num_attention_heads = config.num_attention_heads
300
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
301
+ self.all_head_size = self.num_attention_heads * self.attn_head_size
302
+ self.p_dropout = config.attention_probs_dropout_prob
303
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias)
304
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
305
+ self.out_drop = (
306
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
307
+ )
308
+ self.use_fa2 = config.use_fa2
309
+ self.deterministic_fa2 = config.deterministic_fa2
310
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
311
+
312
+ if config.global_attn_every_n_layers > 0:
313
+ if config.sliding_window == -1:
314
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
315
+ if layer_id % config.global_attn_every_n_layers != 0:
316
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
317
+ else:
318
+ self.sliding_window = (-1, -1)
319
+ else:
320
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
321
+
322
+ # Warn if defaulting to pytorch because of import issues
323
+ if not IMPL_USE_FLASH2 and self.use_fa2:
324
+ logger.warn_once(
325
+ "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's"
326
+ " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead."
327
+ )
328
+ self.use_fa2 = False
329
+ if not self.use_fa2:
330
+ if not self.use_sdpa_attn_mask:
331
+ logger.warn_once(
332
+ "SDPA attention is being used without an attention mask. Including padding in the "
333
+ " attention calculation may cause differences from the Flash Attention implementation."
334
+ )
335
+ else:
336
+ logger.warn_once(
337
+ "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will"
338
+ " use more memory during the backward pass. Use the FA2 backend for linear memory scaling"
339
+ " with sequence length."
340
+ )
341
+ if self.sliding_window[0] > 0:
342
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
343
+
344
+ def _init_weights(self, reset_params: bool = False):
345
+ init_weights(
346
+ self.config,
347
+ self.Wqkv,
348
+ layer_dim=self.config.hidden_size,
349
+ layer_id=None,
350
+ type_of_module=ModuleType.in_module,
351
+ )
352
+ init_weights(
353
+ self.config,
354
+ self.Wo,
355
+ layer_dim=self.config.hidden_size,
356
+ layer_id=self.layer_id,
357
+ type_of_module=ModuleType.out_module,
358
+ )
359
+
360
+ def forward(
361
+ self,
362
+ hidden_states: torch.Tensor,
363
+ cu_seqlens: torch.Tensor,
364
+ max_seqlen: int,
365
+ indices: torch.Tensor,
366
+ attn_mask: torch.Tensor,
367
+ ) -> torch.Tensor:
368
+ """Perform self-attention.
369
+
370
+ There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2.
371
+
372
+ The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the
373
+ Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute
374
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
375
+ sending pad tokens through ffs saves compute.
376
+
377
+ Args:
378
+ hidden_states: (total_nnz, dim)
379
+ cu_seqlens: (batch + 1,)
380
+ max_seqlen: int
381
+ indices: (total_nnz,)
382
+ attn_mask: (batch, max_seqlen)
383
+
384
+ Returns:
385
+ attention: (total_nnz, dim)
386
+ """
387
+ bs, dim = hidden_states.shape
388
+ qkv = self.Wqkv(hidden_states)
389
+
390
+ if self.use_fa2:
391
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
392
+
393
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
394
+ if convert_dtype:
395
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
396
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
397
+ orig_dtype = qkv.dtype
398
+ qkv = qkv.to(torch.bfloat16)
399
+
400
+ attn = flash_attn_varlen_qkvpacked_func(
401
+ qkv,
402
+ cu_seqlens=cu_seqlens,
403
+ max_seqlen=max_seqlen,
404
+ dropout_p=self.p_dropout,
405
+ deterministic=self.deterministic_fa2,
406
+ window_size=self.sliding_window,
407
+ casual=self.is_casual
408
+ )
409
+ attn = attn.to(orig_dtype) # type: ignore
410
+ else:
411
+ attn = flash_attn_varlen_qkvpacked_func(
412
+ qkv,
413
+ cu_seqlens=cu_seqlens,
414
+ max_seqlen=max_seqlen,
415
+ dropout_p=self.p_dropout,
416
+ deterministic=self.deterministic_fa2,
417
+ window_size=self.sliding_window,
418
+ casual=self.is_casual
419
+ )
420
+ attn = attn.view(bs, dim)
421
+ else:
422
+ assert not self.is_casual, f"Casual mask not implemented here yet"
423
+ qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
424
+ unpad_bs, seqlen, _ = qkv.shape
425
+
426
+ qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
427
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
428
+ attn = F.scaled_dot_product_attention(
429
+ q,
430
+ k,
431
+ v,
432
+ dropout_p=self.p_dropout,
433
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen)
434
+ if self.use_sdpa_attn_mask
435
+ else None,
436
+ )
437
+ attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
438
+ attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
439
+
440
+ return self.out_drop(self.Wo(attn))
441
+
442
+
443
+ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
444
+ """Computes the output of the multi-headed self parallel attention on a batch of unpadded sequences
445
+
446
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
447
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
448
+ which requires padding and unpadding inputs, adding some overhead.
449
+
450
+ See `forward` method for additional detail.
451
+ """
452
+
453
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
454
+ super().__init__(config=config, layer_id=layer_id)
455
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
456
+ raise ValueError(
457
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
458
+ f"heads ({config.num_attention_heads})"
459
+ )
460
+
461
+ self.is_casual = config.casual_mask
462
+ self.num_attention_heads = config.num_attention_heads
463
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
464
+ self.hidden_size = config.hidden_size
465
+ self.p_dropout = config.attention_probs_dropout_prob
466
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
467
+ self.out_drop = (
468
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
469
+ )
470
+ self.use_fa2 = config.use_fa2
471
+ self.deterministic_fa2 = config.deterministic_fa2
472
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
473
+
474
+ if config.global_attn_every_n_layers > 0:
475
+ if config.sliding_window == -1:
476
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
477
+ if layer_id % config.global_attn_every_n_layers != 0:
478
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
479
+ else:
480
+ self.sliding_window = (-1, -1)
481
+ else:
482
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
483
+
484
+ # Warn if defaulting to pytorch because of import issues
485
+ if not IMPL_USE_FLASH2 and self.use_fa2:
486
+ logger.warn_once(
487
+ "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's"
488
+ " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead."
489
+ )
490
+ self.use_fa2 = False
491
+ if not self.use_fa2:
492
+ if not self.use_sdpa_attn_mask:
493
+ logger.warn_once(
494
+ "SDPA attention is being used without an attention mask. Including padding in the "
495
+ " attention calculation may cause differences from the Flash Attention implementation."
496
+ )
497
+ else:
498
+ logger.warn_once(
499
+ "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will"
500
+ " use more memory during the backward pass. Use the FA2 backend for linear memory scaling"
501
+ " with sequence length."
502
+ )
503
+ if self.sliding_window[0] > 0:
504
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
505
+
506
+ def _init_weights(self, reset_params: bool = False):
507
+ init_weights(
508
+ self.config,
509
+ self.Wo,
510
+ layer_dim=self.config.hidden_size,
511
+ layer_id=self.layer_id,
512
+ type_of_module=ModuleType.out_module,
513
+ )
514
+
515
+ def forward(
516
+ self,
517
+ qkv: torch.Tensor,
518
+ cu_seqlens: torch.Tensor,
519
+ max_seqlen: int,
520
+ indices: torch.Tensor,
521
+ attn_mask: torch.Tensor,
522
+ ) -> torch.Tensor:
523
+ """Perform self-attention.
524
+
525
+ There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2.
526
+
527
+ The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the
528
+ Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute
529
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
530
+ sending pad tokens through ffs saves compute.
531
+
532
+ Args:
533
+ qkv: (total_nnz, 3 * dim)
534
+ cu_seqlens: (batch + 1,)
535
+ max_seqlen: int
536
+ indices: (total_nnz,)
537
+ attn_mask: (batch, max_seqlen)
538
+
539
+ Returns:
540
+ attention: (total_nnz, dim)
541
+ """
542
+ bs = qkv.shape[0]
543
+ dim = self.hidden_size
544
+ if self.use_fa2:
545
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
546
+
547
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
548
+ if convert_dtype:
549
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
550
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
551
+ orig_dtype = qkv.dtype
552
+ qkv = qkv.to(torch.bfloat16)
553
+
554
+ attn = flash_attn_varlen_qkvpacked_func(
555
+ qkv,
556
+ cu_seqlens=cu_seqlens,
557
+ max_seqlen=max_seqlen,
558
+ dropout_p=self.p_dropout,
559
+ deterministic=self.deterministic_fa2,
560
+ window_size=self.sliding_window,
561
+ casual=self.is_casual
562
+ )
563
+ attn = attn.to(orig_dtype) # type: ignore
564
+ else:
565
+ attn = flash_attn_varlen_qkvpacked_func(
566
+ qkv,
567
+ cu_seqlens=cu_seqlens,
568
+ max_seqlen=max_seqlen,
569
+ dropout_p=self.p_dropout,
570
+ deterministic=self.deterministic_fa2,
571
+ window_size=self.sliding_window,
572
+ casual=self.is_casual
573
+ )
574
+ attn = attn.view(bs, dim)
575
+ else:
576
+ assert not self.is_casual, f"Casual mask not implemented here yet"
577
+ qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
578
+ unpad_bs, seqlen, _ = qkv.shape
579
+
580
+ qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
581
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
582
+ attn = F.scaled_dot_product_attention(
583
+ q,
584
+ k,
585
+ v,
586
+ dropout_p=self.p_dropout,
587
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen)
588
+ if self.use_sdpa_attn_mask
589
+ else None,
590
+ )
591
+ attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
592
+ attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
593
+
594
+ return self.out_drop(self.Wo(attn.view(bs, dim)))
595
+
596
+
597
+ class FlexBertPaddedAttention(FlexBertAttentionBase):
598
+ """Performs multi-headed self attention on a batch of padded sequences.
599
+
600
+ This module supports two attention implementations:
601
+ 1. Flash Attention 2 (if installed), which improves throughput.
602
+ 2. PyTorch's scaled_dot_product_attention.
603
+
604
+ See `forward` method for additional detail.
605
+ """
606
+
607
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
608
+ super().__init__(config=config, layer_id=layer_id)
609
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
610
+ raise ValueError(
611
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
612
+ f"heads ({config.num_attention_heads})"
613
+ )
614
+
615
+ self.is_casual = config.casual_mask
616
+ self.num_attention_heads = config.num_attention_heads
617
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
618
+ self.all_head_size = self.num_attention_heads * self.attn_head_size
619
+ self.p_dropout = config.attention_probs_dropout_prob
620
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias)
621
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
622
+ self.out_drop = (
623
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
624
+ )
625
+ self.use_fa2 = config.use_fa2
626
+ self.deterministic_fa2 = config.deterministic_fa2
627
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
628
+
629
+ if config.global_attn_every_n_layers > 0:
630
+ if config.sliding_window == -1:
631
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
632
+ if layer_id % config.global_attn_every_n_layers != 0:
633
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
634
+ else:
635
+ self.sliding_window = (-1, -1)
636
+ else:
637
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
638
+
639
+ if not IMPL_USE_FLASH2 and self.use_fa2:
640
+ self.use_fa2 = False
641
+ if self.use_fa2 and self.use_sdpa_attn_mask:
642
+ logger.warn_once(
643
+ "Flash Attention 2 does not support attention masks. Use unpadded attention "
644
+ "the equivalent functionality of masking out padding tokens."
645
+ )
646
+ if not self.use_fa2 and self.sliding_window[0] > 0:
647
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
648
+
649
+ def _init_weights(self, reset_params: bool = False):
650
+ init_weights(
651
+ self.config,
652
+ self.Wqkv,
653
+ layer_dim=self.config.hidden_size,
654
+ layer_id=None,
655
+ type_of_module=ModuleType.in_module,
656
+ )
657
+ init_weights(
658
+ self.config,
659
+ self.Wo,
660
+ layer_dim=self.config.hidden_size,
661
+ layer_id=self.layer_id,
662
+ type_of_module=ModuleType.out_module,
663
+ )
664
+
665
+ def forward(
666
+ self,
667
+ hidden_states: torch.Tensor,
668
+ attn_mask: Optional[torch.Tensor] = None,
669
+ ) -> torch.Tensor:
670
+ """Perform self-attention.
671
+
672
+ There are two attention implementations supported:
673
+ Flash Attention 2 and PyTorch's scaled_dot_product_attention.
674
+
675
+ Args:
676
+ hidden_states: (batch, seqlen, dim)
677
+ attn_mask: (batch, seqlen)
678
+
679
+ Returns:
680
+ attention: (batch, seqlen, dim)
681
+ """
682
+ bs, seqlen, dim = hidden_states.shape
683
+ qkv = self.Wqkv(hidden_states)
684
+
685
+ if self.use_fa2:
686
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
687
+
688
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
689
+ if convert_dtype:
690
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
691
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
692
+ orig_dtype = qkv.dtype
693
+ qkv = qkv.to(torch.bfloat16)
694
+
695
+ attn = flash_attn_qkvpacked_func(
696
+ qkv,
697
+ dropout_p=self.p_dropout,
698
+ deterministic=self.deterministic_fa2,
699
+ window_size=self.sliding_window,
700
+ casual=self.is_casual
701
+ )
702
+ attn = attn.to(orig_dtype) # type: ignore
703
+ else:
704
+ attn = flash_attn_qkvpacked_func(
705
+ qkv,
706
+ dropout_p=self.p_dropout,
707
+ deterministic=self.deterministic_fa2,
708
+ window_size=self.sliding_window,
709
+ casual=self.is_casual
710
+ )
711
+ else:
712
+ assert not self.is_casual, f"Casual mask not implemented here yet"
713
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
714
+
715
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2)
716
+ attn = F.scaled_dot_product_attention(
717
+ q,
718
+ k,
719
+ v,
720
+ dropout_p=self.p_dropout,
721
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen)
722
+ if self.use_sdpa_attn_mask
723
+ else None,
724
+ ).transpose(1, 2)
725
+
726
+ attn = attn.view(bs, seqlen, dim)
727
+ return self.out_drop(self.Wo(attn))
728
+
729
+
730
+ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
731
+ """Performs multi-headed self attention on a batch of unpadded sequences.
732
+
733
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
734
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
735
+ which requires padding and unpadding inputs, adding some overhead.
736
+
737
+ See `forward` method for additional details.
738
+ """
739
+
740
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
741
+ super().__init__(config=config, layer_id=layer_id)
742
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
743
+ raise ValueError(
744
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
745
+ f"heads ({config.num_attention_heads})"
746
+ )
747
+
748
+ self.is_casual = config.casual_mask
749
+ self.num_attention_heads = config.num_attention_heads
750
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
751
+ self.all_head_size = self.num_attention_heads * self.attn_head_size
752
+ self.p_dropout = config.attention_probs_dropout_prob
753
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias)
754
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
755
+ self.out_drop = (
756
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
757
+ )
758
+
759
+ if config.global_attn_every_n_layers > 0:
760
+ if config.sliding_window == -1:
761
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
762
+ if layer_id % config.global_attn_every_n_layers != 0:
763
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
764
+ else:
765
+ self.sliding_window = (-1, -1)
766
+ else:
767
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
768
+
769
+ if config.rotary_emb_dim is None:
770
+ config.rotary_emb_dim = self.attn_head_size
771
+
772
+ rotary_base = config.rotary_emb_base
773
+ rotary_dim = config.rotary_emb_dim
774
+ if self.sliding_window != (-1, -1):
775
+ if config.local_attn_rotary_emb_base != -1:
776
+ rotary_base = config.local_attn_rotary_emb_base
777
+ if config.local_attn_rotary_emb_dim is not None:
778
+ rotary_dim = config.local_attn_rotary_emb_dim
779
+
780
+ assert UnpaddedRotaryEmbedding is not None, "rotary_emb is not installed"
781
+ self.rotary_emb = UnpaddedRotaryEmbedding(
782
+ dim=rotary_dim,
783
+ base=rotary_base,
784
+ scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
785
+ interleaved=config.rotary_emb_interleaved,
786
+ )
787
+
788
+ self.use_fa2 = config.use_fa2
789
+ # flash attention 3 only supports global attention
790
+ self.use_fa3 = config.use_fa2 and self.sliding_window == (-1, -1) and IMPL_USE_FLASH3
791
+ self.deterministic_fa2 = config.deterministic_fa2
792
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
793
+
794
+ # Warn if defaulting to pytorch because of import issues
795
+ if not IMPL_USE_FLASH2 and self.use_fa2:
796
+ logger.warn_once(
797
+ "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's"
798
+ " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead."
799
+ )
800
+ self.use_fa2 = False
801
+ if not self.use_fa2:
802
+ if not self.use_sdpa_attn_mask:
803
+ logger.warn_once(
804
+ "SDPA attention is being used without an attention mask. Including padding in the "
805
+ " attention calculation may cause differences from the Flash Attention implementation."
806
+ )
807
+ else:
808
+ logger.warn_once(
809
+ "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will"
810
+ " use more memory during the backward pass. Use the FA2 backend for linear memory scaling"
811
+ " with sequence length."
812
+ )
813
+ if self.sliding_window[0] > 0:
814
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
815
+
816
+ def _init_weights(self, reset_params: bool = False):
817
+ init_weights(
818
+ self.config,
819
+ self.Wqkv,
820
+ layer_dim=self.config.hidden_size,
821
+ layer_id=None,
822
+ type_of_module=ModuleType.in_module,
823
+ )
824
+ init_weights(
825
+ self.config,
826
+ self.Wo,
827
+ layer_dim=self.config.hidden_size,
828
+ layer_id=self.layer_id,
829
+ type_of_module=ModuleType.out_module,
830
+ )
831
+
832
+ def forward(
833
+ self,
834
+ hidden_states: torch.Tensor,
835
+ cu_seqlens: torch.Tensor,
836
+ max_seqlen: int,
837
+ indices: torch.Tensor,
838
+ attn_mask: torch.Tensor,
839
+ ) -> torch.Tensor:
840
+ """Perform self-attention.
841
+
842
+ There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2.
843
+
844
+ The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the
845
+ Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute
846
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
847
+ sending pad tokens through ffs saves compute.
848
+
849
+ Args:
850
+ hidden_states: (total_nnz, dim)
851
+ cu_seqlens: (batch + 1,)
852
+ max_seqlen: int
853
+ indices: (total_nnz,)
854
+ attn_mask: (batch, max_seqlen)
855
+
856
+ Returns:
857
+ attention: (total_nnz, dim)
858
+ """
859
+ bs, dim = hidden_states.shape
860
+ qkv = self.Wqkv(hidden_states)
861
+
862
+ # only needed for inference when we have KV cache
863
+ seqlen_offset = 0
864
+
865
+ # (total_seqlen, 3, nheads, headdim)
866
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
867
+ qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset)
868
+
869
+ if self.use_fa3:
870
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
871
+ if convert_dtype:
872
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
873
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
874
+ orig_dtype = qkv.dtype
875
+ qkv = qkv.to(torch.bfloat16)
876
+ q, k, v = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size).unbind(dim=1)
877
+
878
+ attn, _ = flash_attn_varlen_func(
879
+ q=q,
880
+ k=k,
881
+ v=v,
882
+ cu_seqlens_q=cu_seqlens,
883
+ cu_seqlens_k=cu_seqlens,
884
+ max_seqlen_q=max_seqlen,
885
+ max_seqlen_k=max_seqlen,
886
+ deterministic=self.deterministic_fa2,
887
+ causal=self.is_casual,
888
+ )
889
+ attn = attn.to(orig_dtype) # type: ignore
890
+ else:
891
+ q, k, v = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size).unbind(dim=1)
892
+ attn, _ = flash_attn_varlen_func(
893
+ q=q,
894
+ k=k,
895
+ v=v,
896
+ cu_seqlens_q=cu_seqlens,
897
+ cu_seqlens_k=cu_seqlens,
898
+ max_seqlen_q=max_seqlen,
899
+ max_seqlen_k=max_seqlen,
900
+ deterministic=self.deterministic_fa2,
901
+ causal=self.is_casual,
902
+ )
903
+ attn = attn.view(bs, dim)
904
+ elif self.use_fa2:
905
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
906
+ if convert_dtype:
907
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
908
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
909
+ orig_dtype = qkv.dtype
910
+ qkv = qkv.to(torch.bfloat16)
911
+
912
+ attn = flash_attn_varlen_qkvpacked_func(
913
+ qkv,
914
+ cu_seqlens=cu_seqlens,
915
+ max_seqlen=max_seqlen,
916
+ dropout_p=self.p_dropout,
917
+ deterministic=self.deterministic_fa2,
918
+ window_size=self.sliding_window,
919
+ causal=self.is_casual,
920
+ )
921
+ attn = attn.to(orig_dtype) # type: ignore
922
+ else:
923
+ attn = flash_attn_varlen_qkvpacked_func(
924
+ qkv,
925
+ cu_seqlens=cu_seqlens,
926
+ max_seqlen=max_seqlen,
927
+ dropout_p=self.p_dropout,
928
+ deterministic=self.deterministic_fa2,
929
+ window_size=self.sliding_window,
930
+ causal=self.is_casual,
931
+ )
932
+ attn = attn.view(bs, dim)
933
+ else:
934
+ assert not self.is_casual, f"Casual mask not implemented here yet"
935
+ qkv = bert_padding.pad_input(
936
+ qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
937
+ ) # batch, max_seqlen, thd
938
+ unpad_bs, seqlen, *_ = qkv.shape
939
+
940
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
941
+ attn = F.scaled_dot_product_attention(
942
+ q,
943
+ k,
944
+ v,
945
+ dropout_p=self.p_dropout,
946
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen)
947
+ if self.use_sdpa_attn_mask
948
+ else None,
949
+ )
950
+ attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
951
+ attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
952
+
953
+ return self.out_drop(self.Wo(attn))
954
+
955
+
956
+ class FlexBertPaddedRopeAttention(FlexBertAttentionBase):
957
+ """Performs multi-headed self attention on a batch of padded sequences.
958
+
959
+ This module supports two attention implementations:
960
+ 1. Flash Attention 2 (if installed), which improves throughput.
961
+ 2. PyTorch's scaled_dot_product_attention.
962
+
963
+ See `forward` method for additional details.
964
+ """
965
+
966
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
967
+ super().__init__(config=config, layer_id=layer_id)
968
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
969
+ raise ValueError(
970
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
971
+ f"heads ({config.num_attention_heads})"
972
+ )
973
+
974
+ self.is_casual = config.casual_mask
975
+ self.num_attention_heads = config.num_attention_heads
976
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
977
+ self.all_head_size = self.num_attention_heads * self.attn_head_size
978
+ self.p_dropout = config.attention_probs_dropout_prob
979
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attn_qkv_bias)
980
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
981
+ self.out_drop = (
982
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
983
+ )
984
+
985
+ self.use_fa2 = config.use_fa2
986
+ self.deterministic_fa2 = config.deterministic_fa2
987
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
988
+
989
+ if config.global_attn_every_n_layers > 0:
990
+ if config.sliding_window == -1:
991
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
992
+ if layer_id % config.global_attn_every_n_layers != 0:
993
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
994
+ else:
995
+ self.sliding_window = (-1, -1)
996
+ else:
997
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
998
+
999
+ if config.rotary_emb_dim is None:
1000
+ config.rotary_emb_dim = self.attn_head_size
1001
+
1002
+ rotary_base = config.rotary_emb_base
1003
+ rotary_dim = config.rotary_emb_dim
1004
+ if self.sliding_window != (-1, -1):
1005
+ if config.local_attn_rotary_emb_base != -1:
1006
+ rotary_base = config.local_attn_rotary_emb_base
1007
+ if config.local_attn_rotary_emb_dim is not None:
1008
+ rotary_dim = config.local_attn_rotary_emb_dim
1009
+
1010
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
1011
+ self.rotary_emb = RotaryEmbedding(
1012
+ dim=rotary_dim,
1013
+ base=rotary_base,
1014
+ scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
1015
+ interleaved=config.rotary_emb_interleaved,
1016
+ )
1017
+
1018
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1019
+ self.use_fa2 = False
1020
+ if self.use_fa2 and self.use_sdpa_attn_mask:
1021
+ logger.warn_once(
1022
+ "Flash Attention 2 does not support attention masks. Use unpadded attention "
1023
+ "the equivalent functionality of masking out padding tokens."
1024
+ )
1025
+ if not self.use_fa2 and self.sliding_window[0] > 0:
1026
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
1027
+
1028
+ def _init_weights(self, reset_params: bool = False):
1029
+ init_weights(
1030
+ self.config,
1031
+ self.Wqkv,
1032
+ layer_dim=self.config.hidden_size,
1033
+ layer_id=None,
1034
+ type_of_module=ModuleType.in_module,
1035
+ )
1036
+ init_weights(
1037
+ self.config,
1038
+ self.Wo,
1039
+ layer_dim=self.config.hidden_size,
1040
+ layer_id=self.layer_id,
1041
+ type_of_module=ModuleType.out_module,
1042
+ )
1043
+
1044
+ def forward(
1045
+ self,
1046
+ hidden_states: torch.Tensor,
1047
+ attn_mask: Optional[torch.Tensor] = None,
1048
+ ) -> torch.Tensor:
1049
+ """Perform self-attention.
1050
+
1051
+ There are two attention implementations supported:
1052
+ Flash Attention 2 and PyTorch's scaled_dot_product_attention.
1053
+
1054
+ Args:
1055
+ hidden_states: (batch, seqlen, dim)
1056
+ attn_mask: (batch, seqlen)
1057
+
1058
+ Returns:
1059
+ attention: (batch, seqlen, dim)
1060
+ """
1061
+ bs, seqlen, dim = hidden_states.shape
1062
+ qkv = self.Wqkv(hidden_states)
1063
+
1064
+ seqlen_offset = 0
1065
+
1066
+ # Reshape to (batch, seqlen, 3, nheads, headdim)
1067
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1068
+
1069
+ if IMPL_USE_FLASH2:
1070
+ # Apply RoPE
1071
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1072
+
1073
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
1074
+ if convert_dtype:
1075
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
1076
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
1077
+ orig_dtype = qkv.dtype
1078
+ qkv = qkv.to(torch.bfloat16)
1079
+
1080
+ attn = flash_attn_qkvpacked_func(
1081
+ qkv,
1082
+ dropout_p=self.p_dropout,
1083
+ deterministic=self.deterministic_fa2,
1084
+ window_size=self.sliding_window,
1085
+ casual=self.is_casual,
1086
+ )
1087
+ attn = attn.to(orig_dtype) # type: ignore
1088
+ else:
1089
+ attn = flash_attn_qkvpacked_func(
1090
+ qkv,
1091
+ dropout_p=self.p_dropout,
1092
+ deterministic=self.deterministic_fa2,
1093
+ window_size=self.sliding_window,
1094
+ casual=self.is_casual
1095
+ )
1096
+ else:
1097
+ assert not self.is_casual, f"Casual mask not implemented here yet"
1098
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1099
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2)
1100
+ attn = F.scaled_dot_product_attention(
1101
+ q,
1102
+ k,
1103
+ v,
1104
+ dropout_p=self.p_dropout,
1105
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen)
1106
+ if self.use_sdpa_attn_mask
1107
+ else None,
1108
+ ).transpose(1, 2)
1109
+
1110
+ attn = attn.view(bs, seqlen, dim)
1111
+ return self.out_drop(self.Wo(attn))
1112
+
1113
+
1114
+ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
1115
+ """Performs multi-headed self attention on a batch of unpadded sequences.
1116
+
1117
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
1118
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
1119
+ which requires padding and unpadding inputs, adding some overhead.
1120
+
1121
+ See `forward` method for additional details.
1122
+ """
1123
+
1124
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
1125
+ super().__init__(config=config, layer_id=layer_id)
1126
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
1127
+ raise ValueError(
1128
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
1129
+ f"heads ({config.num_attention_heads})"
1130
+ )
1131
+
1132
+ self.is_casual = config.casual_mask
1133
+ self.num_attention_heads = config.num_attention_heads
1134
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1135
+ self.hidden_size = config.hidden_size
1136
+ self.p_dropout = config.attention_probs_dropout_prob
1137
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
1138
+ self.out_drop = (
1139
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
1140
+ )
1141
+
1142
+ if config.global_attn_every_n_layers > 0:
1143
+ if config.sliding_window == -1:
1144
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
1145
+ if layer_id % config.global_attn_every_n_layers != 0:
1146
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1147
+ else:
1148
+ self.sliding_window = (-1, -1)
1149
+ else:
1150
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1151
+
1152
+ if config.rotary_emb_dim is None:
1153
+ config.rotary_emb_dim = self.attn_head_size
1154
+
1155
+ rotary_base = config.rotary_emb_base
1156
+ rotary_dim = config.rotary_emb_dim
1157
+ if self.sliding_window != (-1, -1):
1158
+ if config.local_attn_rotary_emb_base != -1:
1159
+ rotary_base = config.local_attn_rotary_emb_base
1160
+ if config.local_attn_rotary_emb_dim is not None:
1161
+ rotary_dim = config.local_attn_rotary_emb_dim
1162
+
1163
+ assert UnpaddedRotaryEmbedding is not None, "rotary_emb is not installed"
1164
+ self.rotary_emb = UnpaddedRotaryEmbedding(
1165
+ dim=rotary_dim,
1166
+ base=rotary_base,
1167
+ scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
1168
+ interleaved=config.rotary_emb_interleaved,
1169
+ )
1170
+
1171
+ self.use_fa2 = config.use_fa2
1172
+ self.deterministic_fa2 = config.deterministic_fa2
1173
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
1174
+
1175
+ # Warn if defaulting to pytorch because of import issues
1176
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1177
+ logger.warn_once(
1178
+ "Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's"
1179
+ " SDPA kernel. This requires padding and unpadding inputs, which will add some overhead."
1180
+ )
1181
+ self.use_fa2 = False
1182
+ if not self.use_fa2:
1183
+ if not self.use_sdpa_attn_mask:
1184
+ logger.warn_once(
1185
+ "SDPA attention is being used without an attention mask. Including padding in the "
1186
+ " attention calculation may cause differences from the Flash Attention implementation."
1187
+ )
1188
+ else:
1189
+ logger.warn_once(
1190
+ "SDPA attention with an attention mask doesn't use the Flash Attention kernel and will"
1191
+ " use more memory during the backward pass. Use the FA2 backend for linear memory scaling"
1192
+ " with sequence length."
1193
+ )
1194
+ if self.sliding_window[0] > 0:
1195
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
1196
+
1197
+ def _init_weights(self, reset_params: bool = False):
1198
+ init_weights(
1199
+ self.config,
1200
+ self.Wo,
1201
+ layer_dim=self.config.hidden_size,
1202
+ layer_id=self.layer_id,
1203
+ type_of_module=ModuleType.out_module,
1204
+ )
1205
+
1206
+ def forward(
1207
+ self,
1208
+ qkv: torch.Tensor,
1209
+ cu_seqlens: torch.Tensor,
1210
+ max_seqlen: int,
1211
+ indices: torch.Tensor,
1212
+ attn_mask: torch.Tensor,
1213
+ ) -> torch.Tensor:
1214
+ """Perform self-attention.
1215
+
1216
+ There are two attention implementations supported: PyTorch's SDPA attention and Flash Attention 2.
1217
+
1218
+ The arguments are unpadded. The SDPA implementation of attention requires padded arguments while the
1219
+ Flash Attention implementation does not. If using SDPA we first call `pad_input`. Once we compute
1220
+ attention, we re-unpad our outputs for the other layers. The pad/unpad operations add overhead, but not
1221
+ sending pad tokens through ffs saves compute.
1222
+
1223
+ Args:
1224
+ qkv: (total_nnz, 3 * dim)
1225
+ cu_seqlens: (batch + 1,)
1226
+ max_seqlen: int
1227
+ indices: (total_nnz,)
1228
+ attn_mask: (batch, max_seqlen)
1229
+
1230
+ Returns:
1231
+ attention: (total_nnz, dim)
1232
+ """
1233
+ bs = qkv.shape[0]
1234
+ dim = self.hidden_size
1235
+
1236
+ # only needed for inference when we have KV cache
1237
+ seqlen_offset = 0
1238
+
1239
+ # (total_seqlen, 3, nheads, headdim)
1240
+ qkv = qkv.view(-1, 3, self.num_attention_heads, self.attn_head_size)
1241
+ qkv = self.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, seqlen_offset=seqlen_offset)
1242
+
1243
+ if self.use_fa2:
1244
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
1245
+ if convert_dtype:
1246
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
1247
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
1248
+ orig_dtype = qkv.dtype
1249
+ qkv = qkv.to(torch.bfloat16)
1250
+
1251
+ attn = flash_attn_varlen_qkvpacked_func(
1252
+ qkv,
1253
+ cu_seqlens=cu_seqlens,
1254
+ max_seqlen=max_seqlen,
1255
+ dropout_p=self.p_dropout,
1256
+ deterministic=self.deterministic_fa2,
1257
+ window_size=self.sliding_window,
1258
+ casual=self.is_casual,
1259
+ )
1260
+ attn = attn.to(orig_dtype) # type: ignore
1261
+ else:
1262
+ attn = flash_attn_varlen_qkvpacked_func(
1263
+ qkv,
1264
+ cu_seqlens=cu_seqlens,
1265
+ max_seqlen=max_seqlen,
1266
+ dropout_p=self.p_dropout,
1267
+ deterministic=self.deterministic_fa2,
1268
+ window_size=self.sliding_window,
1269
+ casual=self.is_casual,
1270
+ )
1271
+ attn = attn.view(bs, dim)
1272
+ else:
1273
+ assert not self.is_casual, f"Casual mask not implemented here yet"
1274
+ qkv = bert_padding.pad_input(
1275
+ qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
1276
+ ) # batch, max_seqlen, thd
1277
+ unpad_bs, seqlen, *_ = qkv.shape
1278
+
1279
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
1280
+ attn = F.scaled_dot_product_attention(
1281
+ q,
1282
+ k,
1283
+ v,
1284
+ dropout_p=self.p_dropout,
1285
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(unpad_bs, 1, seqlen, seqlen)
1286
+ if self.use_sdpa_attn_mask
1287
+ else None,
1288
+ )
1289
+ attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
1290
+ attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
1291
+
1292
+ return self.out_drop(self.Wo(attn))
1293
+
1294
+
1295
+ class FlexBertPaddedRopeParallelAttention(FlexBertAttentionBase):
1296
+ """Performs multi-headed self attention on a batch of padded sequences.
1297
+
1298
+ This module supports two attention implementations:
1299
+ 1. Flash Attention 2 (if installed), which improves throughput.
1300
+ 2. PyTorch's scaled_dot_product_attention.
1301
+
1302
+ See `forward` method for additional details.
1303
+ """
1304
+
1305
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
1306
+ super().__init__(config=config, layer_id=layer_id)
1307
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
1308
+ raise ValueError(
1309
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
1310
+ f"heads ({config.num_attention_heads})"
1311
+ )
1312
+
1313
+ self.is_casual = config.casual_mask
1314
+ self.num_attention_heads = config.num_attention_heads
1315
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1316
+ self.hidden_size = config.hidden_size
1317
+ self.p_dropout = config.attention_probs_dropout_prob
1318
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
1319
+ self.out_drop = (
1320
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
1321
+ )
1322
+
1323
+ self.use_fa2 = config.use_fa2
1324
+ self.deterministic_fa2 = config.deterministic_fa2
1325
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
1326
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1327
+ self.use_fa2 = False
1328
+
1329
+ if config.global_attn_every_n_layers > 0:
1330
+ if config.sliding_window == -1:
1331
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
1332
+ if layer_id % config.global_attn_every_n_layers != 0:
1333
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1334
+ else:
1335
+ self.sliding_window = (-1, -1)
1336
+ else:
1337
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1338
+
1339
+ if config.rotary_emb_dim is None:
1340
+ config.rotary_emb_dim = self.attn_head_size
1341
+
1342
+ rotary_base = config.rotary_emb_base
1343
+ rotary_dim = config.rotary_emb_dim
1344
+ if self.sliding_window != (-1, -1):
1345
+ if config.local_attn_rotary_emb_base != -1:
1346
+ rotary_base = config.local_attn_rotary_emb_base
1347
+ if config.local_attn_rotary_emb_dim is not None:
1348
+ rotary_dim = config.local_attn_rotary_emb_dim
1349
+
1350
+ assert RotaryEmbedding is not None, "rotary_emb is not installed"
1351
+ self.rotary_emb = RotaryEmbedding(
1352
+ dim=rotary_dim,
1353
+ base=rotary_base,
1354
+ scale_base=config.rotary_emb_scale_base, # If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
1355
+ interleaved=config.rotary_emb_interleaved,
1356
+ )
1357
+
1358
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1359
+ self.use_fa2 = False
1360
+ if self.use_fa2 and self.use_sdpa_attn_mask:
1361
+ logger.warn_once(
1362
+ "Flash Attention 2 does not support attention masks. Use unpadded attention "
1363
+ "the equivalent functionality of masking out padding tokens."
1364
+ )
1365
+ if not self.use_fa2 and self.sliding_window[0] > 0:
1366
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
1367
+
1368
+ def _init_weights(self, reset_params: bool = False):
1369
+ init_weights(
1370
+ self.config,
1371
+ self.Wo,
1372
+ layer_dim=self.config.hidden_size,
1373
+ layer_id=self.layer_id,
1374
+ type_of_module=ModuleType.out_module,
1375
+ )
1376
+
1377
+ def forward(
1378
+ self,
1379
+ qkv: torch.Tensor,
1380
+ attn_mask: Optional[torch.Tensor] = None,
1381
+ ) -> torch.Tensor:
1382
+ """Perform self-attention.
1383
+
1384
+ There are two attention implementations supported:
1385
+ Flash Attention 2 and PyTorch's scaled_dot_product_attention.
1386
+
1387
+ Args:
1388
+ qkv: (batch, seqlen, 3 * dim)
1389
+ attn_mask: (batch, seqlen)
1390
+
1391
+ Returns:
1392
+ attention: (batch, seqlen, dim)
1393
+ """
1394
+ bs, seqlen, _ = qkv.shape
1395
+ dim = self.hidden_size
1396
+
1397
+ seqlen_offset = 0
1398
+
1399
+ # Reshape to (batch, seqlen, 3, nheads, headdim)
1400
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1401
+
1402
+ if self.use_fa2:
1403
+ # Apply RoPE
1404
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1405
+
1406
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
1407
+ if convert_dtype:
1408
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
1409
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
1410
+ orig_dtype = qkv.dtype
1411
+ qkv = qkv.to(torch.bfloat16)
1412
+
1413
+ attn = flash_attn_qkvpacked_func(
1414
+ qkv,
1415
+ dropout_p=self.p_dropout,
1416
+ deterministic=self.deterministic_fa2,
1417
+ window_size=self.sliding_window,
1418
+ casual=self.is_casual
1419
+ )
1420
+ attn = attn.to(orig_dtype) # type: ignore
1421
+ else:
1422
+ attn = flash_attn_qkvpacked_func(
1423
+ qkv,
1424
+ dropout_p=self.p_dropout,
1425
+ deterministic=self.deterministic_fa2,
1426
+ window_size=self.sliding_window,
1427
+ casual=self.is_casual
1428
+ )
1429
+ else:
1430
+ assert not self.is_casual, f"Casual mask not implemented here yet"
1431
+ qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset, max_seqlen=None)
1432
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2)
1433
+ attn = F.scaled_dot_product_attention(
1434
+ q,
1435
+ k,
1436
+ v,
1437
+ dropout_p=self.p_dropout,
1438
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen)
1439
+ if self.use_sdpa_attn_mask
1440
+ else None,
1441
+ ).transpose(1, 2)
1442
+
1443
+ attn = attn.view(bs, seqlen, dim)
1444
+ return self.out_drop(self.Wo(attn))
1445
+
1446
+
1447
+ class FlexBertPaddedParallelAttention(FlexBertAttentionBase):
1448
+ """Performs multi-headed self attention on a batch of padded sequences.
1449
+
1450
+ This module supports two attention implementations:
1451
+ 1. Flash Attention 2 (if installed), which improves throughput.
1452
+ 2. PyTorch's scaled_dot_product_attention.
1453
+
1454
+ See `forward` method for additional detail.
1455
+ """
1456
+
1457
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
1458
+ super().__init__(config=config, layer_id=layer_id)
1459
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
1460
+ raise ValueError(
1461
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
1462
+ f"heads ({config.num_attention_heads})"
1463
+ )
1464
+
1465
+ self.is_casual = config.casual_mask
1466
+ self.num_attention_heads = config.num_attention_heads
1467
+ self.attn_head_size = int(config.hidden_size / config.num_attention_heads)
1468
+ self.hidden_size = config.hidden_size
1469
+ self.p_dropout = config.attention_probs_dropout_prob
1470
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attn_out_bias)
1471
+ self.out_drop = (
1472
+ nn.Dropout(config.attn_out_dropout_prob) if config.attn_out_dropout_prob > 0.0 else nn.Identity()
1473
+ )
1474
+ self.use_fa2 = config.use_fa2
1475
+ self.deterministic_fa2 = config.deterministic_fa2
1476
+ self.use_sdpa_attn_mask = config.use_sdpa_attn_mask
1477
+
1478
+ if config.global_attn_every_n_layers > 0:
1479
+ if config.sliding_window == -1:
1480
+ raise ValueError("global_attn_every_n_layers` requires `sliding_window` to be set")
1481
+ if layer_id % config.global_attn_every_n_layers != 0:
1482
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1483
+ else:
1484
+ self.sliding_window = (-1, -1)
1485
+ else:
1486
+ self.sliding_window = (config.sliding_window // 2, config.sliding_window // 2)
1487
+
1488
+ if not IMPL_USE_FLASH2 and self.use_fa2:
1489
+ self.use_fa2 = False
1490
+ if self.use_fa2 and self.use_sdpa_attn_mask:
1491
+ logger.warn_once(
1492
+ "Flash Attention 2 does not support attention masks. Use unpadded attention "
1493
+ "the equivalent functionality of masking out padding tokens."
1494
+ )
1495
+ if not self.use_fa2 and self.sliding_window[0] > 0:
1496
+ raise ValueError("Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.")
1497
+
1498
+ def _init_weights(self, reset_params: bool = False):
1499
+ init_weights(
1500
+ self.config,
1501
+ self.Wo,
1502
+ layer_dim=self.config.hidden_size,
1503
+ layer_id=self.layer_id,
1504
+ type_of_module=ModuleType.out_module,
1505
+ )
1506
+
1507
+ def forward(
1508
+ self,
1509
+ qkv: torch.Tensor,
1510
+ attn_mask: Optional[torch.Tensor] = None,
1511
+ ) -> torch.Tensor:
1512
+ """Perform self-attention.
1513
+
1514
+ There are two attention implementations supported:
1515
+ Flash Attention 2 and PyTorch's scaled_dot_product_attention.
1516
+
1517
+ Args:
1518
+ qkv: (batch, seqlen, 3 * dim)
1519
+ attn_mask: (batch, seqlen)
1520
+
1521
+ Returns:
1522
+ attention: (batch, seqlen, dim)
1523
+ """
1524
+ bs, seqlen, _ = qkv.shape
1525
+ dim = self.hidden_size
1526
+
1527
+ if self.use_fa2:
1528
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1529
+
1530
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
1531
+ if convert_dtype:
1532
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
1533
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
1534
+ orig_dtype = qkv.dtype
1535
+ qkv = qkv.to(torch.bfloat16)
1536
+
1537
+ attn = flash_attn_qkvpacked_func(
1538
+ qkv,
1539
+ dropout_p=self.p_dropout,
1540
+ deterministic=self.deterministic_fa2,
1541
+ window_size=self.sliding_window,
1542
+ casual=self.is_casual
1543
+ )
1544
+ attn = attn.to(orig_dtype) # type: ignore
1545
+ else:
1546
+ attn = flash_attn_qkvpacked_func(
1547
+ qkv,
1548
+ dropout_p=self.p_dropout,
1549
+ deterministic=self.deterministic_fa2,
1550
+ window_size=self.sliding_window,
1551
+ casual=self.is_casual
1552
+ )
1553
+ else:
1554
+ assert not self.is_casual, f"Casual attention mask not yet implemented here"
1555
+ qkv = qkv.view(bs, seqlen, 3, self.num_attention_heads, self.attn_head_size)
1556
+ q, k, v = qkv.transpose(3, 1).unbind(dim=2) # b h s d
1557
+ attn = F.scaled_dot_product_attention(
1558
+ q,
1559
+ k,
1560
+ v,
1561
+ dropout_p=self.p_dropout,
1562
+ attn_mask=attn_mask[:, None, None, :seqlen].to(torch.bool).expand(bs, 1, seqlen, seqlen)
1563
+ if self.use_sdpa_attn_mask
1564
+ else None,
1565
+ ).transpose(1, 2)
1566
+
1567
+ attn = attn.view(bs, seqlen, dim)
1568
+ return self.out_drop(self.Wo(attn))
1569
+
1570
+
1571
+ ATTN2CLS = {
1572
+ "unpadded_base": FlexBertUnpadAttention,
1573
+ "padded_base": FlexBertPaddedAttention,
1574
+ "unpadded_parallel": FlexBertUnpadParallelAttention,
1575
+ "padded_parallel": FlexBertPaddedParallelAttention,
1576
+ "unpadded_rope": FlexBertUnpadRopeAttention,
1577
+ "padded_rope": FlexBertPaddedRopeAttention,
1578
+ "unpadded_rope_parallel": FlexBertUnpadRopeParallelAttention,
1579
+ "padded_rope_parallel": FlexBertPaddedRopeParallelAttention,
1580
+ }
1581
+
1582
+
1583
+ def get_attention_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertAttentionBase:
1584
+ try:
1585
+ attention_layer = (
1586
+ config.initial_attention_layer
1587
+ if layer_id < config.num_initial_layers and getattr(config, "initial_attention_layer", None) is not None
1588
+ else config.attention_layer
1589
+ )
1590
+ return ATTN2CLS[maybe_add_padding(config, attention_layer)](config, layer_id=layer_id)
1591
+ except KeyError:
1592
+ if layer_id < config.num_initial_layers and getattr(config, "initial_attention_layer", None) is not None:
1593
+ raise ValueError(
1594
+ f"Invalid attention layer type: {config.initial_attention_layer=}, must be one of {ATTN2CLS.keys()}."
1595
+ f"{config.padding=} will be automatically prepended to `config.attention_layer` if unspecified."
1596
+ )
1597
+ else:
1598
+ raise ValueError(
1599
+ f"Invalid attention layer type: {config.attention_layer=}, must be one of {ATTN2CLS.keys()}. "
1600
+ f"{config.padding=} will be automatically prepended to `config.attention_layer` if unspecified."
1601
+ )
bert_padding.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
5
+ # Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
6
+
7
+ """Helper functions for padding and unpadding batches.
8
+
9
+ These functions are used extensively throughout the Mosaic BERT implementation
10
+ in `bert_layers.py`.
11
+ """
12
+
13
+ from typing import Tuple, cast
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from einops import rearrange, repeat
18
+
19
+
20
+ class IndexFirstAxis(torch.autograd.Function):
21
+ @staticmethod
22
+ def forward(ctx, input: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
23
+ """Get just the values of `input` which are at `indices`.
24
+
25
+ Arguments:
26
+ ctx: the autograd context object
27
+ input: (b, ...) 2+ dimensional tensor
28
+ indices: (num_idx) 1D tensor
29
+ """
30
+ ctx.save_for_backward(indices)
31
+ assert input.ndim >= 2
32
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] # type: ignore
33
+ second_dim = other_shape.numel() # product of sizes of all but first dimension
34
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
35
+ return torch.gather(
36
+ rearrange(input, "b ... -> b (...)"), # (b, ...) -> (b, second_dim)
37
+ 0,
38
+ repeat(indices, "z -> z d", d=second_dim), # (indices,) -> (indices, second_dim)
39
+ ).reshape(-1, *other_shape) # (num_idx, ...)
40
+
41
+ @staticmethod
42
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
43
+ (indices,) = ctx.saved_tensors
44
+ assert grad_output.ndim >= 2
45
+ other_shape = grad_output.shape[1:]
46
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
47
+ grad_input = torch.zeros(
48
+ [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
49
+ )
50
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
51
+ # grad_input[indices] = grad_output
52
+ grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
53
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
54
+
55
+
56
+ index_first_axis = IndexFirstAxis.apply
57
+
58
+
59
+ class IndexPutFirstAxis(torch.autograd.Function):
60
+ @staticmethod
61
+ def forward(ctx, values: torch.Tensor, indices: torch.Tensor, first_axis_dim) -> torch.Tensor:
62
+ ctx.save_for_backward(indices)
63
+ assert indices.ndim == 1
64
+ assert values.ndim >= 2
65
+ output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
66
+ output[indices] = values
67
+ return output
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
71
+ (indices,) = ctx.saved_tensors
72
+ grad_values = grad_output[indices]
73
+ return grad_values, None, None
74
+
75
+
76
+ index_put_first_axis = IndexPutFirstAxis.apply
77
+
78
+
79
+ def unpad_input(
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: torch.Tensor,
82
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
83
+ """Remove padding from input sequences.
84
+
85
+ Arguments:
86
+ hidden_states: (batch, seqlen, ...)
87
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
88
+
89
+ Returns:
90
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
91
+ indices: (total_nnz)
92
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
93
+ max_seqlen_in_batch: int ()
94
+ """
95
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
96
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
97
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
98
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
99
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
100
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
101
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
102
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
103
+ # so we write custom forward and backward to make it a bit faster.
104
+ hidden_states = cast(torch.Tensor, index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices))
105
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
106
+
107
+
108
+ def unpad_input_only(
109
+ hidden_states: torch.Tensor,
110
+ attention_mask: torch.Tensor,
111
+ ) -> torch.Tensor:
112
+ """Like unpad_input, but only return the unpadded first tensor.
113
+
114
+ Save a small amount of overhead.
115
+
116
+ Arguments:
117
+ hidden_states: (batch, seqlen, ...)
118
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
119
+
120
+ Returns:
121
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
122
+ """
123
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
124
+ rearranged = rearrange(hidden_states, "b s ... -> (b s) ...")
125
+ return index_first_axis(rearranged, indices) # type: ignore
126
+
127
+
128
+ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
129
+ """Add padding to sequences.
130
+
131
+ Arguments:
132
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
133
+ indices: (total_nnz)
134
+ batch: int batch_size
135
+ seqlen: int max sequence length
136
+
137
+ Returns:
138
+ hidden_states: (batch, seqlen, ...)
139
+ """
140
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
141
+ return rearrange(output, "(b s) ... -> b s ...", b=batch) # type: ignore
config.json CHANGED
@@ -1,11 +1,13 @@
1
  {
2
  "allow_embedding_resizing": true,
3
  "architectures": [
 
4
  "FlexBertForCasualLM"
5
  ],
6
  "auto_map": {
7
  "AutoConfig": "orionweller/FlexGPT--configuration_bert.FlexBertConfig",
8
- "AutoModel": "lightonai/FlexGPT--modeling_flexbert.FlexBertForCasualLM"
 
9
  },
10
  "attention_layer": "rope",
11
  "attention_probs_dropout_prob": 0.0,
 
1
  {
2
  "allow_embedding_resizing": true,
3
  "architectures": [
4
+ "FlexBertModel",
5
  "FlexBertForCasualLM"
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "orionweller/FlexGPT--configuration_bert.FlexBertConfig",
9
+ "AutoModel": "lightonai/FlexGPT--modeling_flexbert.FlexBertModel",
10
+ "AutoModelForCasualLM": "lightonai/FlexGPT--modeling_flexbert.FlexBertForCasualLM",
11
  },
12
  "attention_layer": "rope",
13
  "attention_probs_dropout_prob": 0.0,
configuration_bert.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import warnings
5
+
6
+ from transformers import BertConfig as TransformersBertConfig
7
+
8
+
9
+ class BertConfig(TransformersBertConfig):
10
+ def __init__(
11
+ self,
12
+ alibi_starting_size: int = 512,
13
+ normalization: str = "layernorm",
14
+ attention_probs_dropout_prob: float = 0.0,
15
+ head_pred_act: str = "gelu",
16
+ deterministic_fa2: bool = False,
17
+ allow_embedding_resizing: bool = False,
18
+ **kwargs,
19
+ ):
20
+ """Configuration class for MosaicBert.
21
+
22
+ Args:
23
+ alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to
24
+ create when initializing the model. You should be able to ignore this parameter in most cases.
25
+ Defaults to 512.
26
+ attention_probs_dropout_prob (float): By default, turn off attention dropout in MosaicBERT
27
+ Note that the custom Triton Flash Attention with ALiBi implementation does not support droput.
28
+ However, Flash Attention 2 supports ALiBi and dropout https://github.com/Dao-AILab/flash-attention
29
+ embed_dropout_prob (float): Dropout probability for the embedding layer.
30
+ attn_out_dropout_prob (float): Dropout probability for the attention output layer.
31
+ mlp_dropout_prob (float): Dropout probability for the MLP layer.
32
+ allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size.
33
+ """
34
+ super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
35
+ self.alibi_starting_size = alibi_starting_size
36
+ self.normalization = normalization
37
+ self.head_pred_act = head_pred_act
38
+ self.deterministic_fa2 = deterministic_fa2
39
+ self.allow_embedding_resizing = allow_embedding_resizing
40
+
41
+
42
+ class FlexBertConfig(TransformersBertConfig):
43
+ def __init__(
44
+ self,
45
+ attention_layer: str = "base",
46
+ attention_probs_dropout_prob: float = 0.0,
47
+ attn_out_bias: bool = False,
48
+ attn_out_dropout_prob: float = 0.0,
49
+ attn_qkv_bias: bool = False,
50
+ bert_layer: str = "prenorm",
51
+ decoder_bias: bool = True,
52
+ embed_dropout_prob: float = 0.0,
53
+ embed_norm: bool = True,
54
+ final_norm: bool = False,
55
+ embedding_layer: str = "absolute_pos",
56
+ encoder_layer: str = "base",
57
+ loss_function: str = "cross_entropy",
58
+ loss_kwargs: dict = {},
59
+ mlp_dropout_prob: float = 0.0,
60
+ mlp_in_bias: bool = False,
61
+ mlp_layer: str = "mlp",
62
+ mlp_out_bias: bool = False,
63
+ norm_kwargs: dict = {},
64
+ normalization: str = "rmsnorm",
65
+ padding: str = "unpadded",
66
+ head_class_act: str = "silu",
67
+ head_class_bias: bool = False,
68
+ head_class_dropout: float = 0.0,
69
+ head_class_norm: str = False,
70
+ head_pred_act: str = "silu",
71
+ head_pred_bias: bool = False,
72
+ head_pred_dropout: float = 0.0,
73
+ head_pred_norm: bool = True,
74
+ pooling_type: str = "cls",
75
+ rotary_emb_dim: int | None = None,
76
+ rotary_emb_base: float = 10000.0,
77
+ rotary_emb_scale_base=None,
78
+ rotary_emb_interleaved: bool = False,
79
+ use_fa2: bool = True,
80
+ use_sdpa_attn_mask: bool = False,
81
+ allow_embedding_resizing: bool = False,
82
+ init_method: str = "default",
83
+ init_std: float = 0.02,
84
+ init_cutoff_factor: float = 2.0,
85
+ init_small_embedding: bool = False,
86
+ initial_attention_layer: str | None = None,
87
+ initial_bert_layer: str | None = None,
88
+ initial_mlp_layer: str | None = None,
89
+ num_initial_layers: int = 1,
90
+ skip_first_prenorm: bool = False,
91
+ deterministic_fa2: bool = False,
92
+ sliding_window: int = -1,
93
+ global_attn_every_n_layers: int = -1,
94
+ local_attn_rotary_emb_base: float = -1,
95
+ local_attn_rotary_emb_dim: int | None = None,
96
+ unpad_embeddings: bool = False,
97
+ pad_logits: bool = False,
98
+ compile_model: bool = False,
99
+ masked_prediction: bool = False,
100
+ casual_mask: bool = False,
101
+ **kwargs,
102
+ ):
103
+ """
104
+ Args:
105
+ attention_layer (str): Attention layer type.
106
+ attention_probs_dropout_prob (float): Dropout probability for attention probabilities.
107
+ attn_out_bias (bool): use bias in attention output projection.
108
+ attn_out_dropout_prob (float): Dropout probability for attention output.
109
+ attn_qkv_bias (bool): use bias for query, key, value linear layer(s).
110
+ bert_layer (str): BERT layer type.
111
+ decoder_bias (bool): use bias in decoder linear layer.
112
+ embed_dropout_prob (float): Dropout probability for embeddings.
113
+ embed_norm (bool): Normalize embedding output.
114
+ final_norm (bool): Add normalization after the final encoder layer and before head.
115
+ embedding_layer (str): Embedding layer type.
116
+ encoder_layer (str): Encoder layer type.
117
+ loss_function (str): Loss function to use.
118
+ loss_kwargs (dict): Keyword arguments for loss function.
119
+ mlp_dropout_prob (float): Dropout probability for MLP layers.
120
+ mlp_in_bias (bool): Use bias in MLP input linear layer.
121
+ mlp_layer (str): MLP layer type.
122
+ mlp_out_bias (bool): Use bias in MLP output linear layer.
123
+ norm_kwargs (dict): Keyword arguments for normalization layers.
124
+ normalization (str): Normalization type.
125
+ padding (str): Unpad inputs. Best with `use_fa2=True`.
126
+ head_class_act (str): Activation function for classification head.
127
+ head_class_bias (bool): Use bias in classification head linear layer(s).
128
+ head_class_dropout (float): Dropout probability for classification head.
129
+ head_class_norm (str): Normalization type for classification head.
130
+ head_pred_act (str): Activation function for prediction head.
131
+ head_pred_bias (bool): Use bias in prediction head linear layer(s).
132
+ head_pred_dropout (float): Dropout probability for prediction head.
133
+ head_pred_norm (bool): Normalize prediction head output.
134
+ pooling_type (str): Pooling type.
135
+ rotary_emb_dim (int | None): Rotary embedding dimension.
136
+ rotary_emb_base (float): Rotary embedding base.
137
+ rotary_emb_scale_base (float): Rotary embedding scale base.
138
+ rotary_emb_interleaved (bool): Use interleaved rotary embeddings.
139
+ use_fa2 (bool): Use FlashAttention2. Requires flash_attn package.
140
+ use_sdpa_attn_mask (bool): Pass a mask to SDPA. This will prevent SDPA from using the PyTorch FA2 kernel.
141
+ allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size.
142
+ init_method (str): Model layers initialization method.
143
+ init_std (float): Standard deviation for initialization. Used for normal and full_megatron init.
144
+ init_cutoff_factor (float): Cutoff factor for initialization. Used for normal and full_megatron init.
145
+ init_small_embedding (bool): Initialize embeddings with RWKV small init.
146
+ initial_attention_layer (str | None): Replace first `num_initial_layers` attention_layer instance with this layer.
147
+ initial_bert_layer (str | None): Replace first `num_initial_layers` bert_layer instance with this layer.
148
+ initial_mlp_layer (str | None): Replace first `num_initial_layers` mlp_layer instance with this layer.
149
+ num_initial_layers (int): Number of initial layers to set via `initial_attention_layer`, `initial_bert_layer`, and `initial_mlp_layer`.
150
+ skip_first_prenorm (bool): Skip pre-normalization for the first bert layer. Requires `embed_norm=True`.
151
+ deterministic_fa2 (bool): Use Flash Attention 2 deterministic mode. This is slower then the default non-deterministic mode.
152
+ sliding_window (int): Use sliding window attention with window size `n`. -1 to disable. Window size split between the left and right context. Only supports FA2.
153
+ global_attn_every_n_layers (int): Use global attention every `n` layers and sliding window for the rest. -1 to disable.
154
+ local_attn_rotary_emb_base (float): Rotary embedding base for local attention. -1 to disable and use `rotary_emb_base` for all layers.
155
+ local_attn_rotary_emb_dim (int | None): Rotary embedding dimension for local attention. None to disable and use `rotary_emb_dim` for all layers.
156
+ unpad_embeddings (bool): Unpad inputs before the embedding layer.
157
+ pad_logits (bool): Pad logits after the calculating the loss.
158
+ compile_model (bool): Compile the subset of the model which can be compiled.
159
+ masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers
160
+ casual_mask (bool): Use a casual mask, defaulting to false.
161
+ **kwargs: Additional keyword arguments.
162
+ """
163
+ super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs)
164
+ self.attention_layer = attention_layer
165
+ self.attn_out_bias = attn_out_bias
166
+ self.attn_out_dropout_prob = attn_out_dropout_prob
167
+ self.attn_qkv_bias = attn_qkv_bias
168
+ self.bert_layer = bert_layer
169
+ self.decoder_bias = decoder_bias
170
+ self.embed_dropout_prob = embed_dropout_prob
171
+ self.embed_norm = embed_norm
172
+ self.final_norm = final_norm
173
+ self.embedding_layer = embedding_layer
174
+ self.encoder_layer = encoder_layer
175
+ self.loss_function = loss_function
176
+ self.loss_kwargs = loss_kwargs
177
+ self.mlp_dropout_prob = mlp_dropout_prob
178
+ self.mlp_in_bias = mlp_in_bias
179
+ self.mlp_layer = mlp_layer
180
+ self.mlp_out_bias = mlp_out_bias
181
+ self.norm_kwargs = norm_kwargs
182
+ self.normalization = normalization
183
+ self.padding = padding
184
+ self.head_class_act = head_class_act
185
+ self.head_class_bias = head_class_bias
186
+ self.head_class_dropout = head_class_dropout
187
+ self.head_class_norm = head_class_norm
188
+ self.head_pred_act = head_pred_act
189
+ self.head_pred_bias = head_pred_bias
190
+ self.head_pred_dropout = head_pred_dropout
191
+ self.head_pred_norm = head_pred_norm
192
+ self.pooling_type = pooling_type
193
+ self.rotary_emb_dim = rotary_emb_dim
194
+ self.rotary_emb_base = rotary_emb_base
195
+ self.rotary_emb_scale_base = rotary_emb_scale_base
196
+ self.rotary_emb_interleaved = rotary_emb_interleaved
197
+ self.use_fa2 = use_fa2
198
+ self.use_sdpa_attn_mask = use_sdpa_attn_mask
199
+ self.allow_embedding_resizing = allow_embedding_resizing
200
+ self.init_method = init_method
201
+ self.init_std = init_std
202
+ self.init_cutoff_factor = init_cutoff_factor
203
+ self.init_small_embedding = init_small_embedding
204
+ self.initial_attention_layer = initial_attention_layer
205
+ self.initial_bert_layer = initial_bert_layer
206
+ self.initial_mlp_layer = initial_mlp_layer
207
+ self.num_initial_layers = num_initial_layers
208
+ self.skip_first_prenorm = skip_first_prenorm
209
+ self.deterministic_fa2 = deterministic_fa2
210
+ self.sliding_window = sliding_window
211
+ self.global_attn_every_n_layers = global_attn_every_n_layers
212
+ self.local_attn_rotary_emb_base = local_attn_rotary_emb_base
213
+ self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim
214
+ self.unpad_embeddings = unpad_embeddings
215
+ self.pad_logits = pad_logits
216
+ self.compile_model = compile_model
217
+ self.masked_prediction = masked_prediction
218
+ self.casual_mask = casual_mask
219
+
220
+ if loss_kwargs.get("return_z_loss", False):
221
+ if loss_function != "fa_cross_entropy":
222
+ raise ValueError("loss_function must be 'fa_cross_entropy' when return_z_loss is True")
223
+ if loss_kwargs.get("lse_square_scale", 0) <= 0:
224
+ raise ValueError(
225
+ "lse_square_scale must be passed to `loss_kwargs` and must be greater than 0 for z_loss"
226
+ )
227
+ if loss_kwargs.get("inplace_backward", False):
228
+ self.loss_kwargs["inplace_backward"] = False
229
+ warnings.warn("`inplace_backward=True` will cause incorrect metrics. Automatically setting to False.")
230
+
231
+ if global_attn_every_n_layers > 0 and (self.num_hidden_layers - 1) % global_attn_every_n_layers != 0:
232
+ raise ValueError(
233
+ f"{global_attn_every_n_layers=} must be a divisor of one less than {self.num_hidden_layers=}"
234
+ )
235
+
236
+ if self.sliding_window != -1:
237
+ if not self.use_fa2:
238
+ raise ValueError("Sliding window attention is only supported with FlashAttention2")
239
+ if self.sliding_window % 2 != 0 and self.sliding_window % 64 != 0:
240
+ raise ValueError(
241
+ f"Sliding window must be an even number and divisible by 64: {self.sliding_window=} {self.sliding_window % 64} {self.sliding_window % 2}"
242
+ )
243
+ else:
244
+ if self.global_attn_every_n_layers != -1:
245
+ raise ValueError("global_attn_every_n_layers must be -1 when sliding_window is disabled")
246
+ if self.local_attn_rotary_emb_base != -1:
247
+ raise ValueError("local_attn_rotary_emb_base must be -1 when sliding_window is disabled")
248
+ if self.local_attn_rotary_emb_dim is not None:
249
+ raise ValueError("local_attn_rotary_emb_dim must be None when sliding_window is disabled")
250
+
251
+ if self.unpad_embeddings and self.padding != "unpadded":
252
+ warnings.warn(
253
+ "`unpad_embeddings=True` requires `padding='unpadded'`. Automatically setting `padding='unpadded'`."
254
+ )
255
+ self.padding = "unpadded"
256
+ if self.pad_logits and not self.unpad_embeddings:
257
+ raise ValueError("`pad_logits=True` requires `unpad_embeddings=True`")
258
+ if self.unpad_embeddings and self.embedding_layer == "absolute_pos":
259
+ raise ValueError(f"{self.unpad_embeddings=} is incompatible with {self.embedding_layer=}")
260
+
261
+
262
+ PADDING = ["unpadded", "padded"]
263
+
264
+
265
+ def maybe_add_padding(config: FlexBertConfig, config_option: str) -> str:
266
+ if config.padding not in PADDING:
267
+ raise ValueError(f"Invalid padding type: {config.padding}, must be one of {PADDING}")
268
+
269
+ if not any(config_option.startswith(pad + "_") for pad in PADDING):
270
+ config_option = f"{config.padding}_{config_option}"
271
+
272
+ return config_option
embeddings.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2022 MosaicML Examples authors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # Copyright 2023 MosaicML Examples authors
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
11
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Optional
18
+
19
+ from .configuration_bert import FlexBertConfig
20
+ from .normalization import get_norm_layer
21
+ from .initialization import ModuleType, init_weights
22
+
23
+
24
+ class BertAlibiEmbeddings(nn.Module):
25
+ """Construct the embeddings for words, ignoring position.
26
+
27
+ There are no positional embeddings since we use ALiBi and token_type
28
+ embeddings.
29
+
30
+ This module is modeled after the Hugging Face BERT's
31
+ :class:`~transformers.model.bert.modeling_bert.BertEmbeddings`, but is
32
+ modified as part of Mosaic BERT's ALiBi implementation. The key change is
33
+ that position embeddings are removed. Position information instead comes
34
+ from attention biases that scale linearly with the position distance
35
+ between query and key tokens.
36
+
37
+ This module ignores the `position_ids` input to the `forward` method.
38
+ """
39
+
40
+ def __init__(self, config):
41
+ super().__init__()
42
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
43
+ # ALiBi doesn't use position embeddings
44
+ if getattr(config, "token_type_embeddings", True):
45
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
46
+ self.use_token_type_embeddings = True
47
+ else:
48
+ self.use_token_type_embeddings = False
49
+
50
+ self.LayerNorm = get_norm_layer(config)
51
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
52
+ if self.use_token_type_embeddings:
53
+ self.register_buffer(
54
+ "token_type_ids", torch.zeros(config.max_position_embeddings, dtype=torch.long), persistent=False
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: Optional[torch.LongTensor] = None,
60
+ token_type_ids: Optional[torch.LongTensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ past_key_values_length: int = 0,
64
+ ) -> torch.Tensor:
65
+ if (input_ids is not None) == (inputs_embeds is not None):
66
+ raise ValueError("Must specify either input_ids or input_embeds!")
67
+ if input_ids is not None:
68
+ input_shape = input_ids.size()
69
+ else:
70
+ assert inputs_embeds is not None # just for type checking
71
+ input_shape = inputs_embeds.size()[:-1]
72
+
73
+ seq_length = input_shape[1]
74
+
75
+ if position_ids is None:
76
+ # great! ALiBi
77
+ pass
78
+
79
+ # Setting the token_type_ids to the registered buffer in constructor
80
+ # where it is all zeros, which usually occurs when it's auto-generated;
81
+ # registered buffer helps users when tracing the model without passing
82
+ # token_type_ids, solves issue #5664
83
+ if self.use_token_type_embeddings and token_type_ids is None:
84
+ if hasattr(self, "token_type_ids"):
85
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
86
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
87
+ token_type_ids = buffered_token_type_ids_expanded
88
+ else:
89
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
90
+
91
+ if inputs_embeds is None:
92
+ inputs_embeds = self.word_embeddings(input_ids)
93
+
94
+ if self.use_token_type_embeddings:
95
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
96
+ embeddings = inputs_embeds + token_type_embeddings
97
+ else:
98
+ embeddings = inputs_embeds
99
+
100
+ # no position embeddings! ALiBi
101
+ embeddings = self.LayerNorm(embeddings)
102
+ embeddings = self.dropout(embeddings)
103
+ return embeddings
104
+
105
+
106
+ class FlexBertEmbeddingsBase(nn.Module):
107
+ """A FlexBERT embeddings base class for type hints."""
108
+
109
+ def __init__(self, config: FlexBertConfig):
110
+ super().__init__()
111
+ self.config = config
112
+
113
+ def _init_weights(self, reset_params: bool = False):
114
+ raise NotImplementedError("This is a base class and should not be used directly.")
115
+
116
+ def reset_parameters(self):
117
+ self._init_weights(reset_params=True)
118
+
119
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
120
+ raise NotImplementedError("This is a base class and should not be used directly.")
121
+
122
+
123
+ class FlexBertAbsoluteEmbeddings(FlexBertEmbeddingsBase):
124
+ """Construct the embeddings with absolute positional embeddings."""
125
+
126
+ def __init__(self, config: FlexBertConfig):
127
+ super().__init__(config)
128
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
129
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
130
+
131
+ self.norm = get_norm_layer(config) if config.embed_norm else nn.Identity()
132
+ self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity()
133
+
134
+ self.register_buffer(
135
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
136
+ )
137
+
138
+ def _init_weights(self, reset_params: bool = False):
139
+ init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb)
140
+ init_weights(self.config, self.position_embeddings, type_of_module=ModuleType.emb)
141
+
142
+ if reset_params:
143
+ if self.config.embed_norm:
144
+ self.norm.reset_parameters() # type: ignore
145
+
146
+ def forward(
147
+ self,
148
+ input_ids: torch.LongTensor,
149
+ position_ids: Optional[torch.LongTensor] = None,
150
+ ) -> torch.Tensor:
151
+ if position_ids is None:
152
+ position_ids = self.position_ids[:, 0 : input_ids.shape[1]]
153
+
154
+ embeddings = self.tok_embeddings(input_ids)
155
+ position_embeddings = self.position_embeddings(position_ids)
156
+
157
+ embeddings = self.norm(embeddings + position_embeddings)
158
+ return self.drop(embeddings)
159
+
160
+
161
+ class FlexBertCompiledSansPositionEmbeddings(FlexBertEmbeddingsBase):
162
+ """Construct the embeddings from token embeddings without any positional embeddings."""
163
+
164
+ def __init__(self, config: FlexBertConfig):
165
+ super().__init__(config)
166
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
167
+
168
+ self.norm = get_norm_layer(config, compiled_norm=config.compile_model) if config.embed_norm else nn.Identity()
169
+ self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity()
170
+
171
+ def _init_weights(self, reset_params: bool = False):
172
+ init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb)
173
+
174
+ if reset_params:
175
+ if self.config.embed_norm:
176
+ self.norm.reset_parameters() # type: ignore
177
+
178
+ @torch.compile(dynamic=True)
179
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
180
+ return self.drop(self.norm(self.tok_embeddings(input_ids)))
181
+
182
+
183
+ class FlexBertSansPositionEmbeddings(FlexBertEmbeddingsBase):
184
+ """Construct the embeddings from token embeddings without any positional embeddings."""
185
+
186
+ def __init__(self, config: FlexBertConfig):
187
+ super().__init__(config)
188
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
189
+
190
+ self.norm = get_norm_layer(config) if config.embed_norm else nn.Identity()
191
+ self.drop = nn.Dropout(config.embed_dropout_prob) if config.embed_dropout_prob > 0.0 else nn.Identity()
192
+
193
+ def _init_weights(self, reset_params: bool = False):
194
+ init_weights(self.config, self.tok_embeddings, type_of_module=ModuleType.emb)
195
+
196
+ if reset_params:
197
+ if self.config.embed_norm:
198
+ self.norm.reset_parameters() # type: ignore
199
+
200
+ def forward(self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
201
+ return self.drop(self.norm(self.tok_embeddings(input_ids)))
202
+
203
+
204
+ EBB2CLS = {
205
+ "absolute_pos": FlexBertAbsoluteEmbeddings,
206
+ "sans_pos": FlexBertSansPositionEmbeddings,
207
+ }
208
+
209
+
210
+ def get_embedding_layer(config: FlexBertConfig) -> FlexBertEmbeddingsBase:
211
+ try:
212
+ if config.compile_model and config.embedding_layer == "sans_pos":
213
+ return FlexBertCompiledSansPositionEmbeddings(config)
214
+ elif config.compile_model:
215
+ raise ValueError(f"{config.compile_model=} only supports sans_pos embeddings.")
216
+ return EBB2CLS[config.embedding_layer](config)
217
+ except KeyError:
218
+ raise ValueError(f"Invalid embeddings layer type: {config.embedding_layer=}, must be one of {EBB2CLS.keys()}.")
initialization.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright 2023 OLMo Authors
5
+ # License: Apache-2.0
6
+
7
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
8
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
9
+ # License: Apache-2.0
10
+
11
+ import math
12
+ from typing import Optional, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from .utils import StrEnum
18
+
19
+ from .configuration_bert import FlexBertConfig
20
+ from .normalization import RMSNorm
21
+
22
+ __all__ = ["init_weights", "ModuleType", "InitFnType"]
23
+
24
+
25
+ class InitFnType(StrEnum):
26
+ mitchell = "mitchell"
27
+ """
28
+ The strategy suggested to us by Mitchell Wortsman from UW.
29
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
30
+ on the size of the weights as well as the depth of the layer.
31
+ """
32
+
33
+ normal = "normal"
34
+ """
35
+ All weights are initialized from the same normal distribution.
36
+ """
37
+
38
+ default = "default"
39
+ """
40
+ All weights are initialized with the default HuggingFace Bert method. Set init_std=0.02 to match.
41
+ """
42
+
43
+ kaiming_normal = "kaiming_normal"
44
+ """
45
+ All weights are initialized with the Kaiming method from a normal distribution.
46
+ Note this currently won't work with FSDP.
47
+ """
48
+
49
+ fan_in = "fan_in"
50
+ """
51
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
52
+ is the input dimensionality of the kernel.
53
+ """
54
+
55
+ full_megatron = "full_megatron"
56
+ """
57
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
58
+ """
59
+
60
+
61
+ class ModuleType(StrEnum):
62
+ in_module = "in"
63
+ out_module = "out"
64
+ emb = "emb"
65
+ final_out = "final_out"
66
+
67
+
68
+ def init_weights(
69
+ config: FlexBertConfig,
70
+ module: Union[nn.Linear, nn.Embedding],
71
+ layer_dim: Optional[int] = None,
72
+ layer_id: Optional[int] = None,
73
+ std_factor: float = 1.0,
74
+ type_of_module: Optional[ModuleType] = None,
75
+ ) -> None:
76
+ """
77
+ Initialize weights of a linear or embedding module.
78
+
79
+ :param config: The model config.
80
+ :param module: The linear or embedding submodule to initialize.
81
+ :param layer_dim: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
82
+ for fused layers.
83
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
84
+ ``1 / sqrt(2 * (layer_id + 1))``.
85
+ """
86
+ if config.init_method == InitFnType.full_megatron and config.init_small_embedding:
87
+ raise ValueError("Cannot use 'small_embedding_init' with 'full_megatron' init.")
88
+
89
+ layer_dim = layer_dim if layer_dim is not None else config.hidden_size
90
+ if config.init_method == InitFnType.normal:
91
+ std = config.init_std * std_factor
92
+ if config.init_cutoff_factor is not None:
93
+ cutoff_value = config.init_cutoff_factor * std
94
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
95
+ else:
96
+ nn.init.normal_(module.weight, mean=0.0, std=std)
97
+ elif config.init_method == InitFnType.mitchell:
98
+ std = std_factor / math.sqrt(layer_dim)
99
+ if layer_id is not None:
100
+ std = std / math.sqrt(2 * (layer_id + 1))
101
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
102
+ elif config.init_method == InitFnType.kaiming_normal:
103
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
104
+ elif config.init_method == InitFnType.fan_in:
105
+ std = std_factor / math.sqrt(layer_dim)
106
+ nn.init.normal_(module.weight, mean=0.0, std=std)
107
+ elif config.init_method == InitFnType.full_megatron:
108
+ if type_of_module is None:
109
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
110
+
111
+ cutoff_factor = config.init_cutoff_factor
112
+ if cutoff_factor is None:
113
+ cutoff_factor = 3
114
+
115
+ if type_of_module == ModuleType.in_module:
116
+ # for att_proj (same as QKV), ff_proj
117
+ std = config.init_std
118
+ elif type_of_module == ModuleType.out_module:
119
+ # for attn_out, ff_out
120
+ std = config.init_std / math.sqrt(2.0 * config.num_hidden_layers)
121
+ elif type_of_module == ModuleType.emb:
122
+ # positional embeddings (wpe)
123
+ # token embeddings (wte)
124
+ std = config.init_std
125
+ elif type_of_module == ModuleType.final_out:
126
+ # final output (ff_out)
127
+ std = config.hidden_size**-0.5
128
+ else:
129
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
130
+
131
+ nn.init.trunc_normal_(
132
+ module.weight,
133
+ mean=0.0,
134
+ std=std,
135
+ a=-cutoff_factor * std,
136
+ b=cutoff_factor * std,
137
+ )
138
+ elif config.init_method == InitFnType.default:
139
+ # default hugging face bert initialization
140
+ # normalization layers already init to ones and zeros
141
+ if isinstance(module, nn.Linear):
142
+ # Slightly different from the TF version which uses truncated_normal for initialization
143
+ # cf https://github.com/pytorch/pytorch/pull/5617
144
+ module.weight.data.normal_(mean=0.0, std=config.init_std)
145
+ if module.bias is not None:
146
+ module.bias.data.zero_()
147
+ elif isinstance(module, nn.Embedding):
148
+ module.weight.data.normal_(mean=0.0, std=config.init_std)
149
+ if module.padding_idx is not None:
150
+ module.weight.data[module.padding_idx].zero_()
151
+ else:
152
+ raise NotImplementedError(config.init_method)
153
+
154
+ if isinstance(module, nn.Linear):
155
+ if module.bias is not None:
156
+ nn.init.zeros_(module.bias)
157
+
158
+ if config.init_method == InitFnType.normal and getattr(module, "_is_residual", False):
159
+ with torch.no_grad():
160
+ module.weight.div_(math.sqrt(2 * config.num_hidden_layers))
161
+
162
+ if isinstance(module, nn.Embedding) and config.init_small_embedding:
163
+ nn.init.uniform_(module.weight, a=-1e-4, b=1e-4)
164
+
165
+
166
+ class TileMode(StrEnum):
167
+ center_weights = "center_weights"
168
+ tile_weights_from_edge = "tile_weights_from_edge"
169
+ tile_weights_from_middle = "tile_weights_from_middle"
170
+
171
+
172
+ def tile_weight(
173
+ pretrained_weights: torch.Tensor,
174
+ new_weights: torch.Tensor,
175
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
176
+ ) -> torch.Tensor:
177
+ """
178
+ Tile or center an input tensor to a larger desired size. Works for both 2D and 1D tensors.
179
+
180
+ Args:
181
+ pretrained_weights (torch.Tensor): The input tensor to be tiled or centered (1D or 2D).
182
+ new_weights (torch.Tensor): The tensor with the desired size.
183
+ mode (Union[str, TileMode]): 'center_weights', 'tile_weights_from_edge', or 'tile_weights_from_middle'
184
+
185
+ Returns:
186
+ torch.Tensor: The resulting tensor of the desired size.
187
+ """
188
+ assert pretrained_weights.dim() in (1, 2), "Input tensor must be 1-dimensional or 2-dimensional"
189
+ if isinstance(mode, str):
190
+ mode = TileMode(mode)
191
+
192
+ pretrained_weights = pretrained_weights.clone()
193
+
194
+ if pretrained_weights.dim() == 1:
195
+ return _tile_1d(pretrained_weights, new_weights, mode)
196
+ else:
197
+ return _tile_2d(pretrained_weights, new_weights, mode)
198
+
199
+
200
+ def _tile_1d(pretrained_weights: torch.Tensor, new_weights: torch.Tensor, mode: TileMode) -> torch.Tensor:
201
+ assert pretrained_weights.dim() == 1, "Input tensor must be 1-dimensional"
202
+ input_size = pretrained_weights.shape[0]
203
+ new_size = new_weights.shape[0]
204
+ assert new_size >= input_size, "Desired size must be greater than or equal to input size"
205
+
206
+ if mode == TileMode.center_weights:
207
+ offset = (new_size - input_size) // 2
208
+ new_weights[offset : offset + input_size] = pretrained_weights
209
+ return new_weights.clone()
210
+ elif mode == TileMode.tile_weights_from_edge:
211
+ repeat_count = (new_size + input_size - 1) // input_size
212
+ tiled_tensor = pretrained_weights.repeat(repeat_count)
213
+ return tiled_tensor[:new_size].clone()
214
+ elif mode == TileMode.tile_weights_from_middle:
215
+ # Calculate offsets to center the original tensor
216
+ offset = (new_size - input_size) // 2
217
+
218
+ # Create a new tensor with the desired size
219
+ result = torch.zeros(new_size, dtype=pretrained_weights.dtype, device=pretrained_weights.device)
220
+
221
+ # Place the original tensor in the center
222
+ result[offset : offset + input_size] = pretrained_weights
223
+
224
+ # Tile the left and right sides
225
+ for i in range(offset):
226
+ result[offset - 1 - i] = pretrained_weights[input_size - 1 - (i % input_size)]
227
+ for i in range(offset + input_size, new_size):
228
+ result[i] = pretrained_weights[(i - offset) % input_size]
229
+ return result.clone()
230
+
231
+
232
+ def _tile_2d(pretrained_weights: torch.Tensor, new_weights: torch.Tensor, mode: TileMode) -> torch.Tensor:
233
+ assert pretrained_weights.dim() == 2, "Input tensor must be 2-dimensional"
234
+ input_height, input_width = pretrained_weights.shape
235
+ new_height, new_width = new_weights.shape
236
+ assert new_height >= input_height, "Desired height must be greater than or equal to input height"
237
+ assert new_width >= input_width, "Desired width must be greater than or equal to input width"
238
+
239
+ if mode == TileMode.center_weights:
240
+ height_offset = (new_height - input_height) // 2
241
+ width_offset = (new_width - input_width) // 2
242
+ new_weights[height_offset : height_offset + input_height, width_offset : width_offset + input_width] = pretrained_weights # fmt: skip
243
+ return new_weights.clone()
244
+ elif mode == TileMode.tile_weights_from_edge:
245
+ repeat_height = (new_height + input_height - 1) // input_height
246
+ repeat_width = (new_width + input_width - 1) // input_width
247
+ tiled_tensor = pretrained_weights.repeat(repeat_height, repeat_width)
248
+ return tiled_tensor[:new_height, :new_width].clone()
249
+ elif mode == TileMode.tile_weights_from_middle:
250
+ # Calculate offsets to center the original tensor
251
+ height_offset = (new_height - input_height) // 2
252
+ width_offset = (new_width - input_width) // 2
253
+
254
+ # Create a new tensor with the desired width and input height
255
+ horizontal_tiled = torch.zeros(
256
+ input_height, new_width, dtype=pretrained_weights.dtype, device=pretrained_weights.device
257
+ )
258
+
259
+ # Place the original tensor in the center horizontally
260
+ horizontal_tiled[:, width_offset : width_offset + input_width] = pretrained_weights
261
+
262
+ # Tile the left and right sides
263
+ for i in range(width_offset):
264
+ horizontal_tiled[:, i] = horizontal_tiled[
265
+ :, width_offset + input_width - 1 - (width_offset - i - 1) % input_width
266
+ ]
267
+ for i in range(width_offset + input_width, new_width):
268
+ horizontal_tiled[:, i] = horizontal_tiled[:, width_offset + (i - width_offset) % input_width]
269
+
270
+ # Now tile vertically
271
+ result = torch.zeros(new_height, new_width, dtype=pretrained_weights.dtype, device=pretrained_weights.device)
272
+ result[height_offset : height_offset + input_height, :] = horizontal_tiled
273
+
274
+ # Tile top
275
+ for i in range(height_offset):
276
+ row_to_copy = (input_height - 1) - (i % input_height)
277
+ result[height_offset - 1 - i, :] = horizontal_tiled[row_to_copy, :]
278
+
279
+ # Tile bottom
280
+ for i in range(height_offset + input_height, new_height):
281
+ row_to_copy = (i - height_offset) % input_height
282
+ result[i, :] = horizontal_tiled[row_to_copy, :]
283
+ return result.clone()
284
+
285
+
286
+ def tile_fused_qkv(
287
+ pretrained_qkv_weight: torch.Tensor,
288
+ new_qkv_weight: torch.Tensor,
289
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
290
+ ):
291
+ """
292
+ Tile the weights of a fused pretrained QKV layer to a new, larger QKV dimension.
293
+
294
+ Args:
295
+ pretrained_qkv_weight (torch.Tensor): The original fused QKV layer
296
+ new_qkv_weight (torch.Tensor): The new fused QKV layer with larger linear_dim
297
+ mode (Union[str, TileMode]): The tiling mode to use
298
+ Returns:
299
+ torch.Tensor: The new fused QKV layer with tiled weights
300
+ """
301
+ # Split QKV, assume new_q, new_k, new_v are the same shape
302
+ pretrained_q, pretrained_k, pretrained_v = pretrained_qkv_weight.chunk(3, dim=0)
303
+ new_q, new_k, new_v = new_qkv_weight.chunk(3, dim=0)
304
+
305
+ # Tile Q, K, V separately
306
+ new_q = tile_weight(pretrained_q, new_q, mode=mode)
307
+ new_k = tile_weight(pretrained_k, new_k, mode=mode)
308
+ new_v = tile_weight(pretrained_v, new_v, mode=mode)
309
+
310
+ # Concatenate tiled Q, K, V
311
+ return torch.cat([new_q, new_k, new_v], dim=0)
312
+
313
+
314
+ def tile_fused_glu(
315
+ pretrained_glu_weight: torch.Tensor,
316
+ new_glu_weight: torch.Tensor,
317
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
318
+ ):
319
+ """
320
+ Tile the weights of a fused pretrained GLU layer to a new, larger GLU dimension.
321
+
322
+ Args:
323
+ pretrained_glu_weight (torch.Tensor): The original fused GLU layer
324
+ new_glu_weight (torch.Tensor): The new fused GLU layer with larger linear_dim
325
+ mode (Union[str, TileMode]): The tiling mode to use
326
+ Returns:
327
+ torch.Tensor: The new fused GLU layer with tiled weights
328
+ """
329
+ # Split GLU, assume new_glu_wi, new_glu_wg are the same shape
330
+ pretrained_glu_wi, pretrained_glu_wg = pretrained_glu_weight.chunk(2, dim=0)
331
+ new_glu_wi, new_glu_wg = new_glu_weight.chunk(2, dim=0)
332
+
333
+ # Tile GLU separately
334
+ new_glu_wi = tile_weight(pretrained_glu_wi, new_glu_wi, mode=mode)
335
+ new_glu_wg = tile_weight(pretrained_glu_wg, new_glu_wg, mode=mode)
336
+
337
+ # Concatenate tiled GLU
338
+ return torch.cat([new_glu_wi, new_glu_wg], dim=0)
339
+
340
+
341
+ def tile_fused_qkvff(
342
+ pretrained_qkvff_weight: torch.Tensor,
343
+ new_qkvff_weight: torch.Tensor,
344
+ pretrained_attn_size: int,
345
+ pretrained_mlp_size: int,
346
+ new_attn_size: int,
347
+ new_mlp_size: int,
348
+ is_glu: bool = False,
349
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
350
+ ):
351
+ """
352
+ Tile the weights of a fused pretrained QKVFF layer to a new, larger QKVFF dimension.
353
+
354
+ Args:
355
+ pretrained_qkvff_weight (torch.Tensor): The original fused QKVFF layer
356
+ new_qkvff_weight (torch.Tensor): The new fused QKVFF layer with larger linear_dim
357
+ pretrained_attn_size (int): The attention size of the pretrained fused QKVFF layer
358
+ pretrained_mlp_size (int): The mlp size of the pretrained fused QKVFF layer
359
+ new_attn_size (int): The attention size of the new fused QKVFF layer
360
+ new_mlp_size (int): The mlp size of the new fused QKVFF layer
361
+ is_glu (bool): Whether the QKVFF layer is a GLU layer
362
+ mode (Union[str, TileMode]): The tiling mode to use
363
+ Returns:
364
+ torch.Tensor: The new fused QKVFF layer with tiled weights
365
+ """
366
+ # Split QKVFF
367
+ pretrained_qkv, pretrained_ff = pretrained_qkvff_weight.split([pretrained_attn_size, pretrained_mlp_size], dim=0)
368
+ new_qkv, new_ff = new_qkvff_weight.split([new_attn_size, new_mlp_size], dim=0)
369
+
370
+ # Tile QKVFF separately
371
+ new_qkv = tile_fused_qkv(pretrained_qkv, new_qkv, mode=mode)
372
+ if is_glu:
373
+ new_ff = tile_fused_glu(pretrained_ff, new_ff, mode=mode)
374
+ else:
375
+ new_ff = tile_weight(pretrained_ff, new_ff, mode=mode)
376
+
377
+ # Concatenate tiled QKVFF
378
+ return torch.cat([new_qkv, new_ff], dim=0)
379
+
380
+
381
+ class TileLinear(StrEnum):
382
+ wqkv = "wqkv"
383
+ glu = "glu"
384
+ wqkvff = "wqkvff"
385
+ default = "default"
386
+
387
+
388
+ def tile_linear(
389
+ pretrained_linear: nn.Linear,
390
+ new_linear: nn.Linear,
391
+ linear_type: Union[str, TileLinear] = TileLinear.default,
392
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
393
+ pretrained_attn_size: Optional[int] = None,
394
+ pretrained_mlp_size: Optional[int] = None,
395
+ new_attn_size: Optional[int] = None,
396
+ new_mlp_size: Optional[int] = None,
397
+ wqkvff_is_glu: Optional[bool] = None,
398
+ bias_only: Optional[bool] = False,
399
+ ):
400
+ """
401
+ Tile the weights of a linear layer to a new, larger linear dimension.
402
+
403
+ Args:
404
+ pretrained_linear (nn.Linear): The original linear layer
405
+ new_linear (nn.Linear): The new linear layer with larger linear_dim
406
+ linear_type (Union[str, TileLinear]): The type of linear layer to tile
407
+ mode (Union[str, TileMode]): The tiling mode to use
408
+ pretrained_attn_size (int): The attention size of the pretrained linear layer. Only used if linear_type is wqkvff.
409
+ pretrained_mlp_size (int): The mlp size of the pretrained linear layer. Only used if linear_type is wqkvff.
410
+ new_attn_size (int): The attention size of the new linear layer. Only used if linear_type is wqkvff.
411
+ new_mlp_size (int): The mlp size of the new linear layer. Only used if linear_type is wqkvff.
412
+ wqkvff_is_glu (bool): Whether the wqkvff layer is a GLU layer. Only used if linear_type is wqkvff.
413
+ bias_only (bool): Whether to only tile the bias. Only used if tiling weight tied decoder.
414
+ """
415
+ if isinstance(linear_type, str):
416
+ linear_type = TileLinear(linear_type)
417
+ if isinstance(mode, str):
418
+ mode = TileMode(mode)
419
+
420
+ with torch.no_grad():
421
+ if linear_type == TileLinear.wqkv:
422
+ if not bias_only:
423
+ new_linear.weight = nn.Parameter(
424
+ tile_fused_qkv(pretrained_linear.weight, new_linear.weight, mode=mode),
425
+ requires_grad=new_linear.weight.requires_grad,
426
+ )
427
+ if pretrained_linear.bias is not None:
428
+ new_linear.bias = nn.Parameter(
429
+ tile_fused_qkv(pretrained_linear.bias, new_linear.bias, mode=mode),
430
+ requires_grad=new_linear.bias.requires_grad,
431
+ )
432
+ elif linear_type == TileLinear.glu:
433
+ if not bias_only:
434
+ new_linear.weight = nn.Parameter(
435
+ tile_fused_glu(pretrained_linear.weight, new_linear.weight, mode=mode),
436
+ requires_grad=new_linear.weight.requires_grad,
437
+ )
438
+ if pretrained_linear.bias is not None:
439
+ new_linear.bias = nn.Parameter(
440
+ tile_fused_glu(pretrained_linear.bias, new_linear.bias, mode=mode),
441
+ requires_grad=new_linear.bias.requires_grad,
442
+ )
443
+ elif linear_type == TileLinear.wqkvff:
444
+ if not bias_only:
445
+ new_linear.weight = nn.Parameter(
446
+ tile_fused_qkvff(
447
+ pretrained_linear.weight,
448
+ new_linear.weight,
449
+ pretrained_attn_size,
450
+ pretrained_mlp_size,
451
+ new_attn_size,
452
+ new_mlp_size,
453
+ wqkvff_is_glu,
454
+ mode=mode,
455
+ ),
456
+ requires_grad=new_linear.weight.requires_grad,
457
+ )
458
+ if pretrained_linear.bias is not None:
459
+ new_linear.bias = nn.Parameter(
460
+ tile_fused_qkvff(
461
+ pretrained_linear.bias,
462
+ new_linear.bias,
463
+ pretrained_attn_size,
464
+ pretrained_mlp_size,
465
+ new_attn_size,
466
+ new_mlp_size,
467
+ wqkvff_is_glu,
468
+ mode=mode,
469
+ ),
470
+ requires_grad=new_linear.bias.requires_grad,
471
+ )
472
+ else:
473
+ if not bias_only:
474
+ new_linear.weight = nn.Parameter(
475
+ tile_weight(pretrained_linear.weight, new_linear.weight, mode=mode),
476
+ requires_grad=new_linear.weight.requires_grad,
477
+ )
478
+ if pretrained_linear.bias is not None:
479
+ new_linear.bias = nn.Parameter(
480
+ tile_weight(pretrained_linear.bias, new_linear.bias, mode=mode),
481
+ requires_grad=new_linear.bias.requires_grad,
482
+ )
483
+
484
+
485
+ def tile_norm(
486
+ pretrained_norm: Union[nn.LayerNorm, RMSNorm, nn.Identity],
487
+ new_norm: Union[nn.LayerNorm, RMSNorm, nn.Identity],
488
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
489
+ ):
490
+ """
491
+ Tile the weights of a pretrained norm layer to a new, larger layer norm dimension.
492
+
493
+ Args:
494
+ pretrained_norm (Union[nn.LayerNorm, RMSNorm, nn.Identity]): The original norm layer
495
+ new_norm (Union[nn.LayerNorm, RMSNorm, nn.Identity]): The new norm layer with larger layer norm dimension
496
+ mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
497
+ """
498
+ if isinstance(pretrained_norm, nn.Identity):
499
+ return
500
+ if isinstance(mode, str):
501
+ mode = TileMode(mode)
502
+
503
+ with torch.no_grad():
504
+ new_norm.weight.data = nn.Parameter(
505
+ tile_weight(pretrained_norm.weight, new_norm.weight, mode=mode),
506
+ requires_grad=new_norm.weight.requires_grad,
507
+ )
508
+ if hasattr(pretrained_norm, "bias") and pretrained_norm.bias is not None:
509
+ new_norm.bias.data = nn.Parameter(
510
+ tile_weight(pretrained_norm.bias, new_norm.bias, mode=mode),
511
+ requires_grad=new_norm.bias.requires_grad,
512
+ )
513
+
514
+
515
+ def tile_embedding(
516
+ pretrained_embedding: nn.Embedding,
517
+ new_embedding: nn.Embedding,
518
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
519
+ ) -> nn.Embedding:
520
+ """
521
+ Tile the weights of an embedding layer to a new, larger embedding dimension.
522
+
523
+ Args:
524
+ pretrained_embedding (nn.Embedding): The original embedding layer
525
+ new_embedding (nn.Embedding): The new embedding layer with larger embedding_dim
526
+ tile_mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
527
+
528
+ Returns:
529
+ nn.Embedding: The new embedding layer with tiled weights
530
+ """
531
+ with torch.no_grad():
532
+ # Ensure vocabulary size remains the same
533
+ if pretrained_embedding.num_embeddings != new_embedding.num_embeddings:
534
+ raise ValueError("Vocabulary size (num_embeddings) must remain constant")
535
+
536
+ # Ensure new embedding dimension is larger
537
+ if new_embedding.embedding_dim <= pretrained_embedding.embedding_dim:
538
+ raise ValueError("New embedding_dim must be larger than the old embedding_dim")
539
+
540
+ # Tile the weights
541
+ new_embedding.weight.data = nn.Parameter(
542
+ tile_weight(pretrained_embedding.weight, new_embedding.weight, mode=mode),
543
+ requires_grad=new_embedding.weight.requires_grad,
544
+ )
545
+
546
+ # Handle padding_idx if it exists
547
+ if pretrained_embedding.padding_idx is not None:
548
+ if new_embedding.padding_idx is None:
549
+ new_embedding.padding_idx = pretrained_embedding.padding_idx
550
+ else:
551
+ assert new_embedding.padding_idx == pretrained_embedding.padding_idx, "padding_idx must remain the same"
layers.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2022 MosaicML Examples authors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # Copyright 2023 MosaicML Examples authors
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
11
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+
15
+ import copy
16
+ import math
17
+ import warnings
18
+ from typing import Optional, Union, List
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from .bert_padding import unpad_input, pad_input
24
+
25
+ from .activation import get_act_fn
26
+ from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
27
+ from .mlp import FlexBertMLPBase, BertResidualGLU, get_mlp_layer
28
+ from .configuration_bert import FlexBertConfig, maybe_add_padding
29
+ from .normalization import get_norm_layer
30
+ from .initialization import ModuleType, init_weights
31
+
32
+
33
+ class BertAlibiLayer(nn.Module):
34
+ """Composes the Mosaic BERT attention and FFN blocks into a single layer."""
35
+
36
+ def __init__(self, config):
37
+ super().__init__()
38
+ self.attention = BertAlibiUnpadAttention(config)
39
+ self.mlp = BertResidualGLU(config)
40
+
41
+ def forward(
42
+ self,
43
+ hidden_states: torch.Tensor,
44
+ cu_seqlens: torch.Tensor,
45
+ seqlen: int,
46
+ subset_idx: Optional[torch.Tensor] = None,
47
+ indices: Optional[torch.Tensor] = None,
48
+ attn_mask: Optional[torch.Tensor] = None,
49
+ bias: Optional[torch.Tensor] = None,
50
+ slopes: Optional[torch.Tensor] = None,
51
+ ) -> torch.Tensor:
52
+ """Forward pass for a BERT layer, including both attention and MLP.
53
+
54
+ Args:
55
+ hidden_states: (total_nnz, dim)
56
+ cu_seqlens: (batch + 1,)
57
+ seqlen: int
58
+ subset_idx: () set of indices whose values we care about at the end of the layer
59
+ (e.g., the masked tokens, if this is the final layer).
60
+ indices: None or (total_nnz,)
61
+ attn_mask: None or (batch, max_seqlen_in_batch)
62
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
63
+ slopes: None or (batch, heads) or (heads,)
64
+ """
65
+ assert (bias is None) == (slopes is None), f"{bias=}, {slopes=}"
66
+ attention_output = self.attention(
67
+ hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias, slopes
68
+ )
69
+ layer_output = self.mlp(attention_output)
70
+ return layer_output
71
+
72
+
73
+ class BertAlibiEncoder(nn.Module):
74
+ """A stack of BERT layers providing the backbone of Mosaic BERT.
75
+
76
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
77
+ but with substantial modifications to implement unpadding and ALiBi.
78
+
79
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
80
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
81
+ """
82
+
83
+ def __init__(self, config):
84
+ super().__init__()
85
+ layer = BertAlibiLayer(config)
86
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
87
+
88
+ self.num_attention_heads = config.num_attention_heads
89
+
90
+ # The alibi mask will be dynamically expanded if it is too small for
91
+ # the input the model receives. But it generally helps to initialize it
92
+ # to a reasonably large size to help pre-allocate CUDA memory.
93
+ # The default `alibi_starting_size` is 512.
94
+ self._current_alibi_size = int(config.alibi_starting_size)
95
+ self.alibi = torch.zeros((1, self.num_attention_heads, self._current_alibi_size, self._current_alibi_size))
96
+ self.rebuild_alibi_tensor(size=config.alibi_starting_size)
97
+
98
+ def rebuild_alibi_tensor(self, size: int, device: Optional[Union[torch.device, str]] = None):
99
+ # Alibi
100
+ # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
101
+ # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
102
+ # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
103
+ # will be applied, it is necessary to construct the diagonal mask.
104
+ n_heads = self.num_attention_heads
105
+
106
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
107
+ def get_slopes_power_of_2(n_heads: int) -> List[float]:
108
+ start = 2 ** (-(2 ** -(math.log2(n_heads) - 3)))
109
+ ratio = start
110
+ return [start * ratio**i for i in range(n_heads)]
111
+
112
+ # In the paper, they only train models that have 2^a heads for some a. This function
113
+ # has some good properties that only occur when the input is a power of 2. To
114
+ # maintain that even when the number of heads is not a power of 2, we use a
115
+ # workaround.
116
+ if math.log2(n_heads).is_integer():
117
+ return get_slopes_power_of_2(n_heads)
118
+
119
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
120
+ slopes_a = get_slopes_power_of_2(closest_power_of_2)
121
+ slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
122
+ slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2]
123
+ return slopes_a + slopes_b
124
+
125
+ context_position = torch.arange(size, device=device)[:, None]
126
+ memory_position = torch.arange(size, device=device)[None, :]
127
+ relative_position = torch.abs(memory_position - context_position)
128
+ # [n_heads, max_token_length, max_token_length]
129
+ relative_position = relative_position.unsqueeze(0).expand(n_heads, -1, -1)
130
+ slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
131
+ self.slopes = slopes
132
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
133
+ # [1, n_heads, max_token_length, max_token_length]
134
+ alibi = alibi.unsqueeze(0)
135
+ assert alibi.shape == torch.Size([1, n_heads, size, size])
136
+
137
+ self._current_alibi_size = size
138
+ self.alibi = alibi
139
+
140
+ def forward(
141
+ self,
142
+ hidden_states: torch.Tensor,
143
+ attention_mask: torch.Tensor,
144
+ output_all_encoded_layers: Optional[bool] = True,
145
+ subset_mask: Optional[torch.Tensor] = None,
146
+ ) -> List[torch.Tensor]:
147
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
148
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
149
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
150
+
151
+ attention_mask_bool = attention_mask.bool()
152
+ batch, seqlen = hidden_states.shape[:2]
153
+ # Unpad inputs and mask. It will remove tokens that are padded.
154
+ # Assume ntokens is total number of tokens (padded and non-padded)
155
+ # and ntokens_unpad is total number of non-padded tokens.
156
+ # Then unpadding performs the following compression of the inputs:
157
+ # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
158
+ hidden_states, indices, cu_seqlens, _ = unpad_input(hidden_states, attention_mask_bool)
159
+
160
+ # Add alibi matrix to extended_attention_mask
161
+ if self._current_alibi_size < seqlen:
162
+ # Rebuild the alibi tensor when needed
163
+ warnings.warn(f"Increasing alibi size from {self._current_alibi_size} to {seqlen}")
164
+ self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
165
+ elif self.alibi.device != hidden_states.device:
166
+ # Device catch-up
167
+ self.alibi = self.alibi.to(hidden_states.device)
168
+ self.slopes = self.slopes.to(hidden_states.device) # type: ignore
169
+ alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
170
+ attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
171
+ alibi_attn_mask = attn_bias + alibi_bias
172
+
173
+ all_encoder_layers = []
174
+ if subset_mask is None:
175
+ for layer_module in self.layer:
176
+ hidden_states = layer_module(
177
+ hidden_states,
178
+ cu_seqlens,
179
+ seqlen,
180
+ None,
181
+ indices,
182
+ attn_mask=attention_mask,
183
+ bias=alibi_attn_mask,
184
+ slopes=self.slopes,
185
+ )
186
+ if output_all_encoded_layers:
187
+ all_encoder_layers.append(hidden_states)
188
+ # Pad inputs and mask. It will insert back zero-padded tokens.
189
+ # Assume ntokens is total number of tokens (padded and non-padded)
190
+ # and ntokens_unpad is total number of non-padded tokens.
191
+ # Then padding performs the following de-compression:
192
+ # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
193
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
194
+ else:
195
+ for i in range(len(self.layer) - 1):
196
+ layer_module = self.layer[i]
197
+ hidden_states = layer_module(
198
+ hidden_states,
199
+ cu_seqlens,
200
+ seqlen,
201
+ None,
202
+ indices,
203
+ attn_mask=attention_mask,
204
+ bias=alibi_attn_mask,
205
+ slopes=self.slopes,
206
+ )
207
+ if output_all_encoded_layers:
208
+ all_encoder_layers.append(hidden_states)
209
+ subset_idx = torch.nonzero(subset_mask[attention_mask_bool], as_tuple=False).flatten()
210
+ hidden_states = self.layer[-1](
211
+ hidden_states,
212
+ cu_seqlens,
213
+ seqlen,
214
+ subset_idx=subset_idx,
215
+ indices=indices,
216
+ attn_mask=attention_mask,
217
+ bias=alibi_attn_mask,
218
+ slopes=self.slopes,
219
+ )
220
+
221
+ if not output_all_encoded_layers:
222
+ all_encoder_layers.append(hidden_states)
223
+ return all_encoder_layers
224
+
225
+
226
+ class BertPooler(nn.Module):
227
+ def __init__(self, config):
228
+ super(BertPooler, self).__init__()
229
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
230
+ self.activation = nn.Tanh()
231
+
232
+ def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
233
+ # We "pool" the model by simply taking the hidden state corresponding
234
+ # to the first token.
235
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
236
+ pooled_output = self.dense(first_token_tensor)
237
+ pooled_output = self.activation(pooled_output)
238
+ return pooled_output
239
+
240
+
241
+ class BertPredictionHeadTransform(nn.Module):
242
+ def __init__(self, config):
243
+ super().__init__()
244
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
245
+ if isinstance(config.hidden_act, str):
246
+ self.transform_act_fn = get_act_fn(config.head_pred_act)
247
+ else:
248
+ self.transform_act_fn = config.hidden_act
249
+ self.LayerNorm = get_norm_layer(config)
250
+
251
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
252
+ hidden_states = self.dense(hidden_states)
253
+ hidden_states = self.transform_act_fn(hidden_states)
254
+ hidden_states = self.LayerNorm(hidden_states)
255
+ return hidden_states
256
+
257
+
258
+ class FlexBertLayerBase(nn.Module):
259
+ """A FlexBERT Layer base class for type hints."""
260
+
261
+ attn: FlexBertAttentionBase
262
+ mlp: FlexBertMLPBase
263
+
264
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
265
+ super().__init__()
266
+ self.config = config
267
+ self.layer_id = layer_id
268
+
269
+ def _init_weights(self, reset_params: bool = False):
270
+ if hasattr(self, "attn"):
271
+ self.attn._init_weights(reset_params)
272
+ if hasattr(self, "mlp"):
273
+ self.mlp._init_weights(reset_params)
274
+
275
+ def reset_parameters(self):
276
+ self._init_weights(reset_params=True)
277
+
278
+ def forward(self, hidden_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
279
+ raise NotImplementedError("This is a base class and should not be used directly.")
280
+
281
+
282
+ class FlexBertCompileUnpadPreNormLayer(FlexBertLayerBase):
283
+ """Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
284
+
285
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
286
+ super().__init__(config=config, layer_id=layer_id)
287
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
288
+ self.attn_norm = nn.Identity()
289
+ else:
290
+ self.attn_norm = get_norm_layer(config)
291
+ self.attn = get_attention_layer(config, layer_id=layer_id)
292
+ self.mlp_norm = get_norm_layer(config, compiled_norm=config.compile_model)
293
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
294
+ self.compile_model = config.compile_model
295
+
296
+ def _init_weights(self, reset_params: bool = False):
297
+ super()._init_weights(reset_params)
298
+ if reset_params:
299
+ self.attn_norm.reset_parameters()
300
+ self.mlp_norm.reset_parameters()
301
+
302
+ @torch.compile(dynamic=True)
303
+ def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
304
+ return self.mlp(self.mlp_norm(hidden_states))
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ cu_seqlens: torch.Tensor,
310
+ max_seqlen: int,
311
+ indices: Optional[torch.Tensor] = None,
312
+ attn_mask: Optional[torch.Tensor] = None,
313
+ ) -> torch.Tensor:
314
+ """Forward pass for a BERT layer, including both attention and MLP.
315
+
316
+ Args:
317
+ hidden_states: (total_nnz, dim)
318
+ cu_seqlens: (batch + 1,)
319
+ max_seqlen: int
320
+ indices: None or (total_nnz,)
321
+ attn_mask: None or (batch, max_seqlen)
322
+ """
323
+ attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), cu_seqlens, max_seqlen, indices, attn_mask)
324
+ return attn_out + self.compiled_mlp(attn_out)
325
+
326
+
327
+ class FlexBertUnpadPreNormLayer(FlexBertLayerBase):
328
+ """Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
329
+
330
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
331
+ super().__init__(config=config, layer_id=layer_id)
332
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
333
+ self.attn_norm = nn.Identity()
334
+ else:
335
+ self.attn_norm = get_norm_layer(config)
336
+ self.attn = get_attention_layer(config, layer_id=layer_id)
337
+ self.mlp_norm = get_norm_layer(config)
338
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
339
+
340
+ def _init_weights(self, reset_params: bool = False):
341
+ super()._init_weights(reset_params)
342
+ if reset_params:
343
+ self.attn_norm.reset_parameters()
344
+ self.mlp_norm.reset_parameters()
345
+
346
+ def forward(
347
+ self,
348
+ hidden_states: torch.Tensor,
349
+ cu_seqlens: torch.Tensor,
350
+ max_seqlen: int,
351
+ indices: Optional[torch.Tensor] = None,
352
+ attn_mask: Optional[torch.Tensor] = None,
353
+ ) -> torch.Tensor:
354
+ """Forward pass for a BERT layer, including both attention and MLP.
355
+
356
+ Args:
357
+ hidden_states: (total_nnz, dim)
358
+ cu_seqlens: (batch + 1,)
359
+ max_seqlen: int
360
+ indices: None or (total_nnz,)
361
+ attn_mask: None or (batch, max_seqlen)
362
+ """
363
+ attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), cu_seqlens, max_seqlen, indices, attn_mask)
364
+ return attn_out + self.mlp(self.mlp_norm(attn_out))
365
+
366
+
367
+ class FlexBertUnpadParallelPreNormLayer(FlexBertLayerBase):
368
+ """Composes the FlexBERT parallel attention and MLP blocks into a single layer using pre-normalization."""
369
+
370
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
371
+ super().__init__(config=config, layer_id=layer_id)
372
+ self.attn_size = config.hidden_size * 3
373
+ self.mlp_size = config.intermediate_size * 2
374
+ # Compute QKV and FF outputs at once
375
+ self.Wqkvff = nn.Linear(config.hidden_size, self.attn_size + self.mlp_size, bias=config.attn_qkv_bias)
376
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
377
+ self.norm = nn.Identity()
378
+ else:
379
+ self.norm = get_norm_layer(config)
380
+ self.attn = get_attention_layer(config, layer_id=layer_id)
381
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
382
+
383
+ def _init_weights(self, reset_params: bool = False):
384
+ super()._init_weights(reset_params)
385
+ if reset_params and hasattr(self.norm, "reset_parameters"):
386
+ self.norm.reset_parameters()
387
+
388
+ init_weights(
389
+ self.config,
390
+ self.Wqkvff,
391
+ layer_dim=self.config.hidden_size,
392
+ layer_id=None,
393
+ type_of_module=ModuleType.in_module,
394
+ )
395
+
396
+ def forward(
397
+ self,
398
+ hidden_states: torch.Tensor,
399
+ cu_seqlens: torch.Tensor,
400
+ max_seqlen: int,
401
+ indices: Optional[torch.Tensor] = None,
402
+ attn_mask: Optional[torch.Tensor] = None,
403
+ ) -> torch.Tensor:
404
+ """Forward pass for a BERT layer, including both attention and MLP.
405
+
406
+ Args:
407
+ hidden_states: (total_nnz, dim)
408
+ attn_mask: None or (batch, max_seqlen)
409
+ """
410
+ # Compute QKV and FF outputs at once and split them
411
+ qkv, intermediate_ff = self.Wqkvff(self.norm(hidden_states)).split([self.attn_size, self.mlp_size], dim=1)
412
+ return hidden_states + self.attn(qkv, cu_seqlens, max_seqlen, indices, attn_mask) + self.mlp(intermediate_ff)
413
+
414
+
415
+ class FlexBertPaddedPreNormLayer(FlexBertLayerBase):
416
+ """Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
417
+
418
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
419
+ super().__init__(config=config, layer_id=layer_id)
420
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
421
+ self.attn_norm = nn.Identity()
422
+ else:
423
+ self.attn_norm = get_norm_layer(config)
424
+ self.attn = get_attention_layer(config, layer_id=layer_id)
425
+ self.mlp_norm = get_norm_layer(config)
426
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
427
+
428
+ def _init_weights(self, reset_params: bool = False):
429
+ super()._init_weights(reset_params)
430
+ if reset_params:
431
+ self.attn_norm.reset_parameters()
432
+ self.mlp_norm.reset_parameters()
433
+
434
+ def forward(
435
+ self,
436
+ hidden_states: torch.Tensor,
437
+ attn_mask: Optional[torch.Tensor] = None,
438
+ ) -> torch.Tensor:
439
+ """Forward pass for a BERT layer, including both attention and MLP.
440
+
441
+ Args:
442
+ hidden_states: (batch, max_seqlen, dim)
443
+ attn_mask: None or (batch, max_seqlen)
444
+ """
445
+ attn_out = hidden_states + self.attn(self.attn_norm(hidden_states), attn_mask)
446
+ return attn_out + self.mlp(self.mlp_norm(attn_out))
447
+
448
+
449
+ class FlexBertPaddedParallelPreNormLayer(FlexBertLayerBase):
450
+ """Composes the FlexBERT attention and MLP blocks into a single layer using pre-normalization."""
451
+
452
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
453
+ super().__init__(config=config, layer_id=layer_id)
454
+ self.attn_size = config.hidden_size * 3
455
+ self.mlp_size = config.intermediate_size * 2
456
+ # Compute QKV and FF outputs at once
457
+ self.Wqkvff = nn.Linear(config.hidden_size, self.attn_size + self.mlp_size, bias=config.attn_qkv_bias)
458
+ if config.skip_first_prenorm and config.embed_norm and layer_id == 0:
459
+ self.norm = nn.Identity()
460
+ else:
461
+ self.norm = get_norm_layer(config)
462
+ self.attn = get_attention_layer(config, layer_id=layer_id)
463
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
464
+
465
+ def _init_weights(self, reset_params: bool = False):
466
+ super()._init_weights(reset_params)
467
+ if reset_params:
468
+ self.norm.reset_parameters()
469
+
470
+ init_weights(
471
+ self.config,
472
+ self.Wqkvff,
473
+ layer_dim=self.config.hidden_size,
474
+ layer_id=None,
475
+ type_of_module=ModuleType.in_module,
476
+ )
477
+
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ attn_mask: Optional[torch.Tensor] = None,
482
+ ) -> torch.Tensor:
483
+ """Forward pass for a BERT layer, including both attention and MLP.
484
+
485
+ Args:
486
+ hidden_states: (batch, max_seqlen, dim)
487
+ attn_mask: None or (batch, max_seqlen)
488
+ """
489
+ # Compute QKV and FF outputs at once and split them
490
+ qkv, intermediate_ff = self.Wqkvff(self.norm(hidden_states)).split([self.attn_size, self.mlp_size], dim=2)
491
+ return hidden_states + self.attn(qkv, attn_mask) + self.mlp(intermediate_ff)
492
+
493
+
494
+ class FlexBertUnpadPostNormLayer(FlexBertLayerBase):
495
+ """Composes the FlexBERT attention and MLP blocks into a single layer using post-normalization."""
496
+
497
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
498
+ super().__init__(config=config, layer_id=layer_id)
499
+ self.attn = get_attention_layer(config, layer_id=layer_id)
500
+ self.attn_norm = get_norm_layer(config)
501
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
502
+ self.mlp_norm = get_norm_layer(config)
503
+
504
+ def _init_weights(self, reset_params: bool = False):
505
+ super()._init_weights(reset_params)
506
+ if reset_params:
507
+ self.attn_norm.reset_parameters()
508
+ self.mlp_norm.reset_parameters()
509
+
510
+ def forward(
511
+ self,
512
+ hidden_states: torch.Tensor,
513
+ cu_seqlens: torch.Tensor,
514
+ max_seqlen: int,
515
+ indices: Optional[torch.Tensor] = None,
516
+ attn_mask: Optional[torch.Tensor] = None,
517
+ ) -> torch.Tensor:
518
+ """Forward pass for a BERT layer, including both attention and MLP.
519
+
520
+ Args:
521
+ hidden_states: (total_nnz, dim)
522
+ cu_seqlens: (batch + 1,)
523
+ max_seqlen: int
524
+ indices: None or (total_nnz,)
525
+ attn_mask: None or (batch, max_seqlen)
526
+ """
527
+ attn_out = self.attn_norm(hidden_states + self.attn(hidden_states, cu_seqlens, max_seqlen, indices, attn_mask))
528
+ return self.mlp_norm(attn_out + self.mlp(attn_out))
529
+
530
+
531
+ class FlexBertPaddedPostNormLayer(FlexBertLayerBase):
532
+ """Composes the FlexBERT attention and MLP blocks into a single layer using post-normalization."""
533
+
534
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
535
+ super().__init__(config=config, layer_id=layer_id)
536
+ self.attn = get_attention_layer(config, layer_id=layer_id)
537
+ self.attn_norm = get_norm_layer(config)
538
+ self.mlp = get_mlp_layer(config, layer_id=layer_id)
539
+ self.mlp_norm = get_norm_layer(config)
540
+
541
+ def _init_weights(self, reset_params: bool = False):
542
+ super()._init_weights(reset_params)
543
+ if reset_params:
544
+ self.mlp_norm.reset_parameters()
545
+
546
+ def forward(
547
+ self,
548
+ hidden_states: torch.Tensor,
549
+ attn_mask: Optional[torch.Tensor] = None,
550
+ ) -> torch.Tensor:
551
+ """Forward pass for a BERT layer, including both attention and MLP.
552
+
553
+ Args:
554
+ hidden_states: (batch, max_seqlen, dim)
555
+ attn_mask: None or (batch, max_seqlen)
556
+ """
557
+ attn_out = self.attn_norm(hidden_states + self.attn(hidden_states, attn_mask))
558
+ return self.mlp_norm(attn_out + self.mlp(attn_out))
559
+
560
+
561
+ LAYER2CLS = {
562
+ "unpadded_prenorm": FlexBertUnpadPreNormLayer,
563
+ "unpadded_compile_prenorm": FlexBertCompileUnpadPreNormLayer,
564
+ "unpadded_parallel_prenorm": FlexBertUnpadParallelPreNormLayer,
565
+ "unpadded_postnorm": FlexBertUnpadPostNormLayer,
566
+ "padded_prenorm": FlexBertPaddedPreNormLayer,
567
+ "padded_parallel_prenorm": FlexBertPaddedParallelPreNormLayer,
568
+ "padded_postnorm": FlexBertPaddedPostNormLayer,
569
+ }
570
+
571
+
572
+ def get_bert_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertLayerBase:
573
+ try:
574
+ bert_layer = (
575
+ config.initial_bert_layer
576
+ if layer_id < config.num_initial_layers and getattr(config, "initial_bert_layer", None) is not None
577
+ else config.bert_layer
578
+ )
579
+ bert_layer = maybe_add_padding(config, bert_layer)
580
+ if config.compile_model and bert_layer == "unpadded_prenorm":
581
+ bert_layer = "unpadded_compile_prenorm"
582
+ return LAYER2CLS[bert_layer](config, layer_id=layer_id)
583
+ except KeyError:
584
+ if layer_id < config.num_initial_layers and getattr(config, "initial_bert_layer", None) is not None:
585
+ raise ValueError(
586
+ f"Invalid BERT layer type: {config.initial_bert_layer=}, must be one of {LAYER2CLS.keys()}."
587
+ f"{config.padding=} will be automatically prepended to `config.bert_layer` if unspecified."
588
+ )
589
+ else:
590
+ raise ValueError(
591
+ f"Invalid BERT layer type: {config.bert_layer=}, must be one of {LAYER2CLS.keys()}. "
592
+ f"{config.padding=} will be automatically prepended to `config.bert_layer` if unspecified."
593
+ )
594
+
595
+
596
+ class FlexBertEncoderBase(nn.Module):
597
+ """A FlexBERT base class for type hints."""
598
+
599
+ layers: nn.ModuleList
600
+
601
+ def _init_weights(self, reset_params: bool = False):
602
+ if hasattr(self, "layers"):
603
+ for layer in self.layers:
604
+ layer._init_weights(reset_params=reset_params)
605
+
606
+ def reset_parameters(self):
607
+ self._init_weights(reset_params=True)
608
+
609
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
610
+ raise NotImplementedError("This is a base class and should not be used directly.")
611
+
612
+
613
+ class FlexBertUnpadEncoder(FlexBertEncoderBase):
614
+ """A stack of BERT layers providing the backbone of FlexBERT.
615
+
616
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
617
+ but with substantial modifications to implement unpadding and ALiBi.
618
+
619
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
620
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
621
+ """
622
+
623
+ def __init__(self, config: FlexBertConfig):
624
+ super().__init__()
625
+ self.layers = nn.ModuleList([get_bert_layer(config, layer_id=i) for i in range(config.num_hidden_layers)])
626
+ self.num_attention_heads = config.num_attention_heads
627
+
628
+ def forward(
629
+ self,
630
+ hidden_states: torch.Tensor,
631
+ attention_mask: torch.Tensor,
632
+ indices: Optional[torch.Tensor] = None,
633
+ cu_seqlens: Optional[torch.Tensor] = None,
634
+ max_seqlen: Optional[int] = None,
635
+ ) -> torch.Tensor:
636
+ if indices is None and cu_seqlens is None and max_seqlen is None:
637
+ attention_mask_bool = attention_mask.bool()
638
+ batch, seqlen = hidden_states.shape[:2]
639
+ hidden_states, indices, cu_seqlens, max_seqlen = unpad_input(
640
+ hidden_states, attention_mask_bool
641
+ )
642
+
643
+ for layer_module in self.layers:
644
+ hidden_states = layer_module(
645
+ hidden_states,
646
+ cu_seqlens,
647
+ max_seqlen,
648
+ indices,
649
+ attn_mask=attention_mask,
650
+ )
651
+
652
+ return pad_input(hidden_states, indices, batch, seqlen)
653
+ else:
654
+ for layer_module in self.layers:
655
+ hidden_states = layer_module(
656
+ hidden_states,
657
+ cu_seqlens,
658
+ max_seqlen,
659
+ indices,
660
+ attn_mask=attention_mask,
661
+ )
662
+ return hidden_states
663
+
664
+
665
+ class FlexBertPaddedEncoder(FlexBertEncoderBase):
666
+ """A stack of BERT layers providing the backbone of FlexBERT.
667
+
668
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertAlibiEncoder`,
669
+ but with substantial modifications to implement unpadding and ALiBi.
670
+
671
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
672
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
673
+ """
674
+
675
+ def __init__(self, config: FlexBertConfig):
676
+ super().__init__()
677
+ self.layers = nn.ModuleList([get_bert_layer(config, layer_id=i) for i in range(config.num_hidden_layers)])
678
+ self.num_attention_heads = config.num_attention_heads
679
+
680
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> torch.Tensor:
681
+ for layer_module in self.layers:
682
+ hidden_states = layer_module(hidden_states, attn_mask=attention_mask)
683
+
684
+ return hidden_states
685
+
686
+
687
+ ENC2CLS = {
688
+ "unpadded_base": FlexBertUnpadEncoder,
689
+ "padded_base": FlexBertPaddedEncoder,
690
+ }
691
+
692
+
693
+ def get_encoder_layer(config: FlexBertConfig) -> FlexBertEncoderBase:
694
+ try:
695
+ return ENC2CLS[maybe_add_padding(config, config.encoder_layer)](config)
696
+ except KeyError:
697
+ raise ValueError(
698
+ f"Invalid encoder layer type: {config.encoder_layer=}, must be one of {ENC2CLS.keys()}. "
699
+ f"{config.padding=} will be automatically prepended to `config.encoder_layer` if unspecified."
700
+ )
mlp.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2022 MosaicML Examples authors
5
+ # SPDX-License-Identifier: Apache-2.0
6
+
7
+ # Copyright 2023 MosaicML Examples authors
8
+ # SPDX-License-Identifier: Apache-2.0
9
+
10
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
11
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
12
+ # Copyright (c) 2023, Tri Dao.
13
+
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from .configuration_bert import FlexBertConfig
20
+ from .activation import get_act_fn
21
+ from .normalization import get_norm_layer
22
+ from .initialization import ModuleType, init_weights
23
+
24
+
25
+ class BertResidualGLU(nn.Module):
26
+ """Applies the FFN at the end of each Mosaic BERT layer.
27
+
28
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
29
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
30
+ introduces Gated Linear Units.
31
+
32
+ Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
33
+ standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
34
+ `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
35
+ with the `config.intermediate_size=3072`.
36
+ However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
37
+ parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ config,
43
+ ):
44
+ super().__init__()
45
+ self.config = config
46
+ self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False)
47
+ self.act = get_act_fn(config.hidden_act)
48
+ self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
49
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
+ self.layernorm = get_norm_layer(config)
51
+
52
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
53
+ """Compute new hidden states from current hidden states.
54
+
55
+ Args:
56
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
57
+ the attention layer [nnz, dim].
58
+ """
59
+ residual_connection = hidden_states
60
+ # compute the activation
61
+ hidden_states = self.gated_layers(hidden_states)
62
+ gated = hidden_states[:, : self.config.intermediate_size]
63
+ non_gated = hidden_states[:, self.config.intermediate_size :]
64
+ hidden_states = self.act(gated) * non_gated
65
+ hidden_states = self.dropout(hidden_states)
66
+ # multiply by the second matrix
67
+ hidden_states = self.wo(hidden_states)
68
+ # add the residual connection and post-LN
69
+ hidden_states = self.layernorm(hidden_states + residual_connection)
70
+ return hidden_states
71
+
72
+
73
+ class FlexBertMLPBase(nn.Module):
74
+ """A FlexBERT MLP base class for type hints."""
75
+
76
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
77
+ super().__init__()
78
+ self.config = config
79
+ self.layer_id = layer_id
80
+
81
+ def _init_weights(self, reset_params: bool = False):
82
+ raise NotImplementedError("This is a base class and should not be used directly.")
83
+
84
+ def reset_parameters(self):
85
+ self._init_weights(reset_params=True)
86
+
87
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
88
+ raise NotImplementedError("This is a base class and should not be used directly.")
89
+
90
+
91
+ class FlexBertMLP(FlexBertMLPBase):
92
+ """Applies the MLP at the end of each FlexBERT layer.
93
+
94
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
95
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
96
+ """
97
+
98
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
99
+ super().__init__(config=config, layer_id=layer_id)
100
+ self.Wi = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_in_bias)
101
+ self.act = get_act_fn(config.hidden_act)
102
+ self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
103
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
104
+
105
+ def _init_weights(self, reset_params: bool = False):
106
+ init_weights(
107
+ self.config,
108
+ self.Wi,
109
+ layer_dim=self.config.hidden_size,
110
+ layer_id=None,
111
+ type_of_module=ModuleType.in_module,
112
+ )
113
+ init_weights(
114
+ self.config,
115
+ self.Wo,
116
+ layer_dim=self.config.intermediate_size,
117
+ layer_id=self.layer_id,
118
+ type_of_module=ModuleType.out_module,
119
+ )
120
+
121
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
122
+ """Compute new hidden states from current hidden states.
123
+
124
+ Args:
125
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
126
+ the attention layer [nnz, dim].
127
+ """
128
+ return self.Wo(self.drop(self.act(self.Wi(hidden_states))))
129
+
130
+
131
+ class FlexBertGLU(FlexBertMLPBase):
132
+ """Applies the GLU at the end of each FlexBERT layer.
133
+
134
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
135
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
136
+ """
137
+
138
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
139
+ super().__init__(config=config, layer_id=layer_id)
140
+ self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_in_bias)
141
+ self.act = get_act_fn(config.hidden_act)
142
+ self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
143
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
144
+
145
+ def _init_weights(self, reset_params: bool = False):
146
+ init_weights(
147
+ self.config,
148
+ self.Wi,
149
+ layer_dim=self.config.hidden_size,
150
+ layer_id=None,
151
+ type_of_module=ModuleType.in_module,
152
+ )
153
+ init_weights(
154
+ self.config,
155
+ self.Wo,
156
+ layer_dim=self.config.intermediate_size,
157
+ layer_id=self.layer_id,
158
+ type_of_module=ModuleType.out_module,
159
+ )
160
+
161
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
162
+ input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
163
+ return self.Wo(self.drop(self.act(input) * gate))
164
+
165
+
166
+ class FlexBertParallelGLU(FlexBertMLPBase):
167
+ """Applies the GLU at the end of each FlexBERT layer using intermediate_ff computed in parallel of the attention.
168
+
169
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
170
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
171
+ """
172
+
173
+ def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None):
174
+ super().__init__(config=config, layer_id=layer_id)
175
+ self.act = get_act_fn(config.hidden_act)
176
+ self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity()
177
+ self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias)
178
+
179
+ def _init_weights(self, reset_params: bool = False):
180
+ init_weights(
181
+ self.config,
182
+ self.Wo,
183
+ layer_dim=self.config.intermediate_size,
184
+ layer_id=self.layer_id,
185
+ type_of_module=ModuleType.out_module,
186
+ )
187
+
188
+ def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor:
189
+ input, gate = intermediate_ff.chunk(2, dim=-1)
190
+ return self.Wo(self.drop(self.act(input) * gate))
191
+
192
+
193
+ MLP2CLS = {
194
+ "mlp": FlexBertMLP,
195
+ "glu": FlexBertGLU,
196
+ "parallel_glu": FlexBertParallelGLU,
197
+ }
198
+
199
+
200
+ def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertMLPBase:
201
+ try:
202
+ mlp_layer = (
203
+ config.initial_mlp_layer
204
+ if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None
205
+ else config.mlp_layer
206
+ )
207
+ return MLP2CLS[mlp_layer](config, layer_id=layer_id)
208
+ except KeyError as e:
209
+ if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None:
210
+ raise ValueError(
211
+ f"Invalid MLP layer type: {config.initial_mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}"
212
+ )
213
+ else:
214
+ raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}")
modeling_flexbert.py ADDED
@@ -0,0 +1,1920 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation)
5
+ # License: LLAMA 2 COMMUNITY LICENSE AGREEMENT
6
+
7
+ # Copyright 2022 Jonas Geiping
8
+ # License: MIT
9
+
10
+ # Copyright 2022 MosaicML Examples authors
11
+ # SPDX-License-Identifier: Apache-2.0
12
+
13
+ # Copyright 2023 MosaicML Examples authors
14
+ # SPDX-License-Identifier: Apache-2.0
15
+
16
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
17
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
18
+ # Copyright (c) 2023, Tri Dao.
19
+
20
+ """Implements Mosaic BERT, with an eye towards the Hugging Face API.
21
+
22
+ Mosaic BERT improves performance over Hugging Face BERT through the following:
23
+
24
+ 1. ALiBi. This architectural change removes positional embeddings and instead encodes positional
25
+ information through attention biases based on query-key position distance. It improves the effectiveness
26
+ of training with shorter sequence lengths by enabling extrapolation to longer sequences.
27
+
28
+ 2. Gated Linear Units (GLU). This architectural change replaces the FFN component of the BERT layer
29
+ to improve overall expressiveness, providing better convergence properties.
30
+
31
+ 3. Flash Attention. The MosaicBERT's self-attention layer makes use of Flash Attention, which dramatically
32
+ improves the speed of self-attention. Our implementation utilizes a bleeding edge implementation that
33
+ supports attention biases, which allows us to use Flash Attention with ALiBi.
34
+
35
+ 4. Unpadding. Padding is often used to simplify batching across sequences of different lengths. Standard BERT
36
+ implementations waste computation on padded tokens. MosaicBERT internally unpads to reduce unnecessary computation
37
+ and improve speed. It does this without changing how the user interfaces with the model, thereby
38
+ preserving the simple API of standard implementations.
39
+
40
+
41
+ Currently, MosaicBERT is available for masked language modeling :class:`BertForMaskedLM` and sequence
42
+ classification :class:`BertForSequenceClassification`. We aim to expand this catalogue in future releases.
43
+
44
+ See :file:`./mosaic_bert.py` for utilities to simplify working with MosaicBERT in Composer, and for example usage
45
+ of the core Mosaic BERT classes.
46
+ """
47
+
48
+ import logging
49
+ import os
50
+ import sys
51
+ import warnings
52
+ from dataclasses import dataclass
53
+ from typing import List, Optional, Tuple, Union
54
+
55
+ # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
56
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
57
+
58
+ import torch
59
+ import torch.nn as nn
60
+ from einops import rearrange
61
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
62
+ from transformers.modeling_outputs import (
63
+ MaskedLMOutput,
64
+ ModelOutput,
65
+ MultipleChoiceModelOutput,
66
+ SequenceClassifierOutput,
67
+ CausalLMOutput,
68
+ )
69
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
70
+
71
+ from bert_padding import index_put_first_axis
72
+
73
+ from src.bert_layers.activation import get_act_fn
74
+ from src.bert_layers.attention import (
75
+ FlexBertPaddedAttention,
76
+ FlexBertPaddedParallelAttention,
77
+ FlexBertPaddedRopeAttention,
78
+ FlexBertPaddedRopeParallelAttention,
79
+ FlexBertUnpadAttention,
80
+ FlexBertUnpadParallelAttention,
81
+ FlexBertUnpadRopeAttention,
82
+ FlexBertUnpadRopeParallelAttention,
83
+ )
84
+ from src.bert_layers.configuration_bert import FlexBertConfig
85
+ from src.bert_layers.embeddings import (
86
+ BertAlibiEmbeddings,
87
+ FlexBertAbsoluteEmbeddings,
88
+ FlexBertCompiledSansPositionEmbeddings,
89
+ FlexBertSansPositionEmbeddings,
90
+ get_embedding_layer,
91
+ )
92
+ from src.bert_layers.initialization import (
93
+ ModuleType,
94
+ TileLinear,
95
+ TileMode,
96
+ init_weights,
97
+ tile_embedding,
98
+ tile_linear,
99
+ tile_norm,
100
+ )
101
+ from src.bert_layers.layers import (
102
+ BertAlibiEncoder,
103
+ BertPooler,
104
+ BertPredictionHeadTransform,
105
+ FlexBertCompileUnpadPreNormLayer,
106
+ FlexBertPaddedEncoder,
107
+ FlexBertPaddedParallelPreNormLayer,
108
+ FlexBertPaddedPostNormLayer,
109
+ FlexBertPaddedPreNormLayer,
110
+ FlexBertUnpadEncoder,
111
+ FlexBertUnpadParallelPreNormLayer,
112
+ FlexBertUnpadPostNormLayer,
113
+ FlexBertUnpadPreNormLayer,
114
+ get_encoder_layer,
115
+ )
116
+ from src.bert_layers.loss import get_loss_fn
117
+ from src.bert_layers.mlp import FlexBertGLU, FlexBertMLP, FlexBertParallelGLU
118
+ from src.bert_layers.normalization import get_norm_layer
119
+ from src.bert_layers.padding import pad_input, unpad_input
120
+
121
+ logger = logging.getLogger(__name__)
122
+
123
+
124
+ def _count_parameters(model: nn.Module, trainable: bool = True) -> int:
125
+ if trainable:
126
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
127
+ else:
128
+ return sum(p.numel() for p in model.parameters())
129
+
130
+
131
+ class BertModel(BertPreTrainedModel):
132
+ """Overall BERT model.
133
+
134
+ Args:
135
+ config: a BertConfig class instance with the configuration to build a new model
136
+
137
+ Inputs:
138
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
139
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
140
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
141
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
142
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
143
+ a `sentence B` token (see BERT paper for more details).
144
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
145
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
146
+ input sequence length in the current batch. It's the mask that we typically use for attention when
147
+ a batch has varying length sentences.
148
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
149
+
150
+ Outputs: Tuple of (encoded_layers, pooled_output)
151
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
152
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
153
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
154
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
155
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
156
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
157
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
158
+ classifier pretrained on top of the hidden state associated to the first character of the
159
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
160
+
161
+ Example usage:
162
+ ```python
163
+ # Already been converted into WordPiece token ids
164
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
165
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
166
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
167
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
168
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
169
+ model = BertModel(config=config)
170
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
171
+ ```
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ config,
177
+ add_pooling_layer: bool = True,
178
+ ):
179
+ super(BertModel, self).__init__(config)
180
+ self.embeddings = BertAlibiEmbeddings(config)
181
+ self.encoder = BertAlibiEncoder(config)
182
+ self.pooler = BertPooler(config) if add_pooling_layer else None
183
+ self.post_init()
184
+
185
+ def get_input_embeddings(self):
186
+ return self.embeddings.word_embeddings
187
+
188
+ def set_input_embeddings(self, value):
189
+ self.embeddings.word_embeddings = value
190
+
191
+ def forward(
192
+ self,
193
+ input_ids: torch.Tensor,
194
+ token_type_ids: Optional[torch.Tensor] = None,
195
+ attention_mask: Optional[torch.Tensor] = None,
196
+ position_ids: Optional[torch.Tensor] = None,
197
+ output_all_encoded_layers: Optional[bool] = False,
198
+ masked_tokens_mask: Optional[torch.Tensor] = None,
199
+ **kwargs,
200
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
201
+ if attention_mask is None:
202
+ attention_mask = torch.ones_like(input_ids)
203
+ if token_type_ids is None:
204
+ token_type_ids = torch.zeros_like(input_ids)
205
+
206
+ embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
207
+
208
+ subset_mask = []
209
+ first_col_mask = []
210
+
211
+ if masked_tokens_mask is None:
212
+ subset_mask = None
213
+ else:
214
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
215
+ first_col_mask[:, 0] = True
216
+ subset_mask = masked_tokens_mask | first_col_mask
217
+
218
+ encoder_outputs = self.encoder(
219
+ embedding_output,
220
+ attention_mask,
221
+ output_all_encoded_layers=output_all_encoded_layers,
222
+ subset_mask=subset_mask,
223
+ )
224
+
225
+ if masked_tokens_mask is None:
226
+ sequence_output = encoder_outputs[-1]
227
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
228
+ else:
229
+ # TD [2022-03-01]: the indexing here is very tricky.
230
+ attention_mask_bool = attention_mask.bool()
231
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
232
+ sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]]
233
+ if self.pooler is not None:
234
+ pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]]
235
+ pooled_output = self.pooler(pool_input, pool=False)
236
+ else:
237
+ pooled_output = None
238
+
239
+ if not output_all_encoded_layers:
240
+ encoder_outputs = sequence_output
241
+
242
+ if self.pooler is not None:
243
+ return encoder_outputs, pooled_output
244
+
245
+ return encoder_outputs, None
246
+
247
+
248
+ ###################
249
+ # Bert Heads
250
+ ###################
251
+ class BertLMPredictionHead(nn.Module):
252
+ def __init__(self, config, bert_model_embedding_weights):
253
+ super().__init__()
254
+ self.transform = BertPredictionHeadTransform(config)
255
+ # The output weights are the same as the input embeddings, but there is
256
+ # an output-only bias for each token.
257
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0))
258
+ self.decoder.weight = bert_model_embedding_weights
259
+
260
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
261
+ hidden_states = self.transform(hidden_states)
262
+ hidden_states = self.decoder(hidden_states)
263
+ return hidden_states
264
+
265
+
266
+ class BertOnlyMLMHead(nn.Module):
267
+ def __init__(self, config, bert_model_embedding_weights):
268
+ super().__init__()
269
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
270
+
271
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
272
+ prediction_scores = self.predictions(sequence_output)
273
+ return prediction_scores
274
+
275
+
276
+ class BertOnlyNSPHead(nn.Module):
277
+ def __init__(self, config):
278
+ super().__init__()
279
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
280
+
281
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
282
+ seq_relationship_score = self.seq_relationship(pooled_output)
283
+ return seq_relationship_score
284
+
285
+
286
+ #####################
287
+ # Various Bert models
288
+ #####################
289
+
290
+
291
+ class BertForPreTraining(BertPreTrainedModel):
292
+ # TBD: Coming in Future Commit
293
+ pass
294
+
295
+
296
+ class BertLMHeadModel(BertPreTrainedModel):
297
+ # TBD: Coming in Future Commit
298
+ pass
299
+
300
+
301
+ class BertForMaskedLM(BertPreTrainedModel):
302
+ def __init__(self, config):
303
+ super().__init__(config)
304
+
305
+ if config.is_decoder:
306
+ warnings.warn(
307
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
308
+ "bi-directional self-attention."
309
+ )
310
+
311
+ self.bert = BertModel(config, add_pooling_layer=False)
312
+ self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
313
+
314
+ # Initialize weights and apply final processing
315
+ self.post_init()
316
+
317
+ @classmethod
318
+ def from_composer(
319
+ cls,
320
+ pretrained_checkpoint,
321
+ state_dict=None,
322
+ cache_dir=None,
323
+ from_tf=False,
324
+ config=None,
325
+ *inputs,
326
+ **kwargs,
327
+ ):
328
+ """Load from pre-trained."""
329
+ model = cls(config, *inputs, **kwargs)
330
+ if from_tf:
331
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
332
+
333
+ state_dict = torch.load(pretrained_checkpoint)
334
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
335
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
336
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
337
+
338
+ if len(missing_keys) > 0:
339
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
340
+ if len(unexpected_keys) > 0:
341
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
342
+
343
+ return model
344
+
345
+ def get_output_embeddings(self):
346
+ return self.cls.predictions.decoder
347
+
348
+ def set_output_embeddings(self, new_embeddings):
349
+ self.cls.predictions.decoder = new_embeddings
350
+
351
+ def forward(
352
+ self,
353
+ input_ids: Optional[torch.Tensor] = None,
354
+ attention_mask: Optional[torch.Tensor] = None,
355
+ token_type_ids: Optional[torch.Tensor] = None,
356
+ position_ids: Optional[torch.Tensor] = None,
357
+ head_mask: Optional[torch.Tensor] = None,
358
+ inputs_embeds: Optional[torch.Tensor] = None,
359
+ encoder_hidden_states: Optional[torch.Tensor] = None,
360
+ encoder_attention_mask: Optional[torch.Tensor] = None,
361
+ labels: Optional[torch.Tensor] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
366
+ # labels should be a `torch.LongTensor` of shape
367
+ # `(batch_size, sequence_length)`. These are used for computing the
368
+ # masked language modeling loss.
369
+ #
370
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
371
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
372
+ # (masked), the loss is only computed for the tokens with labels in `[0,
373
+ # ..., config.vocab_size]`
374
+ #
375
+ # Prediction scores are only computed for masked tokens and the (bs,
376
+ # seqlen) dimensions are flattened
377
+ if (input_ids is not None) == (inputs_embeds is not None):
378
+ raise ValueError("Must specify either input_ids or input_embeds!")
379
+
380
+ if labels is None:
381
+ masked_tokens_mask = None
382
+ else:
383
+ masked_tokens_mask = labels > 0
384
+
385
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
386
+
387
+ outputs = self.bert(
388
+ input_ids,
389
+ attention_mask=attention_mask,
390
+ token_type_ids=token_type_ids,
391
+ position_ids=position_ids,
392
+ head_mask=head_mask,
393
+ inputs_embeds=inputs_embeds,
394
+ encoder_hidden_states=encoder_hidden_states,
395
+ encoder_attention_mask=encoder_attention_mask,
396
+ output_attentions=output_attentions,
397
+ output_hidden_states=output_hidden_states,
398
+ return_dict=return_dict,
399
+ masked_tokens_mask=masked_tokens_mask,
400
+ )
401
+
402
+ sequence_output = outputs[0]
403
+ prediction_scores = self.cls(sequence_output)
404
+
405
+ loss = None
406
+ if labels is not None:
407
+ # Compute loss
408
+ loss_fct = nn.CrossEntropyLoss()
409
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
410
+ loss = loss_fct(prediction_scores, labels.flatten()[masked_token_idx])
411
+
412
+ assert input_ids is not None, "Coding error; please open an issue"
413
+ batch, seqlen = input_ids.shape[:2]
414
+ prediction_scores = rearrange(
415
+ index_put_first_axis(prediction_scores, masked_token_idx, batch * seqlen),
416
+ "(b s) d -> b s d",
417
+ b=batch,
418
+ )
419
+
420
+ if not return_dict:
421
+ output = (prediction_scores,) + outputs[2:]
422
+ return ((loss,) + output) if loss is not None else output
423
+
424
+ return MaskedLMOutput(
425
+ loss=loss,
426
+ logits=prediction_scores,
427
+ hidden_states=None,
428
+ attentions=None,
429
+ )
430
+
431
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
432
+ input_shape = input_ids.shape
433
+ effective_batch_size = input_shape[0]
434
+
435
+ # add a dummy token
436
+ if self.config.pad_token_id is None:
437
+ raise ValueError("The PAD token should be defined for generation")
438
+
439
+ attention_mask = torch.cat(
440
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
441
+ dim=-1,
442
+ )
443
+ dummy_token = torch.full(
444
+ (effective_batch_size, 1),
445
+ self.config.pad_token_id,
446
+ dtype=torch.long,
447
+ device=input_ids.device,
448
+ )
449
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
450
+
451
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
452
+
453
+
454
+ class BertForNextSentencePrediction(BertPreTrainedModel):
455
+ # TBD: Push in future commit
456
+ pass
457
+
458
+
459
+ class BertForSequenceClassification(BertPreTrainedModel):
460
+ """Bert Model transformer with a sequence classification/regression head.
461
+
462
+ This head is just a linear layer on top of the pooled output. Used for,
463
+ e.g., GLUE tasks.
464
+ """
465
+
466
+ def __init__(self, config):
467
+ super().__init__(config)
468
+ self.num_labels = config.num_labels
469
+ self.config = config
470
+
471
+ self.bert = BertModel(config)
472
+ classifier_dropout = (
473
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
474
+ )
475
+ self.dropout = nn.Dropout(classifier_dropout)
476
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
477
+
478
+ # Initialize weights and apply final processing
479
+ self.post_init()
480
+
481
+ @classmethod
482
+ def from_composer(
483
+ cls,
484
+ pretrained_checkpoint,
485
+ state_dict=None,
486
+ cache_dir=None,
487
+ from_tf=False,
488
+ config=None,
489
+ *inputs,
490
+ **kwargs,
491
+ ):
492
+ """Load from pre-trained."""
493
+ model = cls(config, *inputs, **kwargs)
494
+ if from_tf:
495
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
496
+
497
+ state_dict = torch.load(pretrained_checkpoint)
498
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
499
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
500
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
501
+
502
+ if len(missing_keys) > 0:
503
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
504
+ if len(unexpected_keys) > 0:
505
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
506
+
507
+ return model
508
+
509
+ def forward(
510
+ self,
511
+ input_ids: Optional[torch.Tensor] = None,
512
+ attention_mask: Optional[torch.Tensor] = None,
513
+ token_type_ids: Optional[torch.Tensor] = None,
514
+ position_ids: Optional[torch.Tensor] = None,
515
+ head_mask: Optional[torch.Tensor] = None,
516
+ inputs_embeds: Optional[torch.Tensor] = None,
517
+ labels: Optional[torch.Tensor] = None,
518
+ output_attentions: Optional[bool] = None,
519
+ output_hidden_states: Optional[bool] = None,
520
+ return_dict: Optional[bool] = None,
521
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
522
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
523
+ # Labels for computing the sequence classification/regression loss.
524
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
525
+ # If `config.num_labels == 1` a regression loss is computed
526
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
527
+ # is computed (cross-entropy).
528
+
529
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
530
+
531
+ outputs = self.bert(
532
+ input_ids,
533
+ attention_mask=attention_mask,
534
+ token_type_ids=token_type_ids,
535
+ position_ids=position_ids,
536
+ head_mask=head_mask,
537
+ inputs_embeds=inputs_embeds,
538
+ output_attentions=output_attentions,
539
+ output_hidden_states=output_hidden_states,
540
+ return_dict=return_dict,
541
+ )
542
+
543
+ pooled_output = outputs[1]
544
+
545
+ pooled_output = self.dropout(pooled_output)
546
+ logits = self.classifier(pooled_output)
547
+
548
+ loss = None
549
+ if labels is not None:
550
+ # Compute loss
551
+ if self.config.problem_type is None:
552
+ if self.num_labels == 1:
553
+ self.config.problem_type = "regression"
554
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
555
+ self.config.problem_type = "single_label_classification"
556
+ else:
557
+ self.config.problem_type = "multi_label_classification"
558
+
559
+ if self.config.problem_type == "regression":
560
+ loss_fct = nn.MSELoss()
561
+ if self.num_labels == 1:
562
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
563
+ else:
564
+ loss = loss_fct(logits, labels)
565
+ elif self.config.problem_type == "single_label_classification":
566
+ loss_fct = nn.CrossEntropyLoss()
567
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
568
+ elif self.config.problem_type == "multi_label_classification":
569
+ loss_fct = nn.BCEWithLogitsLoss()
570
+ loss = loss_fct(logits, labels)
571
+
572
+ if not return_dict:
573
+ output = (logits,) + outputs[2:]
574
+ return ((loss,) + output) if loss is not None else output
575
+
576
+ return SequenceClassifierOutput(
577
+ loss=loss,
578
+ logits=logits,
579
+ hidden_states=None,
580
+ attentions=None,
581
+ )
582
+
583
+
584
+ class BertForMultipleChoice(BertPreTrainedModel):
585
+ """
586
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
587
+ softmax) e.g. for RocStories/SWAG tasks.
588
+ """
589
+
590
+ def __init__(self, config):
591
+ super().__init__(config)
592
+ self.num_labels = config.num_labels
593
+ self.config = config
594
+
595
+ self.bert = BertModel(config)
596
+ classifier_dropout = (
597
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
598
+ )
599
+ self.dropout = nn.Dropout(classifier_dropout)
600
+
601
+ # In multiple choice tasks, all choices are submitted in a batch, and
602
+ # we compute a logit for each option independently. The logits are then
603
+ # normalized in the forward pass to get a probability distribution over
604
+ # the choices.
605
+ self.classifier = nn.Linear(config.hidden_size, 1)
606
+
607
+ # Initialize weights and apply final processing
608
+ self.post_init()
609
+
610
+ @classmethod
611
+ def from_composer(
612
+ cls,
613
+ pretrained_checkpoint,
614
+ state_dict=None,
615
+ cache_dir=None,
616
+ from_tf=False,
617
+ config=None,
618
+ *inputs,
619
+ **kwargs,
620
+ ):
621
+ """Load from pre-trained."""
622
+ model = cls(config, *inputs, **kwargs)
623
+ if from_tf:
624
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
625
+
626
+ state_dict = torch.load(pretrained_checkpoint)
627
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
628
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
629
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
630
+
631
+ if len(missing_keys) > 0:
632
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
633
+ if len(unexpected_keys) > 0:
634
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
635
+
636
+ return model
637
+
638
+ def forward(
639
+ self,
640
+ input_ids: Optional[torch.Tensor] = None,
641
+ attention_mask: Optional[torch.Tensor] = None,
642
+ token_type_ids: Optional[torch.Tensor] = None,
643
+ position_ids: Optional[torch.Tensor] = None,
644
+ head_mask: Optional[torch.Tensor] = None,
645
+ inputs_embeds: Optional[torch.Tensor] = None,
646
+ labels: Optional[torch.Tensor] = None,
647
+ output_attentions: Optional[bool] = None,
648
+ output_hidden_states: Optional[bool] = None,
649
+ return_dict: Optional[bool] = None,
650
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
651
+ r"""
652
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
653
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
654
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
655
+ `input_ids` above)
656
+ """
657
+
658
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
659
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
660
+
661
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
662
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
663
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
664
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
665
+ inputs_embeds = (
666
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
667
+ if inputs_embeds is not None
668
+ else None
669
+ )
670
+
671
+ outputs = self.bert(
672
+ input_ids,
673
+ attention_mask=attention_mask,
674
+ token_type_ids=token_type_ids,
675
+ position_ids=position_ids,
676
+ head_mask=head_mask,
677
+ inputs_embeds=inputs_embeds,
678
+ output_attentions=output_attentions,
679
+ output_hidden_states=output_hidden_states,
680
+ return_dict=return_dict,
681
+ )
682
+
683
+ pooled_output = outputs[1]
684
+
685
+ pooled_output = self.dropout(pooled_output)
686
+ logits = self.classifier(pooled_output)
687
+ reshaped_logits = logits.view(-1, num_choices)
688
+
689
+ loss = None
690
+ if labels is not None:
691
+ loss_fct = nn.CrossEntropyLoss()
692
+ loss = loss_fct(reshaped_logits, labels)
693
+
694
+ if not return_dict:
695
+ output = (reshaped_logits,) + outputs[2:]
696
+ return ((loss,) + output) if loss is not None else output
697
+
698
+ return MultipleChoiceModelOutput(
699
+ loss=loss,
700
+ logits=reshaped_logits,
701
+ hidden_states=None,
702
+ attentions=None,
703
+ )
704
+
705
+
706
+ class BertForTokenClassification(BertPreTrainedModel):
707
+ # TBD: Push in future commit
708
+ pass
709
+
710
+
711
+ class BertForQuestionAnswering(BertPreTrainedModel):
712
+ """Bert Model with a span classification head.
713
+
714
+ This is used for extractive question-answering tasks like SQuAD (a linear
715
+ layers on top of the hidden states' output to compute `span start logits`
716
+ and `span end logits`).
717
+ """
718
+
719
+ # TBD: Push in future commit
720
+
721
+
722
+ ###################
723
+ # FlexBert Heads
724
+ ###################
725
+
726
+
727
+ class FlexBertPredictionHead(nn.Module):
728
+ def __init__(self, config: FlexBertConfig):
729
+ super().__init__()
730
+ self.config = config
731
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_pred_bias)
732
+ self.act = get_act_fn(config.head_pred_act) if config.head_pred_act else nn.Identity()
733
+ self.norm = (
734
+ get_norm_layer(config, compiled_norm=config.compile_model) if config.head_pred_norm else nn.Identity()
735
+ )
736
+
737
+ def _init_weights(self, reset_params: bool = False):
738
+ if reset_params:
739
+ self.norm.reset_parameters()
740
+ init_weights(self.config, self.dense, layer_dim=self.config.hidden_size, type_of_module=ModuleType.in_module)
741
+
742
+ def reset_parameters(self):
743
+ self._init_weights(reset_params=True)
744
+
745
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
746
+ return self.norm(self.act(self.dense(hidden_states)))
747
+
748
+
749
+ class FlexBertPoolingHead(nn.Module):
750
+ def __init__(self, config: FlexBertConfig):
751
+ super().__init__()
752
+ self.config = config
753
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.head_class_bias)
754
+ self.act = get_act_fn(config.head_class_act) if config.head_class_act else nn.Identity()
755
+ self.norm = get_norm_layer(config) if config.head_class_norm else nn.Identity()
756
+ self.drop = torch.nn.Dropout(config.head_class_dropout) if config.head_class_dropout > 0 else nn.Identity()
757
+ self.pooling_type = config.pooling_type
758
+
759
+ def forward(self, hidden_states: torch.Tensor, pool: Optional[bool] = True) -> torch.Tensor:
760
+ if pool:
761
+ if self.pooling_type == "cls":
762
+ output = hidden_states[:, 0]
763
+ elif self.pooling_type == "mean":
764
+ output = hidden_states.mean(dim=1)
765
+ elif self.pooling_type == "max":
766
+ output = hidden_states.max(dim=1)[0]
767
+ else:
768
+ output = hidden_states
769
+
770
+ return self.drop(self.norm(self.act(self.dense(output))))
771
+
772
+ def _init_weights(self, reset_params: bool = False):
773
+ init_weights(self.config, self.dense, self.config.hidden_size, type_of_module=ModuleType.out_module)
774
+ if reset_params and hasattr(self.norm, "reset_parameters"):
775
+ self.norm.reset_parameters()
776
+
777
+ def reset_parameters(self):
778
+ self._init_weights(reset_params=True)
779
+
780
+
781
+ ###################
782
+ # FlexBert Models
783
+ ###################
784
+
785
+
786
+ @dataclass
787
+ class MaskedLMOutput(ModelOutput):
788
+ """
789
+ Base class for masked language models outputs.
790
+
791
+ Args:
792
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
793
+ Masked language modeling (MLM) loss.
794
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
795
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
796
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
797
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
798
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
799
+
800
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
801
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
802
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
803
+ sequence_length)`.
804
+
805
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
806
+ heads.
807
+ """
808
+
809
+ loss: Optional[torch.FloatTensor] = None
810
+ logits: torch.FloatTensor = None
811
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
812
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
813
+ indices: Optional[torch.LongTensor] = None
814
+ cu_seqlens: Optional[torch.LongTensor] = None
815
+ max_seqlen: Optional[int] = None
816
+ batch_size: Optional[int] = None
817
+ seq_len: Optional[int] = None
818
+ labels: Optional[torch.LongTensor] = None
819
+
820
+
821
+ @dataclass
822
+ class MaskedLMOutputZLoss(ModelOutput):
823
+ """
824
+ Base class for masked language models outputs.
825
+
826
+ Args:
827
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
828
+ Masked language modeling (MLM) loss.
829
+ ce_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
830
+ Cross entropy loss.
831
+ z_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
832
+ Z loss.
833
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
834
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
835
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
836
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
837
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
838
+
839
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
840
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
841
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
842
+ sequence_length)`.
843
+
844
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
845
+ heads.
846
+ indices (`torch.LongTensor` of shape `(batch_size,)`):
847
+ Indices of the tokens to be masked.
848
+ """
849
+
850
+ loss: Optional[torch.FloatTensor] = None
851
+ ce_loss: Optional[torch.FloatTensor] = None
852
+ z_loss: Optional[torch.FloatTensor] = None
853
+ logits: torch.FloatTensor = None
854
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
855
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
856
+ indices: Optional[torch.LongTensor] = None
857
+ cu_seqlens: Optional[torch.LongTensor] = None
858
+ max_seqlen: Optional[int] = None
859
+ batch_size: Optional[int] = None
860
+ seq_len: Optional[int] = None
861
+ labels: Optional[torch.LongTensor] = None
862
+
863
+
864
+ class FlexBertPreTrainedModel(BertPreTrainedModel):
865
+ """
866
+ An abstract class to handle custom weights initialization of modules
867
+ """
868
+
869
+ def _init_module_weights(self, module: nn.Module):
870
+ """
871
+ Custom weight init of modules using src.bert_layers.initialization.init_weights
872
+ Currently only supports init of embedding modules
873
+ """
874
+ assert isinstance(module, nn.Module)
875
+ if isinstance(module, nn.Embedding):
876
+ init_weights(self.config, module, type_of_module=ModuleType.emb)
877
+ else:
878
+ raise NotImplementedError("Custom weight init for the given module is not supported")
879
+
880
+
881
+ class FlexBertModel(FlexBertPreTrainedModel):
882
+ """Overall BERT model.
883
+
884
+ Args:
885
+ config: a BertConfig class instance with the configuration to build a new model
886
+
887
+ Inputs:
888
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
889
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
890
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
891
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
892
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
893
+ a `sentence B` token (see BERT paper for more details).
894
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
895
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
896
+ input sequence length in the current batch. It's the mask that we typically use for attention when
897
+ a batch has varying length sentences.
898
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
899
+
900
+ Outputs: Tuple of (encoded_layers, pooled_output)
901
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
902
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
903
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
904
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
905
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
906
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
907
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
908
+ classifier pretrained on top of the hidden state associated to the first character of the
909
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
910
+
911
+ Example usage:
912
+ ```python
913
+ # Already been converted into WordPiece token ids
914
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
915
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
916
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
917
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
918
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
919
+ model = BertModel(config=config)
920
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
921
+ ```
922
+ """
923
+
924
+ def __init__(self, config: FlexBertConfig):
925
+ super().__init__(config)
926
+ self.embeddings = get_embedding_layer(config)
927
+ self.encoder = get_encoder_layer(config)
928
+ if config.final_norm:
929
+ # if we use prenorm attention we need to add a final norm
930
+ self.final_norm = get_norm_layer(config)
931
+ else:
932
+ self.final_norm = None
933
+ self.unpad_embeddings = config.unpad_embeddings
934
+
935
+ def post_init(self):
936
+ self._init_weights(reset_params=False)
937
+ self._backward_compatibility_gradient_checkpointing()
938
+
939
+ def get_input_embeddings(self):
940
+ return self.embeddings.tok_embeddings
941
+
942
+ def set_input_embeddings(self, value):
943
+ self.embeddings.tok_embeddings = value
944
+
945
+ def forward(
946
+ self,
947
+ input_ids: torch.Tensor,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.Tensor] = None,
950
+ indices: Optional[torch.Tensor] = None,
951
+ cu_seqlens: Optional[torch.Tensor] = None,
952
+ max_seqlen: Optional[int] = None,
953
+ **kwargs,
954
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
955
+ if attention_mask is None:
956
+ attention_mask = torch.ones_like(input_ids)
957
+
958
+ embedding_output = self.embeddings(input_ids, position_ids)
959
+
960
+ encoder_outputs = self.encoder(
961
+ hidden_states=embedding_output,
962
+ attention_mask=attention_mask,
963
+ indices=indices,
964
+ cu_seqlens=cu_seqlens,
965
+ max_seqlen=max_seqlen,
966
+ )
967
+
968
+ if self.final_norm is not None:
969
+ encoder_outputs = self.final_norm(encoder_outputs)
970
+ return encoder_outputs
971
+
972
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
973
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
974
+ if module:
975
+ self._init_module_weights(module)
976
+ else:
977
+ assert isinstance(reset_params, bool)
978
+ self.embeddings._init_weights(reset_params=reset_params)
979
+ self.encoder._init_weights(reset_params=reset_params)
980
+
981
+ if reset_params and self.config.final_norm:
982
+ self.final_norm.reset_parameters()
983
+
984
+ def reset_parameters(self):
985
+ self._init_weights(reset_params=True)
986
+
987
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
988
+ """Returns the number of parameters in the model.
989
+
990
+ Args:
991
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
992
+ trainable: only count trainable parameters.
993
+ """
994
+ params = sum([_count_parameters(layer, trainable) for layer in self.encoder.layers])
995
+ if count_embeddings:
996
+ params += _count_parameters(self.embeddings, trainable)
997
+ if hasattr(self.embeddings, "position_embeddings"):
998
+ params -= _count_parameters(self.embeddings.position_embeddings, trainable)
999
+ return params
1000
+
1001
+
1002
+ class FlexBertForMaskedLM(FlexBertPreTrainedModel):
1003
+ def __init__(self, config: FlexBertConfig):
1004
+ super().__init__(config)
1005
+ self.bert = FlexBertModel(config)
1006
+ self.head = FlexBertPredictionHead(config)
1007
+
1008
+ if config.tie_word_embeddings:
1009
+ decoder_weights = self.bert.embeddings.tok_embeddings.weight
1010
+ else:
1011
+ decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
1012
+ self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1013
+ self.decoder.weight = decoder_weights
1014
+
1015
+ self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config)
1016
+ self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1017
+ self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1018
+ self.unpad_embeddings = config.unpad_embeddings
1019
+ self.pad_logits = config.pad_logits
1020
+ self.compile_model = config.compile_model
1021
+ self.masked_prediction = config.masked_prediction
1022
+
1023
+ # Initialize weights and apply final processing
1024
+ self._init_weights(reset_params=False)
1025
+
1026
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1027
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1028
+ if module:
1029
+ self._init_module_weights(module)
1030
+ else:
1031
+ assert isinstance(reset_params, bool)
1032
+ self.bert._init_weights(reset_params=reset_params)
1033
+ self.head._init_weights(reset_params=reset_params)
1034
+
1035
+ # Output weights.
1036
+ if not self.config.tie_word_embeddings:
1037
+ init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1038
+
1039
+ @classmethod
1040
+ def from_composer(
1041
+ cls,
1042
+ pretrained_checkpoint,
1043
+ state_dict=None,
1044
+ cache_dir=None,
1045
+ from_tf=False,
1046
+ config=None,
1047
+ *inputs,
1048
+ **kwargs,
1049
+ ):
1050
+ """Load from pre-trained."""
1051
+ model = cls(config, *inputs, **kwargs)
1052
+ if from_tf:
1053
+ raise ValueError("FlexBERT does not support loading TensorFlow weights.")
1054
+
1055
+ state_dict = torch.load(pretrained_checkpoint)
1056
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1057
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1058
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1059
+
1060
+ if len(missing_keys) > 0:
1061
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1062
+ if len(unexpected_keys) > 0:
1063
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1064
+
1065
+ return model
1066
+
1067
+ def get_output_embeddings(self):
1068
+ return self.decoder
1069
+
1070
+ def set_output_embeddings(self, new_embeddings):
1071
+ self.decoder = new_embeddings
1072
+
1073
+ @torch.no_grad()
1074
+ def unpad_inputs(
1075
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, labels: torch.Tensor
1076
+ ):
1077
+ return unpad_input(input_ids, attention_mask, position_ids, labels)
1078
+
1079
+ @torch.no_grad()
1080
+ def pad_inputs(
1081
+ self,
1082
+ inputs: torch.Tensor,
1083
+ indices: torch.Tensor,
1084
+ batch_size: int,
1085
+ seqlen: int,
1086
+ labels: Optional[torch.Tensor] = None,
1087
+ ignore_index: int = -100,
1088
+ ):
1089
+ return pad_input(
1090
+ inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index
1091
+ )
1092
+
1093
+ @torch.compile(dynamic=True)
1094
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1095
+ return self.decoder(self.head(output))
1096
+
1097
+ def forward(
1098
+ self,
1099
+ input_ids: Optional[torch.Tensor],
1100
+ attention_mask: Optional[torch.Tensor] = None,
1101
+ position_ids: Optional[torch.Tensor] = None,
1102
+ labels: Optional[torch.Tensor] = None,
1103
+ return_dict: Optional[bool] = None,
1104
+ indices: Optional[torch.Tensor] = None,
1105
+ cu_seqlens: Optional[torch.Tensor] = None,
1106
+ max_seqlen: Optional[int] = None,
1107
+ batch_size: Optional[int] = None,
1108
+ seq_len: Optional[int] = None,
1109
+ **kwargs,
1110
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1111
+ # labels should be a `torch.LongTensor` of shape
1112
+ # `(batch_size, sequence_length)`. These are used for computing the
1113
+ # masked language modeling loss.
1114
+ #
1115
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
1116
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
1117
+ # (masked), the loss is only computed for the tokens with labels in `[0,
1118
+ # ..., config.vocab_size]`
1119
+ #
1120
+ # Prediction scores are only computed for masked tokens and the (bs,
1121
+ # seqlen) dimensions are flattened
1122
+
1123
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1124
+
1125
+ if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1126
+ batch_size, seq_len = input_ids.shape[:2]
1127
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1128
+ input_ids, attention_mask, position_ids, labels
1129
+ )
1130
+
1131
+
1132
+ output = self.bert(
1133
+ input_ids,
1134
+ attention_mask=attention_mask,
1135
+ position_ids=position_ids,
1136
+ indices=indices,
1137
+ cu_seqlens=cu_seqlens,
1138
+ max_seqlen=max_seqlen,
1139
+ )
1140
+
1141
+ if self.masked_prediction and labels is not None:
1142
+ # flatten labels and output first
1143
+ labels = labels.view(-1)
1144
+ output = output.view(labels.shape[0], -1)
1145
+
1146
+ # then filter out the non-masked tokens
1147
+ mask_tokens = labels != self.loss_fn.ignore_index
1148
+ output = output[mask_tokens]
1149
+ labels = labels[mask_tokens]
1150
+
1151
+ if self.compile_model:
1152
+ logits = self.compiled_head(output)
1153
+ else:
1154
+ logits = self.decoder(self.head(output))
1155
+
1156
+ loss = None
1157
+ if labels is not None:
1158
+ if not self.masked_prediction:
1159
+ labels = labels.view(-1)
1160
+ logits = logits.view(labels.shape[0], -1)
1161
+
1162
+ if self.return_z_loss:
1163
+ loss, z_loss = self.loss_fn(logits, labels)
1164
+ if self.pad_logits:
1165
+ return MaskedLMOutputZLoss(
1166
+ loss=loss,
1167
+ ce_loss=loss.detach().clone() - z_loss,
1168
+ z_loss=z_loss,
1169
+ logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1170
+ hidden_states=None,
1171
+ attentions=None,
1172
+ )
1173
+ else:
1174
+ return MaskedLMOutputZLoss(
1175
+ loss=loss,
1176
+ ce_loss=loss.detach().clone() - z_loss,
1177
+ z_loss=z_loss,
1178
+ logits=logits,
1179
+ hidden_states=None,
1180
+ attentions=None,
1181
+ indices=indices,
1182
+ cu_seqlens=cu_seqlens,
1183
+ max_seqlen=max_seqlen,
1184
+ batch_size=batch_size,
1185
+ seq_len=seq_len,
1186
+ labels=labels,
1187
+ )
1188
+ else:
1189
+ loss = self.loss_fn(logits, labels)
1190
+
1191
+ if self.pad_logits:
1192
+ return MaskedLMOutput(
1193
+ loss=loss,
1194
+ logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1195
+ hidden_states=None,
1196
+ attentions=None,
1197
+ )
1198
+ else:
1199
+ return MaskedLMOutput(
1200
+ loss=loss,
1201
+ logits=logits,
1202
+ hidden_states=None,
1203
+ attentions=None,
1204
+ indices=indices,
1205
+ cu_seqlens=cu_seqlens,
1206
+ max_seqlen=max_seqlen,
1207
+ batch_size=batch_size,
1208
+ seq_len=seq_len,
1209
+ labels=labels,
1210
+ )
1211
+
1212
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
1213
+ input_shape = input_ids.shape
1214
+ effective_batch_size = input_shape[0]
1215
+
1216
+ # add a dummy token
1217
+ if self.config.pad_token_id is None:
1218
+ raise ValueError("The PAD token should be defined for generation")
1219
+
1220
+ attention_mask = torch.cat(
1221
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
1222
+ dim=-1,
1223
+ )
1224
+ dummy_token = torch.full(
1225
+ (effective_batch_size, 1),
1226
+ self.config.pad_token_id,
1227
+ dtype=torch.long,
1228
+ device=input_ids.device,
1229
+ )
1230
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1231
+
1232
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1233
+
1234
+ def get_number_parameters(
1235
+ self, count_embeddings: bool = True, count_decoder: bool = False, trainable: bool = True
1236
+ ) -> int:
1237
+ """Returns the number of parameters in the model.
1238
+
1239
+ Args:
1240
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1241
+ count_decoder: count the parameters in the decoder layer if weights are not tied.
1242
+ trainable: only count trainable parameters.
1243
+ """
1244
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1245
+ params += _count_parameters(self.head, trainable)
1246
+ if count_decoder and not self.config.tie_word_embeddings:
1247
+ params += _count_parameters(self.decoder, trainable)
1248
+ return params
1249
+
1250
+
1251
+ class FlexBertForSequenceClassification(FlexBertPreTrainedModel):
1252
+ """Bert Model transformer with a sequence classification/regression head.
1253
+
1254
+ This head is just a linear layer on top of the pooled output. Used for,
1255
+ e.g., GLUE tasks.
1256
+ """
1257
+
1258
+ def __init__(self, config: FlexBertConfig):
1259
+ super().__init__(config)
1260
+ self.num_labels = config.num_labels
1261
+ self.config = config
1262
+
1263
+ self.bert = FlexBertModel(config)
1264
+ self.head = FlexBertPoolingHead(config)
1265
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1266
+
1267
+ # Initialize weights and apply final processing
1268
+ self._init_weights(reset_params=False)
1269
+
1270
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1271
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1272
+ if module:
1273
+ self._init_module_weights(module)
1274
+ else:
1275
+ assert isinstance(reset_params, bool)
1276
+ self.bert._init_weights(reset_params=reset_params)
1277
+ self.head._init_weights(reset_params=reset_params)
1278
+ init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out)
1279
+
1280
+ @classmethod
1281
+ def from_composer(
1282
+ cls,
1283
+ pretrained_checkpoint,
1284
+ state_dict=None,
1285
+ cache_dir=None,
1286
+ from_tf=False,
1287
+ config=None,
1288
+ *inputs,
1289
+ **kwargs,
1290
+ ):
1291
+ """Load from pre-trained."""
1292
+ model = cls(config, *inputs, **kwargs)
1293
+ if from_tf:
1294
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1295
+
1296
+ state_dict = torch.load(pretrained_checkpoint)
1297
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1298
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1299
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1300
+
1301
+ if len(missing_keys) > 0:
1302
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1303
+ if len(unexpected_keys) > 0:
1304
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1305
+
1306
+ return model
1307
+
1308
+ def forward(
1309
+ self,
1310
+ input_ids: Optional[torch.Tensor] = None,
1311
+ attention_mask: Optional[torch.Tensor] = None,
1312
+ position_ids: Optional[torch.Tensor] = None,
1313
+ labels: Optional[torch.Tensor] = None,
1314
+ return_dict: Optional[bool] = None,
1315
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1316
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1317
+ # Labels for computing the sequence classification/regression loss.
1318
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
1319
+ # If `config.num_labels == 1` a regression loss is computed
1320
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
1321
+ # is computed (cross-entropy).
1322
+
1323
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1324
+
1325
+ output = self.bert(
1326
+ input_ids,
1327
+ attention_mask=attention_mask,
1328
+ position_ids=position_ids,
1329
+ )
1330
+
1331
+ pooled_output = self.head(output)
1332
+ logits = self.classifier(pooled_output)
1333
+
1334
+ loss = None
1335
+ if labels is not None:
1336
+ # Compute loss
1337
+ if self.config.problem_type is None:
1338
+ if self.num_labels == 1:
1339
+ self.config.problem_type = "regression"
1340
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1341
+ self.config.problem_type = "single_label_classification"
1342
+ else:
1343
+ self.config.problem_type = "multi_label_classification"
1344
+
1345
+ if self.config.problem_type == "regression":
1346
+ loss_fct = nn.MSELoss()
1347
+ if self.num_labels == 1:
1348
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1349
+ else:
1350
+ loss = loss_fct(logits, labels)
1351
+ elif self.config.problem_type == "single_label_classification":
1352
+ loss_fct = nn.CrossEntropyLoss()
1353
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1354
+ elif self.config.problem_type == "multi_label_classification":
1355
+ loss_fct = nn.BCEWithLogitsLoss()
1356
+ loss = loss_fct(logits, labels)
1357
+
1358
+ if not return_dict:
1359
+ output = (logits,) + output
1360
+ return ((loss,) + output) if loss is not None else output
1361
+
1362
+ return SequenceClassifierOutput(
1363
+ loss=loss,
1364
+ logits=logits,
1365
+ hidden_states=None,
1366
+ attentions=None,
1367
+ )
1368
+
1369
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1370
+ """Returns the number of parameters in the model.
1371
+
1372
+ Args:
1373
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1374
+ trainable: only count trainable parameters.
1375
+ """
1376
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1377
+ params += _count_parameters(self.head, trainable)
1378
+ params += _count_parameters(self.classifier, trainable)
1379
+ return params
1380
+
1381
+
1382
+ class FlexBertForMultipleChoice(FlexBertPreTrainedModel):
1383
+ """
1384
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1385
+ softmax) e.g. for RocStories/SWAG tasks.
1386
+ """
1387
+
1388
+ def __init__(self, config: FlexBertConfig):
1389
+ super().__init__(config)
1390
+ self.num_labels = config.num_labels
1391
+ self.config = config
1392
+
1393
+ self.bert = FlexBertModel(config)
1394
+ self.head = FlexBertPoolingHead(config)
1395
+
1396
+ # In multiple choice tasks, all choices are submitted in a batch, and
1397
+ # we compute a logit for each option independently. The logits are then
1398
+ # normalized in the forward pass to get a probability distribution over
1399
+ # the choices.
1400
+ self.classifier = nn.Linear(config.hidden_size, 1)
1401
+
1402
+ # Initialize weights and apply final processing
1403
+ self._init_weights(reset_params=False)
1404
+
1405
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1406
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1407
+ if module:
1408
+ self._init_module_weights(module)
1409
+ else:
1410
+ assert isinstance(reset_params, bool)
1411
+ self.bert._init_weights(reset_params=reset_params)
1412
+ self.head._init_weights(reset_params=reset_params)
1413
+ init_weights(self.config, self.classifier, self.config.hidden_size, type_of_module=ModuleType.final_out)
1414
+
1415
+ @classmethod
1416
+ def from_composer(
1417
+ cls,
1418
+ pretrained_checkpoint,
1419
+ state_dict=None,
1420
+ cache_dir=None,
1421
+ from_tf=False,
1422
+ config=None,
1423
+ *inputs,
1424
+ **kwargs,
1425
+ ):
1426
+ """Load from pre-trained."""
1427
+ model = cls(config, *inputs, **kwargs)
1428
+ if from_tf:
1429
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1430
+
1431
+ state_dict = torch.load(pretrained_checkpoint)
1432
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1433
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1434
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1435
+
1436
+ if len(missing_keys) > 0:
1437
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1438
+ if len(unexpected_keys) > 0:
1439
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1440
+
1441
+ return model
1442
+
1443
+ def forward(
1444
+ self,
1445
+ input_ids: Optional[torch.Tensor] = None,
1446
+ attention_mask: Optional[torch.Tensor] = None,
1447
+ position_ids: Optional[torch.Tensor] = None,
1448
+ labels: Optional[torch.Tensor] = None,
1449
+ return_dict: Optional[bool] = None,
1450
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1451
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1452
+ # Labels for computing the sequence classification/regression loss.
1453
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
1454
+ # If `config.num_labels == 1` a regression loss is computed
1455
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
1456
+ # is computed (cross-entropy).
1457
+
1458
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1459
+ num_choices = input_ids.shape[1]
1460
+
1461
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1462
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1463
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1464
+
1465
+ output = self.bert(
1466
+ input_ids,
1467
+ attention_mask=attention_mask,
1468
+ position_ids=position_ids,
1469
+ )
1470
+
1471
+ pooled_output = self.head(output)
1472
+ logits = self.classifier(pooled_output)
1473
+ reshaped_logits = logits.view(-1, num_choices)
1474
+
1475
+ loss = None
1476
+ if labels is not None:
1477
+ loss_fct = nn.CrossEntropyLoss()
1478
+ loss = loss_fct(reshaped_logits, labels)
1479
+
1480
+ if not return_dict:
1481
+ output = (reshaped_logits,) + output
1482
+ return ((loss,) + output) if loss is not None else output
1483
+
1484
+ return MultipleChoiceModelOutput(
1485
+ loss=loss,
1486
+ logits=reshaped_logits,
1487
+ hidden_states=None,
1488
+ attentions=None,
1489
+ )
1490
+
1491
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1492
+ """Returns the number of parameters in the model.
1493
+
1494
+ Args:
1495
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1496
+ trainable: only count trainable parameters.
1497
+ """
1498
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1499
+ params += _count_parameters(self.head, trainable)
1500
+ params += _count_parameters(self.classifier, trainable)
1501
+ return params
1502
+
1503
+
1504
+ class FlexBertForCasualLM(FlexBertPreTrainedModel):
1505
+ """Bert Model transformer with a LM head.
1506
+
1507
+ This head is just a standard LM head module. Used for causal language modeling tasks.
1508
+ """
1509
+
1510
+ def __init__(self, config: FlexBertConfig):
1511
+ super().__init__(config)
1512
+ self.bert = FlexBertModel(config)
1513
+ self.lm_head = FlexBertPredictionHead(config)
1514
+
1515
+ if config.tie_word_embeddings:
1516
+ decoder_weights = self.bert.embeddings.tok_embeddings.weight
1517
+ else:
1518
+ decoder_weights = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
1519
+ self.decoder = nn.Linear(decoder_weights.size(1), decoder_weights.size(0), bias=config.decoder_bias)
1520
+ self.decoder.weight = decoder_weights
1521
+
1522
+ self.loss_fn = nn.CrossEntropyLoss() if not hasattr(config, "loss_function") else get_loss_fn(config)
1523
+ self.fa_ce = getattr(config, "loss_function", "cross_entropy") == "fa_cross_entropy"
1524
+ self.return_z_loss = config.loss_kwargs.get("return_z_loss", False)
1525
+ self.unpad_embeddings = config.unpad_embeddings
1526
+ self.pad_logits = config.pad_logits
1527
+ self.compile_model = config.compile_model
1528
+ # self.masked_prediction = config.masked_prediction
1529
+
1530
+ # Initialize weights and apply final processing
1531
+ self._init_weights(reset_params=False)
1532
+
1533
+ def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1534
+ assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1535
+ if module:
1536
+ self._init_module_weights(module)
1537
+ else:
1538
+ assert isinstance(reset_params, bool)
1539
+ self.bert._init_weights(reset_params=reset_params)
1540
+ self.lm_head._init_weights(reset_params=reset_params)
1541
+
1542
+ if not self.config.tie_word_embeddings:
1543
+ init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1544
+
1545
+ @classmethod
1546
+ def from_composer(
1547
+ cls,
1548
+ pretrained_checkpoint,
1549
+ state_dict=None,
1550
+ cache_dir=None,
1551
+ from_tf=False,
1552
+ config=None,
1553
+ *inputs,
1554
+ **kwargs,
1555
+ ):
1556
+ """Load from pre-trained."""
1557
+ model = cls(config, *inputs, **kwargs)
1558
+ if from_tf:
1559
+ raise ValueError("Mosaic BERT does not support loading TensorFlow weights.")
1560
+
1561
+ state_dict = torch.load(pretrained_checkpoint)
1562
+ # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
1563
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
1564
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
1565
+
1566
+ if len(missing_keys) > 0:
1567
+ logger.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
1568
+ if len(unexpected_keys) > 0:
1569
+ logger.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}")
1570
+
1571
+ return model
1572
+
1573
+
1574
+ def get_output_embeddings(self):
1575
+ return self.decoder
1576
+
1577
+ def set_output_embeddings(self, new_embeddings):
1578
+ self.decoder = new_embeddings
1579
+
1580
+ @torch.no_grad()
1581
+ def unpad_inputs(
1582
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, labels: torch.Tensor
1583
+ ):
1584
+ return unpad_input(input_ids, attention_mask, position_ids, labels)
1585
+
1586
+ @torch.no_grad()
1587
+ def pad_inputs(
1588
+ self,
1589
+ inputs: torch.Tensor,
1590
+ indices: torch.Tensor,
1591
+ batch_size: int,
1592
+ seqlen: int,
1593
+ labels: Optional[torch.Tensor] = None,
1594
+ ignore_index: int = -100,
1595
+ ):
1596
+ return pad_input(
1597
+ inputs=inputs, indices=indices, batch=batch_size, seqlen=seqlen, labels=labels, ignore_index=ignore_index
1598
+ )
1599
+
1600
+ @torch.compile(dynamic=True)
1601
+ def compiled_lm_head(self, output: torch.Tensor) -> torch.Tensor:
1602
+ return self.decoder(self.lm_head(output))
1603
+
1604
+ def forward(
1605
+ self,
1606
+ input_ids: Optional[torch.Tensor],
1607
+ attention_mask: Optional[torch.Tensor] = None,
1608
+ position_ids: Optional[torch.Tensor] = None,
1609
+ labels: Optional[torch.Tensor] = None,
1610
+ return_dict: Optional[bool] = None,
1611
+ indices: Optional[torch.Tensor] = None,
1612
+ cu_seqlens: Optional[torch.Tensor] = None,
1613
+ max_seqlen: Optional[int] = None,
1614
+ batch_size: Optional[int] = None,
1615
+ seq_len: Optional[int] = None,
1616
+ **kwargs,
1617
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutput]:
1618
+ # labels should be a `torch.LongTensor` of shape
1619
+ # `(batch_size, sequence_length)`. These are used for computing the
1620
+ # masked language modeling loss.
1621
+ #
1622
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
1623
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
1624
+ # (masked), the loss is only computed for the tokens with labels in `[0,
1625
+ # ..., config.vocab_size]`
1626
+ #
1627
+ # Prediction scores are only computed for masked tokens and the (bs,
1628
+ # seqlen) dimensions are flattened
1629
+
1630
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1631
+
1632
+ if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1633
+ batch_size, seq_len = input_ids.shape[:2]
1634
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1635
+ input_ids, attention_mask, position_ids, labels
1636
+ )
1637
+
1638
+ hidden_states = self.bert(
1639
+ input_ids,
1640
+ attention_mask=None,
1641
+ position_ids=position_ids,
1642
+ indices=indices,
1643
+ cu_seqlens=cu_seqlens,
1644
+ max_seqlen=max_seqlen,
1645
+ )
1646
+
1647
+ if self.compile_model:
1648
+ logits = self.compiled_lm_head(hidden_states)
1649
+ else:
1650
+ logits = self.lm_head(hidden_states)
1651
+
1652
+ loss = None
1653
+ if labels is not None:
1654
+ if indices is not None:
1655
+ # Unpadded case: shift within each sequence using input_ids
1656
+ # Initialize shifted labels from input_ids
1657
+ shift_labels = torch.full_like(input_ids, -100)
1658
+
1659
+ # For each sequence, shift the input_ids to create labels
1660
+ for i in range(len(cu_seqlens) - 1):
1661
+ start = cu_seqlens[i]
1662
+ end = cu_seqlens[i + 1]
1663
+ # Input: [A, B, C, D] -> Labels: [B, C, D, -100]
1664
+ shift_labels[start:end-1] = input_ids[start+1:end]
1665
+
1666
+ # Debug prints
1667
+ # print(f"input_ids slice: {input_ids[:20]}") # Show first 20 tokens
1668
+ # print(f"shift_labels slice: {shift_labels[:20]}") # Show first 20 token
1669
+
1670
+ # # Debug prints
1671
+ # print(f"input_ids slice: {input_ids[:20]}") # Show first 20 tokens
1672
+ # print(f"shift_labels slice: {shift_labels[:20]}") # Show first 20 tokens
1673
+ # print(f"First sequence length: {cu_seqlens[1] - cu_seqlens[0]}")
1674
+
1675
+ else:
1676
+ # Padded case: simple shift
1677
+ shift_labels = input_ids[..., 1:].contiguous()
1678
+ logits = logits[..., :-1, :].contiguous()
1679
+
1680
+ # For both cases, we'll use the shifted input_ids as our labels
1681
+ labels = shift_labels
1682
+
1683
+ # Flatten the tokens
1684
+ loss = self.loss_fn(
1685
+ logits.view(-1, logits.size(-1)),
1686
+ shift_labels.view(-1)
1687
+ )
1688
+
1689
+ if self.pad_logits:
1690
+ return CausalLMOutput(
1691
+ loss=loss,
1692
+ logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1693
+ hidden_states=None,
1694
+ attentions=None,
1695
+ )
1696
+ else:
1697
+ return CausalLMOutput(
1698
+ loss=loss,
1699
+ logits=logits,
1700
+ hidden_states=hidden_states,
1701
+ attentions=None,
1702
+ )
1703
+
1704
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **model_kwargs):
1705
+ input_shape = input_ids.shape
1706
+ effective_batch_size = input_shape[0]
1707
+
1708
+ # add a dummy token
1709
+ if self.config.pad_token_id is None:
1710
+ raise ValueError("The PAD token should be defined for generation")
1711
+
1712
+ attention_mask = torch.cat(
1713
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))],
1714
+ dim=-1,
1715
+ )
1716
+ dummy_token = torch.full(
1717
+ (effective_batch_size, 1),
1718
+ self.config.pad_token_id,
1719
+ dtype=torch.long,
1720
+ device=input_ids.device,
1721
+ )
1722
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1723
+
1724
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1725
+
1726
+ def get_number_parameters(self, count_embeddings: bool = True, trainable: bool = True) -> int:
1727
+ """Returns the number of parameters in the model.
1728
+
1729
+ Args:
1730
+ count_embeddings: count the parameters in the embeddings layer, excluding position embeddings.
1731
+ trainable: only count trainable parameters.
1732
+ """
1733
+ params = self.bert.get_number_parameters(count_embeddings, trainable)
1734
+ params += _count_parameters(self.lm_head, trainable)
1735
+ return params
1736
+
1737
+
1738
+ def init_model_from_pretrained(
1739
+ pretrained_model: FlexBertModel,
1740
+ new_model: FlexBertModel,
1741
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
1742
+ ):
1743
+ """
1744
+ Initialize the new model from the pretrained model.
1745
+
1746
+ This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`.
1747
+ The new model must have the same or more layers and the same or larger dimensions than the pretrained model.
1748
+
1749
+ Args:
1750
+ pretrained_model (FlexBertModel): The smaller, pre-trained model
1751
+ new_model (FlexBertModel): The larger model to be initialized
1752
+ mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
1753
+
1754
+ This function assumes that the new_model has more layers and a larger hidden size
1755
+ than the pretrained_model, but the same vocabulary size.
1756
+ """
1757
+
1758
+ # Tile embeddings
1759
+ assert isinstance(
1760
+ new_model.embeddings, type(pretrained_model.embeddings)
1761
+ ), f"Pretrained and new_model layers must be the same type, got {type(new_model.embeddings)} and {type(pretrained_model.embeddings)}"
1762
+ assert isinstance(
1763
+ new_model.embeddings,
1764
+ (FlexBertAbsoluteEmbeddings, FlexBertSansPositionEmbeddings, FlexBertCompiledSansPositionEmbeddings),
1765
+ ), f"Unsupported embedding layer type: {type(new_model.embeddings)}"
1766
+
1767
+ tile_embedding(pretrained_model.embeddings.tok_embeddings, new_model.embeddings.tok_embeddings, mode=mode)
1768
+ if isinstance(pretrained_model.embeddings, FlexBertAbsoluteEmbeddings):
1769
+ tile_embedding(pretrained_model.embeddings.pos_embeddings, new_model.embeddings.pos_embeddings, mode=mode)
1770
+
1771
+ if hasattr(pretrained_model.embeddings, "norm"):
1772
+ tile_norm(pretrained_model.embeddings.norm, new_model.embeddings.norm, mode=mode)
1773
+
1774
+ # Tile encoder layers
1775
+ assert isinstance(
1776
+ pretrained_model.encoder, (FlexBertUnpadEncoder, FlexBertPaddedEncoder)
1777
+ ), f"Unsupported encoder layer type: {type(pretrained_model.encoder)}"
1778
+ assert isinstance(
1779
+ new_model.encoder, type(pretrained_model.encoder)
1780
+ ), f"Pretrained and new_model encoder layers must be the same type, got {type(new_model.encoder)} and {type(pretrained_model.encoder)}"
1781
+
1782
+ # Calculate the layer mapping
1783
+ pretrained_layers = len(pretrained_model.encoder.layers)
1784
+ new_layers = len(new_model.encoder.layers)
1785
+ layer_mapping = [round(i * pretrained_layers / new_layers) for i in range(new_layers)]
1786
+
1787
+ # Initialize layers
1788
+ for new_model_idx, pretrained_idx in enumerate(layer_mapping):
1789
+ new_model_layer = new_model.encoder.layers[new_model_idx]
1790
+ pretrained_layer = pretrained_model.encoder.layers[pretrained_idx]
1791
+
1792
+ # first tile the PreNorm/PostNorm layers
1793
+ assert isinstance(
1794
+ new_model_layer, type(pretrained_layer)
1795
+ ), f"Pretrained and new_model prenorm/postnorm layers must be the same type, got {type(new_model_layer)} and {type(pretrained_layer)}"
1796
+ assert isinstance(
1797
+ new_model_layer,
1798
+ (
1799
+ FlexBertUnpadPreNormLayer,
1800
+ FlexBertCompileUnpadPreNormLayer,
1801
+ FlexBertUnpadParallelPreNormLayer,
1802
+ FlexBertUnpadPostNormLayer,
1803
+ FlexBertPaddedPreNormLayer,
1804
+ FlexBertPaddedParallelPreNormLayer,
1805
+ FlexBertPaddedPostNormLayer,
1806
+ ),
1807
+ ), f"Unsupported prenorm/postnorm layer type: {type(new_model_layer)}"
1808
+
1809
+ # First tile the normalization layers
1810
+ if hasattr(pretrained_layer, "attn_norm"):
1811
+ tile_norm(pretrained_layer.attn_norm, new_model_layer.attn_norm, mode=mode)
1812
+ if hasattr(pretrained_layer, "norm"):
1813
+ tile_norm(pretrained_layer.norm, new_model_layer.norm, mode=mode)
1814
+ if hasattr(pretrained_layer, "mlp_norm"):
1815
+ tile_norm(pretrained_layer.mlp_norm, new_model_layer.mlp_norm, mode=mode)
1816
+
1817
+ # Then tile the attention & mlp layers
1818
+ assert isinstance(
1819
+ new_model_layer.attn, type(pretrained_layer.attn)
1820
+ ), f"Pretrained and new_model attention layers must be the same type, got {type(new_model_layer.attn)} and {type(pretrained_layer.attn)}"
1821
+
1822
+ # first try the parallel attention layers
1823
+ if isinstance(pretrained_layer, (FlexBertUnpadParallelPreNormLayer, FlexBertPaddedParallelPreNormLayer)):
1824
+ assert isinstance(
1825
+ pretrained_layer.attn,
1826
+ (
1827
+ FlexBertUnpadParallelAttention,
1828
+ FlexBertPaddedParallelAttention,
1829
+ FlexBertUnpadRopeParallelAttention,
1830
+ FlexBertPaddedRopeParallelAttention,
1831
+ ),
1832
+ ), f"Parallel prenorm layer must have parallel attention layer: {type(pretrained_layer.attn)}"
1833
+ if not isinstance(pretrained_layer.mlp, (FlexBertParallelGLU)):
1834
+ raise ValueError(f"Parallel prenorm layer must have parallel MLP layer: {type(pretrained_layer.mlp)}")
1835
+ tile_linear(
1836
+ pretrained_layer.Wqkvff,
1837
+ new_model_layer.Wqkvff,
1838
+ linear_type=TileLinear.wqkvff,
1839
+ mode=mode,
1840
+ pretrained_attn_size=pretrained_layer.attn_size,
1841
+ pretrained_mlp_size=pretrained_layer.mlp_size,
1842
+ new_attn_size=new_model_layer.attn_size,
1843
+ new_mlp_size=new_model_layer.mlp_size,
1844
+ wqkvff_is_glu=True,
1845
+ )
1846
+
1847
+ # then try the fused attention layers
1848
+ elif isinstance(
1849
+ pretrained_layer.attn,
1850
+ (
1851
+ FlexBertUnpadAttention,
1852
+ FlexBertPaddedAttention,
1853
+ FlexBertUnpadRopeAttention,
1854
+ FlexBertPaddedRopeAttention,
1855
+ ),
1856
+ ):
1857
+ tile_linear(pretrained_layer.attn.Wqkv, new_model_layer.attn.Wqkv, linear_type=TileLinear.wqkv, mode=mode)
1858
+ else:
1859
+ raise ValueError(f"Unsupported attention layer type: {type(pretrained_layer.attn)}")
1860
+
1861
+ # finally, tile the attention output layer
1862
+ tile_linear(pretrained_layer.attn.Wo, new_model_layer.attn.Wo, linear_type=TileLinear.default, mode=mode)
1863
+
1864
+ # tile the mlp layer if the model is not using parallel attention layers
1865
+ if not isinstance(pretrained_layer.mlp, (FlexBertMLP, FlexBertGLU, FlexBertParallelGLU)):
1866
+ raise ValueError(f"Unsupported MLP layer type: {type(pretrained_layer.mlp)}")
1867
+ assert isinstance(
1868
+ new_model_layer.mlp, type(pretrained_layer.mlp)
1869
+ ), f"Pretrained and new_model mlp layers must be the same type, got {type(new_model_layer.mlp)} and {type(pretrained_layer.mlp)}"
1870
+
1871
+ # already tiled the parallel glu layer if it exists, so only need to handle mlp & glu Wi
1872
+ if isinstance(pretrained_layer.mlp, FlexBertGLU):
1873
+ tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.glu, mode=mode)
1874
+ elif isinstance(pretrained_layer.mlp, FlexBertMLP):
1875
+ tile_linear(pretrained_layer.mlp.Wi, new_model_layer.mlp.Wi, linear_type=TileLinear.default, mode=mode)
1876
+ # tile the output for both ParallelGLU and MLP/GLU
1877
+ tile_linear(pretrained_layer.mlp.Wo, new_model_layer.mlp.Wo, linear_type=TileLinear.default, mode=mode)
1878
+
1879
+
1880
+ def init_mlm_model_from_pretrained(
1881
+ config: FlexBertConfig,
1882
+ pretrained_model: FlexBertForMaskedLM,
1883
+ new_model: FlexBertForMaskedLM,
1884
+ mode: Union[str, TileMode] = TileMode.tile_weights_from_middle,
1885
+ ):
1886
+ """
1887
+ Initialize the new model from the pretrained model.
1888
+
1889
+ This method uses Gopher layer scaling and Phi-style weight tiling as selected by `mode`.
1890
+ The new model must have the same or more layers and the same or larger dimensions than the pretrained model.
1891
+
1892
+ Args:
1893
+ config (FlexBertConfig): The configuration of the new_model
1894
+ pretrained_model (FlexBertForMaskedLM): The smaller, pre-trained model
1895
+ new_model (FlexBertForMaskedLM): The larger model to be initialized from the pretrained model
1896
+ mode (Union[str, TileMode]): The Phi-style weight tiling mode to use
1897
+
1898
+ This function assumes that the new_model has more layers and a larger hidden size
1899
+ than the pretrained_model, but the same vocabulary size.
1900
+ """
1901
+ init_model_from_pretrained(pretrained_model.bert, new_model.bert, mode=mode)
1902
+
1903
+ # TODO: uncomment this when the repo is turned into a pip installable package
1904
+ # if not isinstance(pretrained_model.head, FlexBertPredictionHead):
1905
+ # raise ValueError(f"Pretrained model must have a prediction head: {type(pretrained_model.head)}")
1906
+ # if not isinstance(new_model.head, FlexBertPredictionHead):
1907
+ # raise ValueError(f"New model must have a prediction head: {type(new_model.head)}")
1908
+
1909
+ # tile the prediction head
1910
+ tile_linear(pretrained_model.head.dense, new_model.head.dense, linear_type=TileLinear.default, mode=mode)
1911
+ tile_norm(pretrained_model.head.norm, new_model.head.norm, mode=mode)
1912
+
1913
+ # setup weight tying
1914
+ if config.tie_word_embeddings:
1915
+ new_model.decoder.weight = new_model.bert.embeddings.tok_embeddings.weight
1916
+ tile_linear(
1917
+ pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
1918
+ )
1919
+ else:
1920
+ tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
normalization.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # RMSNorm Implementation: Copyright Meta (from their Llama RMSNorm implementation)
5
+ # License: LLAMA 2 COMMUNITY LICENSE AGREEMENT
6
+
7
+
8
+ import inspect
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import init
12
+
13
+ from .configuration_bert import FlexBertConfig
14
+
15
+ try:
16
+ from flash_attn.ops.triton.layer_norm import RMSNorm as TritonRMSNorm
17
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
18
+
19
+ except ImportError:
20
+ TritonRMSNorm = None
21
+ layer_norm_fn = None
22
+
23
+
24
+ class RMSNorm(nn.Module):
25
+ """Llama2 RMSNorm implementation"""
26
+
27
+ def __init__(self, dim: int, eps: float = 1e-5):
28
+ """
29
+ Initialize the RMSNorm normalization layer.
30
+
31
+ Args:
32
+ dim (int): The dimension of the input tensor.
33
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
34
+
35
+ Attributes:
36
+ eps (float): A small value added to the denominator for numerical stability.
37
+ weight (nn.Parameter): Learnable scaling parameter.
38
+
39
+ """
40
+ super().__init__()
41
+ self.eps = eps
42
+ self.weight = nn.Parameter(torch.ones(dim))
43
+
44
+ def _norm(self, x):
45
+ """
46
+ Apply the RMSNorm normalization to the input tensor.
47
+
48
+ Args:
49
+ x (torch.Tensor): The input tensor.
50
+
51
+ Returns:
52
+ torch.Tensor: The normalized tensor.
53
+
54
+ """
55
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
56
+
57
+ def forward(self, x):
58
+ """
59
+ Forward pass through the RMSNorm layer.
60
+
61
+ Args:
62
+ x (torch.Tensor): The input tensor.
63
+
64
+ Returns:
65
+ torch.Tensor: The output tensor after applying RMSNorm.
66
+
67
+ """
68
+ output = self._norm(x.float()).type_as(x)
69
+ return output * self.weight
70
+
71
+ def reset_parameters(self):
72
+ init.ones_(self.weight)
73
+
74
+
75
+ if layer_norm_fn is not None:
76
+
77
+ class TritonLayerNorm(nn.LayerNorm):
78
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
79
+ return layer_norm_fn(
80
+ x,
81
+ self.weight,
82
+ self.bias,
83
+ residual=residual,
84
+ eps=self.eps,
85
+ prenorm=prenorm,
86
+ residual_in_fp32=residual_in_fp32,
87
+ )
88
+ else:
89
+ TritonLayerNorm = None
90
+
91
+ NORM2CLS = {
92
+ "layernorm": nn.LayerNorm,
93
+ "triton_layernorm": TritonLayerNorm if TritonLayerNorm is not None else nn.LayerNorm,
94
+ "rmsnorm": RMSNorm,
95
+ "triton_rmsnorm": TritonRMSNorm if TritonRMSNorm is not None else RMSNorm,
96
+ }
97
+
98
+
99
+ def get_norm_layer(config: FlexBertConfig, compiled_norm: bool = False) -> nn.Module:
100
+ try:
101
+ if compiled_norm:
102
+ # Use non-Triton norms when compiling
103
+ if config.normalization.startswith("triton_"):
104
+ norm = config.normalization.replace("triton_", "")
105
+ else:
106
+ norm = config.normalization
107
+ else:
108
+ norm = config.normalization
109
+ signature = inspect.signature(NORM2CLS[norm])
110
+ if hasattr(config, "norm_kwargs"):
111
+ norm_kwargs = {k: v for k, v in config.norm_kwargs.items() if k in signature.parameters}
112
+ else:
113
+ norm_kwargs = {}
114
+ return NORM2CLS[norm](config.hidden_size, **norm_kwargs)
115
+ except KeyError:
116
+ raise ValueError(f"Invalid normalization layer type: {config.normalization}, must be one of {NORM2CLS.keys()}.")
options.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .normalization import NORM2CLS
2
+ from .embeddings import EBB2CLS
3
+ from .activation import ACT2CLS
4
+ from .attention import ATTN2CLS
5
+ from .mlp import MLP2CLS
6
+ from .layers import LAYER2CLS
7
+
8
+
9
+ def print_layer_options():
10
+ print("Activation options:")
11
+ for option in ACT2CLS:
12
+ print(f" {option}")
13
+
14
+ print("\nAttention Layer options:")
15
+ for option in ATTN2CLS:
16
+ print(f" {option}")
17
+
18
+ print("\nEmbedding Layer options:")
19
+ for option in EBB2CLS:
20
+ print(f" {option}")
21
+
22
+ print("\nBert Layer options:")
23
+ for option in LAYER2CLS:
24
+ print(f" {option}")
25
+
26
+ print("\nMLP Layer options:")
27
+ for option in MLP2CLS:
28
+ print(f" {option}")
29
+
30
+ print("\nNormalization options:")
31
+ for option in NORM2CLS:
32
+ print(f" {option}")
padding.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from typing import Optional, Tuple
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def unpad_input(
8
+ inputs: Tensor,
9
+ attention_mask: Tensor,
10
+ position_ids: Optional[Tensor] = None,
11
+ labels: Optional[Tensor] = None,
12
+ ) -> Tuple[Tensor, Tensor, Tensor, int, Optional[Tensor], Optional[Tensor]]:
13
+ """
14
+ Remove padding from input sequences.
15
+
16
+ Args:
17
+ inputs: (batch, seqlen, ...) or (batch, seqlen)
18
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
19
+ position_ids: (batch, seqlen), int, position ids
20
+ labels: (batch, seqlen), int, labels
21
+
22
+ Returns:
23
+ unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
24
+ indices: (total_nnz)
25
+ cu_seqlens: (batch + 1), the cumulative sequence lengths
26
+ max_seqlen_in_batch: int
27
+ unpadded_position_ids: (total_nnz) or None
28
+ unpadded_labels: (total_nnz) or None
29
+ """
30
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
31
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
32
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
33
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
34
+
35
+ if inputs.dim() == 2:
36
+ unpadded_inputs = inputs.flatten()[indices]
37
+ else:
38
+ batch, seqlen, *rest = inputs.shape
39
+ shape = batch * seqlen
40
+ unpadded_inputs = inputs.view(shape, *rest)[indices]
41
+
42
+ unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
43
+ unpadded_labels = labels.flatten()[indices] if labels is not None else None
44
+
45
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
46
+
47
+
48
+ def pad_input(
49
+ inputs: Tensor,
50
+ indices: Tensor,
51
+ batch: int,
52
+ seqlen: int,
53
+ labels: Optional[Tensor] = None,
54
+ ignore_index: int = -100,
55
+ ) -> Tuple[Tensor, Optional[Tensor]]:
56
+ """
57
+ Add padding to sequences.
58
+
59
+ Args:
60
+ inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
61
+ indices: (total_nnz)
62
+ batch: int, batch size
63
+ seqlen: int, max sequence length
64
+ position_ids: (total_nnz) or None
65
+ labels: (total_nnz) or None
66
+
67
+ Returns:
68
+ padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
69
+ padded_labels: (batch, seqlen) or None
70
+ """
71
+ if inputs.dim() == 1:
72
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
73
+ output[indices] = inputs
74
+ padded_inputs = output.view(batch, seqlen)
75
+ else:
76
+ _, *rest = inputs.shape
77
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
78
+ output[indices] = inputs
79
+ padded_inputs = output.view(batch, seqlen, *rest)
80
+
81
+ padded_labels = None
82
+ if labels is not None:
83
+ padded_labels = torch.full((batch * seqlen,), fill_value=ignore_index, dtype=labels.dtype, device=labels.device)
84
+ padded_labels[indices] = labels
85
+ padded_labels = padded_labels.view(batch, seqlen)
86
+
87
+ return padded_inputs, padded_labels
rotary.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 **AUTHORS_TODO**
2
+ # License: Apache-2.0
3
+
4
+ # Copyright (c) 2023, Tri Dao.
5
+ # License: Apache-2.0
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from flash_attn.ops.triton.rotary import apply_rotary
10
+
11
+ from typing import Optional, Tuple, Union
12
+
13
+
14
+ class ApplyRotaryEmbUnpad(torch.autograd.Function):
15
+ @staticmethod
16
+ def forward(
17
+ ctx,
18
+ qkv,
19
+ cos,
20
+ sin,
21
+ interleaved=False,
22
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
23
+ cu_seqlens: Optional[torch.Tensor] = None,
24
+ max_seqlen: Optional[int] = None,
25
+ ):
26
+ # (total_nnz, 3, nheads, headdim)
27
+ total_nnz, three, nheads, headdim = qkv.shape
28
+ assert three == 3
29
+ if qkv.is_contiguous():
30
+ # Call 1 kernel instead of 2 kernels
31
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
32
+ # dimensions, we get the same tensor
33
+ # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
34
+ qk = qkv[:, :2].view(total_nnz, -1, headdim)
35
+ apply_rotary(
36
+ qk,
37
+ cos,
38
+ sin,
39
+ seqlen_offsets=seqlen_offsets,
40
+ cu_seqlens=cu_seqlens,
41
+ max_seqlen=max_seqlen,
42
+ interleaved=interleaved,
43
+ inplace=True,
44
+ )
45
+ else:
46
+ q, k = qkv[:, 0, :, :], qkv[:, 1, :, :]
47
+ apply_rotary(
48
+ q,
49
+ cos,
50
+ sin,
51
+ seqlen_offsets=seqlen_offsets,
52
+ cu_seqlens=cu_seqlens,
53
+ max_seqlen=max_seqlen,
54
+ interleaved=interleaved,
55
+ inplace=True,
56
+ )
57
+ apply_rotary(
58
+ k,
59
+ cos,
60
+ sin,
61
+ seqlen_offsets=seqlen_offsets,
62
+ cu_seqlens=cu_seqlens,
63
+ max_seqlen=max_seqlen,
64
+ interleaved=interleaved,
65
+ inplace=True,
66
+ )
67
+
68
+ if isinstance(seqlen_offsets, int):
69
+ ctx.save_for_backward(cos, sin, cu_seqlens)
70
+ ctx.seqlen_offsets = seqlen_offsets
71
+ else:
72
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
73
+ ctx.seqlen_offsets = None
74
+ ctx.interleaved = interleaved
75
+ ctx.max_seqlen = max_seqlen
76
+ return qkv
77
+
78
+ @staticmethod
79
+ def backward(ctx, do):
80
+ seqlen_offsets = ctx.seqlen_offsets
81
+ if seqlen_offsets is None:
82
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
83
+ else:
84
+ cos, sin, cu_seqlens = ctx.saved_tensors
85
+ if do.is_contiguous():
86
+ total_nnz, three, nheads, headdim = do.shape
87
+ # Call 1 kernel instead of 2 kernels
88
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
89
+ # dimensions, we get the same tensor
90
+ dqk = do[:, :2].view(total_nnz, -1, headdim)
91
+ apply_rotary(
92
+ dqk,
93
+ cos,
94
+ sin,
95
+ seqlen_offsets=seqlen_offsets,
96
+ cu_seqlens=cu_seqlens,
97
+ max_seqlen=ctx.max_seqlen,
98
+ interleaved=ctx.interleaved,
99
+ inplace=True,
100
+ conjugate=True,
101
+ )
102
+ else:
103
+ dq, dk = do[:, 0, :, :], do[:, 1, :, :]
104
+ apply_rotary(
105
+ dq,
106
+ cos,
107
+ sin,
108
+ seqlen_offsets=seqlen_offsets,
109
+ cu_seqlens=cu_seqlens,
110
+ max_seqlen=ctx.max_seqlen,
111
+ interleaved=ctx.interleaved,
112
+ inplace=True,
113
+ conjugate=True,
114
+ )
115
+ apply_rotary(
116
+ dk,
117
+ cos,
118
+ sin,
119
+ seqlen_offsets=seqlen_offsets,
120
+ cu_seqlens=cu_seqlens,
121
+ max_seqlen=ctx.max_seqlen,
122
+ interleaved=ctx.interleaved,
123
+ inplace=True,
124
+ conjugate=True,
125
+ )
126
+
127
+ return do, None, None, None, None, None, None
128
+
129
+
130
+ def apply_rotary_emb_unpad(
131
+ qkv,
132
+ cos,
133
+ sin,
134
+ interleaved=False,
135
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
136
+ cu_seqlens: Optional[torch.Tensor] = None,
137
+ max_seqlen: Optional[int] = None,
138
+ ):
139
+ """
140
+ Arguments:
141
+ qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
142
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
143
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
144
+ of 1st half and 2nd half (GPT-NeoX style).
145
+ inplace: if True, apply rotary embedding in-place.
146
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
147
+ Most commonly used in inference when we have KV cache.
148
+ cu_seqlens: (batch + 1,) or None
149
+ max_seqlen: int
150
+ Return:
151
+ out: (total_nnz, dim)
152
+ rotary_dim must be <= headdim
153
+ Apply rotary embedding to the first rotary_dim of x.
154
+ """
155
+ return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, interleaved, seqlen_offsets, cu_seqlens, max_seqlen)
156
+
157
+
158
+ class UnpaddedRotaryEmbedding(torch.nn.Module):
159
+ """
160
+ The rotary position embeddings applied directly to unpadded sequences.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ dim: int,
166
+ base: float = 10000.0,
167
+ interleaved: bool = False,
168
+ max_seqlen: Optional[int] = None,
169
+ scale_base: Optional[bool] = None,
170
+ pos_idx_in_fp32: bool = True,
171
+ device: Optional[torch.device] = None,
172
+ dtype: Optional[torch.dtype] = None,
173
+ ):
174
+ """
175
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
176
+ of 1st half and 2nd half (GPT-NeoX style).
177
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
178
+ otherwise they might be in lower precision.
179
+ This option was added because previously (before 2023-07-02), when we construct
180
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
181
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
182
+ self.inv_freq would be bf16, and the position indices are also in bf16.
183
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
184
+ embeddings for some positions will coincide.
185
+ To maintain compatibility with models previously trained in pure bf16,
186
+ we add this option.
187
+ max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
188
+ up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
189
+ the cos_sin_cache wll be recomputed during the forward pass.
190
+ """
191
+ super().__init__()
192
+ self.dim = dim
193
+ self.base = float(base)
194
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
195
+ # Generate and save the inverse frequency buffer (non trainable)
196
+ inv_freq = self._compute_inv_freq(device)
197
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
198
+ self.interleaved = interleaved
199
+ self.scale_base = scale_base
200
+ scale = (
201
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
202
+ if scale_base is not None
203
+ else None
204
+ )
205
+ self.register_buffer("scale", scale, persistent=False)
206
+
207
+ self._seq_len_cached = 0
208
+ self._cos_cached = None
209
+ self._sin_cached = None
210
+ self._cos_k_cached = None
211
+ self._sin_k_cached = None
212
+
213
+ if max_seqlen is not None and device is not None and dtype is not None:
214
+ self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
215
+
216
+ def _compute_inv_freq(self, device=None):
217
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
218
+
219
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
220
+ # Reset the tables if the sequence length has changed,
221
+ # if we're on a new device (possibly due to tracing for instance),
222
+ # or if we're switching from inference mode to training
223
+ if (
224
+ seqlen > self._seq_len_cached
225
+ or self._cos_cached is None
226
+ or self._cos_cached.device != device
227
+ or self._cos_cached.dtype != dtype
228
+ or (self.training and self._cos_cached.is_inference())
229
+ ):
230
+ self._seq_len_cached = seqlen
231
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
232
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
233
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
234
+ if self.pos_idx_in_fp32:
235
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
236
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
237
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
238
+ # cos & sin output to change significantly.
239
+ # We want to recompute self.inv_freq if it was not loaded in fp32
240
+ if self.inv_freq.dtype != torch.float32:
241
+ inv_freq = self._compute_inv_freq(device=device)
242
+ else:
243
+ inv_freq = self.inv_freq
244
+ else:
245
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
246
+ inv_freq = self.inv_freq
247
+ # Don't do einsum, it converts fp32 to fp16 under AMP
248
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
249
+ freqs = torch.outer(t, inv_freq)
250
+ if self.scale is None:
251
+ self._cos_cached = torch.cos(freqs).to(dtype)
252
+ self._sin_cached = torch.sin(freqs).to(dtype)
253
+ else:
254
+ power = (
255
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
256
+ ) / self.scale_base
257
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
258
+ # We want the multiplication by scale to happen in fp32
259
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
260
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
261
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
262
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
263
+
264
+ def forward(
265
+ self,
266
+ qkv: torch.Tensor,
267
+ cu_seqlens: torch.Tensor,
268
+ max_seqlen: Optional[int] = None,
269
+ seqlen_offset: Union[int, torch.Tensor] = 0,
270
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
271
+ """
272
+ qkv: (total_nnz, 3, nheads, headdim)
273
+ cu_seqlens: (batch + 1,) cumulative sequence lengths
274
+ max_seqlen: int max seq length in the batch
275
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
276
+ Most commonly used in inference when we have KV cache.
277
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
278
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
279
+ Apply rotary embedding *inplace* to qkv.
280
+ """
281
+ if max_seqlen is not None:
282
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
283
+
284
+ qkv = apply_rotary_emb_unpad(
285
+ qkv,
286
+ self._cos_cached,
287
+ self._sin_cached,
288
+ interleaved=self.interleaved,
289
+ seqlen_offsets=seqlen_offset,
290
+ cu_seqlens=cu_seqlens,
291
+ max_seqlen=max_seqlen,
292
+ )
293
+
294
+ return qkv
295
+
296
+ def extra_repr(self) -> str:
297
+ return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Optuna, Hugging Face
2
+ # License: Apache-2.0
3
+
4
+ # Copyright 2023 OLMo Authors
5
+ # License: Apache-2.0
6
+
7
+ import functools
8
+ import logging
9
+ from enum import Enum
10
+
11
+
12
+ @functools.lru_cache(None)
13
+ def warning_once(self, *args, **kwargs):
14
+ """
15
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
16
+
17
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
18
+ The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
19
+ another type of cache that includes the caller frame information in the hashing function.
20
+ """
21
+ self.warning(*args, **kwargs)
22
+
23
+
24
+ logging.Logger.warning_once = warning_once
25
+ logging.Logger.warn_once = warning_once
26
+
27
+
28
+ class StrEnum(str, Enum):
29
+ """
30
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
31
+ We include this here for compatibility with older version of Python.
32
+ """
33
+
34
+ def __str__(self) -> str:
35
+ return self.value
36
+
37
+ def __repr__(self) -> str:
38
+ return f"'{str(self)}'"