|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import random |
|
from abc import ABC, abstractmethod |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
CONTROLLER_HEART_BEAT_EXPIRATION = 30 |
|
WORKER_HEART_BEAT_INTERVAL = 15 |
|
|
|
LOGDIR = "." |
|
|
|
|
|
IGNORE_INDEX = -100 |
|
IMAGE_TOKEN_INDEX = -200 |
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
|
DEFAULT_IM_START_TOKEN = "<im_start>" |
|
DEFAULT_IM_END_TOKEN = "<im_end>" |
|
IMAGE_PLACEHOLDER = "<image-placeholder>" |
|
|
|
from .multimodal_encoder_builder import build_vision_tower_aux_list |
|
from .multimodal_projector_builder import build_vision_projector |
|
from .vision_sampler import VisionTokenSampler |
|
|
|
IS_XLA_AVAILABLE = False |
|
|
|
|
|
class CambrianMetaModel: |
|
|
|
def __init__(self, config): |
|
super(CambrianMetaModel, self).__init__(config) |
|
|
|
if hasattr(config, "mm_vision_tower_aux_list"): |
|
|
|
projector_type = getattr(config, "mm_projector_type", "linear") |
|
if projector_type == "sva": |
|
|
|
vision_hidden_size = config.vision_hidden_size |
|
num_query_group = config.num_query_group |
|
query_num_list = config.query_num_list |
|
connector_only = config.connector_only |
|
connector_depth = config.connector_depth |
|
self.vision_tower_aux_list = build_vision_tower_aux_list( |
|
config, delay_load=True |
|
) |
|
self.mm_projector = nn.Sequential( |
|
nn.Linear(vision_hidden_size * num_query_group, config.hidden_size), |
|
nn.GELU(), |
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
) |
|
|
|
image_token_len = config.image_token_len |
|
vision_tower_aux_token_len_list = ( |
|
self.config.mm_vision_tower_aux_token_len_list |
|
) |
|
cross_att_token_len_list = [ |
|
int(vision_tower_aux_token_len**0.5) // int(image_token_len**0.5) |
|
for vision_tower_aux_token_len in vision_tower_aux_token_len_list |
|
] |
|
|
|
for aux_i, vision_tower_aux in enumerate(self.vision_tower_aux_list): |
|
setattr( |
|
self, |
|
"mm_projector_aux_{}".format(aux_i), |
|
nn.Sequential( |
|
nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size), |
|
nn.GELU(), |
|
nn.Linear(vision_hidden_size, vision_hidden_size), |
|
nn.LayerNorm(vision_hidden_size), |
|
), |
|
) |
|
|
|
for query_group_i in range(num_query_group): |
|
cross_att_token_len_list = [ |
|
int(vision_tower_aux_token_len**0.5) |
|
// int(query_num_list[query_group_i] ** 0.5) |
|
for vision_tower_aux_token_len in vision_tower_aux_token_len_list |
|
] |
|
setattr( |
|
self, |
|
"vision_sampler_{}".format(query_group_i), |
|
VisionTokenSampler( |
|
vision_hidden_size, |
|
vision_hidden_size, |
|
[vision_hidden_size] * len(self.vision_tower_aux_list), |
|
cross_att_token_len_list, |
|
vision_hidden_size, |
|
connector_depth, |
|
), |
|
) |
|
|
|
if not connector_only: |
|
num_of_vision_sampler_layers = ( |
|
config.num_of_vision_sampler_layers |
|
) = config.num_of_vision_sampler_layers |
|
config.start_of_vision_sampler_layers = ( |
|
config.start_of_vision_sampler_layers |
|
) |
|
config.stride_of_vision_sampler_layers = ( |
|
config.stride_of_vision_sampler_layers |
|
) |
|
cross_att_token_len_list = [ |
|
int(vision_tower_aux_token_len**0.5) |
|
// int(image_token_len**0.5) |
|
for vision_tower_aux_token_len in vision_tower_aux_token_len_list |
|
] |
|
self.vision_sampler_layers = nn.ModuleList( |
|
[ |
|
VisionTokenSampler( |
|
config.hidden_size, |
|
vision_hidden_size, |
|
[vision_hidden_size] * len(self.vision_tower_aux_list), |
|
cross_att_token_len_list, |
|
vision_hidden_size, |
|
1, |
|
) |
|
for layer_idx in range(0, num_of_vision_sampler_layers) |
|
] |
|
) |
|
|
|
self.vision_query = nn.Parameter( |
|
torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype) |
|
) |
|
|
|
self.image_newline = nn.Parameter( |
|
torch.empty(config.hidden_size, dtype=self.dtype) |
|
) |
|
|
|
self.frame_pos = torch.stack( |
|
[ |
|
1 |
|
/ torch.pow( |
|
torch.tensor(10000), |
|
torch.tensor(2 * (hid_j // 2) / config.hidden_size), |
|
) |
|
for hid_j in range(config.hidden_size) |
|
] |
|
) |
|
|
|
else: |
|
self.vision_tower_aux_list = build_vision_tower_aux_list( |
|
config, delay_load=True |
|
) |
|
config.mm_hidden_size = sum( |
|
[ |
|
vision_tower_aux.hidden_size |
|
for vision_tower_aux in self.vision_tower_aux_list |
|
] |
|
) |
|
self.mm_projector = build_vision_projector(config) |
|
self.image_newline = nn.Parameter( |
|
torch.empty(config.hidden_size, dtype=self.dtype) |
|
) |
|
|
|
def get_frame_pos(self, time_range): |
|
frame_pos = self.frame_pos.reshape(1, -1) * time_range.reshape(-1, 1).to( |
|
self.frame_pos.device |
|
) |
|
frame_pos[:, 0::2] = torch.sin(frame_pos[:, 0::2]) |
|
frame_pos[:, 1::2] = torch.cos(frame_pos[:, 0::2]) |
|
frame_pos = frame_pos.unsqueeze(1) |
|
return frame_pos |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_vision_tower_aux_list(self): |
|
vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None) |
|
return vision_tower_aux_list |
|
|
|
def initialize_vision_modules(self, model_args, fsdp=None): |
|
|
|
num_query_group = model_args.num_query_group |
|
query_num_list = model_args.query_num_list |
|
vision_hidden_size = model_args.vision_hidden_size |
|
vision_tower_aux_list = model_args.vision_tower_aux_list |
|
vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list |
|
image_token_len = model_args.image_token_len |
|
mm_vision_select_layer = model_args.mm_vision_select_layer |
|
mm_vision_select_feature = model_args.mm_vision_select_feature |
|
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter |
|
connector_only = model_args.connector_only |
|
connector_depth = model_args.connector_depth |
|
|
|
|
|
self.config.image_token_len = image_token_len |
|
self.config.num_query_group = num_query_group |
|
self.config.query_num_list = query_num_list |
|
assert num_query_group == len(query_num_list) |
|
self.config.connector_depth = connector_depth |
|
self.config.mm_vision_tower_aux_list = vision_tower_aux_list |
|
self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list |
|
self.config.connector_only = connector_only |
|
self.config.highres_connect = model_args.highres_connect |
|
self.config.highres = model_args.highres |
|
self.config.frame_pos = model_args.frame_pos |
|
self.config.lowres_token = model_args.lowres_token |
|
self.config.connect_layer = model_args.connect_layer |
|
self.config.dino_threshold = getattr(model_args, "dino_threshold", 0.83) |
|
self.config.drop_threshold = getattr(model_args, "drop_threshold", 0.6) |
|
self.config.is_image_newline = getattr(model_args, "is_image_newline", True) |
|
|
|
if self.get_vision_tower_aux_list() is None: |
|
vision_tower_aux_list = build_vision_tower_aux_list(model_args) |
|
if model_args.unfreeze_mm_vision_tower: |
|
self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list) |
|
else: |
|
self.vision_tower_aux_list = vision_tower_aux_list |
|
else: |
|
vision_tower_aux_list = self.vision_tower_aux_list |
|
for vision_tower_aux in vision_tower_aux_list: |
|
vision_tower_aux.load_model() |
|
|
|
self.config.use_mm_proj = True |
|
self.config.mm_projector_type = getattr( |
|
model_args, "mm_projector_type", "linear" |
|
) |
|
self.config.vision_hidden_size = vision_hidden_size |
|
self.config.mm_vision_select_layer = mm_vision_select_layer |
|
self.config.mm_vision_select_feature = mm_vision_select_feature |
|
|
|
if getattr(self, "mm_projector", None) is None: |
|
|
|
if self.config.mm_projector_type == "sva": |
|
self.mm_projector = nn.Sequential( |
|
nn.Linear( |
|
vision_hidden_size * num_query_group, self.config.hidden_size |
|
), |
|
nn.GELU(), |
|
nn.Linear(self.config.hidden_size, self.config.hidden_size), |
|
) |
|
for aux_i, vision_tower_aux in enumerate(vision_tower_aux_list): |
|
setattr( |
|
self, |
|
"mm_projector_aux_{}".format(aux_i), |
|
nn.Sequential( |
|
nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size), |
|
nn.GELU(), |
|
nn.Linear(vision_hidden_size, vision_hidden_size), |
|
nn.LayerNorm(vision_hidden_size), |
|
), |
|
) |
|
|
|
|
|
for query_group_i in range(num_query_group): |
|
cross_att_token_len_list = [ |
|
int(vision_tower_aux_token_len**0.5) |
|
// int(query_num_list[query_group_i] ** 0.5) |
|
for vision_tower_aux_token_len in vision_tower_aux_token_len_list |
|
] |
|
setattr( |
|
self, |
|
"vision_sampler_{}".format(query_group_i), |
|
VisionTokenSampler( |
|
vision_hidden_size, |
|
vision_hidden_size, |
|
[vision_hidden_size] * len(vision_tower_aux_list), |
|
cross_att_token_len_list, |
|
vision_hidden_size, |
|
connector_depth, |
|
), |
|
) |
|
|
|
|
|
if not connector_only: |
|
num_of_vision_sampler_layers = ( |
|
self.config.num_of_vision_sampler_layers |
|
) = model_args.num_of_vision_sampler_layers |
|
self.config.start_of_vision_sampler_layers = ( |
|
model_args.start_of_vision_sampler_layers |
|
) |
|
self.config.stride_of_vision_sampler_layers = ( |
|
model_args.stride_of_vision_sampler_layers |
|
) |
|
cross_att_token_len_list = [ |
|
int(vision_tower_aux_token_len**0.5) |
|
// int(image_token_len**0.5) |
|
for vision_tower_aux_token_len in vision_tower_aux_token_len_list |
|
] |
|
self.vision_sampler_layers = nn.ModuleList( |
|
[ |
|
VisionTokenSampler( |
|
self.config.hidden_size, |
|
vision_hidden_size, |
|
[vision_hidden_size] * len(vision_tower_aux_list), |
|
cross_att_token_len_list, |
|
vision_hidden_size, |
|
1, |
|
) |
|
for layer_idx in range(0, num_of_vision_sampler_layers) |
|
] |
|
) |
|
vision_embed_std = 1 / torch.sqrt( |
|
torch.tensor(vision_hidden_size, dtype=self.dtype) |
|
) |
|
self.vision_query = nn.Parameter( |
|
torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype) |
|
* vision_embed_std |
|
) |
|
|
|
embed_std = 1 / torch.sqrt( |
|
torch.tensor(self.config.hidden_size, dtype=self.dtype) |
|
) |
|
self.image_newline = nn.Parameter( |
|
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std |
|
) |
|
|
|
else: |
|
self.config.mm_hidden_size = sum( |
|
[ |
|
vision_tower_aux.hidden_size |
|
for vision_tower_aux in vision_tower_aux_list |
|
] |
|
) |
|
self.mm_projector = build_vision_projector(self.config) |
|
embed_std = 1 / torch.sqrt( |
|
torch.tensor(self.config.hidden_size, dtype=self.dtype) |
|
) |
|
self.image_newline = nn.Parameter( |
|
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std |
|
) |
|
else: |
|
|
|
for p in self.mm_projector.parameters(): |
|
p.requires_grad = True |
|
|
|
if pretrain_mm_mlp_adapter is not None: |
|
mm_projector_weights = torch.load( |
|
pretrain_mm_mlp_adapter, map_location="cpu" |
|
) |
|
|
|
def get_w(weights, keyword): |
|
return { |
|
k.split(keyword + ".")[1]: v |
|
for k, v in weights.items() |
|
if keyword + "." in k |
|
} |
|
|
|
self.mm_projector.load_state_dict( |
|
get_w(mm_projector_weights, "mm_projector"), strict=True |
|
) |
|
|
|
if self.config.mm_projector_type == "sva": |
|
for aux_i in range(len(vision_tower_aux_list)): |
|
getattr(self, "mm_projector_aux_{}".format(aux_i)).load_state_dict( |
|
get_w( |
|
mm_projector_weights, "mm_projector_aux_{}".format(aux_i) |
|
), |
|
strict=True, |
|
) |
|
|
|
for query_group_i in range(num_query_group): |
|
getattr( |
|
self, "vision_sampler_{}".format(query_group_i) |
|
).load_state_dict( |
|
get_w( |
|
mm_projector_weights, |
|
"vision_sampler_{}".format(query_group_i), |
|
), |
|
strict=True, |
|
) |
|
|
|
if not connector_only: |
|
self.vision_sampler_layers.load_state_dict( |
|
get_w(mm_projector_weights, "vision_sampler_layers"), |
|
strict=True, |
|
) |
|
self.vision_query.data = mm_projector_weights["model.vision_query"] |
|
self.image_newline.data = mm_projector_weights["model.image_newline"] |
|
|
|
|
|
def unmask_attention_mask(mask, original_size): |
|
original_w, original_h = original_size |
|
cur_h, cur_w = mask.shape[1:3] |
|
|
|
original_aspect_ratio = original_w / original_h |
|
current_aspect_ratio = cur_w / cur_h |
|
|
|
if original_aspect_ratio > current_aspect_ratio: |
|
scale_factor = cur_w / original_w |
|
new_height = int(original_h * scale_factor) |
|
padding = (cur_h - new_height) // 2 |
|
if padding > 0: |
|
mask[:, :padding, :] = 0 |
|
mask[:, -padding:, :] = 0 |
|
return mask |
|
else: |
|
scale_factor = cur_h / original_h |
|
new_width = int(original_w * scale_factor) |
|
padding = (cur_w - new_width) // 2 |
|
if padding > 0: |
|
mask[:, :, :padding] = 0 |
|
mask[:, :, -padding:] = 0 |
|
return mask |
|
|
|
|
|
def unpad_image(tensor, original_size): |
|
""" |
|
Unpads a PyTorch tensor of a padded and resized image. |
|
|
|
Args: |
|
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. |
|
original_size (tuple): The original size of the image (height, width). |
|
|
|
Returns: |
|
torch.Tensor: The unpadded image tensor. |
|
""" |
|
original_width, original_height = original_size |
|
current_height, current_width = tensor.shape[1:3] |
|
|
|
original_aspect_ratio = original_width / original_height |
|
current_aspect_ratio = current_width / current_height |
|
|
|
if original_aspect_ratio > current_aspect_ratio: |
|
scale_factor = current_width / original_width |
|
new_height = int(original_height * scale_factor) |
|
padding = (current_height - new_height) // 2 |
|
unpadded_tensor = tensor[:, padding : current_height - padding, :] |
|
|
|
|
|
else: |
|
scale_factor = current_height / original_height |
|
new_width = int(original_width * scale_factor) |
|
padding = (current_width - new_width) // 2 |
|
unpadded_tensor = tensor[:, :, padding : current_width - padding] |
|
|
|
|
|
|
|
return unpadded_tensor |
|
|
|
|
|
class CambrianMetaForCausalLM(ABC): |
|
|
|
@abstractmethod |
|
def get_model(self): |
|
pass |
|
|
|
|
|
|
|
|
|
def get_vision_tower_aux_list(self): |
|
return self.get_model().get_vision_tower_aux_list() |
|
|
|
def rearrange_vision_tower_features_train( |
|
self, |
|
vision_tower_aux_feature_list, |
|
vision_tower_aux_attention_masks_list, |
|
query_side_len, |
|
): |
|
vision_tower_aux_feature_rearranged_list = [] |
|
vision_tower_aux_attention_masks_rearranged_list = [] |
|
bs = vision_tower_aux_feature_list[0].shape[0] |
|
for vision_tower_aux_feature, vision_tower_aux_attention_masks in zip( |
|
vision_tower_aux_feature_list, vision_tower_aux_attention_masks_list |
|
): |
|
aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5) |
|
assert (aux_height // query_side_len) * query_side_len == aux_height |
|
|
|
reduce_factor = aux_height // query_side_len |
|
vision_tower_aux_feature_rearranged = vision_tower_aux_feature.view( |
|
bs, query_side_len, reduce_factor, query_side_len, reduce_factor, -1 |
|
) |
|
vision_tower_aux_feature_rearranged = ( |
|
vision_tower_aux_feature_rearranged.permute(0, 1, 3, 2, 4, 5) |
|
.contiguous() |
|
.flatten(0, 2) |
|
.flatten(1, 2) |
|
) |
|
|
|
vision_tower_aux_attention_masks_rearranged = ( |
|
vision_tower_aux_attention_masks.view( |
|
bs * query_side_len * query_side_len, reduce_factor * reduce_factor |
|
) |
|
) |
|
|
|
vision_tower_aux_feature_rearranged_list.append( |
|
vision_tower_aux_feature_rearranged |
|
) |
|
vision_tower_aux_attention_masks_rearranged_list.append( |
|
vision_tower_aux_attention_masks_rearranged |
|
) |
|
return ( |
|
vision_tower_aux_feature_rearranged_list, |
|
vision_tower_aux_attention_masks_rearranged_list, |
|
) |
|
|
|
def rearrange_vision_tower_features_inference( |
|
self, vision_tower_aux_feature_list, query_side_len, image_sizes, unpad=False |
|
): |
|
vision_tower_aux_feature_rearranged_list = [] |
|
vision_tower_aux_attention_masks_rearranged_list = [] |
|
bs = vision_tower_aux_feature_list[0].shape[0] |
|
for vision_tower_aux_feature in vision_tower_aux_feature_list: |
|
aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5) |
|
assert (aux_height // query_side_len) * query_side_len == aux_height |
|
|
|
reduce_factor = aux_height // query_side_len |
|
|
|
vision_tower_aux_feature_rearranged = [] |
|
vision_tower_aux_attention_masks_rearranged = [] |
|
for batch_i in range(bs): |
|
image_size = image_sizes[batch_i] |
|
cur_vision_tower_aux_feature = vision_tower_aux_feature[batch_i] |
|
|
|
cur_vision_tower_aux_attention_masks_rearranged = torch.ones( |
|
(1, aux_height, aux_width), |
|
dtype=torch.bool, |
|
device=cur_vision_tower_aux_feature.device, |
|
) |
|
cur_vision_tower_aux_feature_rearranged = ( |
|
cur_vision_tower_aux_feature.view( |
|
1, |
|
query_side_len, |
|
reduce_factor, |
|
query_side_len, |
|
reduce_factor, |
|
-1, |
|
) |
|
) |
|
cur_vision_tower_aux_feature_rearranged = ( |
|
cur_vision_tower_aux_feature_rearranged.permute( |
|
0, 1, 3, 2, 4, 5 |
|
).contiguous() |
|
) |
|
if unpad: |
|
cur_vision_tower_aux_feature_rearranged = unpad_image( |
|
cur_vision_tower_aux_feature_rearranged, image_size |
|
) |
|
cur_vision_tower_aux_feature_rearranged = ( |
|
cur_vision_tower_aux_feature_rearranged.flatten(0, 2).flatten(1, 2) |
|
) |
|
|
|
cur_vision_tower_aux_attention_masks_rearranged = unmask_attention_mask( |
|
cur_vision_tower_aux_attention_masks_rearranged, image_size |
|
) |
|
cur_vision_tower_aux_attention_masks_rearranged = ( |
|
cur_vision_tower_aux_attention_masks_rearranged.view( |
|
1, query_side_len, reduce_factor, query_side_len, reduce_factor |
|
) |
|
.permute(0, 1, 3, 2, 4) |
|
.contiguous() |
|
) |
|
if unpad: |
|
cur_vision_tower_aux_attention_masks_rearranged = unpad_image( |
|
cur_vision_tower_aux_attention_masks_rearranged, image_size |
|
) |
|
cur_vision_tower_aux_attention_masks_rearranged = ( |
|
cur_vision_tower_aux_attention_masks_rearranged.flatten( |
|
0, 2 |
|
).flatten(1, 2) |
|
) |
|
|
|
cur_vision_tower_aux_attention_masks_rearranged[ |
|
cur_vision_tower_aux_attention_masks_rearranged.sum(-1) == 0 |
|
] = True |
|
|
|
vision_tower_aux_feature_rearranged.append( |
|
cur_vision_tower_aux_feature_rearranged |
|
) |
|
vision_tower_aux_attention_masks_rearranged.append( |
|
cur_vision_tower_aux_attention_masks_rearranged |
|
) |
|
|
|
vision_tower_aux_feature_rearranged = torch.cat( |
|
vision_tower_aux_feature_rearranged, 0 |
|
) |
|
vision_tower_aux_attention_masks_rearranged = torch.cat( |
|
vision_tower_aux_attention_masks_rearranged, 0 |
|
) |
|
|
|
vision_tower_aux_feature_rearranged_list.append( |
|
vision_tower_aux_feature_rearranged |
|
) |
|
vision_tower_aux_attention_masks_rearranged_list.append( |
|
vision_tower_aux_attention_masks_rearranged |
|
) |
|
|
|
return ( |
|
vision_tower_aux_feature_rearranged_list, |
|
vision_tower_aux_attention_masks_rearranged_list, |
|
) |
|
|
|
def encode_images(self, image_aux_list, encode_type=None): |
|
vision_tower_aux_list = self.get_model().get_vision_tower_aux_list() |
|
image_aux_features_list = [] |
|
chunk_size = 64 |
|
if encode_type == "dino": |
|
image_aux = image_aux_list[-1] |
|
vision_tower_aux = vision_tower_aux_list[-1] |
|
if image_aux.shape[0] > chunk_size: |
|
image_aux_features_chunks = [] |
|
for start_idx in range(0, image_aux.shape[0], chunk_size): |
|
end_idx = min(start_idx + chunk_size, image_aux.shape[0]) |
|
chunk = image_aux[start_idx:end_idx] |
|
image_aux_features_chunk = vision_tower_aux(chunk) |
|
image_aux_features_chunks.append(image_aux_features_chunk) |
|
image_aux_features = torch.cat(image_aux_features_chunks, dim=0) |
|
else: |
|
image_aux_features = vision_tower_aux(image_aux) |
|
return image_aux_features |
|
elif encode_type == "siglip": |
|
image_aux = image_aux_list[0] |
|
vision_tower_aux = vision_tower_aux_list[0] |
|
if image_aux.shape[0] > chunk_size: |
|
image_aux_features_chunks = [] |
|
for start_idx in range(0, image_aux.shape[0], chunk_size): |
|
end_idx = min(start_idx + chunk_size, image_aux.shape[0]) |
|
chunk = image_aux[start_idx:end_idx] |
|
image_aux_features_chunk = vision_tower_aux(chunk) |
|
image_aux_features_chunks.append(image_aux_features_chunk) |
|
image_aux_features = torch.cat(image_aux_features_chunks, dim=0) |
|
else: |
|
image_aux_features = vision_tower_aux(image_aux) |
|
return image_aux_features |
|
else: |
|
for image_aux, vision_tower_aux in zip( |
|
image_aux_list, vision_tower_aux_list |
|
): |
|
if image_aux.shape[0] > chunk_size: |
|
image_aux_features_chunks = [] |
|
for start_idx in range(0, image_aux.shape[0], chunk_size): |
|
end_idx = min(start_idx + chunk_size, image_aux.shape[0]) |
|
chunk = image_aux[start_idx:end_idx] |
|
image_aux_features_chunk = vision_tower_aux(chunk) |
|
image_aux_features_chunks.append(image_aux_features_chunk) |
|
image_aux_features = torch.cat(image_aux_features_chunks, dim=0) |
|
else: |
|
image_aux_features = vision_tower_aux(image_aux) |
|
image_aux_features_list.append(image_aux_features) |
|
return image_aux_features_list |
|
|
|
def select_frame( |
|
self, |
|
feature_list, |
|
split_sizes, |
|
input_ids, |
|
new_image_aux_list, |
|
image_sizes, |
|
window_size=16, |
|
threshold=0.83, |
|
): |
|
dino_features_batch = torch.split(feature_list, split_sizes, dim=0) |
|
new_image_aux_batch_0 = torch.split(new_image_aux_list[0], split_sizes, dim=0) |
|
new_image_aux_batch_1 = torch.split(new_image_aux_list[1], split_sizes, dim=0) |
|
new_split_sizes = [] |
|
selected_frames_all_0 = [] |
|
selected_frames_all_1 = [] |
|
selected_frames_feature_all = [] |
|
selected_frame_indices_all = [] |
|
for i_batch, frame_features in enumerate(dino_features_batch): |
|
try: |
|
if "llama" in self.get_model().config.model_type: |
|
text_len = torch.where(input_ids[i_batch] == 128002)[-1][0] |
|
else: |
|
text_len = torch.where(input_ids[i_batch] == 151643)[-1][0] |
|
except: |
|
text_len = len(input_ids[i_batch]) |
|
original_width, original_height = image_sizes[i_batch] |
|
if getattr(self.get_model().config, "highres", False): |
|
token_per_frame = self.get_model().config.lowres_token ** 2 |
|
else: |
|
token_per_frame = self.get_model().config.image_token_len |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_num_frames = max( |
|
1, |
|
( |
|
self.get_model().config.tokenizer_model_max_length |
|
- text_len |
|
- getattr(self.get_model().config, "inference_max_length", 16) |
|
) |
|
// token_per_frame, |
|
) |
|
if len(frame_features) < max_num_frames: |
|
selected_frames_all_0.append(new_image_aux_batch_0[i_batch]) |
|
selected_frames_all_1.append(new_image_aux_batch_1[i_batch]) |
|
selected_frames_feature_all.append(frame_features) |
|
new_split_sizes.append(len(frame_features)) |
|
selected_frame_indices_all.append(torch.arange(len(frame_features))) |
|
continue |
|
|
|
num_segments = len(frame_features) // window_size |
|
if num_segments == 0: |
|
query_feature = frame_features.flatten(1, 2) |
|
query_feature = query_feature / torch.norm( |
|
(query_feature), dim=1, keepdim=True |
|
) |
|
similarities = torch.mean(query_feature @ query_feature.T, dim=1) |
|
similarities[len(frame_features) // 2] = 0 |
|
indices = torch.where(similarities < threshold)[0] |
|
selected_frame_indices_all.append(indices) |
|
selected_frames_all_0.append(new_image_aux_batch_0[i_batch][indices]) |
|
selected_frames_all_1.append(new_image_aux_batch_1[i_batch][indices]) |
|
selected_frames_feature_all.append(frame_features[indices]) |
|
new_split_sizes.append(len(indices)) |
|
continue |
|
segments_frames_0 = [] |
|
segments_frames_1 = [] |
|
segments_features = [] |
|
for start_idx in range(0, len(frame_features), window_size): |
|
end_idx = min(start_idx + window_size, len(frame_features)) |
|
segments_frames_0.append( |
|
new_image_aux_batch_0[i_batch][start_idx:end_idx] |
|
) |
|
segments_frames_1.append( |
|
new_image_aux_batch_1[i_batch][start_idx:end_idx] |
|
) |
|
segments_features.append(frame_features[start_idx:end_idx]) |
|
selected_frames_0 = [] |
|
selected_frames_1 = [] |
|
selected_features = [] |
|
selected_frame_indices = [] |
|
for i, segment in enumerate(segments_features): |
|
query_feature = segment.flatten(1, 2) |
|
query_feature = query_feature / torch.norm( |
|
(query_feature), dim=1, keepdim=True |
|
) |
|
similarities = torch.mean(query_feature @ query_feature.T, dim=1) |
|
similarities[len(segment) // 2] = 0 |
|
indices = torch.where(similarities < threshold)[0] |
|
selected_frames_0.append(segments_frames_0[i][indices]) |
|
selected_frames_1.append(segments_frames_1[i][indices]) |
|
selected_features.append(segment[indices]) |
|
selected_frame_indices.extend(indices + i * window_size) |
|
selected_frames_0 = torch.cat(selected_frames_0, dim=0) |
|
selected_frames_1 = torch.cat(selected_frames_1, dim=0) |
|
selected_features = torch.cat(selected_features, dim=0) |
|
selected_frame_indices = torch.tensor(selected_frame_indices) |
|
|
|
max_num_frames = 400 |
|
if len(selected_frames_0) > max_num_frames: |
|
interval = len(selected_frames_0) / float(max_num_frames) |
|
indices = [int(interval * i) for i in range(max_num_frames)] |
|
new_split_sizes.append(len(indices)) |
|
selected_frames_all_0.append(selected_frames_0[indices]) |
|
selected_frames_all_1.append(selected_frames_1[indices]) |
|
selected_frames_feature_all.append(selected_features[indices]) |
|
selected_frame_indices = selected_frame_indices[indices] |
|
else: |
|
new_split_sizes.append(len(selected_frames_0)) |
|
selected_frames_all_0.append(selected_frames_0) |
|
selected_frames_all_1.append(selected_frames_1) |
|
selected_frames_feature_all.append(selected_features) |
|
selected_frame_indices_all.append(selected_frame_indices) |
|
selected_frames_all_0 = torch.cat(selected_frames_all_0, dim=0) |
|
selected_frames_all_1 = torch.cat(selected_frames_all_1, dim=0) |
|
selected_frames_feature_all = torch.cat(selected_frames_feature_all, dim=0) |
|
return ( |
|
selected_frames_feature_all, |
|
new_split_sizes, |
|
[selected_frames_all_0, selected_frames_all_1], |
|
selected_frame_indices_all, |
|
) |
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
self, |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
labels, |
|
images, |
|
image_aux_attention_masks_list=None, |
|
image_sizes=None, |
|
): |
|
|
|
vision_tower_aux_list = self.get_model().get_vision_tower_aux_list() |
|
if vision_tower_aux_list is None or images is None or input_ids.shape[1] == 1: |
|
return ( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
None, |
|
labels, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
image_aux_list = images |
|
|
|
split_sizes = None |
|
|
|
if type(image_aux_list[0]) is list or image_aux_list[0].ndim == 5: |
|
split_sizes_ori = [ |
|
1 if image.ndim == 3 else image.shape[0] for image in image_aux_list[0] |
|
] |
|
new_image_aux_list = [] |
|
for image_aux in image_aux_list: |
|
if type(image_aux) is list: |
|
image_aux = [ |
|
x.unsqueeze(0) if x.ndim == 3 else x for x in image_aux |
|
] |
|
concat_image_aux = torch.cat([image for image in image_aux], dim=0) |
|
new_image_aux_list.append(concat_image_aux) |
|
image_aux_features_dino = self.encode_images( |
|
new_image_aux_list, encode_type="dino" |
|
) |
|
|
|
( |
|
image_aux_features_dino, |
|
split_sizes, |
|
new_image_aux_list, |
|
selected_frame_indices_all, |
|
) = self.select_frame( |
|
image_aux_features_dino, |
|
split_sizes_ori, |
|
input_ids, |
|
new_image_aux_list, |
|
image_sizes, |
|
threshold=getattr(self.get_model().config, "dino_threshold", 0.83), |
|
) |
|
|
|
image_aux_features_siglip = self.encode_images( |
|
new_image_aux_list, encode_type="siglip" |
|
) |
|
image_aux_features_list = [ |
|
image_aux_features_siglip, |
|
image_aux_features_dino, |
|
] |
|
|
|
bs = image_aux_features_list[0].shape[0] |
|
dtype = new_image_aux_list[0].dtype |
|
|
|
frame_sizes = [] |
|
for i in range(len(image_sizes)): |
|
for j in range(split_sizes[i]): |
|
frame_sizes.append(image_sizes[i]) |
|
image_sizes = frame_sizes |
|
else: |
|
image_aux_features_list = self.encode_images(image_aux_list) |
|
bs = image_aux_list[0].shape[0] |
|
dtype = image_aux_list[0].dtype |
|
|
|
image_token_len = self.get_model().config.image_token_len |
|
query_num_list = self.get_model().config.query_num_list |
|
|
|
final_height = final_width = int(image_token_len**0.5) |
|
|
|
final_image_features_list = [] |
|
final_image_features_down_list = [] |
|
|
|
|
|
vision_tower_aux_feature_list_final = None |
|
vision_tower_aux_attention_masks_list_final = None |
|
global_context_feature_final = None |
|
|
|
if self.get_model().config.mm_projector_type == "sva": |
|
vision_tower_aux_feature_list = [] |
|
vision_tower_aux_attention_masks_list = [] |
|
|
|
for aux_i in range(len(vision_tower_aux_list)): |
|
image_aux_features = image_aux_features_list[aux_i] |
|
|
|
image_aux_features = getattr( |
|
self.get_model(), "mm_projector_aux_{}".format(aux_i) |
|
)(image_aux_features).to(dtype) |
|
if aux_i == 0: |
|
global_context_feature = image_aux_features.mean(1).view( |
|
bs, 1, 1, -1 |
|
) |
|
|
|
vision_tower_aux_feature_list.append(image_aux_features) |
|
input_mix_res = True |
|
input_high_res = True |
|
|
|
for query_group_i, query_num in enumerate(query_num_list): |
|
query_features_i = ( |
|
self.get_model() |
|
.vision_query[query_group_i, :] |
|
.view(1, 1, 1, -1) |
|
.expand(bs, query_num, -1, -1) |
|
) |
|
global_context_feature_i = global_context_feature.expand( |
|
-1, query_num, 1, -1 |
|
).flatten(0, 1) |
|
query_side_len = int(query_num**0.5) |
|
if IS_XLA_AVAILABLE: |
|
( |
|
vision_tower_aux_feature_list_i, |
|
vision_tower_aux_attention_masks_list_i, |
|
) = self.rearrange_vision_tower_features_train( |
|
vision_tower_aux_feature_list, |
|
image_aux_attention_masks_list, |
|
query_side_len, |
|
) |
|
else: |
|
( |
|
vision_tower_aux_feature_list_i, |
|
vision_tower_aux_attention_masks_list_i, |
|
) = self.rearrange_vision_tower_features_inference( |
|
vision_tower_aux_feature_list, query_side_len, image_sizes |
|
) |
|
|
|
query_features_i = getattr( |
|
self.get_model(), "vision_sampler_{}".format(query_group_i) |
|
)( |
|
query_features_i.flatten(0, 1), |
|
global_context_feature_i, |
|
*vision_tower_aux_feature_list_i, |
|
*vision_tower_aux_attention_masks_list_i, |
|
) |
|
query_features_i = query_features_i.view(bs, query_num, -1) |
|
|
|
if split_sizes is not None: |
|
try: |
|
if "llama" in self.get_model().config.model_type: |
|
text_len = torch.where(input_ids[0] == 128002)[-1][0] |
|
else: |
|
text_len = torch.where(input_ids[0] == 151643)[-1][0] |
|
except: |
|
text_len = len(input_ids[0]) |
|
max_visual_len = ( |
|
self.get_model().config.tokenizer_model_max_length |
|
- text_len |
|
- getattr(self.get_model().config, "inference_max_length", 16) |
|
) |
|
max_num_frames = max( |
|
1, |
|
math.floor(max_visual_len // (final_height * final_width)), |
|
) |
|
max_num_frames_low = max( |
|
1, |
|
math.floor( |
|
max_visual_len |
|
// (self.get_model().config.lowres_token ** 2) |
|
), |
|
) |
|
if split_sizes[0] < max_num_frames: |
|
input_mix_res = False |
|
elif split_sizes[0] > max_num_frames_low: |
|
input_mix_res = False |
|
input_high_res = False |
|
|
|
|
|
|
|
if (getattr(self.config, "highres", False)) and input_mix_res: |
|
_query_features_i = ( |
|
query_features_i.permute(0, 2, 1) |
|
.contiguous() |
|
.view(bs, -1, query_side_len, query_side_len) |
|
) |
|
_query_features_i = F.interpolate( |
|
_query_features_i.float(), |
|
size=( |
|
self.get_model().config.lowres_token, |
|
self.get_model().config.lowres_token, |
|
), |
|
mode="bilinear", |
|
align_corners=False, |
|
).to(dtype=query_features_i.dtype) |
|
_query_features_i = ( |
|
_query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2) |
|
) |
|
final_image_features_down_list.append(_query_features_i) |
|
|
|
|
|
if query_side_len != final_height: |
|
query_features_i = ( |
|
query_features_i.permute(0, 2, 1) |
|
.contiguous() |
|
.view(bs, -1, query_side_len, query_side_len) |
|
) |
|
if input_high_res: |
|
query_features_i = F.interpolate( |
|
query_features_i.float(), |
|
size=(final_height, final_width), |
|
mode="bilinear", |
|
align_corners=False, |
|
).to(dtype=query_features_i.dtype) |
|
else: |
|
query_features_i = F.interpolate( |
|
query_features_i.float(), |
|
size=(8, 8), |
|
mode="bilinear", |
|
align_corners=False, |
|
).to(dtype=query_features_i.dtype) |
|
query_features_i = ( |
|
query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2) |
|
) |
|
final_image_features_list.append(query_features_i) |
|
|
|
if IS_XLA_AVAILABLE: |
|
( |
|
vision_tower_aux_feature_list_final, |
|
vision_tower_aux_attention_masks_list_final, |
|
) = self.rearrange_vision_tower_features_train( |
|
vision_tower_aux_feature_list, |
|
image_aux_attention_masks_list, |
|
final_height, |
|
) |
|
global_context_feature_final = global_context_feature.expand( |
|
-1, final_height * final_width, 1, -1 |
|
).flatten(0, 1) |
|
else: |
|
final_image_features_list = image_aux_features_list |
|
|
|
image_features = torch.cat(final_image_features_list, -1) |
|
image_features = self.get_model().mm_projector(image_features).to(dtype) |
|
|
|
if (getattr(self.config, "highres", False)) and input_mix_res: |
|
image_features_down = torch.cat(final_image_features_down_list, -1) |
|
image_features_down = ( |
|
self.get_model().mm_projector(image_features_down).to(dtype) |
|
) |
|
|
|
if IS_XLA_AVAILABLE: |
|
image_features = image_features.view( |
|
image_features.shape[0], final_height, final_width, -1 |
|
) |
|
image_features = torch.cat( |
|
( |
|
image_features, |
|
self.model.image_newline[None, None, None, :].expand( |
|
image_features.shape[0], final_height, 1, -1 |
|
), |
|
), |
|
dim=2, |
|
) |
|
image_features = image_features.flatten(1, 2) |
|
final_size = [(final_height, final_width)] * bs |
|
|
|
else: |
|
image_features = image_features.view(bs, final_height, final_width, -1) |
|
if (getattr(self.config, "highres", False)) and input_mix_res: |
|
image_features_down = image_features_down.view( |
|
bs, |
|
self.get_model().config.lowres_token, |
|
self.get_model().config.lowres_token, |
|
-1, |
|
) |
|
image_features_unpadded = [] |
|
image_features_downsample = [] |
|
final_size = [] |
|
if self.get_model().config.mm_projector_type == "sva": |
|
( |
|
vision_tower_aux_feature_list_final, |
|
vision_tower_aux_attention_masks_list_final, |
|
) = self.rearrange_vision_tower_features_inference( |
|
vision_tower_aux_feature_list, final_height, image_sizes, unpad=True |
|
) |
|
global_context_feature_final = [] |
|
for batch_i in range(bs): |
|
cur_image_feature = image_features[batch_i] |
|
image_size = image_sizes[batch_i] |
|
|
|
cur_image_feature = unpad_image( |
|
cur_image_feature.unsqueeze(0), image_size |
|
) |
|
|
|
cur_h, cur_w = cur_image_feature.shape[1:3] |
|
try: |
|
cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1) |
|
final_size.append((cur_h, cur_w)) |
|
except: |
|
|
|
cur_image_feature = image_features[batch_i].unsqueeze(0) |
|
image_size = image_sizes[batch_i] |
|
cur_h, cur_w = cur_image_feature.shape[1:3] |
|
cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1) |
|
final_size.append((cur_h, cur_w)) |
|
|
|
if (getattr(self.config, "highres", False)) and input_mix_res: |
|
cur_image_feature_down = unpad_image( |
|
image_features_down[batch_i].unsqueeze(0), |
|
( |
|
int( |
|
image_size[0] |
|
/ ( |
|
image_token_len**0.5 |
|
/ self.get_model().config.lowres_token |
|
) |
|
), |
|
int( |
|
image_size[1] |
|
/ ( |
|
image_token_len**0.5 |
|
/ self.get_model().config.lowres_token |
|
) |
|
), |
|
), |
|
) |
|
_cur_h, _cur_w = cur_image_feature_down.shape[1:3] |
|
|
|
try: |
|
cur_image_feature_down = cur_image_feature_down.view( |
|
1, _cur_h, _cur_w, -1 |
|
) |
|
except: |
|
print("invalid after unpad", flush=True) |
|
cur_image_feature_down = image_features_down[batch_i].unsqueeze( |
|
0 |
|
) |
|
_cur_h, _cur_w = cur_image_feature_down.shape[1:3] |
|
cur_image_feature_down = cur_image_feature_down.view( |
|
1, _cur_h, _cur_w, -1 |
|
) |
|
|
|
cur_image_feature_down = torch.cat( |
|
( |
|
cur_image_feature_down, |
|
self.model.image_newline.view(1, 1, 1, -1) |
|
.expand(1, _cur_h, 1, -1) |
|
.to(cur_image_feature_down.device), |
|
), |
|
dim=2, |
|
).flatten(1, 2) |
|
|
|
if split_sizes is None and getattr(self.config, "frame_pos", False): |
|
frame_pos = ( |
|
self.get_model() |
|
.get_frame_pos(torch.arange(1)) |
|
.to(cur_image_feature_down.device) |
|
.to(cur_image_feature_down.dtype) |
|
) |
|
cur_image_feature_down += frame_pos |
|
|
|
image_features_downsample.append(cur_image_feature_down.squeeze(0)) |
|
|
|
cur_image_feature = torch.cat( |
|
( |
|
cur_image_feature, |
|
self.model.image_newline.view(1, 1, 1, -1) |
|
.expand(1, cur_h, 1, -1) |
|
.to(cur_image_feature.device), |
|
), |
|
dim=2, |
|
) |
|
|
|
if split_sizes is None and getattr(self.config, "frame_pos", False): |
|
frame_pos = ( |
|
self.get_model() |
|
.get_frame_pos(torch.arange(1)) |
|
.to(cur_image_feature.device) |
|
.to(cur_image_feature.dtype) |
|
) |
|
cur_image_feature += frame_pos |
|
|
|
cur_image_feature = cur_image_feature.flatten(1, 2) |
|
image_features_unpadded.append(cur_image_feature.squeeze(0)) |
|
|
|
if self.get_model().config.mm_projector_type == "sva": |
|
cur_global_context_feature = global_context_feature[batch_i].expand( |
|
cur_h * cur_w, 1, -1 |
|
) |
|
global_context_feature_final.append(cur_global_context_feature) |
|
if self.get_model().config.mm_projector_type == "sva": |
|
global_context_feature_final = torch.cat( |
|
global_context_feature_final, 0 |
|
) |
|
|
|
if (getattr(self.config, "highres", False)) and input_mix_res: |
|
image_features = image_features_downsample |
|
else: |
|
image_features = image_features_unpadded |
|
|
|
|
|
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr( |
|
self.config, "mm_use_im_start_end", False |
|
): |
|
raise NotImplementedError |
|
|
|
split_image_features_unpadded = None |
|
frame_split_sizes = None |
|
|
|
if split_sizes is not None: |
|
split_image_features = [] |
|
split_image_features_unpadded = ( |
|
[] |
|
if (getattr(self.config, "highres", False)) and input_mix_res |
|
else None |
|
) |
|
start_idx = 0 |
|
for split_batch_idx, split_size in enumerate(split_sizes): |
|
if isinstance(image_features[start_idx : start_idx + split_size], list): |
|
if getattr(self.config, "frame_pos", False): |
|
frame_feature = torch.cat( |
|
image_features[start_idx : start_idx + split_size], dim=0 |
|
).reshape(split_size, -1, image_features[0].shape[-1]) |
|
frame_pos = ( |
|
self.get_model() |
|
.get_frame_pos(selected_frame_indices_all[split_batch_idx]) |
|
.to(frame_feature.device) |
|
.to(frame_feature.dtype) |
|
) |
|
frame_feature += frame_pos |
|
split_image_features.append( |
|
frame_feature.reshape(-1, image_features[0].shape[-1]) |
|
) |
|
else: |
|
split_image_features.append( |
|
torch.cat( |
|
image_features[start_idx : start_idx + split_size], |
|
dim=0, |
|
) |
|
) |
|
if (getattr(self.config, "highres", False)) and input_mix_res: |
|
if getattr(self.config, "frame_pos", False): |
|
frame_feature = torch.cat( |
|
image_features_unpadded[ |
|
start_idx : start_idx + split_size |
|
], |
|
dim=0, |
|
).reshape(split_size, -1, image_features[0].shape[-1]) |
|
frame_pos = ( |
|
self.get_model() |
|
.get_frame_pos( |
|
selected_frame_indices_all[split_batch_idx] |
|
) |
|
.to(frame_feature.device) |
|
.to(frame_feature.dtype) |
|
) |
|
frame_feature += frame_pos |
|
split_image_features_unpadded.append( |
|
frame_feature.reshape(-1, image_features[0].shape[-1]) |
|
) |
|
else: |
|
split_image_features_unpadded.append( |
|
torch.cat( |
|
image_features_unpadded[ |
|
start_idx : start_idx + split_size |
|
], |
|
dim=0, |
|
) |
|
) |
|
else: |
|
if getattr(self.config, "frame_pos", False): |
|
frame_feature = image_features[ |
|
start_idx : start_idx + split_size |
|
].reshape(split_size, -1, image_features[0].shape[-1]) |
|
frame_pos = ( |
|
self.get_model() |
|
.get_frame_pos(selected_frame_indices_all[split_batch_idx]) |
|
.to(frame_feature.device) |
|
.to(frame_feature.dtype) |
|
) |
|
frame_feature += frame_pos |
|
split_image_features.append( |
|
frame_feature.reshape(-1, image_features[0].shape[-1]) |
|
) |
|
else: |
|
split_image_features.append( |
|
image_features[start_idx : start_idx + split_size] |
|
) |
|
if (getattr(self.config, "highres", False)) and input_mix_res: |
|
if getattr(self.config, "frame_pos", False): |
|
frame_feature = image_features_unpadded[ |
|
start_idx : start_idx + split_size |
|
] |
|
frame_pos = ( |
|
self.get_model() |
|
.get_frame_pos( |
|
selected_frame_indices_all[split_batch_idx] |
|
) |
|
.to(frame_feature.device) |
|
.to(frame_feature.dtype) |
|
) |
|
frame_feature += frame_pos |
|
split_image_features_unpadded.append( |
|
frame_feature.reshape(-1, image_features[0].shape[-1]) |
|
) |
|
else: |
|
split_image_features_unpadded.append( |
|
image_features_unpadded[ |
|
start_idx : start_idx + split_size |
|
] |
|
) |
|
start_idx += split_size |
|
image_features = split_image_features |
|
frame_split_sizes = split_sizes |
|
|
|
_labels = labels |
|
_position_ids = position_ids |
|
_attention_mask = attention_mask |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
else: |
|
attention_mask = attention_mask.bool() |
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device |
|
) |
|
if labels is None: |
|
labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
|
|
|
_input_ids = input_ids |
|
|
|
attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX) |
|
|
|
input_ids = [ |
|
cur_input_ids[cur_attention_mask] |
|
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
|
] |
|
labels = [ |
|
cur_labels[cur_attention_mask] |
|
for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
|
] |
|
|
|
new_input_embeds = [] |
|
new_labels = [] |
|
image_token_indices_batch = [] |
|
cur_image_idx = 0 |
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
|
if num_images == 0: |
|
cur_image_features = image_features[cur_image_idx] |
|
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) |
|
cur_input_embeds = torch.cat( |
|
[cur_input_embeds_1, cur_image_features[0:0]], dim=0 |
|
) |
|
new_input_embeds.append(cur_input_embeds) |
|
new_labels.append(labels[batch_idx]) |
|
cur_image_idx += 1 |
|
continue |
|
|
|
image_token_indices = ( |
|
[-1] |
|
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() |
|
+ [cur_input_ids.shape[0]] |
|
) |
|
image_token_indices_batch.append( |
|
torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()[0] |
|
) |
|
cur_input_ids_noim = [] |
|
cur_labels = labels[batch_idx] |
|
cur_labels_noim = [] |
|
for i in range(len(image_token_indices) - 1): |
|
cur_input_ids_noim.append( |
|
cur_input_ids[ |
|
image_token_indices[i] + 1 : image_token_indices[i + 1] |
|
] |
|
) |
|
cur_labels_noim.append( |
|
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]] |
|
) |
|
split_sizes = [x.shape[0] for x in cur_labels_noim] |
|
cur_input_embeds = self.get_model().embed_tokens( |
|
torch.cat(cur_input_ids_noim) |
|
) |
|
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) |
|
cur_new_input_embeds = [] |
|
cur_new_labels = [] |
|
|
|
text_len = sum([x.shape[0] for x in cur_input_embeds_no_im]) |
|
visual_len = len(image_features[cur_image_idx]) |
|
max_visual_len = ( |
|
self.get_model().config.tokenizer_model_max_length |
|
- getattr(self.get_model().config, "inference_max_length", 16) |
|
- text_len |
|
) |
|
mix_token = False |
|
|
|
|
|
if ( |
|
input_mix_res |
|
and ( |
|
self.get_model().config.image_token_len |
|
> getattr(self.get_model().config, "lowres_token", 8) ** 2 |
|
) |
|
and frame_split_sizes is not None |
|
and getattr(self.config, "highres", False) |
|
): |
|
if max_visual_len > visual_len: |
|
visual_emb = image_features[cur_image_idx] |
|
text_emb = cur_input_embeds_no_im[-1] |
|
highres_num = math.floor( |
|
(max_visual_len - visual_len) |
|
/ ( |
|
split_image_features_unpadded[cur_image_idx].shape[0] |
|
// frame_split_sizes[cur_image_idx] |
|
- visual_emb.shape[0] // frame_split_sizes[cur_image_idx] |
|
) |
|
) |
|
if highres_num >= 1: |
|
mix_token = True |
|
sim = torch.matmul(visual_emb, text_emb.transpose(0, 1)).mean( |
|
dim=-1 |
|
) |
|
sim_frame = sim.reshape( |
|
frame_split_sizes[cur_image_idx], -1 |
|
).mean(dim=-1) |
|
highres_num = min(highres_num, sim_frame.shape[0]) |
|
top_values, top_indices = torch.topk(sim_frame, highres_num) |
|
if len(top_indices) > 0: |
|
sorted_indices = torch.sort(top_indices)[1] |
|
top_indices = top_indices[sorted_indices] |
|
visual_emb_frame = image_features[cur_image_idx].reshape( |
|
frame_split_sizes[cur_image_idx], |
|
-1, |
|
image_features[cur_image_idx].shape[-1], |
|
) |
|
visual_emb_frame_highres = split_image_features_unpadded[ |
|
cur_image_idx |
|
].reshape( |
|
frame_split_sizes[cur_image_idx], |
|
-1, |
|
split_image_features_unpadded[cur_image_idx].shape[-1], |
|
) |
|
current_point = 0 |
|
mix_visual_emb_frame = [] |
|
for frame_i in range(len(visual_emb_frame)): |
|
if current_point > len(top_indices) - 1: |
|
mix_visual_emb_frame.append( |
|
visual_emb_frame[frame_i] |
|
) |
|
continue |
|
if frame_i == top_indices[current_point]: |
|
mix_visual_emb_frame.append( |
|
visual_emb_frame_highres[frame_i] |
|
) |
|
current_point += 1 |
|
else: |
|
mix_visual_emb_frame.append( |
|
visual_emb_frame[frame_i] |
|
) |
|
image_features[cur_image_idx] = torch.cat( |
|
mix_visual_emb_frame, dim=0 |
|
) |
|
|
|
|
|
if ( |
|
max_visual_len < visual_len |
|
and frame_split_sizes is not None |
|
and not mix_token |
|
): |
|
visual_emb_frame = image_features[cur_image_idx].reshape( |
|
frame_split_sizes[cur_image_idx], |
|
-1, |
|
image_features[cur_image_idx].shape[-1], |
|
) |
|
|
|
sim = F.cosine_similarity( |
|
visual_emb_frame[:-1], |
|
visual_emb_frame[1:], |
|
dim=-1, |
|
) |
|
|
|
new_visual_emb_frames = [] |
|
for start_idx in range(0, len(visual_emb_frame), 8): |
|
end_idx = min(start_idx + 8, len(visual_emb_frame)) |
|
chunk_feature = visual_emb_frame[start_idx:end_idx] |
|
if len(chunk_feature) == 1: |
|
new_visual_emb_frames.append(chunk_feature[0]) |
|
continue |
|
sim = F.cosine_similarity( |
|
chunk_feature[0] |
|
.unsqueeze(0) |
|
.repeat_interleave(len(chunk_feature[1:]), dim=0), |
|
chunk_feature[1:], |
|
dim=-1, |
|
) |
|
new_visual_emb_frame = torch.cat( |
|
[ |
|
chunk_feature[0], |
|
chunk_feature[1:].flatten(0, 1)[ |
|
sim.flatten(0, 1) |
|
< getattr( |
|
self.get_model().config, "drop_threshold", 0.7 |
|
) |
|
], |
|
], |
|
dim=0, |
|
) |
|
new_visual_emb_frames.append(new_visual_emb_frame) |
|
|
|
reduced_visual_len = sum([x.shape[0] for x in new_visual_emb_frames]) |
|
|
|
if reduced_visual_len > max_visual_len: |
|
force_remove = math.ceil( |
|
(reduced_visual_len - max_visual_len) |
|
/ len(new_visual_emb_frames) |
|
) |
|
for chunk_i in range(len(new_visual_emb_frames)): |
|
new_visual_emb_frames[chunk_i] = new_visual_emb_frames[chunk_i][ |
|
:-force_remove |
|
] |
|
new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0) |
|
else: |
|
new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0) |
|
|
|
image_features[cur_image_idx] = new_visual_emb_frames[:max_visual_len] |
|
|
|
for i in range(num_images + 1): |
|
cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
|
cur_new_labels.append(cur_labels_noim[i]) |
|
if i < num_images: |
|
cur_image_features = image_features[cur_image_idx] |
|
cur_image_idx += 1 |
|
cur_new_input_embeds.append(cur_image_features) |
|
cur_new_labels.append( |
|
torch.full( |
|
(cur_image_features.shape[0],), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype, |
|
) |
|
) |
|
|
|
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] |
|
|
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
|
cur_new_labels = torch.cat(cur_new_labels) |
|
|
|
new_input_embeds.append(cur_new_input_embeds) |
|
new_labels.append(cur_new_labels) |
|
|
|
|
|
tokenizer_model_max_length = getattr( |
|
self.config, "tokenizer_model_max_length", None |
|
) |
|
if tokenizer_model_max_length is not None: |
|
new_input_embeds = [ |
|
x[:tokenizer_model_max_length] for x in new_input_embeds |
|
] |
|
new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
|
|
|
|
|
max_len = max(x.shape[0] for x in new_input_embeds) |
|
batch_size = len(new_input_embeds) |
|
|
|
new_input_embeds_padded = [] |
|
new_labels_padded = torch.full( |
|
(batch_size, max_len), |
|
IGNORE_INDEX, |
|
dtype=new_labels[0].dtype, |
|
device=new_labels[0].device, |
|
) |
|
attention_mask = torch.zeros( |
|
(batch_size, max_len), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
) |
|
position_ids = torch.zeros( |
|
(batch_size, max_len), |
|
dtype=position_ids.dtype, |
|
device=position_ids.device, |
|
) |
|
|
|
for i, (cur_new_embed, cur_new_labels) in enumerate( |
|
zip(new_input_embeds, new_labels) |
|
): |
|
cur_len = cur_new_embed.shape[0] |
|
if getattr(self.config, "tokenizer_padding_side", "right") == "left": |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
cur_new_embed, |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, -cur_len:] = cur_new_labels |
|
attention_mask[i, -cur_len:] = True |
|
position_ids[i, -cur_len:] = torch.arange( |
|
0, |
|
cur_len, |
|
dtype=position_ids.dtype, |
|
device=position_ids.device, |
|
) |
|
else: |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
cur_new_embed, |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
attention_mask[i, :cur_len] = True |
|
position_ids[i, :cur_len] = torch.arange( |
|
0, |
|
cur_len, |
|
dtype=position_ids.dtype, |
|
device=position_ids.device, |
|
) |
|
|
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
|
if _labels is None: |
|
new_labels = None |
|
else: |
|
new_labels = new_labels_padded |
|
|
|
if _attention_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
if _position_ids is None: |
|
position_ids = None |
|
|
|
return ( |
|
None, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
new_input_embeds, |
|
new_labels, |
|
vision_tower_aux_feature_list_final, |
|
vision_tower_aux_attention_masks_list_final, |
|
final_size, |
|
global_context_feature_final, |
|
) |
|
|
|
def initialize_vision_tokenizer(self, model_args, tokenizer): |
|
if model_args.mm_use_im_patch_token: |
|
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) |
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
if model_args.mm_use_im_start_end: |
|
num_new_tokens = tokenizer.add_tokens( |
|
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True |
|
) |
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = self.get_input_embeddings().weight.data |
|
output_embeddings = self.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True |
|
) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True |
|
) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
if model_args.tune_mm_mlp_adapter: |
|
for p in self.get_input_embeddings().parameters(): |
|
p.requires_grad = True |
|
for p in self.get_output_embeddings().parameters(): |
|
p.requires_grad = False |
|
|
|
if model_args.pretrain_mm_mlp_adapter: |
|
mm_projector_weights = torch.load( |
|
model_args.pretrain_mm_mlp_adapter, map_location="cpu" |
|
) |
|
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] |
|
assert num_new_tokens == 2 |
|
if input_embeddings.shape == embed_tokens_weight.shape: |
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight[ |
|
-num_new_tokens: |
|
] |
|
elif embed_tokens_weight.shape[0] == num_new_tokens: |
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight |
|
else: |
|
raise ValueError( |
|
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}." |
|
) |
|
elif model_args.mm_use_im_patch_token: |
|
if model_args.tune_mm_mlp_adapter: |
|
for p in self.get_input_embeddings().parameters(): |
|
p.requires_grad = False |
|
for p in self.get_output_embeddings().parameters(): |
|
p.requires_grad = False |