kz919 commited on
Commit
d760a3b
1 Parent(s): 6917dd3

Update modeling_sliding_llama.py

Browse files
Files changed (1) hide show
  1. modeling_sliding_llama.py +21 -12
modeling_sliding_llama.py CHANGED
@@ -54,15 +54,22 @@ logger = logging.get_logger(__name__)
54
 
55
  _CONFIG_FOR_DOC = "LlamaConfig"
56
 
57
- def attn_causal(b, h, q_idx, kv_idx):
58
  causal_mask = q_idx >= kv_idx
59
- return causal_mask
 
 
 
60
 
61
- def stream_attn_causal(b, h, q_idx, kv_idx, attn_sink_size, sliding_window):
62
- causal_mask = q_idx >= kv_idx
63
- window_mask = q_idx - kv_idx <= sliding_window
64
  sink_mask = kv_idx < attn_sink_size
65
- return causal_mask & (window_mask | sink_mask)
 
 
 
 
 
66
 
67
 
68
  class LlamaRMSNorm(nn.Module):
@@ -343,17 +350,19 @@ class LlamaStreamingFlexAttention(LlamaAttention):
343
 
344
  key_states = repeat_kv(key_states, self.num_key_value_groups)
345
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
346
  sliding_window_size = self.config.sliding_windows[self.layer_idx]
347
  if sliding_window_size > 0:
348
  block_mask = create_block_mask(
349
- lambda b, h, q_idx, kv_idx: stream_attn_causal(b, h, q_idx, kv_idx, 4, sliding_window_size),
350
- B=None, H=None, Q_LEN=query_states.shape[-2], KV_LEN=key_states.shape[-2]
351
  )
352
  else:
353
  block_mask = create_block_mask(
354
- lambda b, h, q_idx, kv_idx: attn_causal(b, h, q_idx, kv_idx),
355
- B=None, H=None, Q_LEN=query_states.shape[-2], KV_LEN=key_states.shape[-2]
356
  )
 
357
  attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
358
 
359
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -363,9 +372,9 @@ class LlamaStreamingFlexAttention(LlamaAttention):
363
  )
364
 
365
  attn_output = attn_output.transpose(1, 2).contiguous()
366
-
367
  attn_output = attn_output.reshape(bsz, q_len, -1)
368
-
369
  if self.config.pretraining_tp > 1:
370
  attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
371
  o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
 
54
 
55
  _CONFIG_FOR_DOC = "LlamaConfig"
56
 
57
+ def attn_causal(b, h, q_idx, kv_idx, q_len):
58
  causal_mask = q_idx >= kv_idx
59
+ if q_len > 1:
60
+ return causal_mask
61
+ else:
62
+ return q_idx >= -1
63
 
64
+ def stream_attn_causal(b, h, q_idx, kv_idx, attn_sink_size, sliding_window, q_len, kv_len):
65
+ causal_mask = q_idx >= kv_idx
 
66
  sink_mask = kv_idx < attn_sink_size
67
+ if q_len > 1:
68
+ window_mask = q_idx - kv_idx <= sliding_window
69
+ return causal_mask & (window_mask | sink_mask)
70
+ else:
71
+ window_mask = kv_len - kv_idx <= sliding_window
72
+ return window_mask | sink_mask
73
 
74
 
75
  class LlamaRMSNorm(nn.Module):
 
350
 
351
  key_states = repeat_kv(key_states, self.num_key_value_groups)
352
  value_states = repeat_kv(value_states, self.num_key_value_groups)
353
+
354
  sliding_window_size = self.config.sliding_windows[self.layer_idx]
355
  if sliding_window_size > 0:
356
  block_mask = create_block_mask(
357
+ lambda b, h, q_idx, kv_idx: stream_attn_causal(b, h, q_idx, kv_idx, 4, sliding_window_size, q_len, key_states.shape[-2]),
358
+ B=None, H=None, Q_LEN=q_len, KV_LEN=key_states.shape[-2]
359
  )
360
  else:
361
  block_mask = create_block_mask(
362
+ lambda b, h, q_idx, kv_idx: attn_causal(b, h, q_idx, kv_idx, q_len),
363
+ B=None, H=None, Q_LEN=q_len, KV_LEN=key_states.shape[-2]
364
  )
365
+
366
  attn_output = flex_attention(query_states, key_states, value_states, block_mask=block_mask)
367
 
368
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
372
  )
373
 
374
  attn_output = attn_output.transpose(1, 2).contiguous()
375
+
376
  attn_output = attn_output.reshape(bsz, q_len, -1)
377
+
378
  if self.config.pretraining_tp > 1:
379
  attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
380
  o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)