Jackmin801
commited on
Commit
•
6f3de15
1
Parent(s):
4f24e0f
set flash attn as option in config
Browse files- configuration_bert.py +4 -0
- flash_attn_triton.py +9 -32
- modeling_bert.py +17 -6
configuration_bert.py
CHANGED
@@ -127,6 +127,8 @@ class JinaBertConfig(PretrainedConfig):
|
|
127 |
emb_pooler (`str`, *optional*, defaults to `None`):
|
128 |
The function to use for pooling the last layer embeddings to get the sentence embeddings.
|
129 |
Should be one of `None`, `"mean"`.
|
|
|
|
|
130 |
|
131 |
Examples:
|
132 |
|
@@ -164,6 +166,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
164 |
classifier_dropout=None,
|
165 |
feed_forward_type="original",
|
166 |
emb_pooler=None,
|
|
|
167 |
**kwargs,
|
168 |
):
|
169 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
@@ -185,6 +188,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
185 |
self.classifier_dropout = classifier_dropout
|
186 |
self.feed_forward_type = feed_forward_type
|
187 |
self.emb_pooler = emb_pooler
|
|
|
188 |
|
189 |
|
190 |
class JinaBertOnnxConfig(OnnxConfig):
|
|
|
127 |
emb_pooler (`str`, *optional*, defaults to `None`):
|
128 |
The function to use for pooling the last layer embeddings to get the sentence embeddings.
|
129 |
Should be one of `None`, `"mean"`.
|
130 |
+
with_flash (`bool`, *optional*, defaults to `False`):
|
131 |
+
Whether to use flash attention. Only works for `triton==2.0.0.dev20230208`
|
132 |
|
133 |
Examples:
|
134 |
|
|
|
166 |
classifier_dropout=None,
|
167 |
feed_forward_type="original",
|
168 |
emb_pooler=None,
|
169 |
+
with_flash=False,
|
170 |
**kwargs,
|
171 |
):
|
172 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
|
|
188 |
self.classifier_dropout = classifier_dropout
|
189 |
self.feed_forward_type = feed_forward_type
|
190 |
self.emb_pooler = emb_pooler
|
191 |
+
self.with_flash = with_flash
|
192 |
|
193 |
|
194 |
class JinaBertOnnxConfig(OnnxConfig):
|
flash_attn_triton.py
CHANGED
@@ -81,21 +81,11 @@ def _fwd_kernel(
|
|
81 |
Lse,
|
82 |
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
83 |
softmax_scale,
|
84 |
-
stride_qb,
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
stride_kn,
|
90 |
-
stride_vb,
|
91 |
-
stride_vh,
|
92 |
-
stride_vn,
|
93 |
-
stride_bb,
|
94 |
-
stride_bh,
|
95 |
-
stride_bm,
|
96 |
-
stride_ob,
|
97 |
-
stride_oh,
|
98 |
-
stride_om,
|
99 |
nheads,
|
100 |
seqlen_q,
|
101 |
seqlen_k,
|
@@ -316,11 +306,6 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
|
316 |
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
317 |
bias_type = 'matrix'
|
318 |
else:
|
319 |
-
print(q.shape)
|
320 |
-
print(k.shape)
|
321 |
-
print(seqlen_q)
|
322 |
-
print(seqlen_k)
|
323 |
-
print(bias.shape)
|
324 |
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
325 |
' or (seqlen_q, seqlen_k)')
|
326 |
if bias.shape[:2] == (1, nheads):
|
@@ -359,19 +344,11 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
|
359 |
lse,
|
360 |
tmp,
|
361 |
softmax_scale,
|
362 |
-
q.stride(0),
|
363 |
-
|
364 |
-
|
365 |
-
k.stride(0),
|
366 |
-
k.stride(2),
|
367 |
-
k.stride(1),
|
368 |
-
v.stride(0),
|
369 |
-
v.stride(2),
|
370 |
-
v.stride(1),
|
371 |
*bias_strides,
|
372 |
-
o.stride(0),
|
373 |
-
o.stride(2),
|
374 |
-
o.stride(1),
|
375 |
nheads,
|
376 |
seqlen_q,
|
377 |
seqlen_k,
|
|
|
81 |
Lse,
|
82 |
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
83 |
softmax_scale,
|
84 |
+
stride_qb, stride_qh, stride_qm,
|
85 |
+
stride_kb, stride_kh, stride_kn,
|
86 |
+
stride_vb, stride_vh, stride_vn,
|
87 |
+
stride_bb, stride_bh, stride_bm,
|
88 |
+
stride_ob, stride_oh, stride_om,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
nheads,
|
90 |
seqlen_q,
|
91 |
seqlen_k,
|
|
|
306 |
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
307 |
bias_type = 'matrix'
|
308 |
else:
|
|
|
|
|
|
|
|
|
|
|
309 |
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
310 |
' or (seqlen_q, seqlen_k)')
|
311 |
if bias.shape[:2] == (1, nheads):
|
|
|
344 |
lse,
|
345 |
tmp,
|
346 |
softmax_scale,
|
347 |
+
q.stride(0), q.stride(2), q.stride(1),
|
348 |
+
k.stride(0), k.stride(2), k.stride(1),
|
349 |
+
v.stride(0), v.stride(2), v.stride(1),
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
*bias_strides,
|
351 |
+
o.stride(0), o.stride(2), o.stride(1),
|
|
|
|
|
352 |
nheads,
|
353 |
seqlen_q,
|
354 |
seqlen_k,
|
modeling_bert.py
CHANGED
@@ -55,7 +55,10 @@ from transformers.utils import (
|
|
55 |
replace_return_docstrings,
|
56 |
)
|
57 |
from .configuration_bert import JinaBertConfig
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
try:
|
61 |
from tqdm.autonotebook import trange
|
@@ -282,7 +285,7 @@ class JinaBertEmbeddings(nn.Module):
|
|
282 |
|
283 |
|
284 |
class JinaBertSelfAttention(nn.Module):
|
285 |
-
def __init__(self, config, position_embedding_type=None):
|
286 |
super().__init__()
|
287 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
288 |
config, "embedding_size"
|
@@ -291,6 +294,13 @@ class JinaBertSelfAttention(nn.Module):
|
|
291 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
292 |
f"heads ({config.num_attention_heads})"
|
293 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
self.num_attention_heads = config.num_attention_heads
|
296 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
@@ -334,14 +344,15 @@ class JinaBertSelfAttention(nn.Module):
|
|
334 |
output_attentions: Optional[bool] = False,
|
335 |
bias: Optional[torch.FloatTensor] = None,
|
336 |
) -> Tuple[torch.Tensor]:
|
337 |
-
if
|
338 |
b, s, h = hidden_states.shape
|
339 |
q = self.query(hidden_states)
|
340 |
k = self.key(hidden_states)
|
341 |
v = self.value(hidden_states)
|
342 |
-
|
343 |
-
|
344 |
-
|
|
|
345 |
attn = flash_attn_func(q, k, v, bias)
|
346 |
return (attn.view(b, s, h),)
|
347 |
mixed_query_layer = self.query(hidden_states)
|
|
|
55 |
replace_return_docstrings,
|
56 |
)
|
57 |
from .configuration_bert import JinaBertConfig
|
58 |
+
try:
|
59 |
+
from .flash_attn_triton import flash_attn_func
|
60 |
+
except Exception:
|
61 |
+
flash_attn_func = None
|
62 |
|
63 |
try:
|
64 |
from tqdm.autonotebook import trange
|
|
|
285 |
|
286 |
|
287 |
class JinaBertSelfAttention(nn.Module):
|
288 |
+
def __init__(self, config: JinaBertConfig, position_embedding_type=None):
|
289 |
super().__init__()
|
290 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
291 |
config, "embedding_size"
|
|
|
294 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
295 |
f"heads ({config.num_attention_heads})"
|
296 |
)
|
297 |
+
|
298 |
+
self.with_flash = config.with_flash
|
299 |
+
if self.with_flash:
|
300 |
+
if flash_attn_func is None:
|
301 |
+
raise ValueError(
|
302 |
+
f"flash_attn_func is None, please install flash_attn_triton"
|
303 |
+
)
|
304 |
|
305 |
self.num_attention_heads = config.num_attention_heads
|
306 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
|
344 |
output_attentions: Optional[bool] = False,
|
345 |
bias: Optional[torch.FloatTensor] = None,
|
346 |
) -> Tuple[torch.Tensor]:
|
347 |
+
if self.with_flash:
|
348 |
b, s, h = hidden_states.shape
|
349 |
q = self.query(hidden_states)
|
350 |
k = self.key(hidden_states)
|
351 |
v = self.value(hidden_states)
|
352 |
+
# B x S x hidden_dim -> B x S x num_heads x head_dim
|
353 |
+
q = q.view(b, s, self.num_attention_heads, self.attention_head_size)
|
354 |
+
k = k.view(b, s, self.num_attention_heads, self.attention_head_size)
|
355 |
+
v = v.view(b, s, self.num_attention_heads, self.attention_head_size)
|
356 |
attn = flash_attn_func(q, k, v, bias)
|
357 |
return (attn.view(b, s, h),)
|
358 |
mixed_query_layer = self.query(hidden_states)
|