camenduru's picture
thanks to show ❤
3bbb319
raw
history blame contribute delete
No virus
2.93 kB
from unittest.mock import patch
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.utils import AdaptiveAvgPool2d, adaptive_avg_pool2d
if torch.__version__ != 'parrots':
torch_version = '1.7'
else:
torch_version = 'parrots'
@patch('torch.__version__', torch_version)
def test_adaptive_avg_pool2d():
# Test the empty batch dimension
# Test the two input conditions
x_empty = torch.randn(0, 3, 4, 5)
# 1. tuple[int, int]
wrapper_out = adaptive_avg_pool2d(x_empty, (2, 2))
assert wrapper_out.shape == (0, 3, 2, 2)
# 2. int
wrapper_out = adaptive_avg_pool2d(x_empty, 2)
assert wrapper_out.shape == (0, 3, 2, 2)
# wrapper op with 3-dim input
x_normal = torch.randn(3, 3, 4, 5)
wrapper_out = adaptive_avg_pool2d(x_normal, (2, 2))
ref_out = F.adaptive_avg_pool2d(x_normal, (2, 2))
assert wrapper_out.shape == (3, 3, 2, 2)
assert torch.equal(wrapper_out, ref_out)
wrapper_out = adaptive_avg_pool2d(x_normal, 2)
ref_out = F.adaptive_avg_pool2d(x_normal, 2)
assert wrapper_out.shape == (3, 3, 2, 2)
assert torch.equal(wrapper_out, ref_out)
@patch('torch.__version__', torch_version)
def test_AdaptiveAvgPool2d():
# Test the empty batch dimension
x_empty = torch.randn(0, 3, 4, 5)
# Test the four input conditions
# 1. tuple[int, int]
wrapper = AdaptiveAvgPool2d((2, 2))
wrapper_out = wrapper(x_empty)
assert wrapper_out.shape == (0, 3, 2, 2)
# 2. int
wrapper = AdaptiveAvgPool2d(2)
wrapper_out = wrapper(x_empty)
assert wrapper_out.shape == (0, 3, 2, 2)
# 3. tuple[None, int]
wrapper = AdaptiveAvgPool2d((None, 2))
wrapper_out = wrapper(x_empty)
assert wrapper_out.shape == (0, 3, 4, 2)
# 3. tuple[int, None]
wrapper = AdaptiveAvgPool2d((2, None))
wrapper_out = wrapper(x_empty)
assert wrapper_out.shape == (0, 3, 2, 5)
# Test the normal batch dimension
x_normal = torch.randn(3, 3, 4, 5)
wrapper = AdaptiveAvgPool2d((2, 2))
ref = nn.AdaptiveAvgPool2d((2, 2))
wrapper_out = wrapper(x_normal)
ref_out = ref(x_normal)
assert wrapper_out.shape == (3, 3, 2, 2)
assert torch.equal(wrapper_out, ref_out)
wrapper = AdaptiveAvgPool2d(2)
ref = nn.AdaptiveAvgPool2d(2)
wrapper_out = wrapper(x_normal)
ref_out = ref(x_normal)
assert wrapper_out.shape == (3, 3, 2, 2)
assert torch.equal(wrapper_out, ref_out)
wrapper = AdaptiveAvgPool2d((None, 2))
ref = nn.AdaptiveAvgPool2d((None, 2))
wrapper_out = wrapper(x_normal)
ref_out = ref(x_normal)
assert wrapper_out.shape == (3, 3, 4, 2)
assert torch.equal(wrapper_out, ref_out)
wrapper = AdaptiveAvgPool2d((2, None))
ref = nn.AdaptiveAvgPool2d((2, None))
wrapper_out = wrapper(x_normal)
ref_out = ref(x_normal)
assert wrapper_out.shape == (3, 3, 2, 5)
assert torch.equal(wrapper_out, ref_out)