|
|
|
import pytest |
|
import torch |
|
from mmcv.utils import ConfigDict |
|
|
|
from mmdet.models.utils.transformer import (AdaptivePadding, |
|
DetrTransformerDecoder, |
|
DetrTransformerEncoder, PatchEmbed, |
|
PatchMerging, Transformer) |
|
|
|
|
|
def test_adaptive_padding(): |
|
|
|
for padding in ('same', 'corner'): |
|
kernel_size = 16 |
|
stride = 16 |
|
dilation = 1 |
|
input = torch.rand(1, 1, 15, 17) |
|
pool = AdaptivePadding( |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding) |
|
out = pool(input) |
|
|
|
assert (out.shape[2], out.shape[3]) == (16, 32) |
|
input = torch.rand(1, 1, 16, 17) |
|
out = pool(input) |
|
|
|
assert (out.shape[2], out.shape[3]) == (16, 32) |
|
|
|
kernel_size = (2, 2) |
|
stride = (2, 2) |
|
dilation = (1, 1) |
|
|
|
adap_pad = AdaptivePadding( |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding) |
|
input = torch.rand(1, 1, 11, 13) |
|
out = adap_pad(input) |
|
|
|
assert (out.shape[2], out.shape[3]) == (12, 14) |
|
|
|
kernel_size = (2, 2) |
|
stride = (10, 10) |
|
dilation = (1, 1) |
|
|
|
adap_pad = AdaptivePadding( |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding) |
|
input = torch.rand(1, 1, 10, 13) |
|
out = adap_pad(input) |
|
|
|
assert (out.shape[2], out.shape[3]) == (10, 13) |
|
|
|
kernel_size = (11, 11) |
|
adap_pad = AdaptivePadding( |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding) |
|
input = torch.rand(1, 1, 11, 13) |
|
out = adap_pad(input) |
|
|
|
assert (out.shape[2], out.shape[3]) == (21, 21) |
|
|
|
|
|
input = torch.rand(1, 1, 11, 13) |
|
stride = (3, 4) |
|
kernel_size = (4, 5) |
|
dilation = (2, 2) |
|
|
|
adap_pad = AdaptivePadding( |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding) |
|
dilation_out = adap_pad(input) |
|
assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21) |
|
kernel_size = (7, 9) |
|
dilation = (1, 1) |
|
adap_pad = AdaptivePadding( |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=padding) |
|
kernel79_out = adap_pad(input) |
|
assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21) |
|
assert kernel79_out.shape == dilation_out.shape |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
AdaptivePadding( |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
dilation=dilation, |
|
padding=1) |
|
|
|
|
|
def test_patch_embed(): |
|
B = 2 |
|
H = 3 |
|
W = 4 |
|
C = 3 |
|
embed_dims = 10 |
|
kernel_size = 3 |
|
stride = 1 |
|
dummy_input = torch.rand(B, C, H, W) |
|
patch_merge_1 = PatchEmbed( |
|
in_channels=C, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
dilation=1, |
|
norm_cfg=None) |
|
|
|
x1, shape = patch_merge_1(dummy_input) |
|
|
|
assert x1.shape == (2, 2, 10) |
|
|
|
assert shape == (1, 2) |
|
|
|
assert shape[0] * shape[1] == x1.shape[1] |
|
|
|
B = 2 |
|
H = 10 |
|
W = 10 |
|
C = 3 |
|
embed_dims = 10 |
|
kernel_size = 5 |
|
stride = 2 |
|
dummy_input = torch.rand(B, C, H, W) |
|
|
|
patch_merge_2 = PatchEmbed( |
|
in_channels=C, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
dilation=2, |
|
norm_cfg=None, |
|
) |
|
|
|
x2, shape = patch_merge_2(dummy_input) |
|
|
|
assert x2.shape == (2, 1, 10) |
|
|
|
assert shape == (1, 1) |
|
|
|
assert shape[0] * shape[1] == x2.shape[1] |
|
|
|
stride = 2 |
|
input_size = (10, 10) |
|
|
|
dummy_input = torch.rand(B, C, H, W) |
|
|
|
patch_merge_3 = PatchEmbed( |
|
in_channels=C, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
dilation=2, |
|
norm_cfg=dict(type='LN'), |
|
input_size=input_size) |
|
|
|
x3, shape = patch_merge_3(dummy_input) |
|
|
|
assert x3.shape == (2, 1, 10) |
|
|
|
assert shape == (1, 1) |
|
|
|
assert shape[0] * shape[1] == x3.shape[1] |
|
|
|
|
|
assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 - |
|
1) // 2 + 1 |
|
assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 - |
|
1) // 2 + 1 |
|
H = 11 |
|
W = 12 |
|
input_size = (H, W) |
|
dummy_input = torch.rand(B, C, H, W) |
|
|
|
patch_merge_3 = PatchEmbed( |
|
in_channels=C, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
dilation=2, |
|
norm_cfg=dict(type='LN'), |
|
input_size=input_size) |
|
|
|
_, shape = patch_merge_3(dummy_input) |
|
|
|
|
|
assert shape == patch_merge_3.init_out_size |
|
|
|
input_size = (H, W) |
|
dummy_input = torch.rand(B, C, H, W) |
|
|
|
patch_merge_3 = PatchEmbed( |
|
in_channels=C, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
dilation=2, |
|
norm_cfg=dict(type='LN'), |
|
input_size=input_size) |
|
|
|
_, shape = patch_merge_3(dummy_input) |
|
|
|
|
|
assert shape == patch_merge_3.init_out_size |
|
|
|
|
|
for padding in ('same', 'corner'): |
|
in_c = 2 |
|
embed_dims = 3 |
|
B = 2 |
|
|
|
|
|
input_size = (5, 5) |
|
kernel_size = (5, 5) |
|
stride = (1, 1) |
|
dilation = 1 |
|
bias = False |
|
|
|
x = torch.rand(B, in_c, *input_size) |
|
patch_embed = PatchEmbed( |
|
in_channels=in_c, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
x_out, out_size = patch_embed(x) |
|
assert x_out.size() == (B, 25, 3) |
|
assert out_size == (5, 5) |
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
input_size = (5, 5) |
|
kernel_size = (5, 5) |
|
stride = (5, 5) |
|
dilation = 1 |
|
bias = False |
|
|
|
x = torch.rand(B, in_c, *input_size) |
|
patch_embed = PatchEmbed( |
|
in_channels=in_c, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
x_out, out_size = patch_embed(x) |
|
assert x_out.size() == (B, 1, 3) |
|
assert out_size == (1, 1) |
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
input_size = (6, 5) |
|
kernel_size = (5, 5) |
|
stride = (5, 5) |
|
dilation = 1 |
|
bias = False |
|
|
|
x = torch.rand(B, in_c, *input_size) |
|
patch_embed = PatchEmbed( |
|
in_channels=in_c, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
x_out, out_size = patch_embed(x) |
|
assert x_out.size() == (B, 2, 3) |
|
assert out_size == (2, 1) |
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
input_size = (6, 5) |
|
kernel_size = (6, 2) |
|
stride = (6, 2) |
|
dilation = 1 |
|
bias = False |
|
|
|
x = torch.rand(B, in_c, *input_size) |
|
patch_embed = PatchEmbed( |
|
in_channels=in_c, |
|
embed_dims=embed_dims, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
x_out, out_size = patch_embed(x) |
|
assert x_out.size() == (B, 3, 3) |
|
assert out_size == (1, 3) |
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
def test_patch_merging(): |
|
|
|
|
|
in_c = 3 |
|
out_c = 4 |
|
kernel_size = 3 |
|
stride = 3 |
|
padding = 1 |
|
dilation = 1 |
|
bias = False |
|
|
|
patch_merge = PatchMerging( |
|
in_channels=in_c, |
|
out_channels=out_c, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
B, L, C = 1, 100, 3 |
|
input_size = (10, 10) |
|
x = torch.rand(B, L, C) |
|
x_out, out_size = patch_merge(x, input_size) |
|
assert x_out.size() == (1, 16, 4) |
|
assert out_size == (4, 4) |
|
|
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
in_c = 4 |
|
out_c = 5 |
|
kernel_size = 6 |
|
stride = 3 |
|
padding = 2 |
|
dilation = 2 |
|
bias = False |
|
patch_merge = PatchMerging( |
|
in_channels=in_c, |
|
out_channels=out_c, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
B, L, C = 1, 100, 4 |
|
input_size = (10, 10) |
|
x = torch.rand(B, L, C) |
|
x_out, out_size = patch_merge(x, input_size) |
|
assert x_out.size() == (1, 4, 5) |
|
assert out_size == (2, 2) |
|
|
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
for padding in ('same', 'corner'): |
|
in_c = 2 |
|
out_c = 3 |
|
B = 2 |
|
|
|
|
|
input_size = (5, 5) |
|
kernel_size = (5, 5) |
|
stride = (1, 1) |
|
dilation = 1 |
|
bias = False |
|
L = input_size[0] * input_size[1] |
|
|
|
x = torch.rand(B, L, in_c) |
|
patch_merge = PatchMerging( |
|
in_channels=in_c, |
|
out_channels=out_c, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
x_out, out_size = patch_merge(x, input_size) |
|
assert x_out.size() == (B, 25, 3) |
|
assert out_size == (5, 5) |
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
input_size = (5, 5) |
|
kernel_size = (5, 5) |
|
stride = (5, 5) |
|
dilation = 1 |
|
bias = False |
|
L = input_size[0] * input_size[1] |
|
|
|
x = torch.rand(B, L, in_c) |
|
patch_merge = PatchMerging( |
|
in_channels=in_c, |
|
out_channels=out_c, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
x_out, out_size = patch_merge(x, input_size) |
|
assert x_out.size() == (B, 1, 3) |
|
assert out_size == (1, 1) |
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
input_size = (6, 5) |
|
kernel_size = (5, 5) |
|
stride = (5, 5) |
|
dilation = 1 |
|
bias = False |
|
L = input_size[0] * input_size[1] |
|
|
|
x = torch.rand(B, L, in_c) |
|
patch_merge = PatchMerging( |
|
in_channels=in_c, |
|
out_channels=out_c, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
x_out, out_size = patch_merge(x, input_size) |
|
assert x_out.size() == (B, 2, 3) |
|
assert out_size == (2, 1) |
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
input_size = (6, 5) |
|
kernel_size = (6, 2) |
|
stride = (6, 2) |
|
dilation = 1 |
|
bias = False |
|
L = input_size[0] * input_size[1] |
|
|
|
x = torch.rand(B, L, in_c) |
|
patch_merge = PatchMerging( |
|
in_channels=in_c, |
|
out_channels=out_c, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias) |
|
|
|
x_out, out_size = patch_merge(x, input_size) |
|
assert x_out.size() == (B, 3, 3) |
|
assert out_size == (1, 3) |
|
assert x_out.size(1) == out_size[0] * out_size[1] |
|
|
|
|
|
def test_detr_transformer_dencoder_encoder_layer(): |
|
config = ConfigDict( |
|
dict( |
|
return_intermediate=True, |
|
num_layers=6, |
|
transformerlayers=dict( |
|
type='DetrTransformerDecoderLayer', |
|
attn_cfgs=dict( |
|
type='MultiheadAttention', |
|
embed_dims=256, |
|
num_heads=8, |
|
dropout=0.1), |
|
feedforward_channels=2048, |
|
ffn_dropout=0.1, |
|
operation_order=( |
|
'norm', |
|
'self_attn', |
|
'norm', |
|
'cross_attn', |
|
'norm', |
|
'ffn', |
|
)))) |
|
assert DetrTransformerDecoder(**config).layers[0].pre_norm |
|
assert len(DetrTransformerDecoder(**config).layers) == 6 |
|
|
|
DetrTransformerDecoder(**config) |
|
with pytest.raises(AssertionError): |
|
config = ConfigDict( |
|
dict( |
|
return_intermediate=True, |
|
num_layers=6, |
|
transformerlayers=[ |
|
dict( |
|
type='DetrTransformerDecoderLayer', |
|
attn_cfgs=dict( |
|
type='MultiheadAttention', |
|
embed_dims=256, |
|
num_heads=8, |
|
dropout=0.1), |
|
feedforward_channels=2048, |
|
ffn_dropout=0.1, |
|
operation_order=('self_attn', 'norm', 'cross_attn', |
|
'norm', 'ffn', 'norm')) |
|
] * 5)) |
|
DetrTransformerDecoder(**config) |
|
|
|
config = ConfigDict( |
|
dict( |
|
num_layers=6, |
|
transformerlayers=dict( |
|
type='DetrTransformerDecoderLayer', |
|
attn_cfgs=dict( |
|
type='MultiheadAttention', |
|
embed_dims=256, |
|
num_heads=8, |
|
dropout=0.1), |
|
feedforward_channels=2048, |
|
ffn_dropout=0.1, |
|
operation_order=('norm', 'self_attn', 'norm', 'cross_attn', |
|
'norm', 'ffn', 'norm')))) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
DetrTransformerEncoder(**config) |
|
|
|
|
|
def test_transformer(): |
|
config = ConfigDict( |
|
dict( |
|
encoder=dict( |
|
type='DetrTransformerEncoder', |
|
num_layers=6, |
|
transformerlayers=dict( |
|
type='BaseTransformerLayer', |
|
attn_cfgs=[ |
|
dict( |
|
type='MultiheadAttention', |
|
embed_dims=256, |
|
num_heads=8, |
|
dropout=0.1) |
|
], |
|
feedforward_channels=2048, |
|
ffn_dropout=0.1, |
|
operation_order=('self_attn', 'norm', 'ffn', 'norm'))), |
|
decoder=dict( |
|
type='DetrTransformerDecoder', |
|
return_intermediate=True, |
|
num_layers=6, |
|
transformerlayers=dict( |
|
type='DetrTransformerDecoderLayer', |
|
attn_cfgs=dict( |
|
type='MultiheadAttention', |
|
embed_dims=256, |
|
num_heads=8, |
|
dropout=0.1), |
|
feedforward_channels=2048, |
|
ffn_dropout=0.1, |
|
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', |
|
'ffn', 'norm')), |
|
))) |
|
transformer = Transformer(**config) |
|
transformer.init_weights() |
|
|