camenduru's picture
thanks to show ❤
3bbb319
raw
history blame contribute delete
No virus
4.81 kB
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import mmcv
import pytest
import torch
from mmdet import digit_version
from mmdet.models.necks import FPN, YOLOV3Neck
from .utils import ort_validate
if digit_version(torch.__version__) <= digit_version('1.5.0'):
pytest.skip(
'ort backend does not support version below 1.5.0',
allow_module_level=True)
# Control the returned model of fpn_neck_config()
fpn_test_step_names = {
'fpn_normal': 0,
'fpn_wo_extra_convs': 1,
'fpn_lateral_bns': 2,
'fpn_bilinear_upsample': 3,
'fpn_scale_factor': 4,
'fpn_extra_convs_inputs': 5,
'fpn_extra_convs_laterals': 6,
'fpn_extra_convs_outputs': 7,
}
# Control the returned model of yolo_neck_config()
yolo_test_step_names = {'yolo_normal': 0}
data_path = osp.join(osp.dirname(__file__), 'data')
def fpn_neck_config(test_step_name):
"""Return the class containing the corresponding attributes according to
the fpn_test_step_names."""
s = 64
in_channels = [8, 16, 32, 64]
feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
out_channels = 8
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels))
]
if (fpn_test_step_names[test_step_name] == 0):
fpn_model = FPN(
in_channels=in_channels,
out_channels=out_channels,
add_extra_convs=True,
num_outs=5)
elif (fpn_test_step_names[test_step_name] == 1):
fpn_model = FPN(
in_channels=in_channels,
out_channels=out_channels,
add_extra_convs=False,
num_outs=5)
elif (fpn_test_step_names[test_step_name] == 2):
fpn_model = FPN(
in_channels=in_channels,
out_channels=out_channels,
add_extra_convs=True,
no_norm_on_lateral=False,
norm_cfg=dict(type='BN', requires_grad=True),
num_outs=5)
elif (fpn_test_step_names[test_step_name] == 3):
fpn_model = FPN(
in_channels=in_channels,
out_channels=out_channels,
add_extra_convs=True,
upsample_cfg=dict(mode='bilinear', align_corners=True),
num_outs=5)
elif (fpn_test_step_names[test_step_name] == 4):
fpn_model = FPN(
in_channels=in_channels,
out_channels=out_channels,
add_extra_convs=True,
upsample_cfg=dict(scale_factor=2),
num_outs=5)
elif (fpn_test_step_names[test_step_name] == 5):
fpn_model = FPN(
in_channels=in_channels,
out_channels=out_channels,
add_extra_convs='on_input',
num_outs=5)
elif (fpn_test_step_names[test_step_name] == 6):
fpn_model = FPN(
in_channels=in_channels,
out_channels=out_channels,
add_extra_convs='on_lateral',
num_outs=5)
elif (fpn_test_step_names[test_step_name] == 7):
fpn_model = FPN(
in_channels=in_channels,
out_channels=out_channels,
add_extra_convs='on_output',
num_outs=5)
return fpn_model, feats
def yolo_neck_config(test_step_name):
"""Config yolov3 Neck."""
in_channels = [16, 8, 4]
out_channels = [8, 4, 2]
# The data of yolov3_neck.pkl contains a list of
# torch.Tensor, where each torch.Tensor is generated by
# torch.rand and each tensor size is:
# (1, 4, 64, 64), (1, 8, 32, 32), (1, 16, 16, 16).
yolov3_neck_data = 'yolov3_neck.pkl'
feats = mmcv.load(osp.join(data_path, yolov3_neck_data))
if (yolo_test_step_names[test_step_name] == 0):
yolo_model = YOLOV3Neck(
in_channels=in_channels, out_channels=out_channels, num_scales=3)
return yolo_model, feats
def test_fpn_normal():
outs = fpn_neck_config('fpn_normal')
ort_validate(*outs)
def test_fpn_wo_extra_convs():
outs = fpn_neck_config('fpn_wo_extra_convs')
ort_validate(*outs)
def test_fpn_lateral_bns():
outs = fpn_neck_config('fpn_lateral_bns')
ort_validate(*outs)
def test_fpn_bilinear_upsample():
outs = fpn_neck_config('fpn_bilinear_upsample')
ort_validate(*outs)
def test_fpn_scale_factor():
outs = fpn_neck_config('fpn_scale_factor')
ort_validate(*outs)
def test_fpn_extra_convs_inputs():
outs = fpn_neck_config('fpn_extra_convs_inputs')
ort_validate(*outs)
def test_fpn_extra_convs_laterals():
outs = fpn_neck_config('fpn_extra_convs_laterals')
ort_validate(*outs)
def test_fpn_extra_convs_outputs():
outs = fpn_neck_config('fpn_extra_convs_outputs')
ort_validate(*outs)
def test_yolo_normal():
outs = yolo_neck_config('yolo_normal')
ort_validate(*outs)