videollm-online / models /vision_live.py
chenjoya's picture
Upload 9 files
7d1b5a5 verified
raw
history blame
No virus
3.16 kB
import math, torch
from functools import partial
from torch import nn, Tensor
from torchvision.transforms.functional import normalize
from transformers import AutoModel
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .configuration_live import LiveConfigMixin
def _siglip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple,
mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], rescale_factor=0.00392156862745098, **kwargs):
frames = normalize(frames * rescale_factor, mean=mean, std=std)
with torch.cuda.amp.autocast():
vision_outputs = vision_model(frames)
last_hidden_state = vision_outputs.last_hidden_state
if frame_token_pooled:
s = int(math.sqrt(last_hidden_state.shape[1]))
spatial_tokens = torch.nn.functional.adaptive_avg_pool2d(
last_hidden_state.reshape(
last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1]
).permute(0, 3, 1, 2),
frame_token_pooled
).flatten(2, 3).permute(0, 2, 1)
if not frame_token_cls:
return spatial_tokens
if frame_token_cls:
cls_token = vision_outputs.pooler_output[:, None]
if not frame_token_pooled:
return cls_token
return torch.cat([cls_token, spatial_tokens], dim=1)
def _clip_vision_encode(vision_model: nn.Module, frames: Tensor, frame_token_cls: bool, frame_token_pooled: tuple,
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, rescale_factor=0.00392156862745098, **kwargs):
frames = normalize(frames * rescale_factor, mean=mean, std=std)
with torch.cuda.amp.autocast():
vision_outputs = vision_model(frames)
last_hidden_state = vision_outputs.last_hidden_state
if frame_token_pooled:
s = int(math.sqrt(last_hidden_state.shape[1]))
spatial_tokens = torch.nn.functional.adaptive_avg_pool2d(
last_hidden_state[:,1:].reshape(
last_hidden_state.shape[0], s, s, last_hidden_state.shape[-1]
).permute(0, 3, 1, 2),
frame_token_pooled
).flatten(2, 3).permute(0, 2, 1)
if not frame_token_cls:
return spatial_tokens
if frame_token_cls:
cls_token = last_hidden_state[:,0]
if not frame_token_pooled:
return cls_token
return torch.cat([cls_token, spatial_tokens], dim=1)
def build_live_vision(config: LiveConfigMixin):
model = AutoModel.from_pretrained(config.vision_pretrained).vision_model
if 'google/siglip-large-patch16-384' == config.vision_pretrained:
return model, partial(_siglip_vision_encode, frame_token_cls=config.frame_token_cls, frame_token_pooled=config.frame_token_pooled)
elif 'laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90k' == config.vision_pretrained or 'openai/clip-vit-large-patch14-336' == config.vision_pretrained:
return model, partial(_clip_vision_encode, config)
else:
raise ValueError(f'Unverified vision_pretrained: {config.vision_pretrained}')