camenduru's picture
thanks to show ❤
3bbb319
raw
history blame contribute delete
No virus
3.33 kB
import pytest
import torch
from mmdet.models.backbones.pvt import (PVTEncoderLayer,
PyramidVisionTransformer,
PyramidVisionTransformerV2)
def test_pvt_block():
# test PVT structure and forward
block = PVTEncoderLayer(
embed_dims=64, num_heads=4, feedforward_channels=256)
assert block.ffn.embed_dims == 64
assert block.attn.num_heads == 4
assert block.ffn.feedforward_channels == 256
x = torch.randn(1, 56 * 56, 64)
x_out = block(x, (56, 56))
assert x_out.shape == torch.Size([1, 56 * 56, 64])
def test_pvt():
"""Test PVT backbone."""
with pytest.raises(TypeError):
# Pretrained arg must be str or None.
PyramidVisionTransformer(pretrained=123)
# test pretrained image size
with pytest.raises(AssertionError):
PyramidVisionTransformer(pretrain_img_size=(224, 224, 224))
# Test absolute position embedding
temp = torch.randn((1, 3, 224, 224))
model = PyramidVisionTransformer(
pretrain_img_size=224, use_abs_pos_embed=True)
model.init_weights()
model(temp)
# Test normal inference
temp = torch.randn((1, 3, 32, 32))
model = PyramidVisionTransformer()
outs = model(temp)
assert outs[0].shape == (1, 64, 8, 8)
assert outs[1].shape == (1, 128, 4, 4)
assert outs[2].shape == (1, 320, 2, 2)
assert outs[3].shape == (1, 512, 1, 1)
# Test abnormal inference size
temp = torch.randn((1, 3, 33, 33))
model = PyramidVisionTransformer()
outs = model(temp)
assert outs[0].shape == (1, 64, 8, 8)
assert outs[1].shape == (1, 128, 4, 4)
assert outs[2].shape == (1, 320, 2, 2)
assert outs[3].shape == (1, 512, 1, 1)
# Test abnormal inference size
temp = torch.randn((1, 3, 112, 137))
model = PyramidVisionTransformer()
outs = model(temp)
assert outs[0].shape == (1, 64, 28, 34)
assert outs[1].shape == (1, 128, 14, 17)
assert outs[2].shape == (1, 320, 7, 8)
assert outs[3].shape == (1, 512, 3, 4)
def test_pvtv2():
"""Test PVTv2 backbone."""
with pytest.raises(TypeError):
# Pretrained arg must be str or None.
PyramidVisionTransformerV2(pretrained=123)
# test pretrained image size
with pytest.raises(AssertionError):
PyramidVisionTransformerV2(pretrain_img_size=(224, 224, 224))
# Test normal inference
temp = torch.randn((1, 3, 32, 32))
model = PyramidVisionTransformerV2()
outs = model(temp)
assert outs[0].shape == (1, 64, 8, 8)
assert outs[1].shape == (1, 128, 4, 4)
assert outs[2].shape == (1, 320, 2, 2)
assert outs[3].shape == (1, 512, 1, 1)
# Test abnormal inference size
temp = torch.randn((1, 3, 31, 31))
model = PyramidVisionTransformerV2()
outs = model(temp)
assert outs[0].shape == (1, 64, 8, 8)
assert outs[1].shape == (1, 128, 4, 4)
assert outs[2].shape == (1, 320, 2, 2)
assert outs[3].shape == (1, 512, 1, 1)
# Test abnormal inference size
temp = torch.randn((1, 3, 112, 137))
model = PyramidVisionTransformerV2()
outs = model(temp)
assert outs[0].shape == (1, 64, 28, 35)
assert outs[1].shape == (1, 128, 14, 18)
assert outs[2].shape == (1, 320, 7, 9)
assert outs[3].shape == (1, 512, 4, 5)