hemanham commited on
Commit
0d80022
1 Parent(s): 4e4be00

remove imports and use of flash attention

Browse files

- to ensure transformers's dynamic_module_utils does not complain when flash attention is not installed

Files changed (2) hide show
  1. modeling_minicpmv.py +2 -6
  2. modeling_navit_siglip.py +5 -205
modeling_minicpmv.py CHANGED
@@ -32,12 +32,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
32
  self.terminators = ['<|im_end|>', '<|endoftext|>']
33
 
34
  def init_vision_module(self):
35
- # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
36
- if self.config._attn_implementation == 'flash_attention_2':
37
- self.config.vision_config._attn_implementation = 'flash_attention_2'
38
- else:
39
- # not suport sdpa
40
- self.config.vision_config._attn_implementation = 'eager'
41
  model = SiglipVisionTransformer(self.config.vision_config)
42
  if self.config.drop_vision_last_layer:
43
  model.encoder.layers = model.encoder.layers[:-1]
 
32
  self.terminators = ['<|im_end|>', '<|endoftext|>']
33
 
34
  def init_vision_module(self):
35
+ # not suport sdpa
36
+ self.config.vision_config._attn_implementation = 'eager'
 
 
 
 
37
  model = SiglipVisionTransformer(self.config.vision_config)
38
  if self.config.drop_vision_last_layer:
39
  model.encoder.layers = model.encoder.layers[:-1]
modeling_navit_siglip.py CHANGED
@@ -38,7 +38,6 @@ from transformers.utils import (
38
  ModelOutput,
39
  add_start_docstrings,
40
  add_start_docstrings_to_model_forward,
41
- is_flash_attn_2_available,
42
  logging,
43
  replace_return_docstrings,
44
  )
@@ -142,10 +141,6 @@ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
142
  # See all SigLIP models at https://huggingface.co/models?filter=siglip
143
  ]
144
 
145
- if is_flash_attn_2_available():
146
- from flash_attn import flash_attn_func, flash_attn_varlen_func
147
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
148
-
149
 
150
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
151
  def _get_unpad_data(attention_mask):
@@ -430,193 +425,6 @@ class SiglipAttention(nn.Module):
430
  return attn_output, attn_weights
431
 
432
 
433
- class SiglipFlashAttention2(SiglipAttention):
434
- """
435
- Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
436
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
437
- flash attention and deal with padding tokens in case the input contains any of them.
438
- """
439
-
440
- def __init__(self, *args, **kwargs):
441
- super().__init__(*args, **kwargs)
442
- self.is_causal = False # Hack to make sure we don't use a causal mask
443
-
444
- def forward(
445
- self,
446
- hidden_states: torch.Tensor,
447
- attention_mask: Optional[torch.LongTensor] = None,
448
- position_ids: Optional[torch.LongTensor] = None,
449
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
450
- output_attentions: bool = False,
451
- use_cache: bool = False,
452
- **kwargs,
453
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
454
- output_attentions = False
455
-
456
- bsz, q_len, _ = hidden_states.size()
457
-
458
- query_states = self.q_proj(hidden_states)
459
- key_states = self.k_proj(hidden_states)
460
- value_states = self.v_proj(hidden_states)
461
-
462
- # Flash attention requires the input to have the shape
463
- # batch_size x seq_length x head_dim x hidden_dim
464
- # therefore we just need to keep the original shape
465
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
466
- key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
467
- value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
468
-
469
- kv_seq_len = key_states.shape[-2]
470
- if past_key_value is not None:
471
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
472
- # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
473
- # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
474
-
475
- # if past_key_value is not None:
476
- # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
477
- # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
478
-
479
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
480
- # to be able to avoid many of these transpose/reshape/view.
481
- query_states = query_states.transpose(1, 2)
482
- key_states = key_states.transpose(1, 2)
483
- value_states = value_states.transpose(1, 2)
484
-
485
- dropout_rate = self.dropout if self.training else 0.0
486
-
487
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
488
- # therefore the input hidden states gets silently casted in float32. Hence, we need
489
- # cast them back in the correct dtype just to be sure everything works as expected.
490
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
491
- # in fp32. (LlamaRMSNorm handles it correctly)
492
-
493
- input_dtype = query_states.dtype
494
- if input_dtype == torch.float32:
495
- if torch.is_autocast_enabled():
496
- target_dtype = torch.get_autocast_gpu_dtype()
497
- # Handle the case where the model is quantized
498
- elif hasattr(self.config, "_pre_quantization_dtype"):
499
- target_dtype = self.config._pre_quantization_dtype
500
- else:
501
- target_dtype = self.q_proj.weight.dtype
502
-
503
- logger.warning_once(
504
- "The input hidden states seems to be silently casted in float32, this might be related to the fact"
505
- " you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
506
- f" {target_dtype}."
507
- )
508
-
509
- query_states = query_states.to(target_dtype)
510
- key_states = key_states.to(target_dtype)
511
- value_states = value_states.to(target_dtype)
512
-
513
- attn_output = self._flash_attention_forward(
514
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
515
- )
516
-
517
- attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
518
- attn_output = self.out_proj(attn_output)
519
-
520
- if not output_attentions:
521
- attn_weights = None
522
-
523
- return attn_output, attn_weights
524
-
525
- def _flash_attention_forward(
526
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
527
- ):
528
- """
529
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
530
- first unpad the input, then computes the attention scores and pad the final attention scores.
531
- Args:
532
- query_states (`torch.Tensor`):
533
- Input query states to be passed to Flash Attention API
534
- key_states (`torch.Tensor`):
535
- Input key states to be passed to Flash Attention API
536
- value_states (`torch.Tensor`):
537
- Input value states to be passed to Flash Attention API
538
- attention_mask (`torch.Tensor`):
539
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
540
- position of padding tokens and 1 for the position of non-padding tokens.
541
- dropout (`int`, *optional*):
542
- Attention dropout
543
- softmax_scale (`float`, *optional*):
544
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
545
- """
546
-
547
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
548
- causal = self.is_causal and query_length != 1
549
-
550
- # Contains at least one padding token in the sequence
551
- if attention_mask is not None:
552
- batch_size = query_states.shape[0]
553
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
554
- query_states, key_states, value_states, attention_mask, query_length
555
- )
556
-
557
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
558
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
559
-
560
- attn_output_unpad = flash_attn_varlen_func(
561
- query_states,
562
- key_states,
563
- value_states,
564
- cu_seqlens_q=cu_seqlens_q,
565
- cu_seqlens_k=cu_seqlens_k,
566
- max_seqlen_q=max_seqlen_in_batch_q,
567
- max_seqlen_k=max_seqlen_in_batch_k,
568
- dropout_p=dropout,
569
- softmax_scale=softmax_scale,
570
- causal=causal,
571
- )
572
-
573
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
574
- else:
575
- attn_output = flash_attn_func(
576
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
577
- )
578
-
579
- return attn_output
580
-
581
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
582
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
583
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
584
-
585
- key_layer = index_first_axis(
586
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
587
- )
588
- value_layer = index_first_axis(
589
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
590
- )
591
- if query_length == kv_seq_len:
592
- query_layer = index_first_axis(
593
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
594
- )
595
- cu_seqlens_q = cu_seqlens_k
596
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
597
- indices_q = indices_k
598
- elif query_length == 1:
599
- max_seqlen_in_batch_q = 1
600
- cu_seqlens_q = torch.arange(
601
- batch_size + 1, dtype=torch.int32, device=query_layer.device
602
- ) # There is a memcpy here, that is very bad.
603
- indices_q = cu_seqlens_q[:-1]
604
- query_layer = query_layer.squeeze(1)
605
- else:
606
- # The -q_len: slice assumes left padding.
607
- attention_mask = attention_mask[:, -query_length:]
608
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
609
-
610
- return (
611
- query_layer,
612
- key_layer,
613
- value_layer,
614
- indices_q,
615
- (cu_seqlens_q, cu_seqlens_k),
616
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
617
- )
618
-
619
-
620
  # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
621
  class SiglipMLP(nn.Module):
622
  def __init__(self, config):
@@ -638,12 +446,8 @@ class SiglipEncoderLayer(nn.Module):
638
  def __init__(self, config: SiglipVisionConfig):
639
  super().__init__()
640
  self.embed_dim = config.hidden_size
641
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
642
- self.self_attn = (
643
- SiglipAttention(config)
644
- if not self._use_flash_attention_2
645
- else SiglipFlashAttention2(config)
646
- )
647
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
648
  self.mlp = SiglipMLP(config)
649
  self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -860,7 +664,7 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
860
  self.embeddings = SiglipVisionEmbeddings(config)
861
  self.encoder = SiglipEncoder(config)
862
  self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
863
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
864
 
865
  # Initialize weights and apply final processing
866
  self.post_init()
@@ -909,11 +713,7 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
909
  if not torch.any(~patch_attention_mask):
910
  attention_mask=None
911
  else:
912
- attention_mask = (
913
- _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
914
- if not self._use_flash_attention_2
915
- else patch_attention_mask
916
- )
917
 
918
  encoder_outputs = self.encoder(
919
  inputs_embeds=hidden_states,
@@ -934,4 +734,4 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
934
  pooler_output=None,
935
  hidden_states=encoder_outputs.hidden_states,
936
  attentions=encoder_outputs.attentions,
937
- )
 
38
  ModelOutput,
39
  add_start_docstrings,
40
  add_start_docstrings_to_model_forward,
 
41
  logging,
42
  replace_return_docstrings,
43
  )
 
141
  # See all SigLIP models at https://huggingface.co/models?filter=siglip
142
  ]
143
 
 
 
 
 
144
 
145
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
146
  def _get_unpad_data(attention_mask):
 
425
  return attn_output, attn_weights
426
 
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
429
  class SiglipMLP(nn.Module):
430
  def __init__(self, config):
 
446
  def __init__(self, config: SiglipVisionConfig):
447
  super().__init__()
448
  self.embed_dim = config.hidden_size
449
+ self._use_flash_attention_2 = False
450
+ self.self_attn = SiglipAttention(config)
 
 
 
 
451
  self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
452
  self.mlp = SiglipMLP(config)
453
  self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
 
664
  self.embeddings = SiglipVisionEmbeddings(config)
665
  self.encoder = SiglipEncoder(config)
666
  self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
667
+ self._use_flash_attention_2 = False
668
 
669
  # Initialize weights and apply final processing
670
  self.post_init()
 
713
  if not torch.any(~patch_attention_mask):
714
  attention_mask=None
715
  else:
716
+ attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
 
 
 
 
717
 
718
  encoder_outputs = self.encoder(
719
  inputs_embeds=hidden_states,
 
734
  pooler_output=None,
735
  hidden_states=encoder_outputs.hidden_states,
736
  attentions=encoder_outputs.attentions,
737
+ )