api-demo
/
opencompass-my-api
/opencompass
/multimodal
/models
/llama_adapter_v2_multimodal
/llama_adapter.py
import json | |
import os | |
import os.path as osp | |
import sys | |
from pathlib import Path | |
import clip | |
import mmengine | |
import torch | |
import torch.nn as nn | |
from mmengine.device import get_device | |
from timm.models.vision_transformer import Block | |
from opencompass.registry import MM_MODELS | |
def load_package(): | |
"""Load required packages from llama_adapter_v2_multimodal7b.""" | |
current_file_path = os.path.abspath(__file__) | |
current_folder_path = os.path.dirname(current_file_path) | |
sys.path.append(os.path.join(current_folder_path, 'LLaMA-Adapter')) # noqa | |
from llama_adapter_v2_multimodal7b.llama.llama import (ModelArgs, | |
Transformer) | |
from llama_adapter_v2_multimodal7b.llama.tokenizer import Tokenizer | |
from llama_adapter_v2_multimodal7b.llama.utils import sample_top_p | |
sys.path.pop(-1) | |
return ModelArgs, Transformer, Tokenizer, sample_top_p | |
ModelArgs, Transformer, Tokenizer, sample_top_p = load_package() | |
class LLaMA_adapter(nn.Module): | |
def __init__(self, | |
llama_ckpt_dir, | |
llama_tokenizer, | |
max_seq_len=512, | |
max_batch_size=1, | |
clip_model='ViT-L/14', | |
v_embed_dim=768, | |
v_depth=8, | |
v_num_heads=16, | |
v_mlp_ratio=4.0, | |
query_len=10, | |
query_layer=31, | |
w_bias=False, | |
w_lora=False, | |
lora_rank=16, | |
prompt_constructor=None, | |
post_processor=None): | |
super().__init__() | |
self.device = get_device() | |
# load llama configs | |
with open(os.path.join(llama_ckpt_dir, 'params.json'), 'r') as f: | |
params = json.loads(f.read()) | |
model_args = ModelArgs(max_seq_len=max_seq_len, | |
max_batch_size=max_batch_size, | |
**params) | |
# 1. clip and clip projector | |
self.clip, self.clip_transform = clip.load(clip_model) | |
clip_dim = self.clip.visual.proj.shape[1] | |
self.clip_proj = nn.Linear(clip_dim, v_embed_dim) | |
self.clip_proj_norm = nn.LayerNorm(v_embed_dim) | |
self.query_len = query_len | |
self.query_layer = query_layer | |
# 2. visual query, blocks and projector | |
self.visual_query = nn.Embedding(query_len, v_embed_dim) | |
self.visual_blocks = nn.ModuleList([ | |
Block(v_embed_dim, v_num_heads, v_mlp_ratio, qkv_bias=True) | |
for _ in range(v_depth) | |
]) | |
self.visual_proj = nn.Linear(v_embed_dim, model_args.dim) | |
self.visual_proj_norm = nn.LayerNorm(model_args.dim) | |
# 3. adapter query | |
self.adapter_query = nn.Embedding(query_len * query_layer, | |
model_args.dim) | |
# 4. tokenizer | |
self.tokenizer = Tokenizer(model_path=llama_tokenizer) | |
# 5. llama | |
model_args.vocab_size = self.tokenizer.n_words | |
model_args.w_bias = w_bias | |
model_args.w_lora = w_lora | |
model_args.lora_rank = lora_rank | |
torch.set_default_tensor_type(torch.cuda.HalfTensor) | |
self.llama = Transformer(model_args) | |
torch.set_default_tensor_type(torch.FloatTensor) | |
ckpts = sorted(Path(llama_ckpt_dir).glob('*.pth')) | |
for ckpt in ckpts: | |
ckpt = torch.load(ckpt, map_location='cpu') | |
self.llama.load_state_dict(ckpt, strict=False) | |
self.prompt_constructor = mmengine.registry.build_from_cfg( | |
prompt_constructor, MM_MODELS) | |
if post_processor is not None: | |
self.post_processor = mmengine.registry.build_from_cfg( | |
post_processor, MM_MODELS) | |
def clip_encode_image(self, x): | |
# modified from CLIP | |
x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid] | |
# shape = [*, width, grid ** 2] | |
x = x.reshape(x.shape[0], x.shape[1], -1) | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
x = torch.cat([ | |
self.clip.visual.class_embedding.to(x.dtype) + torch.zeros( | |
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x | |
], | |
dim=1) # shape = [*, grid ** 2 + 1, width] | |
x = x + self.clip.visual.positional_embedding.to(x.dtype) | |
x = self.clip.visual.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.clip.visual.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
# preserve all spatial tokens | |
x = self.clip.visual.ln_post(x[:, :, :]) | |
if self.clip.visual.proj is not None: | |
x = x @ self.clip.visual.proj | |
return x | |
def forward_visual(self, imgs): | |
clip_feats = self.clip_encode_image(imgs) | |
clip_feats = self.clip_proj_norm(self.clip_proj(clip_feats.float())) | |
visual_query = self.visual_query.weight.unsqueeze(0).repeat( | |
len(imgs), 1, 1) | |
visual_query = torch.cat([visual_query, clip_feats], dim=1) | |
for block in self.visual_blocks: | |
visual_query = block(visual_query) | |
visual_query = visual_query[:, :self.query_len, :] | |
visual_query = self.visual_proj(visual_query) | |
visual_query = self.visual_proj_norm(visual_query) | |
return visual_query | |
def forward(self, visual_query, tokens, start_pos: int): | |
_bsz, seqlen = tokens.shape | |
h = self.llama.tok_embeddings(tokens) | |
freqs_cis = self.llama.freqs_cis.to(h.device) | |
freqs_cis = freqs_cis[start_pos:start_pos + seqlen] | |
mask = None | |
mask = torch.full((1, 1, seqlen, seqlen), | |
float('-inf'), | |
device=h.device) | |
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) | |
for layer in self.llama.layers[:-1 * self.query_layer]: | |
h = layer(h, start_pos, freqs_cis, mask) | |
adapter = self.adapter_query.weight.reshape(self.query_layer, | |
self.query_len, | |
-1).unsqueeze(1) | |
adapter_index = 0 | |
for layer in self.llama.layers[-1 * self.query_layer:]: | |
dynamic_adapter = adapter[adapter_index].repeat(_bsz, 1, 1) | |
dynamic_adapter = dynamic_adapter + visual_query | |
h = layer(h, start_pos, freqs_cis, mask, dynamic_adapter) | |
adapter_index = adapter_index + 1 | |
h = self.llama.norm(h) | |
output = self.llama.output(h[:, -1, :]) | |
return output.float() | |
def pack_inputs(self, batch): | |
images = [image.unsqueeze(0) for image in batch['inputs']] | |
data_samples = [data_sample for data_sample in batch['data_samples']] | |
images = torch.cat(images, dim=0).to(get_device()) | |
inputs = {'image': images, 'data_samples': data_samples} | |
return inputs | |
def generate(self, batch): | |
max_gen_len = 256 | |
temperature = 0.1 | |
top_p = 0.75 | |
inputs = self.pack_inputs(batch) | |
inputs = self.prompt_constructor(inputs) | |
image = inputs['image'] | |
prompts = inputs['prompt'] | |
data_samples = inputs['data_samples'] | |
data_sample = data_samples[0] | |
imgs = image | |
# import pdb;pdb.set_trace() | |
bsz = len(imgs) | |
params = self.llama.params | |
with torch.cuda.amp.autocast(): | |
visual_query = self.forward_visual(imgs) | |
# import pdb;pdb.set_trace() | |
if isinstance(prompts[0], str): | |
prompts = [ | |
self.tokenizer.encode(x, bos=True, eos=False) for x in prompts | |
] | |
# import pdb;pdb.set_trace() | |
min_prompt_size = min([len(t) for t in prompts]) | |
max_prompt_size = max([len(t) for t in prompts]) | |
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | |
tokens = torch.full((bsz, total_len), | |
self.tokenizer.pad_id).cuda().long() | |
# import pdb;pdb.set_trace() | |
for k, t in enumerate(prompts): | |
if len(t) <= total_len: | |
tokens[k, :len(t)] = torch.tensor(t).cuda().long() | |
else: | |
tokens[k, :total_len] = torch.tensor( | |
t[:total_len]).cuda().long() | |
input_text_mask = tokens != self.tokenizer.pad_id | |
start_pos = min_prompt_size | |
prev_pos = 0 | |
for cur_pos in range(start_pos, total_len): | |
with torch.cuda.amp.autocast(): | |
logits = self.forward(visual_query, | |
tokens[:, prev_pos:cur_pos], prev_pos) | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
next_token = sample_top_p(probs, top_p) | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.reshape(-1) | |
next_token = torch.where(input_text_mask[:, cur_pos], | |
tokens[:, cur_pos], next_token) | |
tokens[:, cur_pos] = next_token | |
# trick: early stop if bsz==1 | |
if bsz == 1 and next_token[0] == self.tokenizer.eos_id: | |
break | |
prev_pos = cur_pos | |
decoded = [] | |
for i, t in enumerate(tokens.tolist()): | |
# cut to max gen len | |
t = t[len(prompts[i]):len(prompts[i]) + max_gen_len] | |
# cut to eos tok if any | |
try: | |
t = t[:t.index(self.tokenizer.eos_id)] | |
except ValueError: | |
pass | |
decoded.append(self.tokenizer.decode(t)) | |
output_text = self.post_processor(decoded[0]) | |
data_sample.pred_answer = output_text | |
return data_sample | |
class LLaMA_adapter_v2(nn.Module): | |
def __init__(self, | |
llama_dir, | |
prompt_constructor: dict, | |
post_processor: dict, | |
model_path: str = 'llama_adapter_v2_multimodal7b', | |
name: str = 'LORA-BIAS-7B', | |
mode: str = 'generation', | |
device='cuda' if torch.cuda.is_available() else 'cpu', | |
download_root='ckpts'): | |
super().__init__() | |
assert name in ['LORA-BIAS-7B', 'BIAS-7B', 'CAPTION-7B'] | |
# BIAS-7B or https://xxx/sha256_BIAS-7B.pth -> 7B | |
llama_type = name.split('.')[0].split('-')[-1] | |
llama_ckpt_dir = os.path.join(llama_dir, llama_type) | |
llama_tokenzier_path = os.path.join(llama_dir, 'tokenizer.model') | |
# load llama_adapter weights and model_cfg | |
print(f'Loading LLaMA-Adapter from {llama_dir}') | |
current_file_path = os.path.abspath(__file__) | |
current_folder_path = os.path.dirname(current_file_path) | |
model_path = osp.join(current_folder_path, 'LLaMA-Adapter', model_path) | |
ckpt_root = osp.join(model_path, download_root) | |
ckpt_map = { | |
'LORA-BIAS-7B': | |
'1bcbffc43484332672092e0024a8699a6eb5f558161aebf98a7c6b1db67224d1_LORA-BIAS-7B.pth', # noqa: E501 | |
'BIAS-7B': | |
'7fa55208379faf2dd862565284101b0e4a2a72114d6490a95e432cf9d9b6c813_BIAS-7B.pth', # noqa: E501 | |
'CAPTION-7B': | |
'5088aeb63a89746b90bcfd5cb819e1c7411b2771b267c6d131ce73e250a8abf0_CAPTION-7B.pth' # noqa: E501 | |
} | |
ckpt = torch.load(osp.join(ckpt_root, ckpt_map[name]), | |
map_location='cpu') | |
model_cfg = ckpt.get('config', {}) | |
self.model = LLaMA_adapter( | |
llama_ckpt_dir, | |
llama_tokenzier_path, | |
max_seq_len=512, | |
max_batch_size=1, | |
clip_model='ViT-L/14', | |
v_embed_dim=768, | |
v_depth=8, | |
v_num_heads=16, | |
v_mlp_ratio=4.0, | |
query_len=10, | |
query_layer=31, | |
w_bias=model_cfg.get('w_bias', False), | |
w_lora=model_cfg.get('w_lora', False), | |
lora_rank=model_cfg.get('lora_rank', 16), | |
prompt_constructor=prompt_constructor, | |
post_processor=post_processor, | |
) | |
self.model.load_state_dict(ckpt['model'], strict=False) | |
self.mode = mode | |
def forward(self, batch): | |
if self.mode == 'generation': | |
return self.model.generate(batch) | |