glenn-jocher
commited on
Commit
•
6f5d6fc
1
Parent(s):
c09964c
Robust objectness loss balancing (#2256)
Browse files- utils/loss.py +2 -2
utils/loss.py
CHANGED
@@ -105,8 +105,8 @@ class ComputeLoss:
|
|
105 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
106 |
|
107 |
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
108 |
-
self.balance = {3: [4.0, 1.0, 0.4]
|
109 |
-
self.ssi = (det.stride
|
110 |
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
|
111 |
for k in 'na', 'nc', 'nl', 'anchors':
|
112 |
setattr(self, k, getattr(det, k))
|
|
|
105 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
106 |
|
107 |
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
108 |
+
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
|
109 |
+
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
|
110 |
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
|
111 |
for k in 'na', 'nc', 'nl', 'anchors':
|
112 |
setattr(self, k, getattr(det, k))
|