# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch from mmdet.models.backbones.hourglass import HourglassNet def test_hourglass_backbone(): with pytest.raises(AssertionError): # HourglassNet's num_stacks should larger than 0 HourglassNet(num_stacks=0) with pytest.raises(AssertionError): # len(stage_channels) should equal len(stage_blocks) HourglassNet( stage_channels=[256, 256, 384, 384, 384], stage_blocks=[2, 2, 2, 2, 2, 4]) with pytest.raises(AssertionError): # len(stage_channels) should lagrer than downsample_times HourglassNet( downsample_times=5, stage_channels=[256, 256, 384, 384, 384], stage_blocks=[2, 2, 2, 2, 2]) # Test HourglassNet-52 model = HourglassNet( num_stacks=1, stage_channels=(64, 64, 96, 96, 96, 128), feat_channel=64) model.train() imgs = torch.randn(1, 3, 256, 256) feat = model(imgs) assert len(feat) == 1 assert feat[0].shape == torch.Size([1, 64, 64, 64]) # Test HourglassNet-104 model = HourglassNet( num_stacks=2, stage_channels=(64, 64, 96, 96, 96, 128), feat_channel=64) model.train() imgs = torch.randn(1, 3, 256, 256) feat = model(imgs) assert len(feat) == 2 assert feat[0].shape == torch.Size([1, 64, 64, 64]) assert feat[1].shape == torch.Size([1, 64, 64, 64])