Spaces:
Sleeping
Sleeping
# Copyright (c) 2023-present, BAAI. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
############################################################################## | |
"""Image encoder.""" | |
import torch | |
from torch import nn | |
from tokenize_anything import layers | |
def space_to_depth(input, block_size): | |
"""Rearrange blocks of spatial data into depth.""" | |
if input.dim() == 3: | |
hXw, c = input.size()[1:] | |
h = w = int(hXw**0.5) | |
else: | |
h, w, c = input.size()[1:] | |
h1, w1 = h // block_size, w // block_size | |
c1 = (block_size**2) * c | |
input = input.reshape((-1, h1, block_size, w1, block_size, c)) | |
return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h1, w1, c1)) | |
def depth_to_space(input, block_size): | |
"""Rearrange blocks of depth data into spatial.""" | |
h1, w1, c1 = input.size()[1:] | |
h, w = h1 * block_size, w1 * block_size | |
c = c1 // (block_size**2) | |
input = input.reshape((-1, h1, w1, block_size, block_size, c)) | |
return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h, w, c)) | |
class MLP(nn.Module): | |
"""Two layers MLP.""" | |
def __init__(self, dim, mlp_ratio=4): | |
super(MLP, self).__init__() | |
self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) | |
self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) | |
self.activation = nn.GELU() | |
def forward(self, x): | |
return self.fc2(self.activation(self.fc1(x))) | |
class Attention(nn.Module): | |
"""Multihead attention.""" | |
def __init__(self, dim, num_heads, qkv_bias=True): | |
super(Attention, self).__init__() | |
self.num_heads = num_heads | |
self.head_dim = dim // num_heads | |
self.scale = self.head_dim**-0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.proj = nn.Linear(dim, dim) | |
self.rel_pos_embed = nn.Identity() | |
def forward(self, x): | |
qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim) | |
qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(dim=0) | |
attn = q @ k.transpose(-2, -1).mul(self.scale) | |
attn = self.rel_pos_embed(attn) | |
o = nn.functional.softmax(attn, dim=-1) @ v | |
return self.proj(o.transpose(1, 2).flatten(2)) | |
class Block(nn.Module): | |
"""Transformer block.""" | |
def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True): | |
super(Block, self).__init__() | |
self.norm1 = nn.LayerNorm(dim) | |
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias) | |
self.norm2 = nn.LayerNorm(dim) | |
self.mlp = MLP(dim, mlp_ratio=mlp_ratio) | |
self.drop_path = layers.DropPath(0.1, inplace=True) | |
def forward(self, x): | |
x = self.drop_path(self.attn(self.norm1(x))).add_(x) | |
return self.drop_path(self.mlp(self.norm2(x))).add_(x) | |
class Bottleneck(nn.Module): | |
"""The bottleneck block.""" | |
def __init__(self, dim, expansion=2, width=None): | |
super(Bottleneck, self).__init__() | |
width = width or dim // expansion | |
self.conv1 = nn.Conv2d(dim, width, 1, bias=False) | |
self.norm1 = nn.SyncBatchNorm(width) | |
self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False) | |
self.norm2 = nn.SyncBatchNorm(width) | |
self.conv3 = nn.Conv2d(width, dim, 1, bias=False) | |
self.norm3 = nn.SyncBatchNorm(dim) | |
self.activation = nn.GELU() | |
def forward(self, x): | |
shortcut = x | |
x = self.activation(self.norm1(self.conv1(x))) | |
x = self.activation(self.norm2(self.conv2(x))) | |
return self.norm3(self.conv3(x)).add_(shortcut) | |
class PatchEmbed(nn.Module): | |
"""Patch embedding layer.""" | |
def __init__(self, dim=768, patch_size=16, bias=True): | |
super(PatchEmbed, self).__init__() | |
self.proj = nn.Conv2d(3, dim, patch_size, patch_size, bias=bias) | |
def forward(self, x): | |
return self.proj(x).flatten(2).transpose(1, 2) | |
class PosEmbed(nn.Module): | |
"""Position embedding layer.""" | |
def __init__(self, dim, num_patches): | |
super(PosEmbed, self).__init__() | |
self.dim = dim | |
self.num_patches = num_patches | |
self.weight = nn.Parameter(torch.zeros(num_patches, dim)) | |
nn.init.normal_(self.weight, std=0.02) | |
def forward(self, x): | |
return x.add_(self.weight) | |
class RelPosEmbed(nn.Module): | |
"""Relative position embedding layer.""" | |
def __init__(self, num_heads, size): | |
super(RelPosEmbed, self).__init__() | |
self.register_buffer("index", self.get_index(size)) | |
self.weight = nn.Parameter(torch.zeros(num_heads, (2 * size - 1) ** 2)) | |
def get_index(size): | |
"""Return the relative index.""" | |
grid = torch.arange(size) | |
grid = torch.stack(torch.meshgrid(grid, grid, indexing="ij")).reshape((2, -1)) | |
coords = grid[:, :, None] - grid[:, None, :] + (size - 1) | |
coords[0] *= 2 * size - 1 | |
return coords.sum(0) | |
def get_bias(self): | |
return self.weight[:, self.index] | |
def forward(self, x): | |
return x.add_(self.get_bias()) | |
class SimpleFeaturePyramid(nn.Module): | |
"""Module to create pyramid features.""" | |
def __init__(self, embed_dim, out_dim, patch_size=16, min_lvl=4, max_lvl=4): | |
super(SimpleFeaturePyramid, self).__init__() | |
self.min_lvl, self.max_lvl = min_lvl, max_lvl | |
self.input_conv = nn.ModuleList() | |
self.lateral_conv = nn.ModuleList() | |
self.output_conv = nn.ModuleList() | |
patch_lvl = dict((2**i, i) for i in range(6))[patch_size] | |
for lvl in [min(i + 2, self.max_lvl) for i in range(4)]: | |
if lvl == patch_lvl or lvl < self.min_lvl: | |
self.input_conv += [nn.Identity()] | |
elif lvl < patch_lvl: | |
stride, layers = 2 ** (patch_lvl - lvl), [] | |
while stride > 1: | |
layers += [nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)] | |
layers += [nn.SyncBatchNorm(embed_dim), nn.GELU()] if stride > 2 else [] | |
stride /= 2 | |
self.input_conv.append(nn.Sequential(*layers)) | |
elif lvl > patch_lvl: | |
stride = 2 ** (lvl - patch_lvl) | |
self.input_conv += [nn.MaxPool2d(stride, stride)] | |
for _ in range(min_lvl, max_lvl + 1): | |
self.lateral_conv.append( | |
nn.Sequential( | |
nn.Conv2d(embed_dim, out_dim, kernel_size=1, bias=False), | |
nn.SyncBatchNorm(out_dim), | |
) | |
) | |
self.output_conv.append( | |
nn.Sequential( | |
nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False), | |
nn.SyncBatchNorm(out_dim), | |
) | |
) | |
def forward(self, inputs): | |
inputs = inputs + [inputs[-1]] * (4 - len(inputs)) | |
inputs = [conv(x) for conv, x in zip(self.input_conv, inputs)] | |
features = inputs[self.min_lvl - 1 : self.max_lvl] | |
laterals = [conv(x) for conv, x in zip(self.lateral_conv, features)] | |
return [conv(x) for conv, x in zip(self.output_conv, laterals)] | |
class ImageEncoderViT(nn.Module): | |
"""ViT image encoder.""" | |
def __init__( | |
self, | |
depth, | |
embed_dim, | |
num_heads, | |
mlp_ratio=4, | |
patch_size=16, | |
window_size=16, | |
image_size=1024, | |
out_dim=256, | |
): | |
super(ImageEncoderViT, self).__init__() | |
self.embed_dim = embed_dim | |
self.image_size = image_size | |
self.window_size = window_size or image_size // patch_size | |
self.patch_embed = PatchEmbed(embed_dim, patch_size) | |
self.pos_embed = PosEmbed(embed_dim, (image_size // patch_size) ** 2) | |
self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)) | |
for blk in self.blocks: | |
blk.attn.rel_pos_embed = RelPosEmbed(num_heads, self.window_size) | |
self.norm = nn.LayerNorm(embed_dim) | |
self.cross_conv = nn.ModuleList(Bottleneck(embed_dim) for _ in range(4)) | |
self.neck = SimpleFeaturePyramid(embed_dim, out_dim, patch_size) | |
self.cross_indices = list(range(depth // 4 - 1, depth, depth // 4)) | |
def forward(self, x): | |
x = self.patch_embed(x) | |
x = self.pos_embed(x) | |
x = space_to_depth(x, self.window_size) | |
wmsa_shape = (-1,) + x.shape[1:] | |
msa_shape = (-1, self.window_size**2, self.embed_dim) | |
x = x.reshape(msa_shape) | |
for i, blk in enumerate(self.blocks): | |
x = blk(x) | |
if i in self.cross_indices or i == len(self.blocks) - 1: | |
x = self.norm(x) if i == len(self.blocks) - 1 else x | |
x = depth_to_space(x.reshape(wmsa_shape), self.window_size) | |
x = x.permute(0, 3, 1, 2).contiguous() | |
if i in self.cross_indices: | |
x = self.cross_conv[self.cross_indices.index(i)](x) | |
if i in self.cross_indices and i < len(self.blocks) - 1: | |
x = x.permute(0, 2, 3, 1) | |
x = space_to_depth(x, self.window_size).reshape(msa_shape) | |
return self.neck([x]) | |