|
import torch.nn as nn |
|
import torch |
|
from .general import bbox_iou |
|
from .postprocess import build_targets |
|
from lib.core.evaluate import SegmentationMetric |
|
|
|
class MultiHeadLoss(nn.Module): |
|
""" |
|
collect all the loss we need |
|
""" |
|
def __init__(self, losses, cfg, lambdas=None): |
|
""" |
|
Inputs: |
|
- losses: (list)[nn.Module, nn.Module, ...] |
|
- cfg: config object |
|
- lambdas: (list) + IoU loss, weight for each loss |
|
""" |
|
super().__init__() |
|
|
|
if not lambdas: |
|
lambdas = [1.0 for _ in range(len(losses) + 3)] |
|
assert all(lam >= 0.0 for lam in lambdas) |
|
|
|
self.losses = nn.ModuleList(losses) |
|
self.lambdas = lambdas |
|
self.cfg = cfg |
|
|
|
def forward(self, head_fields, head_targets, shapes, model): |
|
""" |
|
Inputs: |
|
- head_fields: (list) output from each task head |
|
- head_targets: (list) ground-truth for each task head |
|
- model: |
|
|
|
Returns: |
|
- total_loss: sum of all the loss |
|
- head_losses: (tuple) contain all loss[loss1, loss2, ...] |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
total_loss, head_losses = self._forward_impl(head_fields, head_targets, shapes, model) |
|
|
|
return total_loss, head_losses |
|
|
|
def _forward_impl(self, predictions, targets, shapes, model): |
|
""" |
|
|
|
Args: |
|
predictions: predicts of [[det_head1, det_head2, det_head3], drive_area_seg_head, lane_line_seg_head] |
|
targets: gts [det_targets, segment_targets, lane_targets] |
|
model: |
|
|
|
Returns: |
|
total_loss: sum of all the loss |
|
head_losses: list containing losses |
|
|
|
""" |
|
cfg = self.cfg |
|
device = targets[0].device |
|
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) |
|
tcls, tbox, indices, anchors = build_targets(cfg, predictions[0], targets[0], model) |
|
|
|
|
|
cp, cn = smooth_BCE(eps=0.0) |
|
|
|
BCEcls, BCEobj, BCEseg = self.losses |
|
|
|
|
|
nt = 0 |
|
no = len(predictions[0]) |
|
balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1] |
|
|
|
|
|
for i, pi in enumerate(predictions[0]): |
|
b, a, gj, gi = indices[i] |
|
tobj = torch.zeros_like(pi[..., 0], device=device) |
|
|
|
n = b.shape[0] |
|
if n: |
|
nt += n |
|
ps = pi[b, a, gj, gi] |
|
|
|
|
|
pxy = ps[:, :2].sigmoid() * 2. - 0.5 |
|
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] |
|
pbox = torch.cat((pxy, pwh), 1).to(device) |
|
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) |
|
lbox += (1.0 - iou).mean() |
|
|
|
|
|
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype) |
|
|
|
|
|
|
|
if model.nc > 1: |
|
t = torch.full_like(ps[:, 5:], cn, device=device) |
|
t[range(n), tcls[i]] = cp |
|
lcls += BCEcls(ps[:, 5:], t) |
|
lobj += BCEobj(pi[..., 4], tobj) * balance[i] |
|
|
|
drive_area_seg_predicts = predictions[1].view(-1) |
|
drive_area_seg_targets = targets[1].view(-1) |
|
lseg_da = BCEseg(drive_area_seg_predicts, drive_area_seg_targets) |
|
|
|
lane_line_seg_predicts = predictions[2].view(-1) |
|
lane_line_seg_targets = targets[2].view(-1) |
|
lseg_ll = BCEseg(lane_line_seg_predicts, lane_line_seg_targets) |
|
|
|
metric = SegmentationMetric(2) |
|
nb, _, height, width = targets[1].shape |
|
pad_w, pad_h = shapes[0][1][1] |
|
pad_w = int(pad_w) |
|
pad_h = int(pad_h) |
|
_,lane_line_pred=torch.max(predictions[2], 1) |
|
_,lane_line_gt=torch.max(targets[2], 1) |
|
lane_line_pred = lane_line_pred[:, pad_h:height-pad_h, pad_w:width-pad_w] |
|
lane_line_gt = lane_line_gt[:, pad_h:height-pad_h, pad_w:width-pad_w] |
|
metric.reset() |
|
metric.addBatch(lane_line_pred.cpu(), lane_line_gt.cpu()) |
|
IoU = metric.IntersectionOverUnion() |
|
liou_ll = 1 - IoU |
|
|
|
s = 3 / no |
|
lcls *= cfg.LOSS.CLS_GAIN * s * self.lambdas[0] |
|
lobj *= cfg.LOSS.OBJ_GAIN * s * (1.4 if no == 4 else 1.) * self.lambdas[1] |
|
lbox *= cfg.LOSS.BOX_GAIN * s * self.lambdas[2] |
|
|
|
lseg_da *= cfg.LOSS.DA_SEG_GAIN * self.lambdas[3] |
|
lseg_ll *= cfg.LOSS.LL_SEG_GAIN * self.lambdas[4] |
|
liou_ll *= cfg.LOSS.LL_IOU_GAIN * self.lambdas[5] |
|
|
|
|
|
if cfg.TRAIN.DET_ONLY or cfg.TRAIN.ENC_DET_ONLY or cfg.TRAIN.DET_ONLY: |
|
lseg_da = 0 * lseg_da |
|
lseg_ll = 0 * lseg_ll |
|
liou_ll = 0 * liou_ll |
|
|
|
if cfg.TRAIN.SEG_ONLY or cfg.TRAIN.ENC_SEG_ONLY: |
|
lcls = 0 * lcls |
|
lobj = 0 * lobj |
|
lbox = 0 * lbox |
|
|
|
if cfg.TRAIN.LANE_ONLY: |
|
lcls = 0 * lcls |
|
lobj = 0 * lobj |
|
lbox = 0 * lbox |
|
lseg_da = 0 * lseg_da |
|
|
|
if cfg.TRAIN.DRIVABLE_ONLY: |
|
lcls = 0 * lcls |
|
lobj = 0 * lobj |
|
lbox = 0 * lbox |
|
lseg_ll = 0 * lseg_ll |
|
liou_ll = 0 * liou_ll |
|
|
|
loss = lbox + lobj + lcls + lseg_da + lseg_ll + liou_ll |
|
|
|
|
|
return loss, (lbox.item(), lobj.item(), lcls.item(), lseg_da.item(), lseg_ll.item(), liou_ll.item(), loss.item()) |
|
|
|
|
|
def get_loss(cfg, device): |
|
""" |
|
get MultiHeadLoss |
|
|
|
Inputs: |
|
-cfg: configuration use the loss_name part or |
|
function part(like regression classification) |
|
-device: cpu or gpu device |
|
|
|
Returns: |
|
-loss: (MultiHeadLoss) |
|
|
|
""" |
|
|
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([cfg.LOSS.CLS_POS_WEIGHT])).to(device) |
|
|
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([cfg.LOSS.OBJ_POS_WEIGHT])).to(device) |
|
|
|
BCEseg = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([cfg.LOSS.SEG_POS_WEIGHT])).to(device) |
|
|
|
gamma = cfg.LOSS.FL_GAMMA |
|
if gamma > 0: |
|
BCEcls, BCEobj = FocalLoss(BCEcls, gamma), FocalLoss(BCEobj, gamma) |
|
|
|
loss_list = [BCEcls, BCEobj, BCEseg] |
|
loss = MultiHeadLoss(loss_list, cfg=cfg, lambdas=cfg.LOSS.MULTI_HEAD_LAMBDA) |
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
def smooth_BCE(eps=0.1): |
|
|
|
return 1.0 - 0.5 * eps, 0.5 * eps |
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
|
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): |
|
|
|
|
|
super(FocalLoss, self).__init__() |
|
self.loss_fcn = loss_fcn |
|
self.gamma = gamma |
|
self.alpha = alpha |
|
self.reduction = loss_fcn.reduction |
|
self.loss_fcn.reduction = 'none' |
|
|
|
def forward(self, pred, true): |
|
loss = self.loss_fcn(pred, true) |
|
|
|
|
|
|
|
|
|
pred_prob = torch.sigmoid(pred) |
|
p_t = true * pred_prob + (1 - true) * (1 - pred_prob) |
|
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) |
|
modulating_factor = (1.0 - p_t) ** self.gamma |
|
loss *= alpha_factor * modulating_factor |
|
|
|
if self.reduction == 'mean': |
|
return loss.mean() |
|
elif self.reduction == 'sum': |
|
return loss.sum() |
|
else: |
|
return loss |
|
|