remove imports and use of flash attention
Browse files- to ensure transformers's dynamic_module_utils does not complain when flash attention is not installed
- modeling_minicpmv.py +2 -6
- modeling_navit_siglip.py +5 -205
modeling_minicpmv.py
CHANGED
@@ -32,12 +32,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
32 |
self.terminators = ['<|im_end|>', '<|endoftext|>']
|
33 |
|
34 |
def init_vision_module(self):
|
35 |
-
#
|
36 |
-
|
37 |
-
self.config.vision_config._attn_implementation = 'flash_attention_2'
|
38 |
-
else:
|
39 |
-
# not suport sdpa
|
40 |
-
self.config.vision_config._attn_implementation = 'eager'
|
41 |
model = SiglipVisionTransformer(self.config.vision_config)
|
42 |
if self.config.drop_vision_last_layer:
|
43 |
model.encoder.layers = model.encoder.layers[:-1]
|
|
|
32 |
self.terminators = ['<|im_end|>', '<|endoftext|>']
|
33 |
|
34 |
def init_vision_module(self):
|
35 |
+
# not suport sdpa
|
36 |
+
self.config.vision_config._attn_implementation = 'eager'
|
|
|
|
|
|
|
|
|
37 |
model = SiglipVisionTransformer(self.config.vision_config)
|
38 |
if self.config.drop_vision_last_layer:
|
39 |
model.encoder.layers = model.encoder.layers[:-1]
|
modeling_navit_siglip.py
CHANGED
@@ -38,7 +38,6 @@ from transformers.utils import (
|
|
38 |
ModelOutput,
|
39 |
add_start_docstrings,
|
40 |
add_start_docstrings_to_model_forward,
|
41 |
-
is_flash_attn_2_available,
|
42 |
logging,
|
43 |
replace_return_docstrings,
|
44 |
)
|
@@ -142,10 +141,6 @@ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
142 |
# See all SigLIP models at https://huggingface.co/models?filter=siglip
|
143 |
]
|
144 |
|
145 |
-
if is_flash_attn_2_available():
|
146 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
147 |
-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
148 |
-
|
149 |
|
150 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
151 |
def _get_unpad_data(attention_mask):
|
@@ -430,193 +425,6 @@ class SiglipAttention(nn.Module):
|
|
430 |
return attn_output, attn_weights
|
431 |
|
432 |
|
433 |
-
class SiglipFlashAttention2(SiglipAttention):
|
434 |
-
"""
|
435 |
-
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
436 |
-
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
437 |
-
flash attention and deal with padding tokens in case the input contains any of them.
|
438 |
-
"""
|
439 |
-
|
440 |
-
def __init__(self, *args, **kwargs):
|
441 |
-
super().__init__(*args, **kwargs)
|
442 |
-
self.is_causal = False # Hack to make sure we don't use a causal mask
|
443 |
-
|
444 |
-
def forward(
|
445 |
-
self,
|
446 |
-
hidden_states: torch.Tensor,
|
447 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
448 |
-
position_ids: Optional[torch.LongTensor] = None,
|
449 |
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
450 |
-
output_attentions: bool = False,
|
451 |
-
use_cache: bool = False,
|
452 |
-
**kwargs,
|
453 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
454 |
-
output_attentions = False
|
455 |
-
|
456 |
-
bsz, q_len, _ = hidden_states.size()
|
457 |
-
|
458 |
-
query_states = self.q_proj(hidden_states)
|
459 |
-
key_states = self.k_proj(hidden_states)
|
460 |
-
value_states = self.v_proj(hidden_states)
|
461 |
-
|
462 |
-
# Flash attention requires the input to have the shape
|
463 |
-
# batch_size x seq_length x head_dim x hidden_dim
|
464 |
-
# therefore we just need to keep the original shape
|
465 |
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
466 |
-
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
467 |
-
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
468 |
-
|
469 |
-
kv_seq_len = key_states.shape[-2]
|
470 |
-
if past_key_value is not None:
|
471 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
472 |
-
# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
473 |
-
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
474 |
-
|
475 |
-
# if past_key_value is not None:
|
476 |
-
# cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
477 |
-
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
478 |
-
|
479 |
-
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
480 |
-
# to be able to avoid many of these transpose/reshape/view.
|
481 |
-
query_states = query_states.transpose(1, 2)
|
482 |
-
key_states = key_states.transpose(1, 2)
|
483 |
-
value_states = value_states.transpose(1, 2)
|
484 |
-
|
485 |
-
dropout_rate = self.dropout if self.training else 0.0
|
486 |
-
|
487 |
-
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
488 |
-
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
489 |
-
# cast them back in the correct dtype just to be sure everything works as expected.
|
490 |
-
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
491 |
-
# in fp32. (LlamaRMSNorm handles it correctly)
|
492 |
-
|
493 |
-
input_dtype = query_states.dtype
|
494 |
-
if input_dtype == torch.float32:
|
495 |
-
if torch.is_autocast_enabled():
|
496 |
-
target_dtype = torch.get_autocast_gpu_dtype()
|
497 |
-
# Handle the case where the model is quantized
|
498 |
-
elif hasattr(self.config, "_pre_quantization_dtype"):
|
499 |
-
target_dtype = self.config._pre_quantization_dtype
|
500 |
-
else:
|
501 |
-
target_dtype = self.q_proj.weight.dtype
|
502 |
-
|
503 |
-
logger.warning_once(
|
504 |
-
"The input hidden states seems to be silently casted in float32, this might be related to the fact"
|
505 |
-
" you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
506 |
-
f" {target_dtype}."
|
507 |
-
)
|
508 |
-
|
509 |
-
query_states = query_states.to(target_dtype)
|
510 |
-
key_states = key_states.to(target_dtype)
|
511 |
-
value_states = value_states.to(target_dtype)
|
512 |
-
|
513 |
-
attn_output = self._flash_attention_forward(
|
514 |
-
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
515 |
-
)
|
516 |
-
|
517 |
-
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
|
518 |
-
attn_output = self.out_proj(attn_output)
|
519 |
-
|
520 |
-
if not output_attentions:
|
521 |
-
attn_weights = None
|
522 |
-
|
523 |
-
return attn_output, attn_weights
|
524 |
-
|
525 |
-
def _flash_attention_forward(
|
526 |
-
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
527 |
-
):
|
528 |
-
"""
|
529 |
-
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
530 |
-
first unpad the input, then computes the attention scores and pad the final attention scores.
|
531 |
-
Args:
|
532 |
-
query_states (`torch.Tensor`):
|
533 |
-
Input query states to be passed to Flash Attention API
|
534 |
-
key_states (`torch.Tensor`):
|
535 |
-
Input key states to be passed to Flash Attention API
|
536 |
-
value_states (`torch.Tensor`):
|
537 |
-
Input value states to be passed to Flash Attention API
|
538 |
-
attention_mask (`torch.Tensor`):
|
539 |
-
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
540 |
-
position of padding tokens and 1 for the position of non-padding tokens.
|
541 |
-
dropout (`int`, *optional*):
|
542 |
-
Attention dropout
|
543 |
-
softmax_scale (`float`, *optional*):
|
544 |
-
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
545 |
-
"""
|
546 |
-
|
547 |
-
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
548 |
-
causal = self.is_causal and query_length != 1
|
549 |
-
|
550 |
-
# Contains at least one padding token in the sequence
|
551 |
-
if attention_mask is not None:
|
552 |
-
batch_size = query_states.shape[0]
|
553 |
-
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
554 |
-
query_states, key_states, value_states, attention_mask, query_length
|
555 |
-
)
|
556 |
-
|
557 |
-
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
558 |
-
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
559 |
-
|
560 |
-
attn_output_unpad = flash_attn_varlen_func(
|
561 |
-
query_states,
|
562 |
-
key_states,
|
563 |
-
value_states,
|
564 |
-
cu_seqlens_q=cu_seqlens_q,
|
565 |
-
cu_seqlens_k=cu_seqlens_k,
|
566 |
-
max_seqlen_q=max_seqlen_in_batch_q,
|
567 |
-
max_seqlen_k=max_seqlen_in_batch_k,
|
568 |
-
dropout_p=dropout,
|
569 |
-
softmax_scale=softmax_scale,
|
570 |
-
causal=causal,
|
571 |
-
)
|
572 |
-
|
573 |
-
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
574 |
-
else:
|
575 |
-
attn_output = flash_attn_func(
|
576 |
-
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
577 |
-
)
|
578 |
-
|
579 |
-
return attn_output
|
580 |
-
|
581 |
-
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
582 |
-
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
583 |
-
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
584 |
-
|
585 |
-
key_layer = index_first_axis(
|
586 |
-
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
587 |
-
)
|
588 |
-
value_layer = index_first_axis(
|
589 |
-
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
590 |
-
)
|
591 |
-
if query_length == kv_seq_len:
|
592 |
-
query_layer = index_first_axis(
|
593 |
-
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
594 |
-
)
|
595 |
-
cu_seqlens_q = cu_seqlens_k
|
596 |
-
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
597 |
-
indices_q = indices_k
|
598 |
-
elif query_length == 1:
|
599 |
-
max_seqlen_in_batch_q = 1
|
600 |
-
cu_seqlens_q = torch.arange(
|
601 |
-
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
602 |
-
) # There is a memcpy here, that is very bad.
|
603 |
-
indices_q = cu_seqlens_q[:-1]
|
604 |
-
query_layer = query_layer.squeeze(1)
|
605 |
-
else:
|
606 |
-
# The -q_len: slice assumes left padding.
|
607 |
-
attention_mask = attention_mask[:, -query_length:]
|
608 |
-
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
609 |
-
|
610 |
-
return (
|
611 |
-
query_layer,
|
612 |
-
key_layer,
|
613 |
-
value_layer,
|
614 |
-
indices_q,
|
615 |
-
(cu_seqlens_q, cu_seqlens_k),
|
616 |
-
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
617 |
-
)
|
618 |
-
|
619 |
-
|
620 |
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
621 |
class SiglipMLP(nn.Module):
|
622 |
def __init__(self, config):
|
@@ -638,12 +446,8 @@ class SiglipEncoderLayer(nn.Module):
|
|
638 |
def __init__(self, config: SiglipVisionConfig):
|
639 |
super().__init__()
|
640 |
self.embed_dim = config.hidden_size
|
641 |
-
self._use_flash_attention_2 =
|
642 |
-
self.self_attn = (
|
643 |
-
SiglipAttention(config)
|
644 |
-
if not self._use_flash_attention_2
|
645 |
-
else SiglipFlashAttention2(config)
|
646 |
-
)
|
647 |
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
648 |
self.mlp = SiglipMLP(config)
|
649 |
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
@@ -860,7 +664,7 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
|
|
860 |
self.embeddings = SiglipVisionEmbeddings(config)
|
861 |
self.encoder = SiglipEncoder(config)
|
862 |
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
863 |
-
self._use_flash_attention_2 =
|
864 |
|
865 |
# Initialize weights and apply final processing
|
866 |
self.post_init()
|
@@ -909,11 +713,7 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
|
|
909 |
if not torch.any(~patch_attention_mask):
|
910 |
attention_mask=None
|
911 |
else:
|
912 |
-
attention_mask = (
|
913 |
-
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
914 |
-
if not self._use_flash_attention_2
|
915 |
-
else patch_attention_mask
|
916 |
-
)
|
917 |
|
918 |
encoder_outputs = self.encoder(
|
919 |
inputs_embeds=hidden_states,
|
@@ -934,4 +734,4 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
|
|
934 |
pooler_output=None,
|
935 |
hidden_states=encoder_outputs.hidden_states,
|
936 |
attentions=encoder_outputs.attentions,
|
937 |
-
)
|
|
|
38 |
ModelOutput,
|
39 |
add_start_docstrings,
|
40 |
add_start_docstrings_to_model_forward,
|
|
|
41 |
logging,
|
42 |
replace_return_docstrings,
|
43 |
)
|
|
|
141 |
# See all SigLIP models at https://huggingface.co/models?filter=siglip
|
142 |
]
|
143 |
|
|
|
|
|
|
|
|
|
144 |
|
145 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
146 |
def _get_unpad_data(attention_mask):
|
|
|
425 |
return attn_output, attn_weights
|
426 |
|
427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
429 |
class SiglipMLP(nn.Module):
|
430 |
def __init__(self, config):
|
|
|
446 |
def __init__(self, config: SiglipVisionConfig):
|
447 |
super().__init__()
|
448 |
self.embed_dim = config.hidden_size
|
449 |
+
self._use_flash_attention_2 = False
|
450 |
+
self.self_attn = SiglipAttention(config)
|
|
|
|
|
|
|
|
|
451 |
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
452 |
self.mlp = SiglipMLP(config)
|
453 |
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
|
664 |
self.embeddings = SiglipVisionEmbeddings(config)
|
665 |
self.encoder = SiglipEncoder(config)
|
666 |
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
667 |
+
self._use_flash_attention_2 = False
|
668 |
|
669 |
# Initialize weights and apply final processing
|
670 |
self.post_init()
|
|
|
713 |
if not torch.any(~patch_attention_mask):
|
714 |
attention_mask=None
|
715 |
else:
|
716 |
+
attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
|
|
|
|
|
|
|
|
717 |
|
718 |
encoder_outputs = self.encoder(
|
719 |
inputs_embeds=hidden_states,
|
|
|
734 |
pooler_output=None,
|
735 |
hidden_states=encoder_outputs.hidden_states,
|
736 |
attentions=encoder_outputs.attentions,
|
737 |
+
)
|