import torch import torch.nn as nn import torch.nn.functional as F from .matcher import SimOTA from utils.box_ops import get_ious from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized class SigmoidFocalLoss(object): def __init__(self, alpha=0.25, gamma=2.0, reduction='none'): self.alpha = alpha self.gamma = gamma self.reduction = reduction def __call__(self, logits, targets): p = torch.sigmoid(logits) ce_loss = F.binary_cross_entropy_with_logits(input=logits, target=targets, reduction="none") p_t = p * targets + (1.0 - p) * (1.0 - targets) loss = ce_loss * ((1.0 - p_t) ** self.gamma) if self.alpha >= 0: alpha_t = self.alpha * targets + (1.0 - self.alpha) * (1.0 - targets) loss = alpha_t * loss if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.sum() return loss class Criterion(object): def __init__(self, args, img_size, num_classes=3, multi_hot=False): self.num_classes = num_classes self.img_size = img_size self.loss_conf_weight = args.loss_conf_weight self.loss_cls_weight = args.loss_cls_weight self.loss_reg_weight = args.loss_reg_weight self.focal_loss = args.focal_loss self.multi_hot = multi_hot # loss self.obj_lossf = nn.BCEWithLogitsLoss(reduction='none') self.cls_lossf = nn.BCEWithLogitsLoss(reduction='none') # matcher self.matcher = SimOTA( num_classes=num_classes, center_sampling_radius=args.center_sampling_radius, topk_candidate=args.topk_candicate ) def __call__(self, outputs, targets): """ outputs['pred_conf']: List(Tensor) [B, M, 1] outputs['pred_cls']: List(Tensor) [B, M, C] outputs['pred_box']: List(Tensor) [B, M, 4] outputs['strides']: List(Int) [8, 16, 32] output stride targets: (List) [dict{'boxes': [...], 'labels': [...], 'orig_size': ...}, ...] """ bs = outputs['pred_cls'][0].shape[0] device = outputs['pred_cls'][0].device fpn_strides = outputs['strides'] anchors = outputs['anchors'] # preds: [B, M, C] conf_preds = torch.cat(outputs['pred_conf'], dim=1) cls_preds = torch.cat(outputs['pred_cls'], dim=1) box_preds = torch.cat(outputs['pred_box'], dim=1) # label assignment cls_targets = [] box_targets = [] conf_targets = [] fg_masks = [] for batch_idx in range(bs): tgt_labels = targets[batch_idx]["labels"].to(device) tgt_bboxes = targets[batch_idx]["boxes"].to(device) # denormalize tgt_bbox tgt_bboxes *= self.img_size # check target if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.: num_anchors = sum([ab.shape[0] for ab in anchors]) # There is no valid gt cls_target = conf_preds.new_zeros((0, self.num_classes)) box_target = conf_preds.new_zeros((0, 4)) conf_target = conf_preds.new_zeros((num_anchors, 1)) fg_mask = conf_preds.new_zeros(num_anchors).bool() else: ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img, ) = self.matcher( fpn_strides = fpn_strides, anchors = anchors, pred_conf = conf_preds[batch_idx], pred_cls = cls_preds[batch_idx], pred_box = box_preds[batch_idx], tgt_labels = tgt_labels, tgt_bboxes = tgt_bboxes, ) conf_target = fg_mask.unsqueeze(-1) box_target = tgt_bboxes[matched_gt_inds] if self.multi_hot: cls_target = gt_matched_classes.float() else: cls_target = F.one_hot(gt_matched_classes.long(), self.num_classes) cls_target = cls_target * pred_ious_this_matching.unsqueeze(-1) cls_targets.append(cls_target) box_targets.append(box_target) conf_targets.append(conf_target) fg_masks.append(fg_mask) cls_targets = torch.cat(cls_targets, 0) box_targets = torch.cat(box_targets, 0) conf_targets = torch.cat(conf_targets, 0) fg_masks = torch.cat(fg_masks, 0) num_foregrounds = fg_masks.sum() if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_foregrounds) num_foregrounds = (num_foregrounds / get_world_size()).clamp(1.0) # conf loss loss_conf = self.obj_lossf(conf_preds.view(-1, 1), conf_targets.float()) loss_conf = loss_conf.sum() / num_foregrounds # cls loss matched_cls_preds = cls_preds.view(-1, self.num_classes)[fg_masks] loss_cls = self.cls_lossf(matched_cls_preds, cls_targets) loss_cls = loss_cls.sum() / num_foregrounds # box loss matched_box_preds = box_preds.view(-1, 4)[fg_masks] ious = get_ious(matched_box_preds, box_targets, box_mode="xyxy", iou_type='giou') loss_box = (1.0 - ious).sum() / num_foregrounds # total loss losses = self.loss_conf_weight * loss_conf + \ self.loss_cls_weight * loss_cls + \ self.loss_reg_weight * loss_box loss_dict = dict( loss_conf = loss_conf, loss_cls = loss_cls, loss_box = loss_box, losses = losses ) return loss_dict def build_criterion(args, img_size, num_classes, multi_hot=False): criterion = Criterion(args, img_size, num_classes, multi_hot) return criterion