teowu commited on
Commit
58b10f4
1 Parent(s): c6e7a7a

Upload modeling_llama2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_llama2.py +334 -6
modeling_llama2.py CHANGED
@@ -18,6 +18,7 @@ sys.path.insert(0, dir_path)
18
 
19
  import transformers
20
  from transformers.models.llama.modeling_llama import *
 
21
  from transformers.configuration_utils import PretrainedConfig
22
  from transformers.utils import logging
23
 
@@ -54,9 +55,18 @@ class MultiwayNetwork(nn.Module):
54
  class LlamaAttention(nn.Module):
55
  """Multi-headed attention from 'Attention Is All You Need' paper"""
56
 
57
- def __init__(self, config: LlamaConfig):
58
  super().__init__()
59
  self.config = config
 
 
 
 
 
 
 
 
 
60
  self.hidden_size = config.hidden_size
61
  self.num_heads = config.num_attention_heads
62
  self.head_dim = self.hidden_size // self.num_heads
@@ -64,6 +74,7 @@ class LlamaAttention(nn.Module):
64
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
65
  self.max_position_embeddings = config.max_position_embeddings
66
  self.rope_theta = config.rope_theta
 
67
 
68
  if (self.head_dim * self.num_heads) != self.hidden_size:
69
  raise ValueError(
@@ -182,14 +193,314 @@ class LlamaAttention(nn.Module):
182
  attn_weights = None
183
 
184
  return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
 
 
 
 
 
 
187
 
188
  class LlamaDecoderLayer(nn.Module):
189
- def __init__(self, config: LlamaConfig):
190
  super().__init__()
191
  self.hidden_size = config.hidden_size
192
  self.self_attn = LlamaAttention(config=config)
 
193
  self.mlp = LlamaMLP(config)
194
  self.input_layernorm = MultiwayNetwork(module_provider=partial(
195
  LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
@@ -285,7 +596,7 @@ def model_forward(
285
  batch_size, seq_length, _ = inputs_embeds.shape
286
  else:
287
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
288
-
289
  seq_length_with_past = seq_length
290
  past_key_values_length = 0
291
 
@@ -309,9 +620,24 @@ def model_forward(
309
  attention_mask = torch.ones(
310
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
311
  )
312
- attention_mask = self._prepare_decoder_attention_mask(
313
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
314
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  hidden_states = inputs_embeds
317
 
@@ -482,6 +808,8 @@ def causal_model_forward(
482
  def replace_llama_modality_adaptive():
483
  transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
484
  transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
 
 
485
  transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
486
  transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
487
  transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
 
18
 
19
  import transformers
20
  from transformers.models.llama.modeling_llama import *
21
+ from transformers.models.llama.modeling_llama import *
22
  from transformers.configuration_utils import PretrainedConfig
23
  from transformers.utils import logging
24
 
 
55
  class LlamaAttention(nn.Module):
56
  """Multi-headed attention from 'Attention Is All You Need' paper"""
57
 
58
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
59
  super().__init__()
60
  self.config = config
61
+ self.layer_idx = layer_idx
62
+ if layer_idx is None:
63
+ logger.warning_once(
64
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
65
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
66
+ "when creating this class."
67
+ )
68
+
69
+ self.attention_dropout = config.attention_dropout
70
  self.hidden_size = config.hidden_size
71
  self.num_heads = config.num_attention_heads
72
  self.head_dim = self.hidden_size // self.num_heads
 
74
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
75
  self.max_position_embeddings = config.max_position_embeddings
76
  self.rope_theta = config.rope_theta
77
+ self.is_causal = True
78
 
79
  if (self.head_dim * self.num_heads) != self.hidden_size:
80
  raise ValueError(
 
193
  attn_weights = None
194
 
195
  return attn_output, attn_weights, past_key_value
196
+
197
+
198
+ class LlamaFlashAttention2(LlamaAttention):
199
+ """
200
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
201
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
202
+ flash attention and deal with padding tokens in case the input contains any of them.
203
+ """
204
+
205
+ def __init__(self, *args, **kwargs):
206
+ super().__init__(*args, **kwargs)
207
+
208
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
209
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
210
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
211
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ modality_indicators: torch.Tensor,
217
+ attention_mask: Optional[torch.LongTensor] = None,
218
+ position_ids: Optional[torch.LongTensor] = None,
219
+ past_key_value: Optional[Cache] = None,
220
+ output_attentions: bool = False,
221
+ use_cache: bool = False,
222
+ **kwargs,
223
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
224
+ # LlamaFlashAttention2 attention does not support output_attentions
225
+ if "padding_mask" in kwargs:
226
+ warnings.warn(
227
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
228
+ )
229
+
230
+ # overwrite attention_mask with padding_mask
231
+ attention_mask = kwargs.pop("padding_mask")
232
+
233
+ output_attentions = False
234
+
235
+ bsz, q_len, _ = hidden_states.size()
236
+
237
+ query_states = self.q_proj(hidden_states)
238
+ key_states = self.k_proj(hidden_states, modality_indicators)
239
+ value_states = self.v_proj(hidden_states, modality_indicators)
240
+
241
+ # Flash attention requires the input to have the shape
242
+ # batch_size x seq_length x head_dim x hidden_dim
243
+ # therefore we just need to keep the original shape
244
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
245
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
246
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
247
+
248
+ kv_seq_len = key_states.shape[-2]
249
+ if past_key_value is not None:
250
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
251
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
252
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
253
+
254
+ if past_key_value is not None:
255
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
256
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
257
+
258
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
259
+ # to be able to avoid many of these transpose/reshape/view.
260
+ query_states = query_states.transpose(1, 2)
261
+ key_states = key_states.transpose(1, 2)
262
+ value_states = value_states.transpose(1, 2)
263
+
264
+ dropout_rate = self.attention_dropout if self.training else 0.0
265
+
266
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
267
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
268
+ # cast them back in the correct dtype just to be sure everything works as expected.
269
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
270
+ # in fp32. (LlamaRMSNorm handles it correctly)
271
+
272
+ input_dtype = query_states.dtype
273
+ if input_dtype == torch.float32:
274
+ if torch.is_autocast_enabled():
275
+ target_dtype = torch.get_autocast_gpu_dtype()
276
+ # Handle the case where the model is quantized
277
+ elif hasattr(self.config, "_pre_quantization_dtype"):
278
+ target_dtype = self.config._pre_quantization_dtype
279
+ else:
280
+ target_dtype = self.q_proj.weight.dtype
281
+
282
+ logger.warning_once(
283
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
284
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
285
+ f" {target_dtype}."
286
+ )
287
+
288
+ query_states = query_states.to(target_dtype)
289
+ key_states = key_states.to(target_dtype)
290
+ value_states = value_states.to(target_dtype)
291
+
292
+ attn_output = self._flash_attention_forward(
293
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
294
+ )
295
+
296
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
297
+ attn_output = self.o_proj(attn_output)
298
+
299
+ if not output_attentions:
300
+ attn_weights = None
301
+
302
+ return attn_output, attn_weights, past_key_value
303
+
304
+ def _flash_attention_forward(
305
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
306
+ ):
307
+ """
308
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
309
+ first unpad the input, then computes the attention scores and pad the final attention scores.
310
+
311
+ Args:
312
+ query_states (`torch.Tensor`):
313
+ Input query states to be passed to Flash Attention API
314
+ key_states (`torch.Tensor`):
315
+ Input key states to be passed to Flash Attention API
316
+ value_states (`torch.Tensor`):
317
+ Input value states to be passed to Flash Attention API
318
+ attention_mask (`torch.Tensor`):
319
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
320
+ position of padding tokens and 1 for the position of non-padding tokens.
321
+ dropout (`int`, *optional*):
322
+ Attention dropout
323
+ softmax_scale (`float`, *optional*):
324
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
325
+ """
326
+ if not self._flash_attn_uses_top_left_mask:
327
+ causal = self.is_causal
328
+ else:
329
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
330
+ causal = self.is_causal and query_length != 1
331
+
332
+ # Contains at least one padding token in the sequence
333
+ if attention_mask is not None:
334
+ batch_size = query_states.shape[0]
335
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
336
+ query_states, key_states, value_states, attention_mask, query_length
337
+ )
338
+
339
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
340
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
341
+
342
+ attn_output_unpad = flash_attn_varlen_func(
343
+ query_states,
344
+ key_states,
345
+ value_states,
346
+ cu_seqlens_q=cu_seqlens_q,
347
+ cu_seqlens_k=cu_seqlens_k,
348
+ max_seqlen_q=max_seqlen_in_batch_q,
349
+ max_seqlen_k=max_seqlen_in_batch_k,
350
+ dropout_p=dropout,
351
+ softmax_scale=softmax_scale,
352
+ causal=causal,
353
+ )
354
+
355
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
356
+ else:
357
+ attn_output = flash_attn_func(
358
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
359
+ )
360
+
361
+ return attn_output
362
+
363
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
364
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
365
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
366
+
367
+ key_layer = index_first_axis(
368
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
369
+ )
370
+ value_layer = index_first_axis(
371
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
372
+ )
373
+ if query_length == kv_seq_len:
374
+ query_layer = index_first_axis(
375
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
376
+ )
377
+ cu_seqlens_q = cu_seqlens_k
378
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
379
+ indices_q = indices_k
380
+ elif query_length == 1:
381
+ max_seqlen_in_batch_q = 1
382
+ cu_seqlens_q = torch.arange(
383
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
384
+ ) # There is a memcpy here, that is very bad.
385
+ indices_q = cu_seqlens_q[:-1]
386
+ query_layer = query_layer.squeeze(1)
387
+ else:
388
+ # The -q_len: slice assumes left padding.
389
+ attention_mask = attention_mask[:, -query_length:]
390
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
391
+
392
+ return (
393
+ query_layer,
394
+ key_layer,
395
+ value_layer,
396
+ indices_q,
397
+ (cu_seqlens_q, cu_seqlens_k),
398
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
399
+ )
400
+
401
+
402
+ class LlamaSdpaAttention(LlamaAttention):
403
+ """
404
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
405
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
406
+ SDPA API.
407
+ """
408
+
409
+ # Adapted from LlamaAttention.forward
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ modality_indicators: torch.Tensor,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ position_ids: Optional[torch.LongTensor] = None,
416
+ past_key_value: Optional[Cache] = None,
417
+ output_attentions: bool = False,
418
+ use_cache: bool = False,
419
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
420
+ if output_attentions:
421
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
422
+ logger.warning_once(
423
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
424
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
425
+ )
426
+ return super().forward(
427
+ hidden_states=hidden_states,
428
+ modality_indicators=modality_indicators,
429
+ attention_mask=attention_mask,
430
+ position_ids=position_ids,
431
+ past_key_value=past_key_value,
432
+ output_attentions=output_attentions,
433
+ use_cache=use_cache,
434
+ )
435
+
436
+ bsz, q_len, _ = hidden_states.size()
437
+
438
+ query_states = self.q_proj(hidden_states)
439
+ key_states = self.k_proj(hidden_states, modality_indicators)
440
+ value_states = self.v_proj(hidden_states, modality_indicators)
441
+
442
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
443
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
444
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
445
+
446
+ kv_seq_len = key_states.shape[-2]
447
+ if past_key_value is not None:
448
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
449
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
450
+
451
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
452
+
453
+ if past_key_value is not None:
454
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
455
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
456
+
457
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
458
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
459
+
460
+ if attention_mask is not None:
461
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
462
+ raise ValueError(
463
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
464
+ )
465
+
466
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
467
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
468
+ if query_states.device.type == "cuda" and attention_mask is not None:
469
+ query_states = query_states.contiguous()
470
+ key_states = key_states.contiguous()
471
+ value_states = value_states.contiguous()
472
+
473
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
474
+ query_states,
475
+ key_states,
476
+ value_states,
477
+ attn_mask=attention_mask,
478
+ dropout_p=self.attention_dropout if self.training else 0.0,
479
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
480
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
481
+ )
482
+
483
+ attn_output = attn_output.transpose(1, 2).contiguous()
484
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
485
+
486
+ attn_output = self.o_proj(attn_output)
487
+
488
+ return attn_output, None, past_key_value
489
+
490
 
491
 
492
+ LLAMA_ATTENTION_CLASSES = {
493
+ "eager": LlamaAttention,
494
+ "flash_attention_2": LlamaFlashAttention2,
495
+ "sdpa": LlamaSdpaAttention,
496
+ }
497
 
498
  class LlamaDecoderLayer(nn.Module):
499
+ def __init__(self, config: LlamaConfig, layer_idx):
500
  super().__init__()
501
  self.hidden_size = config.hidden_size
502
  self.self_attn = LlamaAttention(config=config)
503
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
504
  self.mlp = LlamaMLP(config)
505
  self.input_layernorm = MultiwayNetwork(module_provider=partial(
506
  LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
 
596
  batch_size, seq_length, _ = inputs_embeds.shape
597
  else:
598
  raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
599
+
600
  seq_length_with_past = seq_length
601
  past_key_values_length = 0
602
 
 
620
  attention_mask = torch.ones(
621
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
622
  )
623
+
624
+ if self._use_flash_attention_2:
625
+ # 2d mask is passed through the layers
626
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
627
+ elif self._use_sdpa and not output_attentions:
628
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
629
+ # the manual implementation that requires a 4D causal mask in all cases.
630
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
631
+ attention_mask,
632
+ (batch_size, seq_length),
633
+ inputs_embeds,
634
+ past_key_values_length,
635
+ )
636
+ else:
637
+ # 4d mask is passed through the layers
638
+ attention_mask = _prepare_4d_causal_attention_mask(
639
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
640
+ )
641
 
642
  hidden_states = inputs_embeds
643
 
 
808
  def replace_llama_modality_adaptive():
809
  transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
810
  transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
811
+ transformers.models.llama.modeling_llama.LlamaFlashAttention2 = LlamaFlashAttention2
812
+ transformers.models.llama.modeling_llama.LlamaSdpaAttention = LlamaSdpaAttention
813
  transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
814
  transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
815
  transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward