Porting v2 models to flash attention (#15)
Browse files- Added GLUMLP, changed config accordingly, added code to convert state_dict (0211324e8c38d72ef847d46db9a9f389c864a5de)
- fixed GLU implementation, added conversion of layer norms (9587227ceebcbf4e7335c0938838e9a2eb0b5d6b)
Co-authored-by: Markus Krimmel <[email protected]>
- configuration_bert.py +3 -3
- convert_v2_weights.py +144 -0
- mlp.py +41 -0
- modeling_bert.py +18 -5
configuration_bert.py
CHANGED
@@ -75,7 +75,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
75 |
pad_token_id=0,
|
76 |
window_size=(-1, -1),
|
77 |
dense_seq_output=False,
|
78 |
-
|
79 |
mlp_checkpoint_lvl=0,
|
80 |
last_layer_subset=False,
|
81 |
fused_dropout_add_ln=False,
|
@@ -92,7 +92,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
92 |
assert 'max_position_embeddings' not in kwargs
|
93 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
94 |
|
95 |
-
if fused_mlp and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
|
96 |
raise ValueError('Fused MLP only supports approximate gelu')
|
97 |
|
98 |
self.vocab_size = vocab_size
|
@@ -108,7 +108,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
108 |
self.layer_norm_eps = layer_norm_eps
|
109 |
self.window_size = window_size
|
110 |
self.dense_seq_output = dense_seq_output
|
111 |
-
self.
|
112 |
self.mlp_checkpoint_lvl = mlp_checkpoint_lvl
|
113 |
self.last_layer_subset = last_layer_subset
|
114 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
|
|
75 |
pad_token_id=0,
|
76 |
window_size=(-1, -1),
|
77 |
dense_seq_output=False,
|
78 |
+
mlp_type='mlp',
|
79 |
mlp_checkpoint_lvl=0,
|
80 |
last_layer_subset=False,
|
81 |
fused_dropout_add_ln=False,
|
|
|
92 |
assert 'max_position_embeddings' not in kwargs
|
93 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
94 |
|
95 |
+
if mlp_type == 'fused_mlp' and hidden_act not in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]:
|
96 |
raise ValueError('Fused MLP only supports approximate gelu')
|
97 |
|
98 |
self.vocab_size = vocab_size
|
|
|
108 |
self.layer_norm_eps = layer_norm_eps
|
109 |
self.window_size = window_size
|
110 |
self.dense_seq_output = dense_seq_output
|
111 |
+
self.mlp_type= mlp_type
|
112 |
self.mlp_checkpoint_lvl = mlp_checkpoint_lvl
|
113 |
self.last_layer_subset = last_layer_subset
|
114 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
convert_v2_weights.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import OrderedDict
|
3 |
+
from transformers import AutoModel, AutoTokenizer
|
4 |
+
from .configuration_bert import JinaBertConfig
|
5 |
+
import torch
|
6 |
+
from .modeling_bert import BertModel
|
7 |
+
|
8 |
+
def remap_state_dict(state_dict, config: JinaBertConfig):
|
9 |
+
"""
|
10 |
+
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
11 |
+
"""
|
12 |
+
|
13 |
+
# LayerNorm
|
14 |
+
def key_mapping_ln_gamma_beta(key):
|
15 |
+
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
16 |
+
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
17 |
+
return key
|
18 |
+
|
19 |
+
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
20 |
+
|
21 |
+
# Layers
|
22 |
+
def key_mapping_layers(key):
|
23 |
+
return re.sub(r"^encoder.layer.", "encoder.layers.", key)
|
24 |
+
|
25 |
+
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
26 |
+
|
27 |
+
# LayerNorm
|
28 |
+
def key_mapping_ln(key):
|
29 |
+
key = re.sub(r"^embeddings.LayerNorm.", "emb_ln.", key)
|
30 |
+
key = re.sub(
|
31 |
+
r"^encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
|
32 |
+
r"encoder.layers.\1.norm1.\2",
|
33 |
+
key,
|
34 |
+
)
|
35 |
+
key = re.sub(
|
36 |
+
r"^encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
|
37 |
+
r"encoder.layers.\1.norm2.\2",
|
38 |
+
key,
|
39 |
+
)
|
40 |
+
key = re.sub(
|
41 |
+
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
|
42 |
+
r"cls.predictions.transform.layer_norm.\1",
|
43 |
+
key,
|
44 |
+
)
|
45 |
+
return key
|
46 |
+
|
47 |
+
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
48 |
+
|
49 |
+
# MLP
|
50 |
+
def key_mapping_mlp(key):
|
51 |
+
key = re.sub(
|
52 |
+
r"^encoder.layers.(\d+).intermediate.dense.(weight|bias)",
|
53 |
+
r"encoder.layers.\1.mlp.fc1.\2",
|
54 |
+
key,
|
55 |
+
)
|
56 |
+
key = re.sub(
|
57 |
+
r"^encoder.layers.(\d+).output.dense.(weight|bias)",
|
58 |
+
r"encoder.layers.\1.mlp.fc2.\2",
|
59 |
+
key,
|
60 |
+
)
|
61 |
+
return key
|
62 |
+
|
63 |
+
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
64 |
+
|
65 |
+
# Attention
|
66 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
67 |
+
for d in range(config.num_hidden_layers):
|
68 |
+
Wq = state_dict.pop(f"encoder.layers.{d}.attention.self.query.weight")
|
69 |
+
Wk = state_dict.pop(f"encoder.layers.{d}.attention.self.key.weight")
|
70 |
+
Wv = state_dict.pop(f"encoder.layers.{d}.attention.self.value.weight")
|
71 |
+
bq = state_dict.pop(f"encoder.layers.{d}.attention.self.query.bias")
|
72 |
+
bk = state_dict.pop(f"encoder.layers.{d}.attention.self.key.bias")
|
73 |
+
bv = state_dict.pop(f"encoder.layers.{d}.attention.self.value.bias")
|
74 |
+
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
75 |
+
state_dict[f"encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
76 |
+
[Wq, Wk, Wv], dim=0
|
77 |
+
)
|
78 |
+
state_dict[f"encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
79 |
+
else:
|
80 |
+
state_dict[f"encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
81 |
+
state_dict[f"encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
|
82 |
+
state_dict[f"encoder.layers.{d}.mixer.Wq.bias"] = bq
|
83 |
+
state_dict[f"encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
|
84 |
+
|
85 |
+
def key_mapping_attn(key):
|
86 |
+
return re.sub(
|
87 |
+
r"^encoder.layers.(\d+).attention.output.dense.(weight|bias)",
|
88 |
+
r"encoder.layers.\1.mixer.out_proj.\2",
|
89 |
+
key,
|
90 |
+
)
|
91 |
+
|
92 |
+
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
93 |
+
|
94 |
+
def key_mapping_decoder_bias(key):
|
95 |
+
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
96 |
+
|
97 |
+
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
98 |
+
|
99 |
+
# Word embedding
|
100 |
+
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
101 |
+
if pad_vocab_size_multiple > 1:
|
102 |
+
word_embeddings = state_dict["embeddings.word_embeddings.weight"]
|
103 |
+
state_dict["embeddings.word_embeddings.weight"] = F.pad(
|
104 |
+
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
105 |
+
)
|
106 |
+
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
107 |
+
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
108 |
+
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
109 |
+
)
|
110 |
+
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
111 |
+
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
112 |
+
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
113 |
+
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
114 |
+
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
115 |
+
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
116 |
+
)
|
117 |
+
|
118 |
+
# LayerNorm
|
119 |
+
def key_mapping_layernorm(key):
|
120 |
+
return re.sub(r'^encoder.layers.(\d+).mlp.layernorm.(weight|bias)', r"encoder.layers.\1.norm2.\2", key)
|
121 |
+
|
122 |
+
state_dict = OrderedDict((key_mapping_layernorm(k), v) for k, v in state_dict.items())
|
123 |
+
|
124 |
+
return state_dict
|
125 |
+
|
126 |
+
|
127 |
+
v2_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
|
128 |
+
config = JinaBertConfig(vocab_size=30528, use_qk_norm=False, mlp_type='glu', hidden_act='gelu')
|
129 |
+
state_dict = v2_model.state_dict()
|
130 |
+
new_state_dict = remap_state_dict(state_dict, config)
|
131 |
+
flash_model = BertModel(config)
|
132 |
+
flash_model.load_state_dict(new_state_dict)
|
133 |
+
|
134 |
+
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
|
135 |
+
inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
|
136 |
+
v2_model.eval()
|
137 |
+
flash_model.eval()
|
138 |
+
v2_model = v2_model.to('cuda', torch.float16)
|
139 |
+
flash_model = flash_model.to('cuda', torch.float16)
|
140 |
+
output_v2 = v2_model(**inp)
|
141 |
+
output_flash = flash_model(**inp)
|
142 |
+
x = output_v2.last_hidden_state
|
143 |
+
y = output_flash.last_hidden_state
|
144 |
+
print(torch.abs(x - y))
|
mlp.py
CHANGED
@@ -27,6 +27,47 @@ except ImportError:
|
|
27 |
FusedMLP, ParallelFusedMLP = None, None
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
class Mlp(nn.Module):
|
31 |
def __init__(
|
32 |
self,
|
|
|
27 |
FusedMLP, ParallelFusedMLP = None, None
|
28 |
|
29 |
|
30 |
+
class GLUMLP(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
in_features,
|
34 |
+
hidden_features,
|
35 |
+
activation,
|
36 |
+
return_residual=False,
|
37 |
+
hidden_dropout_prob=0.1
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
self.hidden_features = hidden_features
|
41 |
+
self.gated_layers = nn.Linear(
|
42 |
+
in_features, hidden_features * 2, bias=False
|
43 |
+
)
|
44 |
+
if activation == 'relu':
|
45 |
+
self.act = nn.ReLU()
|
46 |
+
elif activation == 'gelu':
|
47 |
+
self.act = nn.GELU()
|
48 |
+
else:
|
49 |
+
raise ValueError(
|
50 |
+
f"activation {activation} not supported"
|
51 |
+
)
|
52 |
+
self.wo = nn.Linear(hidden_features, in_features)
|
53 |
+
self.dropout = nn.Dropout(hidden_dropout_prob)
|
54 |
+
self.return_residual = return_residual
|
55 |
+
#self.layernorm = nn.LayerNorm(in_features, eps=layer_norm_eps)
|
56 |
+
|
57 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
58 |
+
residual_connection = hidden_states
|
59 |
+
# compute the activation
|
60 |
+
hidden_states = self.gated_layers(hidden_states)
|
61 |
+
gated = hidden_states[:, : self.hidden_features]
|
62 |
+
non_gated = hidden_states[:, self.hidden_features :]
|
63 |
+
hidden_states = self.act(gated) * non_gated
|
64 |
+
hidden_states = self.dropout(hidden_states)
|
65 |
+
# multiply by the second matrix
|
66 |
+
hidden_states = self.wo(hidden_states)
|
67 |
+
# add the residual connection and post-LN
|
68 |
+
# hidden_states = self.layernorm(hidden_states + residual_connection)
|
69 |
+
return hidden_states if not self.return_residual else (hidden_states, residual_connection)
|
70 |
+
|
71 |
class Mlp(nn.Module):
|
72 |
def __init__(
|
73 |
self,
|
modeling_bert.py
CHANGED
@@ -39,7 +39,7 @@ from .bert_padding import (
|
|
39 |
from .block import Block
|
40 |
from .embedding import BertEmbeddings
|
41 |
from .mha import MHA
|
42 |
-
from .mlp import FusedMLP, Mlp
|
43 |
|
44 |
try:
|
45 |
from flash_attn.ops.fused_dense import FusedDense
|
@@ -89,12 +89,15 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
89 |
|
90 |
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
91 |
inner_dim = config.intermediate_size
|
92 |
-
|
93 |
-
|
|
|
94 |
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
95 |
"fused_mlp only " "supports approximate gelu"
|
96 |
)
|
97 |
-
if
|
|
|
|
|
98 |
approximate = (
|
99 |
"tanh"
|
100 |
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
@@ -106,7 +109,15 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
|
106 |
activation=partial(F.gelu, approximate=approximate),
|
107 |
return_residual=return_residual,
|
108 |
)
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
if FusedMLP is None:
|
111 |
raise ImportError("fused_dense is not installed")
|
112 |
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
@@ -120,6 +131,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
|
120 |
checkpoint_lvl=mlp_checkpoint_lvl,
|
121 |
return_residual=return_residual,
|
122 |
)
|
|
|
|
|
123 |
return mlp_cls
|
124 |
|
125 |
|
|
|
39 |
from .block import Block
|
40 |
from .embedding import BertEmbeddings
|
41 |
from .mha import MHA
|
42 |
+
from .mlp import FusedMLP, Mlp, GLUMLP
|
43 |
|
44 |
try:
|
45 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
89 |
|
90 |
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
91 |
inner_dim = config.intermediate_size
|
92 |
+
mlp_type = config.mlp_type
|
93 |
+
assert mlp_type in ('mlp', 'fused_mlp', 'glu')
|
94 |
+
if mlp_type == 'fused_mlp':
|
95 |
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
96 |
"fused_mlp only " "supports approximate gelu"
|
97 |
)
|
98 |
+
if mlp_type == 'glu':
|
99 |
+
assert config.hidden_act in ('relu', 'gelu')
|
100 |
+
if mlp_type == 'mlp':
|
101 |
approximate = (
|
102 |
"tanh"
|
103 |
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
|
|
109 |
activation=partial(F.gelu, approximate=approximate),
|
110 |
return_residual=return_residual,
|
111 |
)
|
112 |
+
elif mlp_type == 'glu':
|
113 |
+
mlp_cls = partial(
|
114 |
+
GLUMLP,
|
115 |
+
hidden_features=inner_dim,
|
116 |
+
activation=config.hidden_act,
|
117 |
+
hidden_dropout_prob=config.hidden_dropout_prob,
|
118 |
+
return_residual=return_residual,
|
119 |
+
)
|
120 |
+
elif mlp_type == 'fused_mlp':
|
121 |
if FusedMLP is None:
|
122 |
raise ImportError("fused_dense is not installed")
|
123 |
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
|
|
131 |
checkpoint_lvl=mlp_checkpoint_lvl,
|
132 |
return_residual=return_residual,
|
133 |
)
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
return mlp_cls
|
137 |
|
138 |
|