Spaces:
Paused
Paused
File size: 7,333 Bytes
6e14436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# Copyright (c) Facebook, Inc. and its affiliates.
import inspect
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from detectron2.config import configurable
from detectron2.layers import ShapeSpec, nonzero_tuple
from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
from detectron2.utils.events import get_event_storage
from detectron2.utils.registry import Registry
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
from detectron2.modeling.roi_heads.box_head import build_box_head
from .detic_fast_rcnn import DeticFastRCNNOutputLayers
from ..debug import debug_second_stage
from torch.cuda.amp import autocast
@ROI_HEADS_REGISTRY.register()
class CustomRes5ROIHeads(Res5ROIHeads):
@configurable
def __init__(self, **kwargs):
cfg = kwargs.pop('cfg')
super().__init__(**kwargs)
stage_channel_factor = 2 ** 3
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor
self.with_image_labels = cfg.WITH_IMAGE_LABELS
self.ws_num_props = cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS
self.add_image_box = cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX
self.add_feature_to_prop = cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP
self.image_box_size = cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE
self.box_predictor = DeticFastRCNNOutputLayers(
cfg, ShapeSpec(channels=out_channels, height=1, width=1)
)
self.save_debug = cfg.SAVE_DEBUG
self.save_debug_path = cfg.SAVE_DEBUG_PATH
if self.save_debug:
self.debug_show_name = cfg.DEBUG_SHOW_NAME
self.vis_thresh = cfg.VIS_THRESH
self.pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(
torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1)
self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(
torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1)
self.bgr = (cfg.INPUT.FORMAT == 'BGR')
@classmethod
def from_config(cls, cfg, input_shape):
ret = super().from_config(cfg, input_shape)
ret['cfg'] = cfg
return ret
def forward(self, images, features, proposals, targets=None,
ann_type='box', classifier_info=(None,None,None)):
'''
enable debug and image labels
classifier_info is shared across the batch
'''
if not self.save_debug:
del images
if self.training:
if ann_type in ['box']:
proposals = self.label_and_sample_proposals(
proposals, targets)
else:
proposals = self.get_top_proposals(proposals)
proposal_boxes = [x.proposal_boxes for x in proposals]
box_features = self._shared_roi_transform(
[features[f] for f in self.in_features], proposal_boxes
)
predictions = self.box_predictor(
box_features.mean(dim=[2, 3]),
classifier_info=classifier_info)
if self.add_feature_to_prop:
feats_per_image = box_features.mean(dim=[2, 3]).split(
[len(p) for p in proposals], dim=0)
for feat, p in zip(feats_per_image, proposals):
p.feat = feat
if self.training:
del features
if (ann_type != 'box'):
image_labels = [x._pos_category_ids for x in targets]
losses = self.box_predictor.image_label_losses(
predictions, proposals, image_labels,
classifier_info=classifier_info,
ann_type=ann_type)
else:
losses = self.box_predictor.losses(
(predictions[0], predictions[1]), proposals)
if self.with_image_labels:
assert 'image_loss' not in losses
losses['image_loss'] = predictions[0].new_zeros([1])[0]
if self.save_debug:
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean
if ann_type != 'box':
image_labels = [x._pos_category_ids for x in targets]
else:
image_labels = [[] for x in targets]
debug_second_stage(
[denormalizer(x.clone()) for x in images],
targets, proposals=proposals,
save_debug=self.save_debug,
debug_show_name=self.debug_show_name,
vis_thresh=self.vis_thresh,
image_labels=image_labels,
save_debug_path=self.save_debug_path,
bgr=self.bgr)
return proposals, losses
else:
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
if self.save_debug:
denormalizer = lambda x: x * self.pixel_std + self.pixel_mean
debug_second_stage(
[denormalizer(x.clone()) for x in images],
pred_instances, proposals=proposals,
save_debug=self.save_debug,
debug_show_name=self.debug_show_name,
vis_thresh=self.vis_thresh,
save_debug_path=self.save_debug_path,
bgr=self.bgr)
return pred_instances, {}
def get_top_proposals(self, proposals):
for i in range(len(proposals)):
proposals[i].proposal_boxes.clip(proposals[i].image_size)
proposals = [p[:self.ws_num_props] for p in proposals]
for i, p in enumerate(proposals):
p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach()
if self.add_image_box:
proposals[i] = self._add_image_box(p)
return proposals
def _add_image_box(self, p, use_score=False):
image_box = Instances(p.image_size)
n = 1
h, w = p.image_size
if self.image_box_size < 1.0:
f = self.image_box_size
image_box.proposal_boxes = Boxes(
p.proposal_boxes.tensor.new_tensor(
[w * (1. - f) / 2.,
h * (1. - f) / 2.,
w * (1. - (1. - f) / 2.),
h * (1. - (1. - f) / 2.)]
).view(n, 4))
else:
image_box.proposal_boxes = Boxes(
p.proposal_boxes.tensor.new_tensor(
[0, 0, w, h]).view(n, 4))
if use_score:
image_box.scores = \
p.objectness_logits.new_ones(n)
image_box.pred_classes = \
p.objectness_logits.new_zeros(n, dtype=torch.long)
image_box.objectness_logits = \
p.objectness_logits.new_ones(n)
else:
image_box.objectness_logits = \
p.objectness_logits.new_ones(n)
return Instances.cat([p, image_box]) |