|
from unittest.mock import patch |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from mmdet.models.utils import AdaptiveAvgPool2d, adaptive_avg_pool2d |
|
|
|
if torch.__version__ != 'parrots': |
|
torch_version = '1.7' |
|
else: |
|
torch_version = 'parrots' |
|
|
|
|
|
@patch('torch.__version__', torch_version) |
|
def test_adaptive_avg_pool2d(): |
|
|
|
|
|
x_empty = torch.randn(0, 3, 4, 5) |
|
|
|
wrapper_out = adaptive_avg_pool2d(x_empty, (2, 2)) |
|
assert wrapper_out.shape == (0, 3, 2, 2) |
|
|
|
wrapper_out = adaptive_avg_pool2d(x_empty, 2) |
|
assert wrapper_out.shape == (0, 3, 2, 2) |
|
|
|
|
|
x_normal = torch.randn(3, 3, 4, 5) |
|
wrapper_out = adaptive_avg_pool2d(x_normal, (2, 2)) |
|
ref_out = F.adaptive_avg_pool2d(x_normal, (2, 2)) |
|
assert wrapper_out.shape == (3, 3, 2, 2) |
|
assert torch.equal(wrapper_out, ref_out) |
|
|
|
wrapper_out = adaptive_avg_pool2d(x_normal, 2) |
|
ref_out = F.adaptive_avg_pool2d(x_normal, 2) |
|
assert wrapper_out.shape == (3, 3, 2, 2) |
|
assert torch.equal(wrapper_out, ref_out) |
|
|
|
|
|
@patch('torch.__version__', torch_version) |
|
def test_AdaptiveAvgPool2d(): |
|
|
|
x_empty = torch.randn(0, 3, 4, 5) |
|
|
|
|
|
wrapper = AdaptiveAvgPool2d((2, 2)) |
|
wrapper_out = wrapper(x_empty) |
|
assert wrapper_out.shape == (0, 3, 2, 2) |
|
|
|
|
|
wrapper = AdaptiveAvgPool2d(2) |
|
wrapper_out = wrapper(x_empty) |
|
assert wrapper_out.shape == (0, 3, 2, 2) |
|
|
|
|
|
wrapper = AdaptiveAvgPool2d((None, 2)) |
|
wrapper_out = wrapper(x_empty) |
|
assert wrapper_out.shape == (0, 3, 4, 2) |
|
|
|
|
|
wrapper = AdaptiveAvgPool2d((2, None)) |
|
wrapper_out = wrapper(x_empty) |
|
assert wrapper_out.shape == (0, 3, 2, 5) |
|
|
|
|
|
x_normal = torch.randn(3, 3, 4, 5) |
|
wrapper = AdaptiveAvgPool2d((2, 2)) |
|
ref = nn.AdaptiveAvgPool2d((2, 2)) |
|
wrapper_out = wrapper(x_normal) |
|
ref_out = ref(x_normal) |
|
assert wrapper_out.shape == (3, 3, 2, 2) |
|
assert torch.equal(wrapper_out, ref_out) |
|
|
|
wrapper = AdaptiveAvgPool2d(2) |
|
ref = nn.AdaptiveAvgPool2d(2) |
|
wrapper_out = wrapper(x_normal) |
|
ref_out = ref(x_normal) |
|
assert wrapper_out.shape == (3, 3, 2, 2) |
|
assert torch.equal(wrapper_out, ref_out) |
|
|
|
wrapper = AdaptiveAvgPool2d((None, 2)) |
|
ref = nn.AdaptiveAvgPool2d((None, 2)) |
|
wrapper_out = wrapper(x_normal) |
|
ref_out = ref(x_normal) |
|
assert wrapper_out.shape == (3, 3, 4, 2) |
|
assert torch.equal(wrapper_out, ref_out) |
|
|
|
wrapper = AdaptiveAvgPool2d((2, None)) |
|
ref = nn.AdaptiveAvgPool2d((2, None)) |
|
wrapper_out = wrapper(x_normal) |
|
ref_out = ref(x_normal) |
|
assert wrapper_out.shape == (3, 3, 2, 5) |
|
assert torch.equal(wrapper_out, ref_out) |
|
|