Add QFocalLoss() (#1482)
Browse files* Update loss.py
implement the quality focal loss which is a more general case of focal loss
more detail in https://arxiv.org/abs/2006.04388
In the obj loss (or the case cls loss with label smooth), the targets is no long barely be 0 or 1 (can be 0.7), in this case, the normal focal loss is not work accurately
quality focal loss in behave the same as focal loss when the target is equal to 0 or 1, and work accurately when targets in (0, 1)
example:
targets:
tensor([[0.6225, 0.0000, 0.0000],
[0.9000, 0.0000, 0.0000],
[1.0000, 0.0000, 0.0000]])
___________________________
pred_prob:
tensor([[0.6225, 0.2689, 0.1192],
[0.7773, 0.5000, 0.2227],
[0.8176, 0.8808, 0.1978]])
____________________________
focal_loss
tensor([[0.0937, 0.0328, 0.0039],
[0.0166, 0.1838, 0.0199],
[0.0039, 1.3186, 0.0145]])
______________
qfocal_loss
tensor([[7.5373e-08, 3.2768e-02, 3.9179e-03],
[4.8601e-03, 1.8380e-01, 1.9857e-02],
[3.9233e-03, 1.3186e+00, 1.4545e-02]])
we can see that targets[0][0] = 0.6255 is almost the same as pred_prob[0][0] = 0.6225,
the targets[1][0] = 0.9 is greater then pred_prob[1][0] = 0.7773 by 0.1227
however, the focal loss[0][0] = 0.0937 larger then focal loss[1][0] = 0.0166 (which against the purpose of focal loss)
for the quality focal loss , it implement the case of targets not equal to 0 or 1
* Update loss.py
- utils/loss.py +26 -0
@@ -57,6 +57,32 @@ class FocalLoss(nn.Module):
|
|
57 |
return loss.sum()
|
58 |
else: # 'none'
|
59 |
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
def compute_loss(p, targets, model): # predictions, targets, model
|
|
|
57 |
return loss.sum()
|
58 |
else: # 'none'
|
59 |
return loss
|
60 |
+
|
61 |
+
|
62 |
+
class QFocalLoss(nn.Module):
|
63 |
+
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
64 |
+
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
65 |
+
super(QFocalLoss, self).__init__()
|
66 |
+
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
67 |
+
self.gamma = gamma
|
68 |
+
self.alpha = alpha
|
69 |
+
self.reduction = loss_fcn.reduction
|
70 |
+
self.loss_fcn.reduction = 'none' # required to apply FL to each element
|
71 |
+
|
72 |
+
def forward(self, pred, true):
|
73 |
+
loss = self.loss_fcn(pred, true)
|
74 |
+
|
75 |
+
pred_prob = torch.sigmoid(pred) # prob from logits
|
76 |
+
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
77 |
+
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
|
78 |
+
loss *= alpha_factor * modulating_factor
|
79 |
+
|
80 |
+
if self.reduction == 'mean':
|
81 |
+
return loss.mean()
|
82 |
+
elif self.reduction == 'sum':
|
83 |
+
return loss.sum()
|
84 |
+
else: # 'none'
|
85 |
+
return loss
|
86 |
|
87 |
|
88 |
def compute_loss(p, targets, model): # predictions, targets, model
|