Spaces:
Sleeping
Sleeping
from safetensors import safe_open | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from timm.models.layers import to_2tuple | |
from timm.models.vision_transformer import Block | |
# Taken and adapted from Pritvhi `geospatial_fm.py`, for the purpose of avoiding MMCV/MMSegmentation dependencies | |
def _convTranspose2dOutput( | |
input_size: int, | |
stride: int, | |
padding: int, | |
dilation: int, | |
kernel_size: int, | |
output_padding: int, | |
): | |
""" | |
Calculate the output size of a ConvTranspose2d. | |
Taken from: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html | |
""" | |
return ( | |
(input_size - 1) * stride | |
- 2 * padding | |
+ dilation * (kernel_size - 1) | |
+ output_padding | |
+ 1 | |
) | |
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor): | |
""" | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
omega = np.arange(embed_dim // 2, dtype=np.float32) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
def get_3d_sincos_pos_embed(embed_dim: int, grid_size: tuple, cls_token: bool = False): | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# Position embedding utils | |
# -------------------------------------------------------- | |
""" | |
grid_size: 3d tuple of grid size: t, h, w | |
return: | |
pos_embed: L, D | |
""" | |
assert embed_dim % 16 == 0 | |
t_size, h_size, w_size = grid_size | |
w_embed_dim = embed_dim // 16 * 6 | |
h_embed_dim = embed_dim // 16 * 6 | |
t_embed_dim = embed_dim // 16 * 4 | |
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) | |
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) | |
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) | |
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) | |
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) | |
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) | |
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) | |
if cls_token: | |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
class Norm2d(nn.Module): | |
def __init__(self, embed_dim: int): | |
super().__init__() | |
self.ln = nn.LayerNorm(embed_dim, eps=1e-6) | |
def forward(self, x): | |
x = x.permute(0, 2, 3, 1) | |
x = self.ln(x) | |
x = x.permute(0, 3, 1, 2).contiguous() | |
return x | |
class PatchEmbed(nn.Module): | |
"""Frames of 2D Images to Patch Embedding | |
The 3D version of timm.models.vision_transformer.PatchEmbed | |
""" | |
def __init__( | |
self, | |
img_size: int = 224, | |
patch_size: int = 16, | |
num_frames: int = 3, | |
tubelet_size: int = 1, | |
in_chans: int = 3, | |
embed_dim: int = 768, | |
norm_layer: nn.Module = None, | |
flatten: bool = True, | |
bias: bool = True, | |
): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.num_frames = num_frames | |
self.tubelet_size = tubelet_size | |
self.grid_size = ( | |
num_frames // tubelet_size, | |
img_size[0] // patch_size[0], | |
img_size[1] // patch_size[1], | |
) | |
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] | |
self.flatten = flatten | |
self.proj = nn.Conv3d( | |
in_chans, | |
embed_dim, | |
kernel_size=(tubelet_size, patch_size[0], patch_size[1]), | |
stride=(tubelet_size, patch_size[0], patch_size[1]), | |
bias=bias, | |
) | |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
def forward(self, x): | |
B, C, T, H, W = x.shape | |
assert ( | |
H == self.img_size[0] | |
), f"Input image height ({H}) doesn't match model ({self.img_size[0]})." | |
assert ( | |
W == self.img_size[1] | |
), f"Input image width ({W}) doesn't match model ({self.img_size[1]})." | |
x = self.proj(x) | |
Hp, Wp = x.shape[3], x.shape[4] | |
if self.flatten: | |
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C | |
x = self.norm(x) | |
return x, Hp, Wp | |
class ConvTransformerTokensToEmbeddingNeck(nn.Module): | |
""" | |
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. | |
Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2 | |
""" | |
def __init__( | |
self, | |
embed_dim: int, | |
output_embed_dim: int, | |
# num_frames: int = 1, | |
Hp: int = 14, | |
Wp: int = 14, | |
drop_cls_token: bool = True, | |
): | |
""" | |
Args: | |
embed_dim (int): Input embedding dimension | |
output_embed_dim (int): Output embedding dimension | |
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14. | |
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14. | |
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True. | |
""" | |
super().__init__() | |
self.drop_cls_token = drop_cls_token | |
self.Hp = Hp | |
self.Wp = Wp | |
self.H_out = Hp | |
self.W_out = Wp | |
# self.num_frames = num_frames | |
kernel_size = 2 | |
stride = 2 | |
dilation = 1 | |
padding = 0 | |
output_padding = 0 | |
for _ in range(4): | |
self.H_out = _convTranspose2dOutput( | |
self.H_out, stride, padding, dilation, kernel_size, output_padding | |
) | |
self.W_out = _convTranspose2dOutput( | |
self.W_out, stride, padding, dilation, kernel_size, output_padding | |
) | |
self.embed_dim = embed_dim | |
self.output_embed_dim = output_embed_dim | |
self.fpn1 = nn.Sequential( | |
nn.ConvTranspose2d( | |
self.embed_dim, | |
self.output_embed_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding, | |
output_padding=output_padding, | |
), | |
Norm2d(self.output_embed_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d( | |
self.output_embed_dim, | |
self.output_embed_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding, | |
output_padding=output_padding, | |
), | |
) | |
self.fpn2 = nn.Sequential( | |
nn.ConvTranspose2d( | |
self.output_embed_dim, | |
self.output_embed_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding, | |
output_padding=output_padding, | |
), | |
Norm2d(self.output_embed_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d( | |
self.output_embed_dim, | |
self.output_embed_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=padding, | |
output_padding=output_padding, | |
), | |
) | |
def forward(self, x): | |
x = x[0] | |
if self.drop_cls_token: | |
x = x[:, 1:, :] | |
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp) | |
x = self.fpn1(x) | |
x = self.fpn2(x) | |
x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out)) | |
out = tuple([x]) | |
return out | |
class ConvTransformerTokensToEmbeddingBottleneckNeck(nn.Module): | |
""" | |
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. | |
Performs ConvTranspose2d operations with bottleneck layers to reduce channels. | |
""" | |
def __init__( | |
self, | |
embed_dim: int, | |
output_embed_dim: int, | |
Hp: int = 14, | |
Wp: int = 14, | |
drop_cls_token: bool = True, | |
bottleneck_reduction_factor: int = 4, | |
): | |
""" | |
Args: | |
embed_dim (int): Input embedding dimension | |
output_embed_dim (int): Output embedding dimension | |
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14. | |
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14. | |
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. Defaults to True. | |
bottleneck_ratio (int, optional): Ratio to reduce channels in bottleneck layers. Defaults to 4. | |
""" | |
super().__init__() | |
self.drop_cls_token = drop_cls_token | |
self.Hp = Hp | |
self.Wp = Wp | |
self.H_out = Hp | |
self.W_out = Wp | |
kernel_size = 2 | |
stride = 2 | |
dilation = 1 | |
padding = 0 | |
output_padding = 0 | |
for _ in range(4): | |
self.H_out = _convTranspose2dOutput( | |
self.H_out, stride, padding, dilation, kernel_size, output_padding | |
) | |
self.W_out = _convTranspose2dOutput( | |
self.W_out, stride, padding, dilation, kernel_size, output_padding | |
) | |
self.embed_dim = embed_dim | |
self.output_embed_dim = output_embed_dim | |
bottleneck_dim = self.embed_dim // bottleneck_reduction_factor | |
self.fpn1 = nn.Sequential( | |
nn.Conv2d( | |
self.embed_dim, | |
bottleneck_dim, | |
kernel_size=1 | |
), | |
Norm2d(bottleneck_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d( | |
bottleneck_dim, | |
bottleneck_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding | |
), | |
Norm2d(bottleneck_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d( | |
bottleneck_dim, | |
bottleneck_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding | |
), | |
Norm2d(bottleneck_dim), | |
nn.GELU(), | |
nn.Conv2d( | |
bottleneck_dim, | |
self.output_embed_dim, | |
kernel_size=1 | |
), | |
Norm2d(self.output_embed_dim), | |
nn.GELU(), | |
) | |
self.fpn2 = nn.Sequential( | |
nn.Conv2d( | |
self.output_embed_dim, | |
bottleneck_dim, | |
kernel_size=1 | |
), | |
Norm2d(bottleneck_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d( | |
bottleneck_dim, | |
bottleneck_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding | |
), | |
Norm2d(bottleneck_dim), | |
nn.GELU(), | |
nn.ConvTranspose2d( | |
bottleneck_dim, | |
bottleneck_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding | |
), | |
Norm2d(bottleneck_dim), | |
nn.GELU(), | |
nn.Conv2d( | |
bottleneck_dim, | |
self.output_embed_dim, | |
kernel_size=1 | |
), | |
Norm2d(self.output_embed_dim), | |
nn.GELU(), | |
) | |
def forward(self, x): | |
x = x[0] | |
if self.drop_cls_token: | |
x = x[:, 1:, :] | |
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp) | |
x = self.fpn1(x) | |
x = self.fpn2(x) | |
x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out)) | |
out = tuple([x]) | |
return out | |
class TemporalViTEncoder(nn.Module): | |
"""Encoder from an ViT with capability to take in temporal input. | |
This class defines an encoder taken from a ViT architecture. | |
""" | |
def __init__( | |
self, | |
img_size: int = 224, | |
patch_size: int = 16, | |
num_frames: int = 1, | |
tubelet_size: int = 1, | |
in_chans: int = 3, | |
embed_dim: int = 1024, | |
depth: int = 24, | |
num_heads: int = 16, | |
mlp_ratio: float = 4.0, | |
norm_layer: nn.Module = nn.LayerNorm, | |
norm_pix_loss: bool = False, | |
pretrained: str = None, | |
debug=False | |
): | |
""" | |
Args: | |
img_size (int, optional): Input image size. Defaults to 224. | |
patch_size (int, optional): Patch size to be used by the transformer. Defaults to 16. | |
num_frames (int, optional): Number of frames (temporal dimension) to be input to the encoder. Defaults to 1. | |
tubelet_size (int, optional): Tubelet size used in patch embedding. Defaults to 1. | |
in_chans (int, optional): Number of input channels. Defaults to 3. | |
embed_dim (int, optional): Embedding dimension. Defaults to 1024. | |
depth (int, optional): Encoder depth. Defaults to 24. | |
num_heads (int, optional): Number of heads used in the encoder blocks. Defaults to 16. | |
mlp_ratio (float, optional): Ratio to be used for the size of the MLP in encoder blocks. Defaults to 4.0. | |
norm_layer (nn.Module, optional): Norm layer to be used. Defaults to nn.LayerNorm. | |
norm_pix_loss (bool, optional): Whether to use Norm Pix Loss. Defaults to False. | |
pretrained (str, optional): Path to pretrained encoder weights. Defaults to None. | |
""" | |
super().__init__() | |
# -------------------------------------------------------------------------- | |
# MAE encoder specifics | |
self.embed_dim = embed_dim | |
self.patch_embed = PatchEmbed( | |
img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim | |
) | |
num_patches = self.patch_embed.num_patches | |
self.num_frames = num_frames | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed = nn.Parameter( | |
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False | |
) # fixed sin-cos embedding | |
self.blocks = nn.ModuleList( | |
[ | |
Block( | |
embed_dim, | |
num_heads, | |
mlp_ratio, | |
qkv_bias=True, | |
norm_layer=norm_layer, | |
) | |
for _ in range(depth) | |
] | |
) | |
self.norm = norm_layer(embed_dim) | |
self.norm_pix_loss = norm_pix_loss | |
self.pretrained = pretrained | |
self.debug = debug | |
self.initialize_weights() | |
def initialize_weights(self): | |
# initialize (and freeze) pos_embed by sin-cos embedding | |
pos_embed = get_3d_sincos_pos_embed( | |
self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True | |
) | |
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
# initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
w = self.patch_embed.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
# TODO: FIX huggingface config | |
# load pretrained weights | |
# if self.pretrained: | |
# if self.pretrained.endswith('.safetensors'): | |
# self._load_safetensors_weights() | |
# elif self.pretrained == 'huggingface': | |
# print("TemporalViTEncoder | Using HuggingFace pretrained weights.") | |
# else: | |
# self._load_pt_weights() | |
# else: | |
# self.apply(self._init_weights) | |
def _load_safetensors_weights(self): | |
with safe_open(self.pretrained, framework='pt', device='cpu') as f: | |
checkpoint_state_dict = {k: torch.tensor(v) for k, v in f.items()} | |
missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False) | |
if missing_keys: | |
print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys) | |
if unexpected_keys: | |
print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys) | |
print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (safetensors).") | |
def _load_pt_weights(self): | |
checkpoint = torch.load(self.pretrained, map_location='cpu') | |
checkpoint_state_dict = checkpoint.get('state_dict', checkpoint) | |
missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False) | |
if missing_keys: | |
print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys) | |
if unexpected_keys: | |
print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys) | |
print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (pt file).") | |
def _init_weights(self, m): | |
print("TemporalViTEncoder | Newly Initializing weights...") | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward(self, x): | |
if self.debug: | |
print('TemporalViTEncoder IN:', x.shape) | |
# embed patches | |
x, _, _ = self.patch_embed(x) | |
if self.debug: | |
print('TemporalViTEncoder EMBED:', x.shape) | |
# add pos embed w/o cls token | |
x = x + self.pos_embed[:, 1:, :] | |
# append cls token | |
cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# apply Transformer blocks | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
if self.debug: | |
print('TemporalViTEncoder OUT:', x.shape) | |
return tuple([x]) |