File size: 8,705 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.utils import digit_version
from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss, DiceLoss,
DistributionFocalLoss, FocalLoss,
GaussianFocalLoss,
KnowledgeDistillationKLDivLoss, L1Loss,
MSELoss, QualityFocalLoss, SeesawLoss,
SmoothL1Loss, VarifocalLoss)
from mmdet.models.losses.ghm_loss import GHMC, GHMR
from mmdet.models.losses.iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss,
GIoULoss, IoULoss)
@pytest.mark.parametrize(
'loss_class', [IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss])
def test_iou_type_loss_zeros_weight(loss_class):
pred = torch.rand((10, 4))
target = torch.rand((10, 4))
weight = torch.zeros(10)
loss = loss_class()(pred, target, weight)
assert loss == 0.
@pytest.mark.parametrize('loss_class', [
BalancedL1Loss, BoundedIoULoss, CIoULoss, CrossEntropyLoss, DIoULoss,
FocalLoss, DistributionFocalLoss, MSELoss, SeesawLoss, GaussianFocalLoss,
GIoULoss, IoULoss, L1Loss, QualityFocalLoss, VarifocalLoss, GHMR, GHMC,
SmoothL1Loss, KnowledgeDistillationKLDivLoss, DiceLoss
])
def test_loss_with_reduction_override(loss_class):
pred = torch.rand((10, 4))
target = torch.rand((10, 4)),
weight = None
with pytest.raises(AssertionError):
# only reduction_override from [None, 'none', 'mean', 'sum']
# is not allowed
reduction_override = True
loss_class()(
pred, target, weight, reduction_override=reduction_override)
@pytest.mark.parametrize('loss_class', [
IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, MSELoss, L1Loss,
SmoothL1Loss, BalancedL1Loss
])
@pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
def test_regression_losses(loss_class, input_shape):
pred = torch.rand(input_shape)
target = torch.rand(input_shape)
weight = torch.rand(input_shape)
# Test loss forward
loss = loss_class()(pred, target)
assert isinstance(loss, torch.Tensor)
# Test loss forward with weight
loss = loss_class()(pred, target, weight)
assert isinstance(loss, torch.Tensor)
# Test loss forward with reduction_override
loss = loss_class()(pred, target, reduction_override='mean')
assert isinstance(loss, torch.Tensor)
# Test loss forward with avg_factor
loss = loss_class()(pred, target, avg_factor=10)
assert isinstance(loss, torch.Tensor)
with pytest.raises(ValueError):
# loss can evaluate with avg_factor only if
# reduction is None, 'none' or 'mean'.
reduction_override = 'sum'
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
# Test loss forward with avg_factor and reduction
for reduction_override in [None, 'none', 'mean']:
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
assert isinstance(loss, torch.Tensor)
@pytest.mark.parametrize('loss_class', [FocalLoss, CrossEntropyLoss])
@pytest.mark.parametrize('input_shape', [(10, 5), (0, 5)])
def test_classification_losses(loss_class, input_shape):
if input_shape[0] == 0 and digit_version(
torch.__version__) < digit_version('1.5.0'):
pytest.skip(
f'CELoss in PyTorch {torch.__version__} does not support empty'
f'tensor.')
pred = torch.rand(input_shape)
target = torch.randint(0, 5, (input_shape[0], ))
# Test loss forward
loss = loss_class()(pred, target)
assert isinstance(loss, torch.Tensor)
# Test loss forward with reduction_override
loss = loss_class()(pred, target, reduction_override='mean')
assert isinstance(loss, torch.Tensor)
# Test loss forward with avg_factor
loss = loss_class()(pred, target, avg_factor=10)
assert isinstance(loss, torch.Tensor)
with pytest.raises(ValueError):
# loss can evaluate with avg_factor only if
# reduction is None, 'none' or 'mean'.
reduction_override = 'sum'
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
# Test loss forward with avg_factor and reduction
for reduction_override in [None, 'none', 'mean']:
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
assert isinstance(loss, torch.Tensor)
@pytest.mark.parametrize('loss_class', [GHMR])
@pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
def test_GHMR_loss(loss_class, input_shape):
pred = torch.rand(input_shape)
target = torch.rand(input_shape)
weight = torch.rand(input_shape)
# Test loss forward
loss = loss_class()(pred, target, weight)
assert isinstance(loss, torch.Tensor)
@pytest.mark.parametrize('use_sigmoid', [True, False])
@pytest.mark.parametrize('reduction', ['sum', 'mean', None])
@pytest.mark.parametrize('avg_non_ignore', [True, False])
def test_loss_with_ignore_index(use_sigmoid, reduction, avg_non_ignore):
# Test cross_entropy loss
loss_class = CrossEntropyLoss(
use_sigmoid=use_sigmoid,
use_mask=False,
ignore_index=255,
avg_non_ignore=avg_non_ignore)
pred = torch.rand((10, 5))
target = torch.randint(0, 5, (10, ))
ignored_indices = torch.randint(0, 10, (2, ), dtype=torch.long)
target[ignored_indices] = 255
# Test loss forward with default ignore
loss_with_ignore = loss_class(pred, target, reduction_override=reduction)
assert isinstance(loss_with_ignore, torch.Tensor)
# Test loss forward with forward ignore
target[ignored_indices] = 255
loss_with_forward_ignore = loss_class(
pred, target, ignore_index=255, reduction_override=reduction)
assert isinstance(loss_with_forward_ignore, torch.Tensor)
# Verify correctness
if avg_non_ignore:
# manually remove the ignored elements
not_ignored_indices = (target != 255)
pred = pred[not_ignored_indices]
target = target[not_ignored_indices]
loss = loss_class(pred, target, reduction_override=reduction)
assert torch.allclose(loss, loss_with_ignore)
assert torch.allclose(loss, loss_with_forward_ignore)
# test ignore all target
pred = torch.rand((10, 5))
target = torch.ones((10, ), dtype=torch.long) * 255
loss = loss_class(pred, target, reduction_override=reduction)
assert loss == 0
@pytest.mark.parametrize('naive_dice', [True, False])
def test_dice_loss(naive_dice):
loss_class = DiceLoss
pred = torch.rand((10, 4, 4))
target = torch.rand((10, 4, 4))
weight = torch.rand((10))
# Test loss forward
loss = loss_class(naive_dice=naive_dice)(pred, target)
assert isinstance(loss, torch.Tensor)
# Test loss forward with weight
loss = loss_class(naive_dice=naive_dice)(pred, target, weight)
assert isinstance(loss, torch.Tensor)
# Test loss forward with reduction_override
loss = loss_class(naive_dice=naive_dice)(
pred, target, reduction_override='mean')
assert isinstance(loss, torch.Tensor)
# Test loss forward with avg_factor
loss = loss_class(naive_dice=naive_dice)(pred, target, avg_factor=10)
assert isinstance(loss, torch.Tensor)
with pytest.raises(ValueError):
# loss can evaluate with avg_factor only if
# reduction is None, 'none' or 'mean'.
reduction_override = 'sum'
loss_class(naive_dice=naive_dice)(
pred, target, avg_factor=10, reduction_override=reduction_override)
# Test loss forward with avg_factor and reduction
for reduction_override in [None, 'none', 'mean']:
loss_class(naive_dice=naive_dice)(
pred, target, avg_factor=10, reduction_override=reduction_override)
assert isinstance(loss, torch.Tensor)
# Test loss forward with has_acted=False and use_sigmoid=False
with pytest.raises(NotImplementedError):
loss_class(
use_sigmoid=False, activate=True, naive_dice=naive_dice)(pred,
target)
# Test loss forward with weight.ndim != loss.ndim
with pytest.raises(AssertionError):
weight = torch.rand((2, 8))
loss_class(naive_dice=naive_dice)(pred, target, weight)
# Test loss forward with len(weight) != len(pred)
with pytest.raises(AssertionError):
weight = torch.rand((8))
loss_class(naive_dice=naive_dice)(pred, target, weight)
|