SakuraD's picture
update
cdfecf8
import torch
from mmcv.runner import force_fp32
from mmdet.core import (bbox2distance, bbox_overlaps, distance2bbox,
multi_apply, reduce_mean)
from ..builder import HEADS, build_loss
from .gfl_head import GFLHead
@HEADS.register_module()
class LDHead(GFLHead):
"""Localization distillation Head. (Short description)
It utilizes the learned bbox distributions to transfer the localization
dark knowledge from teacher to student. Original paper: `Localization
Distillation for Object Detection. <https://arxiv.org/abs/2102.12252>`_
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
loss_ld (dict): Config of Localization Distillation Loss (LD),
T is the temperature for distillation.
"""
def __init__(self,
num_classes,
in_channels,
loss_ld=dict(
type='LocalizationDistillationLoss',
loss_weight=0.25,
T=10),
**kwargs):
super(LDHead, self).__init__(num_classes, in_channels, **kwargs)
self.loss_ld = build_loss(loss_ld)
def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
bbox_targets, stride, soft_targets, num_total_samples):
"""Compute loss of a single scale level.
Args:
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
cls_score (Tensor): Cls and quality joint scores for each scale
level has shape (N, num_classes, H, W).
bbox_pred (Tensor): Box distribution logits for each scale
level with shape (N, 4*(n+1), H, W), n is max value of integral
set.
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor wight
shape (N, num_total_anchors, 4).
stride (tuple): Stride in this scale level.
num_total_samples (int): Number of positive samples that is
reduced over all GPUs.
Returns:
dict[tuple, Tensor]: Loss components and weight targets.
"""
assert stride[0] == stride[1], 'h stride is not equal to w stride!'
anchors = anchors.reshape(-1, 4)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, 4 * (self.reg_max + 1))
soft_targets = soft_targets.permute(0, 2, 3,
1).reshape(-1,
4 * (self.reg_max + 1))
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
score = label_weights.new_zeros(labels.shape)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
weight_targets = cls_score.detach().sigmoid()
weight_targets = weight_targets.max(dim=1)[0][pos_inds]
pos_bbox_pred_corners = self.integral(pos_bbox_pred)
pos_decode_bbox_pred = distance2bbox(pos_anchor_centers,
pos_bbox_pred_corners)
pos_decode_bbox_targets = pos_bbox_targets / stride[0]
score[pos_inds] = bbox_overlaps(
pos_decode_bbox_pred.detach(),
pos_decode_bbox_targets,
is_aligned=True)
pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
pos_soft_targets = soft_targets[pos_inds]
soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1)
target_corners = bbox2distance(pos_anchor_centers,
pos_decode_bbox_targets,
self.reg_max).reshape(-1)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
weight=weight_targets,
avg_factor=1.0)
# dfl loss
loss_dfl = self.loss_dfl(
pred_corners,
target_corners,
weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
avg_factor=4.0)
# ld loss
loss_ld = self.loss_ld(
pred_corners,
soft_corners,
weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
avg_factor=4.0)
else:
loss_ld = bbox_pred.sum() * 0
loss_bbox = bbox_pred.sum() * 0
loss_dfl = bbox_pred.sum() * 0
weight_targets = bbox_pred.new_tensor(0)
# cls (qfl) loss
loss_cls = self.loss_cls(
cls_score, (labels, score),
weight=label_weights,
avg_factor=num_total_samples)
return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
def forward_train(self,
x,
out_teacher,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=None,
proposal_cfg=None,
**kwargs):
"""
Args:
x (list[Tensor]): Features from FPN.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
Returns:
tuple[dict, list]: The loss components and proposals of each image.
- losses (dict[str, Tensor]): A dictionary of loss components.
- proposal_list (list[Tensor]): Proposals of each image.
"""
outs = self(x)
soft_target = out_teacher[1]
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, soft_target, img_metas)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, soft_target, img_metas)
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
if proposal_cfg is None:
return losses
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
soft_target,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Cls and quality scores for each scale
level has shape (N, num_classes, H, W).
bbox_preds (list[Tensor]): Box distribution logits for each scale
level with shape (N, 4*(n+1), H, W), n is max value of integral
set.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = reduce_mean(
torch.tensor(num_total_pos, dtype=torch.float,
device=device)).item()
num_total_samples = max(num_total_samples, 1.0)
losses_cls, losses_bbox, losses_dfl, losses_ld, \
avg_factor = multi_apply(
self.loss_single,
anchor_list,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
self.anchor_generator.strides,
soft_target,
num_total_samples=num_total_samples)
avg_factor = sum(avg_factor) + 1e-6
avg_factor = reduce_mean(avg_factor).item()
losses_bbox = [x / avg_factor for x in losses_bbox]
losses_dfl = [x / avg_factor for x in losses_dfl]
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_dfl=losses_dfl,
loss_ld=losses_ld)