File size: 2,931 Bytes
3bbb319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from unittest.mock import patch

import torch
import torch.nn as nn
import torch.nn.functional as F

from mmdet.models.utils import AdaptiveAvgPool2d, adaptive_avg_pool2d

if torch.__version__ != 'parrots':
    torch_version = '1.7'
else:
    torch_version = 'parrots'


@patch('torch.__version__', torch_version)
def test_adaptive_avg_pool2d():
    # Test the empty batch dimension
    # Test the two input conditions
    x_empty = torch.randn(0, 3, 4, 5)
    # 1. tuple[int, int]
    wrapper_out = adaptive_avg_pool2d(x_empty, (2, 2))
    assert wrapper_out.shape == (0, 3, 2, 2)
    # 2. int
    wrapper_out = adaptive_avg_pool2d(x_empty, 2)
    assert wrapper_out.shape == (0, 3, 2, 2)

    # wrapper op with 3-dim input
    x_normal = torch.randn(3, 3, 4, 5)
    wrapper_out = adaptive_avg_pool2d(x_normal, (2, 2))
    ref_out = F.adaptive_avg_pool2d(x_normal, (2, 2))
    assert wrapper_out.shape == (3, 3, 2, 2)
    assert torch.equal(wrapper_out, ref_out)

    wrapper_out = adaptive_avg_pool2d(x_normal, 2)
    ref_out = F.adaptive_avg_pool2d(x_normal, 2)
    assert wrapper_out.shape == (3, 3, 2, 2)
    assert torch.equal(wrapper_out, ref_out)


@patch('torch.__version__', torch_version)
def test_AdaptiveAvgPool2d():
    # Test the empty batch dimension
    x_empty = torch.randn(0, 3, 4, 5)
    # Test the four input conditions
    # 1. tuple[int, int]
    wrapper = AdaptiveAvgPool2d((2, 2))
    wrapper_out = wrapper(x_empty)
    assert wrapper_out.shape == (0, 3, 2, 2)

    # 2. int
    wrapper = AdaptiveAvgPool2d(2)
    wrapper_out = wrapper(x_empty)
    assert wrapper_out.shape == (0, 3, 2, 2)

    # 3. tuple[None, int]
    wrapper = AdaptiveAvgPool2d((None, 2))
    wrapper_out = wrapper(x_empty)
    assert wrapper_out.shape == (0, 3, 4, 2)

    # 3. tuple[int, None]
    wrapper = AdaptiveAvgPool2d((2, None))
    wrapper_out = wrapper(x_empty)
    assert wrapper_out.shape == (0, 3, 2, 5)

    # Test the normal batch dimension
    x_normal = torch.randn(3, 3, 4, 5)
    wrapper = AdaptiveAvgPool2d((2, 2))
    ref = nn.AdaptiveAvgPool2d((2, 2))
    wrapper_out = wrapper(x_normal)
    ref_out = ref(x_normal)
    assert wrapper_out.shape == (3, 3, 2, 2)
    assert torch.equal(wrapper_out, ref_out)

    wrapper = AdaptiveAvgPool2d(2)
    ref = nn.AdaptiveAvgPool2d(2)
    wrapper_out = wrapper(x_normal)
    ref_out = ref(x_normal)
    assert wrapper_out.shape == (3, 3, 2, 2)
    assert torch.equal(wrapper_out, ref_out)

    wrapper = AdaptiveAvgPool2d((None, 2))
    ref = nn.AdaptiveAvgPool2d((None, 2))
    wrapper_out = wrapper(x_normal)
    ref_out = ref(x_normal)
    assert wrapper_out.shape == (3, 3, 4, 2)
    assert torch.equal(wrapper_out, ref_out)

    wrapper = AdaptiveAvgPool2d((2, None))
    ref = nn.AdaptiveAvgPool2d((2, None))
    wrapper_out = wrapper(x_normal)
    ref_out = ref(x_normal)
    assert wrapper_out.shape == (3, 3, 2, 5)
    assert torch.equal(wrapper_out, ref_out)