Update modeling_sliding_llama.py
Browse files- modeling_sliding_llama.py +21 -12
modeling_sliding_llama.py
CHANGED
@@ -54,15 +54,22 @@ logger = logging.get_logger(__name__)
|
|
54 |
|
55 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
56 |
|
57 |
-
def attn_causal(b, h, q_idx, kv_idx):
|
58 |
causal_mask = q_idx >= kv_idx
|
59 |
-
|
|
|
|
|
|
|
60 |
|
61 |
-
def stream_attn_causal(b, h, q_idx, kv_idx, attn_sink_size, sliding_window):
|
62 |
-
causal_mask = q_idx >= kv_idx
|
63 |
-
window_mask = q_idx - kv_idx <= sliding_window
|
64 |
sink_mask = kv_idx < attn_sink_size
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
|
68 |
class LlamaRMSNorm(nn.Module):
|
@@ -343,17 +350,19 @@ class LlamaStreamingFlexAttention(LlamaAttention):
|
|
343 |
|
344 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
345 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
346 |
sliding_window_size = self.config.sliding_windows[self.layer_idx]
|
347 |
if sliding_window_size > 0:
|
348 |
block_mask = create_block_mask(
|
349 |
-
lambda b, h, q_idx, kv_idx: stream_attn_causal(b, h, q_idx, kv_idx, 4, sliding_window_size),
|
350 |
-
B=None, H=None, Q_LEN=
|
351 |
)
|
352 |
else:
|
353 |
block_mask = create_block_mask(
|
354 |
-
lambda b, h, q_idx, kv_idx: attn_causal(b, h, q_idx, kv_idx),
|
355 |
-
B=None, H=None, Q_LEN=
|
356 |
)
|
|
|
357 |
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
358 |
|
359 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
@@ -363,9 +372,9 @@ class LlamaStreamingFlexAttention(LlamaAttention):
|
|
363 |
)
|
364 |
|
365 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
366 |
-
|
367 |
attn_output = attn_output.reshape(bsz, q_len, -1)
|
368 |
-
|
369 |
if self.config.pretraining_tp > 1:
|
370 |
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
371 |
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
|
54 |
|
55 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
56 |
|
57 |
+
def attn_causal(b, h, q_idx, kv_idx, q_len):
|
58 |
causal_mask = q_idx >= kv_idx
|
59 |
+
if q_len > 1:
|
60 |
+
return causal_mask
|
61 |
+
else:
|
62 |
+
return q_idx >= -1
|
63 |
|
64 |
+
def stream_attn_causal(b, h, q_idx, kv_idx, attn_sink_size, sliding_window, q_len, kv_len):
|
65 |
+
causal_mask = q_idx >= kv_idx
|
|
|
66 |
sink_mask = kv_idx < attn_sink_size
|
67 |
+
if q_len > 1:
|
68 |
+
window_mask = q_idx - kv_idx <= sliding_window
|
69 |
+
return causal_mask & (window_mask | sink_mask)
|
70 |
+
else:
|
71 |
+
window_mask = kv_len - kv_idx <= sliding_window
|
72 |
+
return window_mask | sink_mask
|
73 |
|
74 |
|
75 |
class LlamaRMSNorm(nn.Module):
|
|
|
350 |
|
351 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
352 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
353 |
+
|
354 |
sliding_window_size = self.config.sliding_windows[self.layer_idx]
|
355 |
if sliding_window_size > 0:
|
356 |
block_mask = create_block_mask(
|
357 |
+
lambda b, h, q_idx, kv_idx: stream_attn_causal(b, h, q_idx, kv_idx, 4, sliding_window_size, q_len, key_states.shape[-2]),
|
358 |
+
B=None, H=None, Q_LEN=q_len, KV_LEN=key_states.shape[-2]
|
359 |
)
|
360 |
else:
|
361 |
block_mask = create_block_mask(
|
362 |
+
lambda b, h, q_idx, kv_idx: attn_causal(b, h, q_idx, kv_idx, q_len),
|
363 |
+
B=None, H=None, Q_LEN=q_len, KV_LEN=key_states.shape[-2]
|
364 |
)
|
365 |
+
|
366 |
attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
|
367 |
|
368 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
|
372 |
)
|
373 |
|
374 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
375 |
+
|
376 |
attn_output = attn_output.reshape(bsz, q_len, -1)
|
377 |
+
|
378 |
if self.config.pretraining_tp > 1:
|
379 |
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
380 |
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|