camenduru's picture
thanks to show ❤
3bbb319
raw
history blame contribute delete
No virus
1.44 kB
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmdet.models.utils import (LearnedPositionalEncoding,
SinePositionalEncoding)
def test_sine_positional_encoding(num_feats=16, batch_size=2):
# test invalid type of scale
with pytest.raises(AssertionError):
module = SinePositionalEncoding(
num_feats, scale=(3., ), normalize=True)
module = SinePositionalEncoding(num_feats)
h, w = 10, 6
mask = (torch.rand(batch_size, h, w) > 0.5).to(torch.int)
assert not module.normalize
out = module(mask)
assert out.shape == (batch_size, num_feats * 2, h, w)
# set normalize
module = SinePositionalEncoding(num_feats, normalize=True)
assert module.normalize
out = module(mask)
assert out.shape == (batch_size, num_feats * 2, h, w)
def test_learned_positional_encoding(num_feats=16,
row_num_embed=10,
col_num_embed=10,
batch_size=2):
module = LearnedPositionalEncoding(num_feats, row_num_embed, col_num_embed)
assert module.row_embed.weight.shape == (row_num_embed, num_feats)
assert module.col_embed.weight.shape == (col_num_embed, num_feats)
h, w = 10, 6
mask = torch.rand(batch_size, h, w) > 0.5
out = module(mask)
assert out.shape == (batch_size, num_feats * 2, h, w)