from torch import nn | |
import torch | |
# pool of square window of size=3, stride=2 | |
m = nn.AvgPool2d(3, stride=2) | |
# pool of non-square window | |
m = nn.AvgPool2d(5) | |
input = torch.randn(32,256, 5, 5) | |
output = m(input) | |
output = output.squeeze(-1).squeeze(-1) | |
print(output.shape) |