File size: 3,332 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import pytest
import torch
from mmdet.models.backbones.pvt import (PVTEncoderLayer,
PyramidVisionTransformer,
PyramidVisionTransformerV2)
def test_pvt_block():
# test PVT structure and forward
block = PVTEncoderLayer(
embed_dims=64, num_heads=4, feedforward_channels=256)
assert block.ffn.embed_dims == 64
assert block.attn.num_heads == 4
assert block.ffn.feedforward_channels == 256
x = torch.randn(1, 56 * 56, 64)
x_out = block(x, (56, 56))
assert x_out.shape == torch.Size([1, 56 * 56, 64])
def test_pvt():
"""Test PVT backbone."""
with pytest.raises(TypeError):
# Pretrained arg must be str or None.
PyramidVisionTransformer(pretrained=123)
# test pretrained image size
with pytest.raises(AssertionError):
PyramidVisionTransformer(pretrain_img_size=(224, 224, 224))
# Test absolute position embedding
temp = torch.randn((1, 3, 224, 224))
model = PyramidVisionTransformer(
pretrain_img_size=224, use_abs_pos_embed=True)
model.init_weights()
model(temp)
# Test normal inference
temp = torch.randn((1, 3, 32, 32))
model = PyramidVisionTransformer()
outs = model(temp)
assert outs[0].shape == (1, 64, 8, 8)
assert outs[1].shape == (1, 128, 4, 4)
assert outs[2].shape == (1, 320, 2, 2)
assert outs[3].shape == (1, 512, 1, 1)
# Test abnormal inference size
temp = torch.randn((1, 3, 33, 33))
model = PyramidVisionTransformer()
outs = model(temp)
assert outs[0].shape == (1, 64, 8, 8)
assert outs[1].shape == (1, 128, 4, 4)
assert outs[2].shape == (1, 320, 2, 2)
assert outs[3].shape == (1, 512, 1, 1)
# Test abnormal inference size
temp = torch.randn((1, 3, 112, 137))
model = PyramidVisionTransformer()
outs = model(temp)
assert outs[0].shape == (1, 64, 28, 34)
assert outs[1].shape == (1, 128, 14, 17)
assert outs[2].shape == (1, 320, 7, 8)
assert outs[3].shape == (1, 512, 3, 4)
def test_pvtv2():
"""Test PVTv2 backbone."""
with pytest.raises(TypeError):
# Pretrained arg must be str or None.
PyramidVisionTransformerV2(pretrained=123)
# test pretrained image size
with pytest.raises(AssertionError):
PyramidVisionTransformerV2(pretrain_img_size=(224, 224, 224))
# Test normal inference
temp = torch.randn((1, 3, 32, 32))
model = PyramidVisionTransformerV2()
outs = model(temp)
assert outs[0].shape == (1, 64, 8, 8)
assert outs[1].shape == (1, 128, 4, 4)
assert outs[2].shape == (1, 320, 2, 2)
assert outs[3].shape == (1, 512, 1, 1)
# Test abnormal inference size
temp = torch.randn((1, 3, 31, 31))
model = PyramidVisionTransformerV2()
outs = model(temp)
assert outs[0].shape == (1, 64, 8, 8)
assert outs[1].shape == (1, 128, 4, 4)
assert outs[2].shape == (1, 320, 2, 2)
assert outs[3].shape == (1, 512, 1, 1)
# Test abnormal inference size
temp = torch.randn((1, 3, 112, 137))
model = PyramidVisionTransformerV2()
outs = model(temp)
assert outs[0].shape == (1, 64, 28, 35)
assert outs[1].shape == (1, 128, 14, 18)
assert outs[2].shape == (1, 320, 7, 9)
assert outs[3].shape == (1, 512, 4, 5)
|