from .base import BasePredictor from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor from .brs_functors import InputOptimizer, ScaleBiasOptimizer from isegm.inference.transforms import ZoomIn from isegm.model.is_hrnet_model import HRNetModel def get_predictor(net, brs_mode, device, gra=None, sam_type=None, prob_thresh=0.49, with_flip=True, zoom_in_params=dict(), predictor_params=None, brs_opt_func_params=None, lbfgs_params=None): lbfgs_params_ = { 'm': 20, 'factr': 0, 'pgtol': 1e-8, 'maxfun': 20, } predictor_params_ = { 'optimize_after_n_clicks': 1 } if zoom_in_params is not None: zoom_in = ZoomIn(**zoom_in_params) else: zoom_in = None if lbfgs_params is not None: lbfgs_params_.update(lbfgs_params) lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] if brs_opt_func_params is None: brs_opt_func_params = dict() if isinstance(net, (list, tuple)): assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode." if brs_mode == 'NoBRS': if predictor_params is not None: predictor_params_.update(predictor_params) predictor = BasePredictor(net, device, gra=gra, sam_type=sam_type, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) elif brs_mode.startswith('f-BRS'): predictor_params_.update({ 'net_clicks_limit': 8, }) if predictor_params is not None: predictor_params_.update(predictor_params) insertion_mode = { 'f-BRS-A': 'after_c4', 'f-BRS-B': 'after_aspp', 'f-BRS-C': 'after_deeplab' }[brs_mode] opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, with_flip=with_flip, optimizer_params=lbfgs_params_, **brs_opt_func_params) if isinstance(net, HRNetModel): FeaturePredictor = HRNetFeatureBRSPredictor insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] else: FeaturePredictor = FeatureBRSPredictor predictor = FeaturePredictor(net, device, opt_functor=opt_functor, with_flip=with_flip, insertion_mode=insertion_mode, zoom_in=zoom_in, **predictor_params_) elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': use_dmaps = brs_mode == 'DistMap-BRS' predictor_params_.update({ 'net_clicks_limit': 5, }) if predictor_params is not None: predictor_params_.update(predictor_params) opt_functor = InputOptimizer(prob_thresh=prob_thresh, with_flip=with_flip, optimizer_params=lbfgs_params_, **brs_opt_func_params) predictor = InputBRSPredictor(net, device, optimize_target='dmaps' if use_dmaps else 'rgb', opt_functor=opt_functor, with_flip=with_flip, zoom_in=zoom_in, **predictor_params_) else: raise NotImplementedError return predictor