replace 1e4 mask
Browse files- README.md +1 -0
- modeling_lsg_camembert.py +11 -7
README.md
CHANGED
@@ -46,6 +46,7 @@ You can change various parameters like :
|
|
46 |
* local block size (block_size=128)
|
47 |
* sparse block size (sparse_block_size=128)
|
48 |
* sparsity factor (sparsity_factor=2)
|
|
|
49 |
* see config.json file
|
50 |
|
51 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
|
|
46 |
* local block size (block_size=128)
|
47 |
* sparse block size (sparse_block_size=128)
|
48 |
* sparsity factor (sparsity_factor=2)
|
49 |
+
* mask_first_token (mask first token since it is redundant with the first global token)
|
50 |
* see config.json file
|
51 |
|
52 |
Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
|
modeling_lsg_camembert.py
CHANGED
@@ -182,7 +182,11 @@ class CausalAttentionProduct(nn.Module):
|
|
182 |
|
183 |
# Add causal mask
|
184 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
185 |
-
causal_mask = torch.tril(
|
|
|
|
|
|
|
|
|
186 |
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
187 |
|
188 |
del attention_mask
|
@@ -300,7 +304,7 @@ class LSGAttentionProduct(nn.Module):
|
|
300 |
|
301 |
# Pad before block reshaping
|
302 |
if is_attn_mask:
|
303 |
-
pad_value =
|
304 |
hidden_states = hidden_states.transpose(-1, -2)
|
305 |
else:
|
306 |
pad_value = 0
|
@@ -333,7 +337,7 @@ class LSGAttentionProduct(nn.Module):
|
|
333 |
|
334 |
# Pad before block reshaping
|
335 |
if is_attn_mask:
|
336 |
-
pad_value =
|
337 |
hidden_states = hidden_states.transpose(-1, -2)
|
338 |
else:
|
339 |
pad_value = 0
|
@@ -557,7 +561,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
557 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
558 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
559 |
|
560 |
-
mask =
|
561 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
562 |
|
563 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -622,7 +626,7 @@ class LSGSelfAttention(BaseSelfAttention):
|
|
622 |
keys /= mask + 1e-8
|
623 |
values /= mask + 1e-8
|
624 |
|
625 |
-
mask =
|
626 |
|
627 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
628 |
|
@@ -988,7 +992,7 @@ class LSGCamembertModel(LSGCamembertPreTrainedModel, RobertaModel):
|
|
988 |
n, t = inputs_.size()[:2]
|
989 |
|
990 |
if attention_mask is None:
|
991 |
-
attention_mask = torch.ones(n, t, device=inputs_.device)
|
992 |
if self.mask_first_token:
|
993 |
attention_mask[:,0] = 0
|
994 |
|
@@ -1069,7 +1073,7 @@ class LSGCamembertModel(LSGCamembertPreTrainedModel, RobertaModel):
|
|
1069 |
)
|
1070 |
|
1071 |
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
1072 |
-
extended_attention_mask = (1.0 - extended_attention_mask) *
|
1073 |
|
1074 |
return extended_attention_mask
|
1075 |
|
|
|
182 |
|
183 |
# Add causal mask
|
184 |
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
|
185 |
+
causal_mask = torch.tril(
|
186 |
+
torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
|
187 |
+
diagonal=-1
|
188 |
+
)
|
189 |
+
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
190 |
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
191 |
|
192 |
del attention_mask
|
|
|
304 |
|
305 |
# Pad before block reshaping
|
306 |
if is_attn_mask:
|
307 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
308 |
hidden_states = hidden_states.transpose(-1, -2)
|
309 |
else:
|
310 |
pad_value = 0
|
|
|
337 |
|
338 |
# Pad before block reshaping
|
339 |
if is_attn_mask:
|
340 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
341 |
hidden_states = hidden_states.transpose(-1, -2)
|
342 |
else:
|
343 |
pad_value = 0
|
|
|
561 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
562 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
563 |
|
564 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
565 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
566 |
|
567 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
626 |
keys /= mask + 1e-8
|
627 |
values /= mask + 1e-8
|
628 |
|
629 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
630 |
|
631 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
632 |
|
|
|
992 |
n, t = inputs_.size()[:2]
|
993 |
|
994 |
if attention_mask is None:
|
995 |
+
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
996 |
if self.mask_first_token:
|
997 |
attention_mask[:,0] = 0
|
998 |
|
|
|
1073 |
)
|
1074 |
|
1075 |
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
1076 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(extended_attention_mask.dtype).min
|
1077 |
|
1078 |
return extended_attention_mask
|
1079 |
|