|
|
|
import pytest |
|
import torch |
|
|
|
from mmdet.models.backbones import ResNeXt |
|
from mmdet.models.backbones.resnext import Bottleneck as BottleneckX |
|
from .utils import is_block |
|
|
|
|
|
def test_renext_bottleneck(): |
|
with pytest.raises(AssertionError): |
|
|
|
BottleneckX(64, 64, groups=32, base_width=4, style='tensorflow') |
|
|
|
|
|
block = BottleneckX( |
|
64, 64, groups=32, base_width=4, stride=2, style='pytorch') |
|
assert block.conv2.stride == (2, 2) |
|
assert block.conv2.groups == 32 |
|
assert block.conv2.out_channels == 128 |
|
|
|
|
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
|
with pytest.raises(AssertionError): |
|
|
|
BottleneckX( |
|
64, |
|
64, |
|
groups=32, |
|
base_width=4, |
|
dcn=dcn, |
|
conv_cfg=dict(type='Conv')) |
|
BottleneckX(64, 64, dcn=dcn) |
|
|
|
|
|
block = BottleneckX(64, 16, groups=32, base_width=4) |
|
x = torch.randn(1, 64, 56, 56) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 64, 56, 56]) |
|
|
|
|
|
plugins = [ |
|
dict( |
|
cfg=dict( |
|
type='GeneralizedAttention', |
|
spatial_range=-1, |
|
num_heads=8, |
|
attention_type='0010', |
|
kv_stride=2), |
|
stages=(False, False, True, True), |
|
position='after_conv2') |
|
] |
|
block = BottleneckX(64, 16, groups=32, base_width=4, plugins=plugins) |
|
x = torch.randn(1, 64, 56, 56) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 64, 56, 56]) |
|
|
|
|
|
def test_resnext_backbone(): |
|
with pytest.raises(KeyError): |
|
|
|
ResNeXt(depth=18) |
|
|
|
|
|
model = ResNeXt(depth=50, groups=32, base_width=4) |
|
for m in model.modules(): |
|
if is_block(m): |
|
assert m.conv2.groups == 32 |
|
model.train() |
|
|
|
imgs = torch.randn(1, 3, 32, 32) |
|
feat = model(imgs) |
|
assert len(feat) == 4 |
|
assert feat[0].shape == torch.Size([1, 256, 8, 8]) |
|
assert feat[1].shape == torch.Size([1, 512, 4, 4]) |
|
assert feat[2].shape == torch.Size([1, 1024, 2, 2]) |
|
assert feat[3].shape == torch.Size([1, 2048, 1, 1]) |
|
|
|
|
|
regnet_test_data = [ |
|
('regnetx_400mf', |
|
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, |
|
bot_mul=1.0), [32, 64, 160, 384]), |
|
('regnetx_800mf', |
|
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, |
|
bot_mul=1.0), [64, 128, 288, 672]), |
|
('regnetx_1.6gf', |
|
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, |
|
bot_mul=1.0), [72, 168, 408, 912]), |
|
('regnetx_3.2gf', |
|
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, |
|
bot_mul=1.0), [96, 192, 432, 1008]), |
|
('regnetx_4.0gf', |
|
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, |
|
bot_mul=1.0), [80, 240, 560, 1360]), |
|
('regnetx_6.4gf', |
|
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, |
|
bot_mul=1.0), [168, 392, 784, 1624]), |
|
('regnetx_8.0gf', |
|
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, |
|
bot_mul=1.0), [80, 240, 720, 1920]), |
|
('regnetx_12gf', |
|
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, |
|
bot_mul=1.0), [224, 448, 896, 2240]), |
|
] |
|
|