# Copyright (c) OpenMMLab. All rights reserved. import torch from ..builder import BBOX_SAMPLERS from .base_sampler import BaseSampler @BBOX_SAMPLERS.register_module() class RandomSampler(BaseSampler): """Random sampler. Args: num (int): Number of samples pos_fraction (float): Fraction of positive samples neg_pos_ub (int, optional): Upper bound number of negative and positive samples. Defaults to -1. add_gt_as_proposals (bool, optional): Whether to add ground truth boxes as proposals. Defaults to True. """ def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): from mmdet.core.bbox import demodata super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, add_gt_as_proposals) self.rng = demodata.ensure_rng(kwargs.get('rng', None)) def random_choice(self, gallery, num): """Random select some elements from the gallery. If `gallery` is a Tensor, the returned indices will be a Tensor; If `gallery` is a ndarray or list, the returned indices will be a ndarray. Args: gallery (Tensor | ndarray | list): indices pool. num (int): expected sample num. Returns: Tensor or ndarray: sampled indices. """ assert len(gallery) >= num is_tensor = isinstance(gallery, torch.Tensor) if not is_tensor: if torch.cuda.is_available(): device = torch.cuda.current_device() else: device = 'cpu' gallery = torch.tensor(gallery, dtype=torch.long, device=device) # This is a temporary fix. We can revert the following code # when PyTorch fixes the abnormal return of torch.randperm. # See: https://github.com/open-mmlab/mmdetection/pull/5014 perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device) rand_inds = gallery[perm] if not is_tensor: rand_inds = rand_inds.cpu().numpy() return rand_inds def _sample_pos(self, assign_result, num_expected, **kwargs): """Randomly sample some positive samples.""" pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) if pos_inds.numel() != 0: pos_inds = pos_inds.squeeze(1) if pos_inds.numel() <= num_expected: return pos_inds else: return self.random_choice(pos_inds, num_expected) def _sample_neg(self, assign_result, num_expected, **kwargs): """Randomly sample some negative samples.""" neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) if neg_inds.numel() != 0: neg_inds = neg_inds.squeeze(1) if len(neg_inds) <= num_expected: return neg_inds else: return self.random_choice(neg_inds, num_expected)