Spaces:
Runtime error
Runtime error
File size: 4,198 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmocr.models.common.losses import DiceLoss
from mmocr.models.textrecog.losses import (ABILoss, CELoss, CTCLoss, SARLoss,
TFLoss)
def test_ctc_loss():
with pytest.raises(AssertionError):
CTCLoss(flatten='flatten')
with pytest.raises(AssertionError):
CTCLoss(blank=None)
with pytest.raises(AssertionError):
CTCLoss(reduction=1)
with pytest.raises(AssertionError):
CTCLoss(zero_infinity='zero')
# test CTCLoss
ctc_loss = CTCLoss()
outputs = torch.zeros(2, 40, 37)
targets_dict = {
'flatten_targets': torch.IntTensor([1, 2, 3, 4, 5]),
'target_lengths': torch.LongTensor([2, 3])
}
losses = ctc_loss(outputs, targets_dict)
assert isinstance(losses, dict)
assert 'loss_ctc' in losses
assert torch.allclose(losses['loss_ctc'],
torch.tensor(losses['loss_ctc'].item()).float())
def test_ce_loss():
with pytest.raises(AssertionError):
CELoss(ignore_index='ignore')
with pytest.raises(AssertionError):
CELoss(reduction=1)
with pytest.raises(AssertionError):
CELoss(reduction='avg')
ce_loss = CELoss(ignore_index=0)
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
losses = ce_loss(outputs, targets_dict)
assert isinstance(losses, dict)
assert 'loss_ce' in losses
assert losses['loss_ce'].size(1) == 10
ce_loss = CELoss(ignore_first_char=True)
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
new_output, new_target = ce_loss.format(outputs, targets_dict)
assert new_output.shape == torch.Size([1, 37, 9])
assert new_target.shape == torch.Size([1, 9])
def test_sar_loss():
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
sar_loss = SARLoss()
new_output, new_target = sar_loss.format(outputs, targets_dict)
assert new_output.shape == torch.Size([1, 37, 9])
assert new_target.shape == torch.Size([1, 9])
def test_tf_loss():
with pytest.raises(AssertionError):
TFLoss(flatten=1.0)
outputs = torch.rand(1, 10, 37)
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]])
}
tf_loss = TFLoss(flatten=False)
new_output, new_target = tf_loss.format(outputs, targets_dict)
assert new_output.shape == torch.Size([1, 37, 9])
assert new_target.shape == torch.Size([1, 9])
def test_dice_loss():
with pytest.raises(AssertionError):
DiceLoss(eps='1')
dice_loss = DiceLoss()
pred = torch.rand(1, 1, 32, 32)
gt = torch.rand(1, 1, 32, 32)
loss = dice_loss(pred, gt, None)
assert isinstance(loss, torch.Tensor)
mask = torch.rand(1, 1, 1, 1)
loss = dice_loss(pred, gt, mask)
assert isinstance(loss, torch.Tensor)
def test_abi_loss():
loss = ABILoss(num_classes=90)
outputs = dict(
out_enc=dict(logits=torch.randn(2, 10, 90)),
out_decs=[
dict(logits=torch.randn(2, 10, 90)),
dict(logits=torch.randn(2, 10, 90))
],
out_fusers=[
dict(logits=torch.randn(2, 10, 90)),
dict(logits=torch.randn(2, 10, 90))
])
targets_dict = {
'padded_targets': torch.LongTensor([[1, 2, 3, 4, 0, 0, 0, 0, 0, 0]]),
'targets':
[torch.LongTensor([1, 2, 3, 4]),
torch.LongTensor([1, 2, 3])]
}
result = loss(outputs, targets_dict)
assert isinstance(result, dict)
assert isinstance(result['loss_visual'], torch.Tensor)
assert isinstance(result['loss_lang'], torch.Tensor)
assert isinstance(result['loss_fusion'], torch.Tensor)
outputs.pop('out_enc')
loss(outputs, targets_dict)
outputs.pop('out_decs')
loss(outputs, targets_dict)
outputs.pop('out_fusers')
with pytest.raises(AssertionError):
loss(outputs, targets_dict)
|