File size: 5,872 Bytes
a175ed9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
torch.manual_seed(1024)
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_qformer import QformerConfig
from .qformer_src import BertConfig, BertLMHeadModel
from transformers import BertTokenizerFast as BertTokenizer
from .configuration_projector import ProjectorConfig
from .modeling_projector import ProjectorModel
from .fuse_modules import BiAttentionBlock
import torch.nn.functional as F
from transformers.activations import ACT2FN
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
ret = super().forward(x)
return ret
#orig_type = x.dtype
#ret = super().forward(x.type(torch.float32))
#return ret.type(orig_type)
class QformerAttnModel(PreTrainedModel):
_auto_class = 'AutoModel'
config_class = QformerConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = False
def __init__(self, config) -> None:
super().__init__(config)
self.gradient_checkpointing = False
vision_width = config.visual_hidden_size
num_query_token = config.num_query_token
bert = config.bert
llm_hidden_size = config.llm_hidden_size
cross_attention_freq = config.cross_attention_freq
qformer_pth = config.qformer_pth
encoder_config = BertConfig.from_pretrained(bert)
encoder_config.encoder_width = vision_width
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
encoder_config.num_hidden_layers = 12
#encoder_config.attention_probs_dropout_prob=0.5
#encoder_config.hidden_dropout_prob=0.5
Qformer = BertLMHeadModel.from_pretrained(
bert, config=encoder_config
)
remove_text = False
if remove_text:
# remove the Q-former's text component
Qformer.cls = None
Qformer.bert.embeddings.word_embeddings = None
Qformer.bert.embeddings.position_embeddings = None
for layer in Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
self.Qformer = Qformer
self.query_tokens = query_tokens
self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias)
self.ln_vision = LayerNorm(encoder_config.encoder_width)
self.ln_llava = LayerNorm(encoder_config.encoder_width)
tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right')
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.Qformer.resize_token_embeddings(len(tokenizer))
if qformer_pth is not None:
pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model']
print(f'Load Qformer from {qformer_pth}')
self.load_state_dict(pretrained_state_dict, strict=False)
print('Done.')
projector_config = ProjectorConfig(
visual_hidden_size = config.visual_hidden_size,
llm_hidden_size = config.llm_hidden_size,
projector_depth = 2)
self.connector = ProjectorModel(projector_config)
d_model = config.llm_hidden_size
dim_feedforward = 1024
nhead = 8
fusion_dropout = 0.0
fusion_droppath = 0.1
self.fuse = BiAttentionBlock(
v_dim=d_model,
l_dim=d_model,
embed_dim=dim_feedforward,
num_heads=nhead,
dropout=fusion_dropout,
drop_path=fusion_droppath,
)
modules = [
nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False),
ACT2FN['gelu'],
nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False)
]
self.ffn = nn.Sequential(*modules)
def enable_input_require_grads(self):
def make_inputs_require_grad(module, input, output):
if isinstance(output, tuple):
output[0].requires_grad_(True)
output[1].requires_grad_(True)
else:
output.requires_grad_(True)
self.Qformer.register_forward_hook(make_inputs_require_grad)
self.llm_proj.register_forward_hook(make_inputs_require_grad)
self.ln_vision.register_forward_hook(make_inputs_require_grad)
self.connector.register_forward_hook(make_inputs_require_grad)
self.ffn.register_forward_hook(make_inputs_require_grad)
self.fuse.register_forward_hook(make_inputs_require_grad)
def _set_gradient_checkpointing(self, module, value=False):
exit()
if isinstance(module, ProjectorModel):
module.gradient_checkpointing = value
def forward(self, x_):
if self.gradient_checkpointing and self.training:
print('Not supprted gradient checkpointing')
#
x = self.ln_vision(x_)
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=x,
return_dict=True,
)
q_feat = self.llm_proj(query_output.last_hidden_state)
mlp_outputs = self.connector(x_)
mlp_feat = mlp_outputs
mlp_feat = mlp_feat + self.fuse(mlp_feat, q_feat)
out = mlp_feat + self.ffn(mlp_feat)
return out
|