|
from mmdet.registry import MODELS |
|
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig |
|
|
|
from mmdet.models.detectors import SingleStageDetector |
|
|
|
from .mask2former_vid import Mask2formerVideo |
|
|
|
|
|
@MODELS.register_module() |
|
class YOSOVideoSam(Mask2formerVideo): |
|
OVERLAPPING = None |
|
|
|
def __init__(self, |
|
backbone: ConfigType, |
|
neck: OptConfigType = None, |
|
panoptic_head: OptConfigType = None, |
|
panoptic_fusion_head: OptConfigType = None, |
|
train_cfg: OptConfigType = None, |
|
test_cfg: OptConfigType = None, |
|
data_preprocessor: OptConfigType = None, |
|
inference_sam: bool = False, |
|
init_cfg: OptMultiConfig = None |
|
): |
|
super(SingleStageDetector, self).__init__( |
|
data_preprocessor=data_preprocessor, init_cfg=init_cfg) |
|
self.backbone = MODELS.build(backbone) |
|
if neck is not None: |
|
self.neck = MODELS.build(neck) |
|
|
|
panoptic_head_ = panoptic_head.deepcopy() |
|
panoptic_head_.update(train_cfg=train_cfg) |
|
panoptic_head_.update(test_cfg=test_cfg) |
|
self.panoptic_head = MODELS.build(panoptic_head_) |
|
|
|
panoptic_fusion_head_ = panoptic_fusion_head.deepcopy() |
|
panoptic_fusion_head_.update(test_cfg=test_cfg) |
|
self.panoptic_fusion_head = MODELS.build(panoptic_fusion_head_) |
|
|
|
self.num_things_classes = self.panoptic_head.num_things_classes |
|
self.num_stuff_classes = self.panoptic_head.num_stuff_classes |
|
self.num_classes = self.panoptic_head.num_classes |
|
|
|
self.train_cfg = train_cfg |
|
self.test_cfg = test_cfg |
|
|
|
self.alpha = 0.4 |
|
self.beta = 0.8 |
|
|
|
self.inference_sam = inference_sam |
|
|
|
def predict_with_point(self, x, batch_data_samples): |
|
feats = self.extract_feat(x) |
|
mask_cls_results, mask_pred_results, iou_results = self.panoptic_head.predict(feats, batch_data_samples) |
|
|
|
if 'gt_instances_collected' not in batch_data_samples[0]: |
|
results_list = self.panoptic_fusion_head.predict( |
|
mask_cls_results, |
|
mask_pred_results, |
|
batch_data_samples, |
|
iou_results=iou_results, |
|
rescale=False |
|
) |
|
|
|
return results_list, None |
|
|
|
return mask_pred_results.cpu().numpy(), mask_cls_results.cpu().numpy() |
|
|