liminghong commited on
Commit
e228d76
1 Parent(s): ff974d2

convert bias type

Browse files
Files changed (1) hide show
  1. bert_layers.py +3 -3
bert_layers.py CHANGED
@@ -171,18 +171,18 @@ class BertUnpadSelfAttention(nn.Module):
171
  3) # b s h d
172
  else:
173
  # Triton implementation only supports 0 attention dropout
 
 
174
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
175
  if convert_dtype:
176
  # Triton implementation only supports fp16 and bf16
177
  orig_dtype = qkv.dtype
178
  qkv = qkv.to(torch.float16)
179
- bias_dtype = bias.dtype
180
- bias = bias.to(torch.float16)
181
  attention = flash_attn_qkvpacked_func(qkv, bias)
182
  attention = attention.to(orig_dtype)
183
- bias = bias.to(bias_dtype)
184
  else:
185
  attention = flash_attn_qkvpacked_func(qkv, bias)
 
186
 
187
  # attn_mask is 1 for attend and 0 for don't
188
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
 
171
  3) # b s h d
172
  else:
173
  # Triton implementation only supports 0 attention dropout
174
+ bias_dtype = bias.dtype
175
+ bias = bias.to(torch.float16)
176
  convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
177
  if convert_dtype:
178
  # Triton implementation only supports fp16 and bf16
179
  orig_dtype = qkv.dtype
180
  qkv = qkv.to(torch.float16)
 
 
181
  attention = flash_attn_qkvpacked_func(qkv, bias)
182
  attention = attention.to(orig_dtype)
 
183
  else:
184
  attention = flash_attn_qkvpacked_func(qkv, bias)
185
+ bias = bias.to(bias_dtype)
186
 
187
  # attn_mask is 1 for attend and 0 for don't
188
  attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)