PhyscalX's picture
Sync with main repo
825a49c
raw
history blame
9.62 kB
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Image encoder."""
import torch
from torch import nn
from tokenize_anything import layers
def space_to_depth(input, block_size):
"""Rearrange blocks of spatial data into depth."""
if input.dim() == 3:
hXw, c = input.size()[1:]
h = w = int(hXw**0.5)
else:
h, w, c = input.size()[1:]
h1, w1 = h // block_size, w // block_size
c1 = (block_size**2) * c
input = input.reshape((-1, h1, block_size, w1, block_size, c))
return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h1, w1, c1))
def depth_to_space(input, block_size):
"""Rearrange blocks of depth data into spatial."""
h1, w1, c1 = input.size()[1:]
h, w = h1 * block_size, w1 * block_size
c = c1 // (block_size**2)
input = input.reshape((-1, h1, w1, block_size, block_size, c))
return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h, w, c))
class MLP(nn.Module):
"""Two layers MLP."""
def __init__(self, dim, mlp_ratio=4):
super(MLP, self).__init__()
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
self.activation = nn.GELU()
def forward(self, x):
return self.fc2(self.activation(self.fc1(x)))
class Attention(nn.Module):
"""Multihead attention."""
def __init__(self, dim, num_heads, qkv_bias=True):
super(Attention, self).__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.rel_pos_embed = nn.Identity()
def forward(self, x):
qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0)
attn = q @ k.transpose(-2, -1).mul(self.scale)
attn = self.rel_pos_embed(attn)
o = nn.functional.softmax(attn, dim=-1) @ v
return self.proj(o.transpose(1, 2).flatten(2))
class Block(nn.Module):
"""Transformer block."""
def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True):
super(Block, self).__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
self.drop_path = layers.DropPath(0.1, inplace=True)
def forward(self, x):
x = self.drop_path(self.attn(self.norm1(x))).add_(x)
return self.drop_path(self.mlp(self.norm2(x))).add_(x)
class Bottleneck(nn.Module):
"""The bottleneck block."""
def __init__(self, dim, expansion=2, width=None):
super(Bottleneck, self).__init__()
width = width or dim // expansion
self.conv1 = nn.Conv2d(dim, width, 1, bias=False)
self.norm1 = nn.SyncBatchNorm(width)
self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False)
self.norm2 = nn.SyncBatchNorm(width)
self.conv3 = nn.Conv2d(width, dim, 1, bias=False)
self.norm3 = nn.SyncBatchNorm(dim)
self.activation = nn.GELU()
def forward(self, x):
shortcut = x
x = self.activation(self.norm1(self.conv1(x)))
x = self.activation(self.norm2(self.conv2(x)))
return self.norm3(self.conv3(x)).add_(shortcut)
class PatchEmbed(nn.Module):
"""Patch embedding layer."""
def __init__(self, dim=768, patch_size=16, bias=True):
super(PatchEmbed, self).__init__()
self.proj = nn.Conv2d(3, dim, patch_size, patch_size, bias=bias)
def forward(self, x):
return self.proj(x).flatten(2).transpose(1, 2)
class PosEmbed(nn.Module):
"""Position embedding layer."""
def __init__(self, dim, num_patches):
super(PosEmbed, self).__init__()
self.dim = dim
self.num_patches = num_patches
self.weight = nn.Parameter(torch.zeros(num_patches, dim))
nn.init.normal_(self.weight, std=0.02)
def forward(self, x):
return x.add_(self.weight)
class RelPosEmbed(nn.Module):
"""Relative position embedding layer."""
def __init__(self, num_heads, size):
super(RelPosEmbed, self).__init__()
self.register_buffer("index", self.get_index(size))
self.weight = nn.Parameter(torch.zeros(num_heads, (2 * size - 1) ** 2))
@staticmethod
def get_index(size):
"""Return the relative index."""
grid = torch.arange(size)
grid = torch.stack(torch.meshgrid(grid, grid, indexing="ij")).reshape((2, -1))
coords = grid[:, :, None] - grid[:, None, :] + (size - 1)
coords[0] *= 2 * size - 1
return coords.sum(0)
def get_bias(self):
return self.weight[:, self.index]
def forward(self, x):
return x.add_(self.get_bias())
class SimpleFeaturePyramid(nn.Module):
"""Module to create pyramid features."""
def __init__(self, embed_dim, out_dim, patch_size=16, min_lvl=4, max_lvl=4):
super(SimpleFeaturePyramid, self).__init__()
self.min_lvl, self.max_lvl = min_lvl, max_lvl
self.input_conv = nn.ModuleList()
self.lateral_conv = nn.ModuleList()
self.output_conv = nn.ModuleList()
patch_lvl = dict((2**i, i) for i in range(6))[patch_size]
for lvl in [min(i + 2, self.max_lvl) for i in range(4)]:
if lvl == patch_lvl or lvl < self.min_lvl:
self.input_conv += [nn.Identity()]
elif lvl < patch_lvl:
stride, layers = 2 ** (patch_lvl - lvl), []
while stride > 1:
layers += [nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)]
layers += [nn.SyncBatchNorm(embed_dim), nn.GELU()] if stride > 2 else []
stride /= 2
self.input_conv.append(nn.Sequential(*layers))
elif lvl > patch_lvl:
stride = 2 ** (lvl - patch_lvl)
self.input_conv += [nn.MaxPool2d(stride, stride)]
for _ in range(min_lvl, max_lvl + 1):
self.lateral_conv.append(
nn.Sequential(
nn.Conv2d(embed_dim, out_dim, kernel_size=1, bias=False),
nn.SyncBatchNorm(out_dim),
)
)
self.output_conv.append(
nn.Sequential(
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
nn.SyncBatchNorm(out_dim),
)
)
def forward(self, inputs):
inputs = inputs + [inputs[-1]] * (4 - len(inputs))
inputs = [conv(x) for conv, x in zip(self.input_conv, inputs)]
features = inputs[self.min_lvl - 1 : self.max_lvl]
laterals = [conv(x) for conv, x in zip(self.lateral_conv, features)]
return [conv(x) for conv, x in zip(self.output_conv, laterals)]
class ImageEncoderViT(nn.Module):
"""ViT image encoder."""
def __init__(
self,
depth,
embed_dim,
num_heads,
mlp_ratio=4,
patch_size=16,
window_size=16,
image_size=1024,
out_dim=256,
):
super(ImageEncoderViT, self).__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.window_size = window_size or image_size // patch_size
self.patch_embed = PatchEmbed(embed_dim, patch_size)
self.pos_embed = PosEmbed(embed_dim, (image_size // patch_size) ** 2)
self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth))
for blk in self.blocks:
blk.attn.rel_pos_embed = RelPosEmbed(num_heads, self.window_size)
self.norm = nn.LayerNorm(embed_dim)
self.cross_conv = nn.ModuleList(Bottleneck(embed_dim) for _ in range(4))
self.neck = SimpleFeaturePyramid(embed_dim, out_dim, patch_size)
self.cross_indices = list(range(depth // 4 - 1, depth, depth // 4))
def forward(self, x):
x = self.patch_embed(x)
x = self.pos_embed(x)
x = space_to_depth(x, self.window_size)
wmsa_shape = (-1,) + x.shape[1:]
msa_shape = (-1, self.window_size**2, self.embed_dim)
x = x.reshape(msa_shape)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in self.cross_indices or i == len(self.blocks) - 1:
x = self.norm(x) if i == len(self.blocks) - 1 else x
x = depth_to_space(x.reshape(wmsa_shape), self.window_size)
x = x.permute(0, 3, 1, 2).contiguous()
if i in self.cross_indices:
x = self.cross_conv[self.cross_indices.index(i)](x)
if i in self.cross_indices and i < len(self.blocks) - 1:
x = x.permute(0, 2, 3, 1)
x = space_to_depth(x, self.window_size).reshape(msa_shape)
return self.neck([x])