|
|
|
import pytest |
|
import torch |
|
from mmcv.cnn import is_norm |
|
from torch.nn.modules import GroupNorm |
|
|
|
from mmdet.models.utils import InvertedResidual, SELayer |
|
|
|
|
|
def test_inverted_residual(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
InvertedResidual(16, 16, 32, stride=3) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
InvertedResidual(16, 16, 32, se_cfg=list()) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
InvertedResidual(16, 16, 32, with_expand_conv=False) |
|
|
|
|
|
block = InvertedResidual(16, 16, 32, stride=1) |
|
x = torch.randn(1, 16, 56, 56) |
|
x_out = block(x) |
|
assert getattr(block, 'se', None) is None |
|
assert block.with_res_shortcut |
|
assert x_out.shape == torch.Size((1, 16, 56, 56)) |
|
|
|
|
|
block = InvertedResidual(16, 16, 32, stride=2) |
|
x = torch.randn(1, 16, 56, 56) |
|
x_out = block(x) |
|
assert not block.with_res_shortcut |
|
assert x_out.shape == torch.Size((1, 16, 28, 28)) |
|
|
|
|
|
se_cfg = dict(channels=32) |
|
block = InvertedResidual(16, 16, 32, stride=1, se_cfg=se_cfg) |
|
x = torch.randn(1, 16, 56, 56) |
|
x_out = block(x) |
|
assert isinstance(block.se, SELayer) |
|
assert x_out.shape == torch.Size((1, 16, 56, 56)) |
|
|
|
|
|
block = InvertedResidual(32, 16, 32, with_expand_conv=False) |
|
x = torch.randn(1, 32, 56, 56) |
|
x_out = block(x) |
|
assert getattr(block, 'expand_conv', None) is None |
|
assert x_out.shape == torch.Size((1, 16, 56, 56)) |
|
|
|
|
|
block = InvertedResidual( |
|
16, 16, 32, norm_cfg=dict(type='GN', num_groups=2)) |
|
x = torch.randn(1, 16, 56, 56) |
|
x_out = block(x) |
|
for m in block.modules(): |
|
if is_norm(m): |
|
assert isinstance(m, GroupNorm) |
|
assert x_out.shape == torch.Size((1, 16, 56, 56)) |
|
|
|
|
|
block = InvertedResidual(16, 16, 32, act_cfg=dict(type='HSigmoid')) |
|
x = torch.randn(1, 16, 56, 56) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size((1, 16, 56, 56)) |
|
|
|
|
|
block = InvertedResidual(16, 16, 32, with_cp=True) |
|
x = torch.randn(1, 16, 56, 56) |
|
x_out = block(x) |
|
assert block.with_cp |
|
assert x_out.shape == torch.Size((1, 16, 56, 56)) |
|
|