ClownRat's picture
v1 demo.
000c55e
# Copyright 2024 Alibaba DAMO Academy
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.regnet import RegStage
from timm.models.layers import LayerNorm2d
from transformers import TRANSFORMERS_CACHE
def parse_snapshot_folder(repo_id, cache_dir=None, repo_type="model"):
revision = "main"
# 1. parse the downloaded cache folder
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
else:
cache_dir = cache_dir
object_id = repo_id.replace("/", "--")
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
# 2. resolve refs (for instance to convert main to the associated commit sha)
refs_dir = os.path.join(repo_cache, "refs")
if os.path.isdir(refs_dir):
revision_file = os.path.join(refs_dir, revision)
if os.path.isfile(revision_file):
with open(revision_file) as f:
revision = f.read()
# 3. acquire the snapshot folder
folder = os.path.join(repo_cache, "snapshots", revision)
return folder
def load_mm_projector(model_path, cache_dir=None, token=None):
if os.path.exists(os.path.join(model_path, 'mm_projector.bin')):
is_local = True
folder = model_path
else:
is_local = False
folder = parse_snapshot_folder(model_path, cache_dir=cache_dir, repo_type="model")
if not os.path.exists(os.path.join(folder, 'mm_projector.bin')):
# downloading from remote repo
from huggingface_hub import snapshot_download
snapshot_download(repo_id=model_path, cache_dir=cache_dir, token=token)
mm_projector_weights = torch.load(os.path.join(folder, 'mm_projector.bin'), map_location='cpu')
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
return mm_projector_weights
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == "linear":
# NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features
return nn.Linear(config.mm_hidden_size, config.hidden_size)
elif projector_type == "stc_connector":
return STCConnector(config)
elif projector_type == "stp_connector":
return STPConnector(config)
elif projector_type == "stc_connector_v35":
return STCConnectorV35(config)
elif projector_type == "spatial_conv":
return SpatialConv(config)
elif projector_type == "spatial_pool":
return SpatialPool(config)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
def build_mlp(depth, hidden_size, output_hidden_size):
modules = [nn.Linear(hidden_size, output_hidden_size)]
for _ in range(1, depth):
modules.append(nn.GELU())
modules.append(nn.Linear(output_hidden_size, output_hidden_size))
return nn.Sequential(*modules)
class STCConnector(nn.Module):
def __init__(self, config, downsample=(2, 2, 2), depth=4, mlp_depth=2):
"""Temporal Convolutional Vision-Language Connector.
Args:
config: config object.
downsample: (temporal, height, width) downsample rate.
depth: depth of the spatial interaction blocks.
mlp_depth: depth of the vision-language projector layers.
"""
super().__init__()
self.encoder_hidden_size = encoder_hidden_size = config.mm_hidden_size
self.hidden_size = hidden_size = config.hidden_size
self.output_hidden_size = output_hidden_size = config.hidden_size
# TODO: make these as config arguments
self.depth = depth
self.mlp_depth = mlp_depth
self.downsample = downsample
if depth != 0:
self.s1 = RegStage(
depth=depth,
in_chs=encoder_hidden_size,
out_chs=hidden_size,
stride=1,
dilation=1,
act_layer=nn.SiLU,
norm_layer=LayerNorm2d,
)
else:
self.s1 = nn.Identity()
self.sampler = nn.Sequential(
nn.Conv3d(
in_channels=hidden_size,
out_channels=hidden_size,
kernel_size=downsample,
stride=downsample,
padding=1,
bias=True
),
nn.SiLU()
)
if depth != 0:
self.s2 = RegStage(
depth=depth,
in_chs=hidden_size,
out_chs=hidden_size,
stride=1,
dilation=1,
act_layer=nn.SiLU,
norm_layer=LayerNorm2d,
)
else:
self.s2 = nn.Identity()
self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
def forward(self, x):
"""Aggregate tokens on the temporal and spatial dimensions.
Args:
x: input tokens [b, t, h, w, d] / [b, t, l, d]
Returns:
aggregated tokens [b, l, d]
"""
t = x.size(1)
if x.ndim == 4:
hw = int(x.size(2) ** 0.5)
x = einops.rearrange(x, "b t (h w) d -> b d t h w", h=hw, w=hw)
elif x.ndim == 5:
x = einops.rearrange(x, "b t h w d -> b d t h w")
x = einops.rearrange(x, "b d t h w -> (b t) d h w")
# 1. the first stage of the adapter
x = self.s1(x)
x = einops.rearrange(x, "(b t) d h w -> b d t h w", t=t)
# 2. downsampler
x = self.sampler(x)
new_t = x.size(2)
# 3. the second stage of the adapter
x = einops.rearrange(x, "b d t h w -> (b t) d h w")
x = self.s2(x)
x = einops.rearrange(x, "(b t) d h w -> b (t h w) d", t=new_t)
x = self.readout(x)
return x
class STPConnector(STCConnector):
def __init__(self, config, downsample=(2, 2, 2), depth=4, mlp_depth=2):
super().__init__(config=config, downsample=downsample, depth=depth, mlp_depth=mlp_depth)
self.sampler = nn.Sequential(nn.AvgPool3d(downsample), nn.SiLU())
class STCConnectorV35(STCConnector):
def __init__(self, config, downsample=(2, 2, 2), depth=4, mlp_depth=2):
super().__init__(config=config, downsample=downsample, depth=depth, mlp_depth=mlp_depth)
self.sampler = nn.Sequential(
nn.Conv3d(
in_channels=self.hidden_size,
out_channels=self.hidden_size,
kernel_size=downsample,
stride=downsample,
padding=0,
bias=True
),
nn.SiLU())
class SpatialConv(STCConnector):
def __init__(self, config, downsample=(1, 2, 2), depth=0, mlp_depth=2):
super().__init__(config=config, downsample=downsample, depth=depth, mlp_depth=mlp_depth)
class SpatialPool(STPConnector):
def __init__(self, config, downsample=(1, 2, 2), depth=0, mlp_depth=2):
super().__init__(config=config, downsample=downsample, depth=depth, mlp_depth=mlp_depth)