Changes in modelling_RW.py to be able to handle past_key_values for faster model generations

#85
by puru22 - opened
Files changed (1) hide show
  1. modelling_RW.py +69 -30
modelling_RW.py CHANGED
@@ -87,10 +87,18 @@ class RotaryEmbedding(torch.nn.Module):
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
 
 
 
 
 
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
94
 
95
 
96
  def _make_causal_mask(
@@ -100,10 +108,10 @@ def _make_causal_mask(
100
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
  seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
 
105
  if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
@@ -248,6 +256,7 @@ class Attention(nn.Module):
248
  head_mask: Optional[torch.Tensor] = None,
249
  use_cache: bool = False,
250
  output_attentions: bool = False,
 
251
  ):
252
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
253
 
@@ -264,31 +273,43 @@ class Attention(nn.Module):
264
  )
265
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
266
 
267
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
 
 
 
 
268
 
269
  if layer_past is not None:
270
  past_key, past_value = layer_past
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
 
274
  key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
278
 
279
  if use_cache is True:
280
- present = (key_layer, value_layer)
 
281
  else:
282
  present = None
283
-
284
  if alibi is None:
285
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
286
  key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
  value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
288
-
289
- attn_output = F.scaled_dot_product_attention(
290
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
291
- )
 
 
 
 
 
292
 
293
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
294
  x = x.permute(0, 2, 1, 3)
@@ -385,6 +406,7 @@ class DecoderLayer(nn.Module):
385
  head_mask: Optional[torch.Tensor] = None,
386
  use_cache: bool = False,
387
  output_attentions: bool = False,
 
388
  ):
389
 
390
  ln_attn = self.ln_attn(hidden_states)
@@ -401,6 +423,7 @@ class DecoderLayer(nn.Module):
401
  head_mask=head_mask,
402
  use_cache=use_cache,
403
  output_attentions=output_attentions,
 
404
  )
405
 
406
  attention_output = attn_outputs[0]
@@ -528,10 +551,10 @@ class RWModel(RWPreTrainedModel):
528
  device = attention_mask.device
529
  _, src_length = input_shape
530
 
531
- if src_length > 1:
532
- combined_attention_mask = _make_causal_mask(
533
- input_shape, device=device, past_key_values_length=past_key_values_length
534
- )
535
 
536
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
537
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
@@ -651,15 +674,28 @@ class RWModel(RWPreTrainedModel):
651
  head_mask[i],
652
  )
653
  else:
654
- outputs = block(
655
- hidden_states,
656
- layer_past=layer_past,
657
- attention_mask=causal_mask,
658
- head_mask=head_mask[i],
659
- use_cache=use_cache,
660
- output_attentions=output_attentions,
661
- alibi=alibi,
662
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
  hidden_states = outputs[0]
665
  if use_cache is True:
@@ -710,16 +746,19 @@ class RWForCausalLM(RWPreTrainedModel):
710
  **kwargs,
711
  ) -> dict:
712
  # only last token for input_ids if past is not None
713
- if past:
714
  input_ids = input_ids[:, -1].unsqueeze(-1)
715
-
716
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
717
- if past[0][0].shape[0] == input_ids.shape[0]:
718
- past = self._convert_to_rw_cache(past)
 
 
 
719
 
720
  return {
721
  "input_ids": input_ids,
722
- "past_key_values": past,
723
  "use_cache": kwargs.get("use_cache"),
724
  "attention_mask": attention_mask,
725
  }
 
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
+ def forward(self, q, k, past_seq_length=None):
91
+ if past_seq_length == None :
92
+ batch, seq_len, head_dim = q.shape
93
+ else :
94
+ # print("past_seq_length", past_seq_length)
95
+ batch, input_seq_len, head_dim = q.shape
96
+ seq_len = past_seq_length + input_seq_len
97
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
98
+ if past_seq_length != None :
99
+ return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
100
+ else :
101
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
102
 
103
 
104
  def _make_causal_mask(
 
108
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
109
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
110
  seq_ids = torch.arange(target_length, device=device)
111
+ mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
112
 
113
  if past_key_values_length > 0:
114
+ mask[:, :past_key_values_length] = True
115
 
116
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
117
  return expanded_mask
 
256
  head_mask: Optional[torch.Tensor] = None,
257
  use_cache: bool = False,
258
  output_attentions: bool = False,
259
+ layer_number = None
260
  ):
261
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
262
 
 
273
  )
274
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
275
 
276
+ if layer_past is not None :
277
+ past_key, past_value = layer_past
278
+ past_kv_length = past_key.shape[2]
279
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
280
+ else :
281
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
282
 
283
  if layer_past is not None:
284
  past_key, past_value = layer_past
285
  # concatenate along seq_length dimension:
286
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
287
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
288
+ past_key = past_key.permute(0, 2, 1)
289
  key_layer = torch.cat((past_key, key_layer), dim=1)
290
  value_layer = torch.cat((past_value, value_layer), dim=1)
291
 
292
  _, kv_length, _ = key_layer.shape
293
 
294
  if use_cache is True:
295
+ key_layer_permute = key_layer.permute(0, 2, 1)
296
+ present = (key_layer_permute, value_layer)
297
  else:
298
  present = None
299
+
300
  if alibi is None:
301
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
302
  key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
303
  value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
304
+
305
+ if attention_mask is not None :
306
+ attn_output = F.scaled_dot_product_attention(
307
+ query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
308
+ )
309
+ else :
310
+ attn_output = F.scaled_dot_product_attention(
311
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
312
+ )
313
 
314
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
315
  x = x.permute(0, 2, 1, 3)
 
406
  head_mask: Optional[torch.Tensor] = None,
407
  use_cache: bool = False,
408
  output_attentions: bool = False,
409
+ layer_number = None
410
  ):
411
 
412
  ln_attn = self.ln_attn(hidden_states)
 
423
  head_mask=head_mask,
424
  use_cache=use_cache,
425
  output_attentions=output_attentions,
426
+ layer_number=layer_number
427
  )
428
 
429
  attention_output = attn_outputs[0]
 
551
  device = attention_mask.device
552
  _, src_length = input_shape
553
 
554
+ # if src_length > 1:
555
+ combined_attention_mask = _make_causal_mask(
556
+ input_shape, device=device, past_key_values_length=past_key_values_length
557
+ )
558
 
559
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
560
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
 
674
  head_mask[i],
675
  )
676
  else:
677
+ if i==0 :
678
+ outputs = block(
679
+ hidden_states,
680
+ layer_past=layer_past,
681
+ attention_mask=causal_mask,
682
+ head_mask=head_mask[i],
683
+ use_cache=use_cache,
684
+ output_attentions=output_attentions,
685
+ alibi=alibi,
686
+ layer_number=0
687
+ )
688
+ else :
689
+ outputs = block(
690
+ hidden_states,
691
+ layer_past=layer_past,
692
+ attention_mask=causal_mask,
693
+ head_mask=head_mask[i],
694
+ use_cache=use_cache,
695
+ output_attentions=output_attentions,
696
+ alibi=alibi,
697
+ )
698
+
699
 
700
  hidden_states = outputs[0]
701
  if use_cache is True:
 
746
  **kwargs,
747
  ) -> dict:
748
  # only last token for input_ids if past is not None
749
+ if kwargs.get("past_key_values", None) :
750
  input_ids = input_ids[:, -1].unsqueeze(-1)
751
+ past_key_values = kwargs["past_key_values"]
752
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
753
+ # if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
754
+ # past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
755
+ # past_key_values = kwargs["past_key_values"]
756
+ else :
757
+ past_key_values = None
758
 
759
  return {
760
  "input_ids": input_ids,
761
+ "past_key_values": past_key_values,
762
  "use_cache": kwargs.get("use_cache"),
763
  "attention_mask": attention_mask,
764
  }