liminghong
commited on
Commit
•
e228d76
1
Parent(s):
ff974d2
convert bias type
Browse files- bert_layers.py +3 -3
bert_layers.py
CHANGED
@@ -171,18 +171,18 @@ class BertUnpadSelfAttention(nn.Module):
|
|
171 |
3) # b s h d
|
172 |
else:
|
173 |
# Triton implementation only supports 0 attention dropout
|
|
|
|
|
174 |
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
175 |
if convert_dtype:
|
176 |
# Triton implementation only supports fp16 and bf16
|
177 |
orig_dtype = qkv.dtype
|
178 |
qkv = qkv.to(torch.float16)
|
179 |
-
bias_dtype = bias.dtype
|
180 |
-
bias = bias.to(torch.float16)
|
181 |
attention = flash_attn_qkvpacked_func(qkv, bias)
|
182 |
attention = attention.to(orig_dtype)
|
183 |
-
bias = bias.to(bias_dtype)
|
184 |
else:
|
185 |
attention = flash_attn_qkvpacked_func(qkv, bias)
|
|
|
186 |
|
187 |
# attn_mask is 1 for attend and 0 for don't
|
188 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|
|
|
171 |
3) # b s h d
|
172 |
else:
|
173 |
# Triton implementation only supports 0 attention dropout
|
174 |
+
bias_dtype = bias.dtype
|
175 |
+
bias = bias.to(torch.float16)
|
176 |
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
177 |
if convert_dtype:
|
178 |
# Triton implementation only supports fp16 and bf16
|
179 |
orig_dtype = qkv.dtype
|
180 |
qkv = qkv.to(torch.float16)
|
|
|
|
|
181 |
attention = flash_attn_qkvpacked_func(qkv, bias)
|
182 |
attention = attention.to(orig_dtype)
|
|
|
183 |
else:
|
184 |
attention = flash_attn_qkvpacked_func(qkv, bias)
|
185 |
+
bias = bias.to(bias_dtype)
|
186 |
|
187 |
# attn_mask is 1 for attend and 0 for don't
|
188 |
attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
|