fix: use attention dropout with torch SDPA implementation
Browse files- modeling_bert.py +2 -1
modeling_bert.py
CHANGED
@@ -356,7 +356,8 @@ class JinaBertSelfAttention(nn.Module):
|
|
356 |
if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
|
357 |
b, _, s, _ = query_layer.shape
|
358 |
new_bias = attention_mask + bias
|
359 |
-
|
|
|
360 |
attn = attn.permute(0, 2, 1, 3).contiguous()
|
361 |
return (attn.view(b, s, self.all_head_size),)
|
362 |
|
|
|
356 |
if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
|
357 |
b, _, s, _ = query_layer.shape
|
358 |
new_bias = attention_mask + bias
|
359 |
+
dropout_p = self.dropout.p if self.training else 0.0
|
360 |
+
attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias, dropout_p=dropout_p)
|
361 |
attn = attn.permute(0, 2, 1, 3).contiguous()
|
362 |
return (attn.view(b, s, self.all_head_size),)
|
363 |
|