messis-demo / messis /prithvi.py
yvokeller's picture
first messis demo app version
5b24075
raw
history blame
19.2 kB
from safetensors import safe_open
import torch
import torch.nn as nn
import numpy as np
from timm.models.layers import to_2tuple
from timm.models.vision_transformer import Block
# Taken and adapted from Pritvhi `geospatial_fm.py`, for the purpose of avoiding MMCV/MMSegmentation dependencies
def _convTranspose2dOutput(
input_size: int,
stride: int,
padding: int,
dilation: int,
kernel_size: int,
output_padding: int,
):
"""
Calculate the output size of a ConvTranspose2d.
Taken from: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
"""
return (
(input_size - 1) * stride
- 2 * padding
+ dilation * (kernel_size - 1)
+ output_padding
+ 1
)
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_3d_sincos_pos_embed(embed_dim: int, grid_size: tuple, cls_token: bool = False):
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
"""
grid_size: 3d tuple of grid size: t, h, w
return:
pos_embed: L, D
"""
assert embed_dim % 16 == 0
t_size, h_size, w_size = grid_size
w_embed_dim = embed_dim // 16 * 6
h_embed_dim = embed_dim // 16 * 6
t_embed_dim = embed_dim // 16 * 4
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
class Norm2d(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.ln = nn.LayerNorm(embed_dim, eps=1e-6)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.ln(x)
x = x.permute(0, 3, 1, 2).contiguous()
return x
class PatchEmbed(nn.Module):
"""Frames of 2D Images to Patch Embedding
The 3D version of timm.models.vision_transformer.PatchEmbed
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
num_frames: int = 3,
tubelet_size: int = 1,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: nn.Module = None,
flatten: bool = True,
bias: bool = True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_size = (
num_frames // tubelet_size,
img_size[0] // patch_size[0],
img_size[1] // patch_size[1],
)
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.flatten = flatten
self.proj = nn.Conv3d(
in_chans,
embed_dim,
kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
stride=(tubelet_size, patch_size[0], patch_size[1]),
bias=bias,
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, T, H, W = x.shape
assert (
H == self.img_size[0]
), f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
assert (
W == self.img_size[1]
), f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
x = self.proj(x)
Hp, Wp = x.shape[3], x.shape[4]
if self.flatten:
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
x = self.norm(x)
return x, Hp, Wp
class ConvTransformerTokensToEmbeddingNeck(nn.Module):
"""
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers.
Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2
"""
def __init__(
self,
embed_dim: int,
output_embed_dim: int,
# num_frames: int = 1,
Hp: int = 14,
Wp: int = 14,
drop_cls_token: bool = True,
):
"""
Args:
embed_dim (int): Input embedding dimension
output_embed_dim (int): Output embedding dimension
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14.
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14.
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True.
"""
super().__init__()
self.drop_cls_token = drop_cls_token
self.Hp = Hp
self.Wp = Wp
self.H_out = Hp
self.W_out = Wp
# self.num_frames = num_frames
kernel_size = 2
stride = 2
dilation = 1
padding = 0
output_padding = 0
for _ in range(4):
self.H_out = _convTranspose2dOutput(
self.H_out, stride, padding, dilation, kernel_size, output_padding
)
self.W_out = _convTranspose2dOutput(
self.W_out, stride, padding, dilation, kernel_size, output_padding
)
self.embed_dim = embed_dim
self.output_embed_dim = output_embed_dim
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(
self.embed_dim,
self.output_embed_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
output_padding=output_padding,
),
Norm2d(self.output_embed_dim),
nn.GELU(),
nn.ConvTranspose2d(
self.output_embed_dim,
self.output_embed_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
output_padding=output_padding,
),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(
self.output_embed_dim,
self.output_embed_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
output_padding=output_padding,
),
Norm2d(self.output_embed_dim),
nn.GELU(),
nn.ConvTranspose2d(
self.output_embed_dim,
self.output_embed_dim,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
output_padding=output_padding,
),
)
def forward(self, x):
x = x[0]
if self.drop_cls_token:
x = x[:, 1:, :]
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp)
x = self.fpn1(x)
x = self.fpn2(x)
x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out))
out = tuple([x])
return out
class ConvTransformerTokensToEmbeddingBottleneckNeck(nn.Module):
"""
Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers.
Performs ConvTranspose2d operations with bottleneck layers to reduce channels.
"""
def __init__(
self,
embed_dim: int,
output_embed_dim: int,
Hp: int = 14,
Wp: int = 14,
drop_cls_token: bool = True,
bottleneck_reduction_factor: int = 4,
):
"""
Args:
embed_dim (int): Input embedding dimension
output_embed_dim (int): Output embedding dimension
Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14.
Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14.
drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. Defaults to True.
bottleneck_ratio (int, optional): Ratio to reduce channels in bottleneck layers. Defaults to 4.
"""
super().__init__()
self.drop_cls_token = drop_cls_token
self.Hp = Hp
self.Wp = Wp
self.H_out = Hp
self.W_out = Wp
kernel_size = 2
stride = 2
dilation = 1
padding = 0
output_padding = 0
for _ in range(4):
self.H_out = _convTranspose2dOutput(
self.H_out, stride, padding, dilation, kernel_size, output_padding
)
self.W_out = _convTranspose2dOutput(
self.W_out, stride, padding, dilation, kernel_size, output_padding
)
self.embed_dim = embed_dim
self.output_embed_dim = output_embed_dim
bottleneck_dim = self.embed_dim // bottleneck_reduction_factor
self.fpn1 = nn.Sequential(
nn.Conv2d(
self.embed_dim,
bottleneck_dim,
kernel_size=1
),
Norm2d(bottleneck_dim),
nn.GELU(),
nn.ConvTranspose2d(
bottleneck_dim,
bottleneck_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding
),
Norm2d(bottleneck_dim),
nn.GELU(),
nn.ConvTranspose2d(
bottleneck_dim,
bottleneck_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding
),
Norm2d(bottleneck_dim),
nn.GELU(),
nn.Conv2d(
bottleneck_dim,
self.output_embed_dim,
kernel_size=1
),
Norm2d(self.output_embed_dim),
nn.GELU(),
)
self.fpn2 = nn.Sequential(
nn.Conv2d(
self.output_embed_dim,
bottleneck_dim,
kernel_size=1
),
Norm2d(bottleneck_dim),
nn.GELU(),
nn.ConvTranspose2d(
bottleneck_dim,
bottleneck_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding
),
Norm2d(bottleneck_dim),
nn.GELU(),
nn.ConvTranspose2d(
bottleneck_dim,
bottleneck_dim,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding
),
Norm2d(bottleneck_dim),
nn.GELU(),
nn.Conv2d(
bottleneck_dim,
self.output_embed_dim,
kernel_size=1
),
Norm2d(self.output_embed_dim),
nn.GELU(),
)
def forward(self, x):
x = x[0]
if self.drop_cls_token:
x = x[:, 1:, :]
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp)
x = self.fpn1(x)
x = self.fpn2(x)
x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out))
out = tuple([x])
return out
class TemporalViTEncoder(nn.Module):
"""Encoder from an ViT with capability to take in temporal input.
This class defines an encoder taken from a ViT architecture.
"""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
num_frames: int = 1,
tubelet_size: int = 1,
in_chans: int = 3,
embed_dim: int = 1024,
depth: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
norm_layer: nn.Module = nn.LayerNorm,
norm_pix_loss: bool = False,
pretrained: str = None,
debug=False
):
"""
Args:
img_size (int, optional): Input image size. Defaults to 224.
patch_size (int, optional): Patch size to be used by the transformer. Defaults to 16.
num_frames (int, optional): Number of frames (temporal dimension) to be input to the encoder. Defaults to 1.
tubelet_size (int, optional): Tubelet size used in patch embedding. Defaults to 1.
in_chans (int, optional): Number of input channels. Defaults to 3.
embed_dim (int, optional): Embedding dimension. Defaults to 1024.
depth (int, optional): Encoder depth. Defaults to 24.
num_heads (int, optional): Number of heads used in the encoder blocks. Defaults to 16.
mlp_ratio (float, optional): Ratio to be used for the size of the MLP in encoder blocks. Defaults to 4.0.
norm_layer (nn.Module, optional): Norm layer to be used. Defaults to nn.LayerNorm.
norm_pix_loss (bool, optional): Whether to use Norm Pix Loss. Defaults to False.
pretrained (str, optional): Path to pretrained encoder weights. Defaults to None.
"""
super().__init__()
# --------------------------------------------------------------------------
# MAE encoder specifics
self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim
)
num_patches = self.patch_embed.num_patches
self.num_frames = num_frames
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
) # fixed sin-cos embedding
self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
)
for _ in range(depth)
]
)
self.norm = norm_layer(embed_dim)
self.norm_pix_loss = norm_pix_loss
self.pretrained = pretrained
self.debug = debug
self.initialize_weights()
def initialize_weights(self):
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_3d_sincos_pos_embed(
self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
TODO: FIX huggingface config
# load pretrained weights
# if self.pretrained:
# if self.pretrained.endswith('.safetensors'):
# self._load_safetensors_weights()
# elif self.pretrained == 'huggingface':
# print("TemporalViTEncoder | Using HuggingFace pretrained weights.")
# else:
# self._load_pt_weights()
# else:
# self.apply(self._init_weights)
def _load_safetensors_weights(self):
with safe_open(self.pretrained, framework='pt', device='cpu') as f:
checkpoint_state_dict = {k: torch.tensor(v) for k, v in f.items()}
missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False)
if missing_keys:
print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys)
if unexpected_keys:
print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys)
print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (safetensors).")
def _load_pt_weights(self):
checkpoint = torch.load(self.pretrained, map_location='cpu')
checkpoint_state_dict = checkpoint.get('state_dict', checkpoint)
missing_keys, unexpected_keys = self.load_state_dict(checkpoint_state_dict, strict=False)
if missing_keys:
print("TemporalViTEncoder | Warning: Missing keys in the state dict:", missing_keys)
if unexpected_keys:
print("TemporalViTEncoder | Warning: Unexpected keys in the state dict:", unexpected_keys)
print(f"TemporalViTEncoder | Loaded pretrained weights from '{self.pretrained}' (pt file).")
def _init_weights(self, m):
print("TemporalViTEncoder | Newly Initializing weights...")
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
if self.debug:
print('TemporalViTEncoder IN:', x.shape)
# embed patches
x, _, _ = self.patch_embed(x)
if self.debug:
print('TemporalViTEncoder EMBED:', x.shape)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.debug:
print('TemporalViTEncoder OUT:', x.shape)
return tuple([x])