from unittest.mock import patch import torch import torch.nn as nn import torch.nn.functional as F from mmdet.models.utils import AdaptiveAvgPool2d, adaptive_avg_pool2d if torch.__version__ != 'parrots': torch_version = '1.7' else: torch_version = 'parrots' @patch('torch.__version__', torch_version) def test_adaptive_avg_pool2d(): # Test the empty batch dimension # Test the two input conditions x_empty = torch.randn(0, 3, 4, 5) # 1. tuple[int, int] wrapper_out = adaptive_avg_pool2d(x_empty, (2, 2)) assert wrapper_out.shape == (0, 3, 2, 2) # 2. int wrapper_out = adaptive_avg_pool2d(x_empty, 2) assert wrapper_out.shape == (0, 3, 2, 2) # wrapper op with 3-dim input x_normal = torch.randn(3, 3, 4, 5) wrapper_out = adaptive_avg_pool2d(x_normal, (2, 2)) ref_out = F.adaptive_avg_pool2d(x_normal, (2, 2)) assert wrapper_out.shape == (3, 3, 2, 2) assert torch.equal(wrapper_out, ref_out) wrapper_out = adaptive_avg_pool2d(x_normal, 2) ref_out = F.adaptive_avg_pool2d(x_normal, 2) assert wrapper_out.shape == (3, 3, 2, 2) assert torch.equal(wrapper_out, ref_out) @patch('torch.__version__', torch_version) def test_AdaptiveAvgPool2d(): # Test the empty batch dimension x_empty = torch.randn(0, 3, 4, 5) # Test the four input conditions # 1. tuple[int, int] wrapper = AdaptiveAvgPool2d((2, 2)) wrapper_out = wrapper(x_empty) assert wrapper_out.shape == (0, 3, 2, 2) # 2. int wrapper = AdaptiveAvgPool2d(2) wrapper_out = wrapper(x_empty) assert wrapper_out.shape == (0, 3, 2, 2) # 3. tuple[None, int] wrapper = AdaptiveAvgPool2d((None, 2)) wrapper_out = wrapper(x_empty) assert wrapper_out.shape == (0, 3, 4, 2) # 3. tuple[int, None] wrapper = AdaptiveAvgPool2d((2, None)) wrapper_out = wrapper(x_empty) assert wrapper_out.shape == (0, 3, 2, 5) # Test the normal batch dimension x_normal = torch.randn(3, 3, 4, 5) wrapper = AdaptiveAvgPool2d((2, 2)) ref = nn.AdaptiveAvgPool2d((2, 2)) wrapper_out = wrapper(x_normal) ref_out = ref(x_normal) assert wrapper_out.shape == (3, 3, 2, 2) assert torch.equal(wrapper_out, ref_out) wrapper = AdaptiveAvgPool2d(2) ref = nn.AdaptiveAvgPool2d(2) wrapper_out = wrapper(x_normal) ref_out = ref(x_normal) assert wrapper_out.shape == (3, 3, 2, 2) assert torch.equal(wrapper_out, ref_out) wrapper = AdaptiveAvgPool2d((None, 2)) ref = nn.AdaptiveAvgPool2d((None, 2)) wrapper_out = wrapper(x_normal) ref_out = ref(x_normal) assert wrapper_out.shape == (3, 3, 4, 2) assert torch.equal(wrapper_out, ref_out) wrapper = AdaptiveAvgPool2d((2, None)) ref = nn.AdaptiveAvgPool2d((2, None)) wrapper_out = wrapper(x_normal) ref_out = ref(x_normal) assert wrapper_out.shape == (3, 3, 2, 5) assert torch.equal(wrapper_out, ref_out)