diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..dd98599c484c82ceb0dd29293f9d7c5ba3c912a5 --- /dev/null +++ b/app.py @@ -0,0 +1,10 @@ +from web_app import GraCoWebApplication + + +def main(): + app = GraCoWebApplication() + app.launch() + + +if __name__ == '__main__': + main() diff --git a/isegm/__init__.py b/isegm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/isegm/inference/clicker.py b/isegm/inference/clicker.py new file mode 100644 index 0000000000000000000000000000000000000000..6b739854f41e7dfc2b7fc57bc6777dbba649a2ba --- /dev/null +++ b/isegm/inference/clicker.py @@ -0,0 +1,118 @@ +import numpy as np +from copy import deepcopy +import cv2 + + +class Clicker(object): + def __init__(self, gt_mask=None, init_clicks=None, ignore_label=-1, click_indx_offset=0): + self.click_indx_offset = click_indx_offset + if gt_mask is not None: + self.gt_mask = gt_mask == 1 + self.not_ignore_mask = gt_mask != ignore_label + else: + self.gt_mask = None + + self.reset_clicks() + + if init_clicks is not None: + for click in init_clicks: + self.add_click(click) + + def make_next_click(self, pred_mask): + assert self.gt_mask is not None + click = self._get_next_click(pred_mask) + self.add_click(click) + + def get_clicks(self, clicks_limit=None): + return self.clicks_list[:clicks_limit] + + def _get_next_click(self, pred_mask, padding=True): + fn_mask = np.logical_and(np.logical_and(self.gt_mask, np.logical_not(pred_mask)), self.not_ignore_mask) + fp_mask = np.logical_and(np.logical_and(np.logical_not(self.gt_mask), pred_mask), self.not_ignore_mask) + + if padding: + fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), 'constant') + fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), 'constant') + + fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0) + fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0) + + if padding: + fn_mask_dt = fn_mask_dt[1:-1, 1:-1] + fp_mask_dt = fp_mask_dt[1:-1, 1:-1] + + fn_mask_dt = fn_mask_dt * self.not_clicked_map + fp_mask_dt = fp_mask_dt * self.not_clicked_map + + fn_max_dist = np.max(fn_mask_dt) + fp_max_dist = np.max(fp_mask_dt) + + is_positive = fn_max_dist > fp_max_dist + if is_positive: + coords_y, coords_x = np.where(fn_mask_dt == fn_max_dist) # coords is [y, x] + else: + coords_y, coords_x = np.where(fp_mask_dt == fp_max_dist) # coords is [y, x] + + return Click(is_positive=is_positive, coords=(coords_y[0], coords_x[0])) + + def add_click(self, click): + coords = click.coords + + click.indx = self.click_indx_offset + self.num_pos_clicks + self.num_neg_clicks + if click.is_positive: + self.num_pos_clicks += 1 + else: + self.num_neg_clicks += 1 + + self.clicks_list.append(click) + if self.gt_mask is not None: + self.not_clicked_map[coords[0], coords[1]] = False + + def _remove_last_click(self): + click = self.clicks_list.pop() + coords = click.coords + + if click.is_positive: + self.num_pos_clicks -= 1 + else: + self.num_neg_clicks -= 1 + + if self.gt_mask is not None: + self.not_clicked_map[coords[0], coords[1]] = True + + def reset_clicks(self): + if self.gt_mask is not None: + self.not_clicked_map = np.ones_like(self.gt_mask, dtype=bool) + + self.num_pos_clicks = 0 + self.num_neg_clicks = 0 + + self.clicks_list = [] + + def get_state(self): + return deepcopy(self.clicks_list) + + def set_state(self, state): + self.reset_clicks() + for click in state: + self.add_click(click) + + def __len__(self): + return len(self.clicks_list) + + +class Click: + def __init__(self, is_positive, coords, indx=None): + self.is_positive = is_positive + self.coords = coords + self.indx = indx + + @property + def coords_and_indx(self): + return (*self.coords, self.indx) + + def copy(self, **kwargs): + self_copy = deepcopy(self) + for k, v in kwargs.items(): + setattr(self_copy, k, v) + return self_copy diff --git a/isegm/inference/evaluation.py b/isegm/inference/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe9585f180a2d5a4c58b1d8e03a677f1d839c34 --- /dev/null +++ b/isegm/inference/evaluation.py @@ -0,0 +1,197 @@ +from time import time + +import numpy as np +import torch +import cv2 +from isegm.inference import utils +from isegm.inference.clicker import Click, Clicker + +try: + get_ipython() + from tqdm import tqdm_notebook as tqdm +except NameError: + from tqdm import tqdm + + +def evaluate_dataset(dataset, predictor, sam_type=None, oracle=False, gra_oracle=False, **kwargs): + all_ious = [] + start_time = time() + all_gras = {} + + for index in tqdm(range(len(dataset)), leave=False): + sample = dataset.get_sample(index) + + for object_id in sample.objects_ids: + if gra_oracle: + sample_ious, gra_idx = evaluate_sample_oracle(sample.image, sample.gt_mask(object_id), predictor, + sample_id=index, sam_type=sam_type, oracle=oracle, **kwargs) + all_gras[gra_idx] = all_gras.get(gra_idx, 0) + 1 + else: + _, sample_ious, _ = evaluate_sample(sample.image, sample.gt_mask(object_id), predictor, + sample_id=index, sam_type=sam_type, oracle=oracle, **kwargs) + all_ious.append(sample_ious) + end_time = time() + elapsed_time = end_time - start_time + if len(all_gras) > 0: + print(all_gras) + + return all_ious, elapsed_time + + +def evaluate_sample(image, gt_mask, predictor, max_iou_thr, + pred_thr=0.49, min_clicks=1, max_clicks=20, + sample_id=None, sam_type=False, oracle=False, callback=None): + clicker = Clicker(gt_mask=gt_mask) + pred_mask = np.zeros_like(gt_mask) + ious_list = [] + with torch.no_grad(): + predictor.set_input_image(image) + if sam_type == 'SAM': + for click_indx in range(max_clicks): + clicker.make_next_click(pred_mask) + point_coords, point_labels = get_sam_input(clicker) + if oracle: + ious = [] + pred_masks = [] + pred_probs, _, _ = predictor.predict(point_coords, point_labels, multimask_output=True, return_logits=True) + for idx in range(pred_probs.shape[0]): + pred_masks.append(pred_probs[idx] > predictor.model.mask_threshold) + ious.append(utils.get_iou(gt_mask, pred_masks[-1])) + tgt_idx = np.argmax(np.array(ious)) + iou = ious[tgt_idx] + pred_mask = pred_masks[tgt_idx] + else: + pred_probs, _, _ = predictor.predict(point_coords, point_labels, multimask_output=False, return_logits=True) + pred_probs = pred_probs[0] + pred_mask = pred_probs > predictor.model.mask_threshold + iou = utils.get_iou(gt_mask, pred_mask) + + if callback is not None: + callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) + + ious_list.append(iou) + if iou >= max_iou_thr and click_indx + 1 >= min_clicks: + break + return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs + else: + for click_indx in range(max_clicks): + clicker.make_next_click(pred_mask) + pred_probs = predictor.get_prediction(clicker) + pred_mask = pred_probs > pred_thr + iou = utils.get_iou(gt_mask, pred_mask) + + if callback is not None: + callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) + + ious_list.append(iou) + if iou >= max_iou_thr and click_indx + 1 >= min_clicks: + break + return clicker.clicks_list, np.array(ious_list, dtype=np.float32), pred_probs + + +def evaluate_sample_oracle(image, gt_mask, predictor, max_iou_thr, + pred_thr=0.49, min_clicks=1, max_clicks=20, + sample_id=None, sam_type=False, oracle=False, callback=None): + clicker = Clicker(gt_mask=gt_mask) + ious_lists = [] + click_indxs = [] + with torch.no_grad(): + predictor.set_input_image(image) + min_num = 100 + for gra in range(1, 11): + cur_gra = round(gra * 0.1, 1) + ious_list = [] + clicker.reset_clicks() + pred_mask = np.zeros_like(gt_mask) + if sam_type == 'SAM_GraCo': + for click_indx in range(max_clicks): + clicker.make_next_click(pred_mask) + point_coords, point_labels = get_sam_input(clicker) + if oracle: + ious = [] + pred_masks = [] + pred_probs, _, _ = predictor.predict(point_coords, point_labels, gra=cur_gra, multimask_output=True, return_logits=True) + for idx in range(pred_probs.shape[0]): + pred_masks.append(pred_probs[idx] > predictor.model.mask_threshold) + ious.append(utils.get_iou(gt_mask, pred_masks[-1])) + tgt_idx = np.argmax(np.array(ious)) + iou = ious[tgt_idx] + pred_mask = pred_masks[tgt_idx] + else: + pred_probs, _, _ = predictor.predict(point_coords, point_labels, gra=cur_gra, multimask_output=False, return_logits=True) + pred_probs = pred_probs[0] + pred_mask = pred_probs > predictor.model.mask_threshold + iou = utils.get_iou(gt_mask, pred_mask) + + if callback is not None: + callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) + + ious_list.append(iou) + if iou >= max_iou_thr and click_indx + 1 >= min_clicks: + min_num = min(min_num, click_indx + 1) + break + if min_num <= max_clicks and click_indx + 1 > min_num: + break + else: + predictor.prev_prediction = torch.zeros_like(predictor.original_image[:, :1, :, :]) + for click_indx in range(max_clicks): + clicker.make_next_click(pred_mask) + pred_probs = predictor.get_prediction(clicker, gra=cur_gra) + + pred_mask = pred_probs > pred_thr + iou = utils.get_iou(gt_mask, pred_mask) + + if callback is not None: + callback(image, gt_mask, pred_probs, sample_id, click_indx, clicker.clicks_list) + + ious_list.append(iou) + if iou >= max_iou_thr and click_indx + 1 >= min_clicks: + min_num = min(min_num, click_indx + 1) + break + if min_num <= max_clicks and click_indx + 1 > min_num: + break + ious_lists.append(np.array(ious_list, dtype=np.float32)) + click_indxs.append(click_indx) + click_indxs = np.array(click_indxs) + tgt_idxs = np.squeeze(np.argwhere(click_indxs == np.min(click_indxs)), axis=1) + selected_ious = [ious_lists[i] for i in tgt_idxs] + max_index = np.argmax([ious[0] for ious in selected_ious]) + ious = selected_ious[max_index] + tgt_idx = tgt_idxs[max_index] + + return ious, tgt_idx + + +def get_sam_input(clicker, reverse=True): + clicks_list = clicker.get_clicks() + points_nd = get_points_nd([clicks_list]) + point_length = len(points_nd[0]) // 2 + point_coords = [] + point_labels = [] + for i, point in enumerate(points_nd[0]): + if point[0] == -1: + continue + if i < point_length: + point_labels.append(1) + else: + point_labels.append(0) + if reverse: + point_coords.append([point[1], point[0]]) # for SAM + return np.array(point_coords), np.array(point_labels) + +def get_points_nd(clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + + neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return total_clicks diff --git a/isegm/inference/predictors/__init__.py b/isegm/inference/predictors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38ed0f307a17c8577bee9ba392e2fadaa44a104f --- /dev/null +++ b/isegm/inference/predictors/__init__.py @@ -0,0 +1,99 @@ +from .base import BasePredictor +from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor +from .brs_functors import InputOptimizer, ScaleBiasOptimizer +from isegm.inference.transforms import ZoomIn +from isegm.model.is_hrnet_model import HRNetModel + + +def get_predictor(net, brs_mode, device, + gra=None, sam_type=None, + prob_thresh=0.49, + with_flip=True, + zoom_in_params=dict(), + predictor_params=None, + brs_opt_func_params=None, + lbfgs_params=None): + lbfgs_params_ = { + 'm': 20, + 'factr': 0, + 'pgtol': 1e-8, + 'maxfun': 20, + } + + predictor_params_ = { + 'optimize_after_n_clicks': 1 + } + + if zoom_in_params is not None: + zoom_in = ZoomIn(**zoom_in_params) + else: + zoom_in = None + + if lbfgs_params is not None: + lbfgs_params_.update(lbfgs_params) + lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] + + if brs_opt_func_params is None: + brs_opt_func_params = dict() + + if isinstance(net, (list, tuple)): + assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode." + + if brs_mode == 'NoBRS': + if predictor_params is not None: + predictor_params_.update(predictor_params) + predictor = BasePredictor(net, device, gra=gra, sam_type=sam_type, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) + elif brs_mode.startswith('f-BRS'): + predictor_params_.update({ + 'net_clicks_limit': 8, + }) + if predictor_params is not None: + predictor_params_.update(predictor_params) + + insertion_mode = { + 'f-BRS-A': 'after_c4', + 'f-BRS-B': 'after_aspp', + 'f-BRS-C': 'after_deeplab' + }[brs_mode] + + opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params) + + if isinstance(net, HRNetModel): + FeaturePredictor = HRNetFeatureBRSPredictor + insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] + else: + FeaturePredictor = FeatureBRSPredictor + + predictor = FeaturePredictor(net, device, + opt_functor=opt_functor, + with_flip=with_flip, + insertion_mode=insertion_mode, + zoom_in=zoom_in, + **predictor_params_) + elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': + use_dmaps = brs_mode == 'DistMap-BRS' + + predictor_params_.update({ + 'net_clicks_limit': 5, + }) + if predictor_params is not None: + predictor_params_.update(predictor_params) + + opt_functor = InputOptimizer(prob_thresh=prob_thresh, + with_flip=with_flip, + optimizer_params=lbfgs_params_, + **brs_opt_func_params) + + predictor = InputBRSPredictor(net, device, + optimize_target='dmaps' if use_dmaps else 'rgb', + opt_functor=opt_functor, + with_flip=with_flip, + zoom_in=zoom_in, + **predictor_params_) + else: + raise NotImplementedError + + return predictor diff --git a/isegm/inference/predictors/base.py b/isegm/inference/predictors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..526c1ae3cb21e50b1b3898235cfc2559276955d1 --- /dev/null +++ b/isegm/inference/predictors/base.py @@ -0,0 +1,191 @@ +import torch +import torch.nn.functional as F +import numpy as np +from torchvision import transforms +from isegm.inference.transforms import AddHorizontalFlip, SigmoidForPred, LimitLongestSide + +class BasePredictor(object): + def __init__(self, model, device, gra=None, sam_type=None, + net_clicks_limit=None, + with_flip=False, + zoom_in=None, + max_size=None, + **kwargs): + self.with_flip = with_flip + self.net_clicks_limit = net_clicks_limit + self.original_image = None + self.device = device + self.gra=gra if gra is not None and gra > 0 else None + self.sam_type = sam_type + self.zoom_in = zoom_in + self.prev_prediction = None + self.model_indx = 0 + self.click_models = None + self.net_state_dict = None + + if isinstance(model, tuple): + self.net, self.click_models = model + else: + self.net = model + + self.to_tensor = transforms.ToTensor() + + self.transforms = [zoom_in] if zoom_in is not None else [] + if max_size is not None: + self.transforms.append(LimitLongestSide(max_size=max_size)) + self.transforms.append(SigmoidForPred()) + if with_flip: + self.transforms.append(AddHorizontalFlip()) + + def set_input_image(self, image): + if not isinstance(image, torch.Tensor): + image_nd = self.to_tensor(image) + else: + image_nd = image + for transform in self.transforms: + transform.reset() + self.original_image = image_nd.to(self.device) + if len(self.original_image.shape) == 3: + self.original_image = self.original_image.unsqueeze(0) + self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :]) + + def get_prediction(self, clicker, prev_mask=None, gra=None): + clicks_list = clicker.get_clicks() + + if self.click_models is not None: + model_indx = min(clicker.click_indx_offset + len(clicks_list), len(self.click_models)) - 1 + if model_indx != self.model_indx: + self.model_indx = model_indx + self.net = self.click_models[model_indx] + + input_image = self.original_image + if prev_mask is None: + prev_mask = self.prev_prediction + if (hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask) or self.sam_type is not None: + input_image = torch.cat((input_image, prev_mask), dim=1) + image_nd, clicks_lists, is_image_changed = self.apply_transforms( + input_image, [clicks_list] + ) + pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed, gra=gra) + + prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, + size=image_nd.size()[2:]) + + for t in reversed(self.transforms): + prediction = t.inv_transform(prediction) + + if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): + return self.get_prediction(clicker) + + self.prev_prediction = prediction + return prediction.cpu().numpy()[0, 0] + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed, gra=None): + points_nd = self.get_points_nd(clicks_lists) + if gra is None: + gra = self.gra + if self.sam_type == 'SAM': + batched_input = self.get_sam_batched_input(image_nd, points_nd) + batched_output = self.net(batched_input, multimask_output=False, return_logits=True) + return torch.cat([batch['masks'] for batch in batched_output], dim=0) + + if gra is not None: + return self.net(image_nd, points_nd, torch.Tensor([gra]).to(self.device))['instances'] + else: + return self.net(image_nd, points_nd)['instances'] + + + def _batch_infer(self, batch_image_tensor, batch_clickers, prev_mask=None): + if prev_mask is None: + prev_mask = self.prev_prediction + + if hasattr(self.net, 'with_prev_mask') and self.net.with_prev_mask: + input_image = torch.cat((batch_image_tensor, prev_mask), dim=1) + + clicks_lists = [clicker.get_clicks() for clicker in batch_clickers] + image_nd, clicks_lists, is_image_changed = self.apply_transforms( + input_image, clicks_lists + ) + points_nd = self.get_points_nd(clicks_lists) + pred_logits = self.net(image_nd, points_nd)['instances'] + prediction = F.interpolate(pred_logits, mode='bilinear', align_corners=True, + size=image_nd.size()[2:]) + + for t in reversed(self.transforms): + prediction = t.inv_transform(prediction) + + self.prev_prediction = prediction + return prediction.cpu().numpy()[:, 0] + + def _get_transform_states(self): + return [x.get_state() for x in self.transforms] + + def _set_transform_states(self, states): + assert len(states) == len(self.transforms) + for state, transform in zip(states, self.transforms): + transform.set_state(state) + + def apply_transforms(self, image_nd, clicks_lists): + is_image_changed = False + for t in self.transforms: + image_nd, clicks_lists = t.transform(image_nd, clicks_lists) + is_image_changed |= t.image_changed + + return image_nd, clicks_lists, is_image_changed + + def get_points_nd(self, clicks_lists): + total_clicks = [] + num_pos_clicks = [sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists] + num_neg_clicks = [len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)] + num_max_points = max(num_pos_clicks + num_neg_clicks) + if self.net_clicks_limit is not None: + num_max_points = min(self.net_clicks_limit, num_max_points) + num_max_points = max(1, num_max_points) + + for clicks_list in clicks_lists: + clicks_list = clicks_list[:self.net_clicks_limit] + pos_clicks = [click.coords_and_indx for click in clicks_list if click.is_positive] + pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [(-1, -1, -1)] + + neg_clicks = [click.coords_and_indx for click in clicks_list if not click.is_positive] + neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [(-1, -1, -1)] + total_clicks.append(pos_clicks + neg_clicks) + + return torch.tensor(total_clicks, device=self.device) + + def get_sam_batched_input(self, image_nd, points_nd): + batched_output = [] + for i in range(image_nd.shape[0]): + image = image_nd[i] + point_length = points_nd[i].shape[0] // 2 + point_coords = [] + point_labels = [] + for i, point in enumerate(points_nd[i]): + point_np = point.cpu().numpy() + if point_np[0] == -1: + continue + if i < point_length: + point_labels.append(1) + else: + point_labels.append(0) + + point_coords.append([point_np[1], point_np[0]]) + res = { + 'image': image[:3, :, :], + 'point_coords': torch.as_tensor(np.array(point_coords), dtype=torch.float, device=self.device)[None, :], + 'point_labels': torch.as_tensor(np.array(point_labels), dtype=torch.float, device=self.device)[None, :], + 'original_size': image.cpu().numpy().shape[1:], + 'mask_inputs': image[3, :, :][None, None, :] + } + batched_output.append(res) + return batched_output + + def get_states(self): + return { + 'transform_states': self._get_transform_states(), + 'prev_prediction': self.prev_prediction.clone() + } + + def set_states(self, states): + self._set_transform_states(states['transform_states']) + self.prev_prediction = states['prev_prediction'] diff --git a/isegm/inference/predictors/brs.py b/isegm/inference/predictors/brs.py new file mode 100644 index 0000000000000000000000000000000000000000..910e3fd52471c39fe56668575765adcc00393d3d --- /dev/null +++ b/isegm/inference/predictors/brs.py @@ -0,0 +1,307 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.optimize import fmin_l_bfgs_b + +from .base import BasePredictor + + +class BRSBasePredictor(BasePredictor): + def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs): + super().__init__(model, device, **kwargs) + self.optimize_after_n_clicks = optimize_after_n_clicks + self.opt_functor = opt_functor + + self.opt_data = None + self.input_data = None + + def set_input_image(self, image): + super().set_input_image(image) + self.opt_data = None + self.input_data = None + + def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1): + pos_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + neg_clicks_map = np.zeros((len(clicks_lists), 1) + image_shape, dtype=np.float32) + + for list_indx, clicks_list in enumerate(clicks_lists): + for click in clicks_list: + y, x = click.coords + y, x = int(round(y)), int(round(x)) + y1, x1 = y - radius, x - radius + y2, x2 = y + radius + 1, x + radius + 1 + + if click.is_positive: + pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True + else: + neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True + + with torch.no_grad(): + pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device) + neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device) + + return pos_clicks_map, neg_clicks_map + + def get_states(self): + return {'transform_states': self._get_transform_states(), 'opt_data': self.opt_data} + + def set_states(self, states): + self._set_transform_states(states['transform_states']) + self.opt_data = states['opt_data'] + + +class FeatureBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, insertion_mode='after_deeplab', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.insertion_mode = insertion_mode + self._c1_features = None + + if self.insertion_mode == 'after_deeplab': + self.num_channels = model.feature_extractor.ch + elif self.insertion_mode == 'after_c4': + self.num_channels = model.feature_extractor.aspp_in_channels + elif self.insertion_mode == 'after_aspp': + self.num_channels = model.feature_extractor.ch + 32 + else: + raise NotImplementedError + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + + num_clicks = len(clicks_lists[0]) + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + + if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) + + if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + self.input_data = self._get_head_input(image_nd, points_nd) + + def get_prediction_logits(scale, bias): + scale = scale.view(bs, -1, 1, 1) + bias = bias.view(bs, -1, 1, 1) + if self.with_flip: + scale = scale.repeat(2, 1, 1, 1) + bias = bias.repeat(2, 1, 1, 1) + + scaled_backbone_features = self.input_data * scale + scaled_backbone_features = scaled_backbone_features + bias + if self.insertion_mode == 'after_c4': + x = self.net.feature_extractor.aspp(scaled_backbone_features) + x = F.interpolate(x, mode='bilinear', size=self._c1_features.size()[2:], + align_corners=True) + x = torch.cat((x, self._c1_features), dim=1) + scaled_backbone_features = self.net.feature_extractor.head(x) + elif self.insertion_mode == 'after_aspp': + scaled_backbone_features = self.net.feature_extractor.head(scaled_backbone_features) + + pred_logits = self.net.head(scaled_backbone_features) + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', + align_corners=True) + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, + **self.opt_functor.optimizer_params) + self.opt_data = opt_result[0] + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) + opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits + + def _get_head_input(self, image_nd, points): + with torch.no_grad(): + image_nd, prev_mask = self.net.prepare_input(image_nd) + coord_features = self.net.get_coord_features(image_nd, prev_mask, points) + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) + additional_features = None + elif hasattr(self.net, 'maps_transform'): + x = image_nd + additional_features = self.net.maps_transform(coord_features) + + if self.insertion_mode == 'after_c4' or self.insertion_mode == 'after_aspp': + c1, _, c3, c4 = self.net.feature_extractor.backbone(x, additional_features) + c1 = self.net.feature_extractor.skip_project(c1) + + if self.insertion_mode == 'after_aspp': + x = self.net.feature_extractor.aspp(c4) + x = F.interpolate(x, size=c1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x, c1), dim=1) + backbone_features = x + else: + backbone_features = c4 + self._c1_features = c1 + else: + backbone_features = self.net.feature_extractor(x, additional_features)[0] + + return backbone_features + + +class HRNetFeatureBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, insertion_mode='A', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.insertion_mode = insertion_mode + self._c1_features = None + + if self.insertion_mode == 'A': + self.num_channels = sum(k * model.feature_extractor.width for k in [1, 2, 4, 8]) + elif self.insertion_mode == 'C': + self.num_channels = 2 * model.feature_extractor.ocr_width + else: + raise NotImplementedError + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + num_clicks = len(clicks_lists[0]) + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + + if self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs: + self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) + + if num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None: + self.input_data = self._get_head_input(image_nd, points_nd) + + def get_prediction_logits(scale, bias): + scale = scale.view(bs, -1, 1, 1) + bias = bias.view(bs, -1, 1, 1) + if self.with_flip: + scale = scale.repeat(2, 1, 1, 1) + bias = bias.repeat(2, 1, 1, 1) + + scaled_backbone_features = self.input_data * scale + scaled_backbone_features = scaled_backbone_features + bias + if self.insertion_mode == 'A': + if self.net.feature_extractor.ocr_width > 0: + out_aux = self.net.feature_extractor.aux_head(scaled_backbone_features) + feats = self.net.feature_extractor.conv3x3_ocr(scaled_backbone_features) + + context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) + feats = self.net.feature_extractor.ocr_distri_head(feats, context) + else: + feats = scaled_backbone_features + pred_logits = self.net.feature_extractor.cls_head(feats) + elif self.insertion_mode == 'C': + pred_logits = self.net.feature_extractor.cls_head(scaled_backbone_features) + else: + raise NotImplementedError + + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', + align_corners=True) + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data, + **self.opt_functor.optimizer_params) + self.opt_data = opt_result[0] + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) + opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits + + def _get_head_input(self, image_nd, points): + with torch.no_grad(): + image_nd, prev_mask = self.net.prepare_input(image_nd) + coord_features = self.net.get_coord_features(image_nd, prev_mask, points) + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) + additional_features = None + elif hasattr(self.net, 'maps_transform'): + x = image_nd + additional_features = self.net.maps_transform(coord_features) + + feats = self.net.feature_extractor.compute_hrnet_feats(x, additional_features) + + if self.insertion_mode == 'A': + backbone_features = feats + elif self.insertion_mode == 'C': + out_aux = self.net.feature_extractor.aux_head(feats) + feats = self.net.feature_extractor.conv3x3_ocr(feats) + + context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) + backbone_features = self.net.feature_extractor.ocr_distri_head(feats, context) + else: + raise NotImplementedError + + return backbone_features + + +class InputBRSPredictor(BRSBasePredictor): + def __init__(self, model, device, opt_functor, optimize_target='rgb', **kwargs): + super().__init__(model, device, opt_functor=opt_functor, **kwargs) + self.optimize_target = optimize_target + + def _get_prediction(self, image_nd, clicks_lists, is_image_changed): + points_nd = self.get_points_nd(clicks_lists) + pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) + num_clicks = len(clicks_lists[0]) + + if self.opt_data is None or is_image_changed: + if self.optimize_target == 'dmaps': + opt_channels = self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch + else: + opt_channels = 3 + bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] + self.opt_data = torch.zeros((bs, opt_channels, image_nd.shape[2], image_nd.shape[3]), + device=self.device, dtype=torch.float32) + + def get_prediction_logits(opt_bias): + input_image, prev_mask = self.net.prepare_input(image_nd) + dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd) + + if self.optimize_target == 'rgb': + input_image = input_image + opt_bias + elif self.optimize_target == 'dmaps': + if self.net.with_prev_mask: + dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias + else: + dmaps = dmaps + opt_bias + + if self.net.rgb_conv is not None: + x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1)) + if self.optimize_target == 'all': + x = x + opt_bias + coord_features = None + elif hasattr(self.net, 'maps_transform'): + x = input_image + coord_features = self.net.maps_transform(dmaps) + + pred_logits = self.net.backbone_forward(x, coord_features=coord_features)['instances'] + pred_logits = F.interpolate(pred_logits, size=image_nd.size()[2:], mode='bilinear', align_corners=True) + + return pred_logits + + self.opt_functor.init_click(get_prediction_logits, pos_mask, neg_mask, self.device, + shape=self.opt_data.shape) + if num_clicks > self.optimize_after_n_clicks: + opt_result = fmin_l_bfgs_b(func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(), + **self.opt_functor.optimizer_params) + + self.opt_data = torch.from_numpy(opt_result[0]).view(self.opt_data.shape).to(self.device) + + with torch.no_grad(): + if self.opt_functor.best_prediction is not None: + opt_pred_logits = self.opt_functor.best_prediction + else: + opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data) + opt_pred_logits = get_prediction_logits(*opt_vars) + + return opt_pred_logits diff --git a/isegm/inference/predictors/brs_functors.py b/isegm/inference/predictors/brs_functors.py new file mode 100644 index 0000000000000000000000000000000000000000..f919e13c6c9edb6a9eb7c4afc37933db7b303c12 --- /dev/null +++ b/isegm/inference/predictors/brs_functors.py @@ -0,0 +1,109 @@ +import torch +import numpy as np + +from isegm.model.metrics import _compute_iou +from .brs_losses import BRSMaskLoss + + +class BaseOptimizer: + def __init__(self, optimizer_params, + prob_thresh=0.49, + reg_weight=1e-3, + min_iou_diff=0.01, + brs_loss=BRSMaskLoss(), + with_flip=False, + flip_average=False, + **kwargs): + self.brs_loss = brs_loss + self.optimizer_params = optimizer_params + self.prob_thresh = prob_thresh + self.reg_weight = reg_weight + self.min_iou_diff = min_iou_diff + self.with_flip = with_flip + self.flip_average = flip_average + + self.best_prediction = None + self._get_prediction_logits = None + self._opt_shape = None + self._best_loss = None + self._click_masks = None + self._last_mask = None + self.device = None + + def init_click(self, get_prediction_logits, pos_mask, neg_mask, device, shape=None): + self.best_prediction = None + self._get_prediction_logits = get_prediction_logits + self._click_masks = (pos_mask, neg_mask) + self._opt_shape = shape + self._last_mask = None + self.device = device + + def __call__(self, x): + opt_params = torch.from_numpy(x).float().to(self.device) + opt_params.requires_grad_(True) + + with torch.enable_grad(): + opt_vars, reg_loss = self.unpack_opt_params(opt_params) + result_before_sigmoid = self._get_prediction_logits(*opt_vars) + result = torch.sigmoid(result_before_sigmoid) + + pos_mask, neg_mask = self._click_masks + if self.with_flip and self.flip_average: + result, result_flipped = torch.chunk(result, 2, dim=0) + result = 0.5 * (result + torch.flip(result_flipped, dims=[3])) + pos_mask, neg_mask = pos_mask[:result.shape[0]], neg_mask[:result.shape[0]] + + loss, f_max_pos, f_max_neg = self.brs_loss(result, pos_mask, neg_mask) + loss = loss + reg_loss + + f_val = loss.detach().cpu().numpy() + if self.best_prediction is None or f_val < self._best_loss: + self.best_prediction = result_before_sigmoid.detach() + self._best_loss = f_val + + if f_max_pos < (1 - self.prob_thresh) and f_max_neg < self.prob_thresh: + return [f_val, np.zeros_like(x)] + + current_mask = result > self.prob_thresh + if self._last_mask is not None and self.min_iou_diff > 0: + diff_iou = _compute_iou(current_mask, self._last_mask) + if len(diff_iou) > 0 and diff_iou.mean() > 1 - self.min_iou_diff: + return [f_val, np.zeros_like(x)] + self._last_mask = current_mask + + loss.backward() + f_grad = opt_params.grad.cpu().numpy().ravel().astype(np.float) + + return [f_val, f_grad] + + def unpack_opt_params(self, opt_params): + raise NotImplementedError + + +class InputOptimizer(BaseOptimizer): + def unpack_opt_params(self, opt_params): + opt_params = opt_params.view(self._opt_shape) + if self.with_flip: + opt_params_flipped = torch.flip(opt_params, dims=[3]) + opt_params = torch.cat([opt_params, opt_params_flipped], dim=0) + reg_loss = self.reg_weight * torch.sum(opt_params**2) + + return (opt_params,), reg_loss + + +class ScaleBiasOptimizer(BaseOptimizer): + def __init__(self, *args, scale_act=None, reg_bias_weight=10.0, **kwargs): + super().__init__(*args, **kwargs) + self.scale_act = scale_act + self.reg_bias_weight = reg_bias_weight + + def unpack_opt_params(self, opt_params): + scale, bias = torch.chunk(opt_params, 2, dim=0) + reg_loss = self.reg_weight * (torch.sum(scale**2) + self.reg_bias_weight * torch.sum(bias**2)) + + if self.scale_act == 'tanh': + scale = torch.tanh(scale) + elif self.scale_act == 'sin': + scale = torch.sin(scale) + + return (1 + scale, bias), reg_loss diff --git a/isegm/inference/predictors/brs_losses.py b/isegm/inference/predictors/brs_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..ea98824356cf5a4d09094fb92c13ee8d8dfe15dc --- /dev/null +++ b/isegm/inference/predictors/brs_losses.py @@ -0,0 +1,58 @@ +import torch + +from isegm.model.losses import SigmoidBinaryCrossEntropyLoss + + +class BRSMaskLoss(torch.nn.Module): + def __init__(self, eps=1e-5): + super().__init__() + self._eps = eps + + def forward(self, result, pos_mask, neg_mask): + pos_diff = (1 - result) * pos_mask + pos_target = torch.sum(pos_diff ** 2) + pos_target = pos_target / (torch.sum(pos_mask) + self._eps) + + neg_diff = result * neg_mask + neg_target = torch.sum(neg_diff ** 2) + neg_target = neg_target / (torch.sum(neg_mask) + self._eps) + + loss = pos_target + neg_target + + with torch.no_grad(): + f_max_pos = torch.max(torch.abs(pos_diff)).item() + f_max_neg = torch.max(torch.abs(neg_diff)).item() + + return loss, f_max_pos, f_max_neg + + +class OracleMaskLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.gt_mask = None + self.loss = SigmoidBinaryCrossEntropyLoss(from_sigmoid=True) + self.predictor = None + self.history = [] + + def set_gt_mask(self, gt_mask): + self.gt_mask = gt_mask + self.history = [] + + def forward(self, result, pos_mask, neg_mask): + gt_mask = self.gt_mask.to(result.device) + if self.predictor.object_roi is not None: + r1, r2, c1, c2 = self.predictor.object_roi[:4] + gt_mask = gt_mask[:, :, r1:r2 + 1, c1:c2 + 1] + gt_mask = torch.nn.functional.interpolate(gt_mask, result.size()[2:], mode='bilinear', align_corners=True) + + if result.shape[0] == 2: + gt_mask_flipped = torch.flip(gt_mask, dims=[3]) + gt_mask = torch.cat([gt_mask, gt_mask_flipped], dim=0) + + loss = self.loss(result, gt_mask) + self.history.append(loss.detach().cpu().numpy()[0]) + + if len(self.history) > 5 and abs(self.history[-5] - self.history[-1]) < 1e-5: + return 0, 0, 0 + + return loss, 1.0, 1.0 diff --git a/isegm/inference/transforms/__init__.py b/isegm/inference/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd54e38a2f84b3fef481672a7ceab070eb01b82 --- /dev/null +++ b/isegm/inference/transforms/__init__.py @@ -0,0 +1,5 @@ +from .base import SigmoidForPred +from .flip import AddHorizontalFlip +from .zoom_in import ZoomIn +from .limit_longest_side import LimitLongestSide +from .crops import Crops diff --git a/isegm/inference/transforms/base.py b/isegm/inference/transforms/base.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5a2deb3c44f5aed7530fd1e299fff1273737b8 --- /dev/null +++ b/isegm/inference/transforms/base.py @@ -0,0 +1,38 @@ +import torch + + +class BaseTransform(object): + def __init__(self): + self.image_changed = False + + def transform(self, image_nd, clicks_lists): + raise NotImplementedError + + def inv_transform(self, prob_map): + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def get_state(self): + raise NotImplementedError + + def set_state(self, state): + raise NotImplementedError + + +class SigmoidForPred(BaseTransform): + def transform(self, image_nd, clicks_lists): + return image_nd, clicks_lists + + def inv_transform(self, prob_map): + return torch.sigmoid(prob_map) + + def reset(self): + pass + + def get_state(self): + return None + + def set_state(self, state): + pass diff --git a/isegm/inference/transforms/crops.py b/isegm/inference/transforms/crops.py new file mode 100644 index 0000000000000000000000000000000000000000..428d977295e2ff973b5aa1bf0a0c955df1235614 --- /dev/null +++ b/isegm/inference/transforms/crops.py @@ -0,0 +1,97 @@ +import math + +import torch +import numpy as np +from typing import List + +from isegm.inference.clicker import Click +from .base import BaseTransform + + +class Crops(BaseTransform): + def __init__(self, crop_size=(320, 480), min_overlap=0.2): + super().__init__() + self.crop_height, self.crop_width = crop_size + self.min_overlap = min_overlap + + self.x_offsets = None + self.y_offsets = None + self._counts = None + + def transform(self, image_nd, clicks_lists: List[List[Click]]): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_height, image_width = image_nd.shape[2:4] + self._counts = None + + if image_height < self.crop_height or image_width < self.crop_width: + return image_nd, clicks_lists + + self.x_offsets = get_offsets(image_width, self.crop_width, self.min_overlap) + self.y_offsets = get_offsets(image_height, self.crop_height, self.min_overlap) + self._counts = np.zeros((image_height, image_width)) + + image_crops = [] + for dy in self.y_offsets: + for dx in self.x_offsets: + self._counts[dy:dy + self.crop_height, dx:dx + self.crop_width] += 1 + image_crop = image_nd[:, :, dy:dy + self.crop_height, dx:dx + self.crop_width] + image_crops.append(image_crop) + image_crops = torch.cat(image_crops, dim=0) + self._counts = torch.tensor(self._counts, device=image_nd.device, dtype=torch.float32) + + clicks_list = clicks_lists[0] + clicks_lists = [] + for dy in self.y_offsets: + for dx in self.x_offsets: + crop_clicks = [x.copy(coords=(x.coords[0] - dy, x.coords[1] - dx)) for x in clicks_list] + clicks_lists.append(crop_clicks) + + return image_crops, clicks_lists + + def inv_transform(self, prob_map): + if self._counts is None: + return prob_map + + new_prob_map = torch.zeros((1, 1, *self._counts.shape), + dtype=prob_map.dtype, device=prob_map.device) + + crop_indx = 0 + for dy in self.y_offsets: + for dx in self.x_offsets: + new_prob_map[0, 0, dy:dy + self.crop_height, dx:dx + self.crop_width] += prob_map[crop_indx, 0] + crop_indx += 1 + new_prob_map = torch.div(new_prob_map, self._counts) + + return new_prob_map + + def get_state(self): + return self.x_offsets, self.y_offsets, self._counts + + def set_state(self, state): + self.x_offsets, self.y_offsets, self._counts = state + + def reset(self): + self.x_offsets = None + self.y_offsets = None + self._counts = None + + +def get_offsets(length, crop_size, min_overlap_ratio=0.2): + if length == crop_size: + return [0] + + N = (length / crop_size - min_overlap_ratio) / (1 - min_overlap_ratio) + N = math.ceil(N) + + overlap_ratio = (N - length / crop_size) / (N - 1) + overlap_width = int(crop_size * overlap_ratio) + + offsets = [0] + for i in range(1, N): + new_offset = offsets[-1] + crop_size - overlap_width + if new_offset + crop_size > length: + new_offset = length - crop_size + + offsets.append(new_offset) + + return offsets diff --git a/isegm/inference/transforms/flip.py b/isegm/inference/transforms/flip.py new file mode 100644 index 0000000000000000000000000000000000000000..373640ebe153ae8a53c136c72f13e0c14aa788ec --- /dev/null +++ b/isegm/inference/transforms/flip.py @@ -0,0 +1,37 @@ +import torch + +from typing import List +from isegm.inference.clicker import Click +from .base import BaseTransform + + +class AddHorizontalFlip(BaseTransform): + def transform(self, image_nd, clicks_lists: List[List[Click]]): + assert len(image_nd.shape) == 4 + image_nd = torch.cat([image_nd, torch.flip(image_nd, dims=[3])], dim=0) + + image_width = image_nd.shape[3] + clicks_lists_flipped = [] + for clicks_list in clicks_lists: + clicks_list_flipped = [click.copy(coords=(click.coords[0], image_width - click.coords[1] - 1)) + for click in clicks_list] + clicks_lists_flipped.append(clicks_list_flipped) + clicks_lists = clicks_lists + clicks_lists_flipped + + return image_nd, clicks_lists + + def inv_transform(self, prob_map): + assert len(prob_map.shape) == 4 and prob_map.shape[0] % 2 == 0 + num_maps = prob_map.shape[0] // 2 + prob_map, prob_map_flipped = prob_map[:num_maps], prob_map[num_maps:] + + return 0.5 * (prob_map + torch.flip(prob_map_flipped, dims=[3])) + + def get_state(self): + return None + + def set_state(self, state): + pass + + def reset(self): + pass diff --git a/isegm/inference/transforms/limit_longest_side.py b/isegm/inference/transforms/limit_longest_side.py new file mode 100644 index 0000000000000000000000000000000000000000..50c5a53d2670df52285621dc0d33e86df520d77c --- /dev/null +++ b/isegm/inference/transforms/limit_longest_side.py @@ -0,0 +1,22 @@ +from .zoom_in import ZoomIn, get_roi_image_nd + + +class LimitLongestSide(ZoomIn): + def __init__(self, max_size=800): + super().__init__(target_size=max_size, skip_clicks=0) + + def transform(self, image_nd, clicks_lists): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + image_max_size = max(image_nd.shape[2:4]) + self.image_changed = False + + if image_max_size <= self.target_size: + return image_nd, clicks_lists + self._input_image = image_nd + + self._object_roi = (0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1) + self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) + self.image_changed = True + + tclicks_lists = [self._transform_clicks(clicks_lists[0])] + return self._roi_image, tclicks_lists diff --git a/isegm/inference/transforms/zoom_in.py b/isegm/inference/transforms/zoom_in.py new file mode 100644 index 0000000000000000000000000000000000000000..618d7ec3b52def6cc03d9aa221ac479965686302 --- /dev/null +++ b/isegm/inference/transforms/zoom_in.py @@ -0,0 +1,190 @@ +import torch +import numpy as np +from typing import List +from isegm.inference.clicker import Click +from isegm.utils.misc import get_bbox_iou, get_bbox_from_mask, expand_bbox, clamp_bbox +from .base import BaseTransform + + +class ZoomIn(BaseTransform): + def __init__(self, + target_size=400, + skip_clicks=1, + expansion_ratio=1.4, + min_crop_size=200, + recompute_thresh_iou=0.5, + prob_thresh=0.50): + super().__init__() + self.target_size = target_size + self.min_crop_size = min_crop_size + self.skip_clicks = skip_clicks + self.expansion_ratio = expansion_ratio + self.recompute_thresh_iou = recompute_thresh_iou + self.prob_thresh = prob_thresh + + self._input_image_shape = None + self._prev_probs = None + self._object_roi = None + self._roi_image = None + + def transform(self, image_nd, clicks_lists: List[List[Click]]): + transformed_image = [] + transformed_clicks_lists = [] + for bindx in range(len(clicks_lists)): + new_image_nd, new_clicks_lists = self._transform(image_nd[bindx].unsqueeze(0), [clicks_lists[bindx]]) + transformed_image.append(new_image_nd) + transformed_clicks_lists.append(new_clicks_lists[0]) + return torch.cat(transformed_image, dim=0), transformed_clicks_lists + + def _transform(self, image_nd, clicks_lists: List[List[Click]]): + assert image_nd.shape[0] == 1 and len(clicks_lists) == 1 + self.image_changed = False + + clicks_list = clicks_lists[0] + if len(clicks_list) <= self.skip_clicks: + return image_nd, clicks_lists + + self._input_image_shape = image_nd.shape + + current_object_roi = None + if self._prev_probs is not None: + current_pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] + if current_pred_mask.sum() > 0: + current_object_roi = get_object_roi(current_pred_mask, clicks_list, + self.expansion_ratio, self.min_crop_size) + + if current_object_roi is None: + if self.skip_clicks >= 0: + return image_nd, clicks_lists + else: + current_object_roi = 0, image_nd.shape[2] - 1, 0, image_nd.shape[3] - 1 + + update_object_roi = False + if self._object_roi is None: + update_object_roi = True + elif not check_object_roi(self._object_roi, clicks_list): + update_object_roi = True + elif get_bbox_iou(current_object_roi, self._object_roi) < self.recompute_thresh_iou: + update_object_roi = True + + if update_object_roi: + self._object_roi = current_object_roi + self.image_changed = True + self._roi_image = get_roi_image_nd(image_nd, self._object_roi, self.target_size) + tclicks_lists = [self._transform_clicks(clicks_list)] + return self._roi_image.to(image_nd.device), tclicks_lists + + def inv_transform(self, prob_map): + new_prob_maps = [] + for bindx in range(prob_map.shape[0]): + new_prob_map = self._inv_transform(prob_map[bindx].unsqueeze(0)) + new_prob_maps.append(new_prob_map) + return torch.cat(new_prob_maps, dim=0) + + def _inv_transform(self, prob_map): + if self._object_roi is None: + self._prev_probs = prob_map.cpu().numpy() + return prob_map + + assert prob_map.shape[0] == 1 + rmin, rmax, cmin, cmax = self._object_roi + prob_map = torch.nn.functional.interpolate(prob_map, size=(rmax - rmin + 1, cmax - cmin + 1), + mode='bilinear', align_corners=True) + + if self._prev_probs is not None: + new_prob_map = torch.zeros(*self._prev_probs.shape, device=prob_map.device, dtype=prob_map.dtype) + new_prob_map[:, :, rmin:rmax + 1, cmin:cmax + 1] = prob_map + else: + new_prob_map = prob_map + + self._prev_probs = new_prob_map.cpu().numpy() + + return new_prob_map + + def check_possible_recalculation(self): + if self._prev_probs is None or self._object_roi is not None or self.skip_clicks > 0: + return False + + pred_mask = (self._prev_probs > self.prob_thresh)[0, 0] + if pred_mask.sum() > 0: + possible_object_roi = get_object_roi(pred_mask, [], + self.expansion_ratio, self.min_crop_size) + image_roi = (0, self._input_image_shape[2] - 1, 0, self._input_image_shape[3] - 1) + if get_bbox_iou(possible_object_roi, image_roi) < 0.50: + return True + return False + + def get_state(self): + roi_image = self._roi_image.cpu() if self._roi_image is not None else None + return self._input_image_shape, self._object_roi, self._prev_probs, roi_image, self.image_changed + + def set_state(self, state): + self._input_image_shape, self._object_roi, self._prev_probs, self._roi_image, self.image_changed = state + + def reset(self): + self._input_image_shape = None + self._object_roi = None + self._prev_probs = None + self._roi_image = None + self.image_changed = False + + def _transform_clicks(self, clicks_list): + if self._object_roi is None: + return clicks_list + + rmin, rmax, cmin, cmax = self._object_roi + crop_height, crop_width = self._roi_image.shape[2:] + + transformed_clicks = [] + for click in clicks_list: + new_r = crop_height * (click.coords[0] - rmin) / (rmax - rmin + 1) + new_c = crop_width * (click.coords[1] - cmin) / (cmax - cmin + 1) + transformed_clicks.append(click.copy(coords=(new_r, new_c))) + return transformed_clicks + + +def get_object_roi(pred_mask, clicks_list, expansion_ratio, min_crop_size): + pred_mask = pred_mask.copy() + + for click in clicks_list: + if click.is_positive: + pred_mask[int(click.coords[0]), int(click.coords[1])] = 1 + + bbox = get_bbox_from_mask(pred_mask) + bbox = expand_bbox(bbox, expansion_ratio, min_crop_size) + h, w = pred_mask.shape[0], pred_mask.shape[1] + bbox = clamp_bbox(bbox, 0, h - 1, 0, w - 1) + + return bbox + + +def get_roi_image_nd(image_nd, object_roi, target_size): + rmin, rmax, cmin, cmax = object_roi + + height = rmax - rmin + 1 + width = cmax - cmin + 1 + + if isinstance(target_size, tuple): + new_height, new_width = target_size + else: + scale = target_size / max(height, width) + new_height = int(round(height * scale)) + new_width = int(round(width * scale)) + + with torch.no_grad(): + roi_image_nd = image_nd[:, :, rmin:rmax + 1, cmin:cmax + 1] + roi_image_nd = torch.nn.functional.interpolate(roi_image_nd, size=(new_height, new_width), + mode='bilinear', align_corners=True) + + return roi_image_nd + + +def check_object_roi(object_roi, clicks_list): + for click in clicks_list: + if click.is_positive: + if click.coords[0] < object_roi[0] or click.coords[0] >= object_roi[1]: + return False + if click.coords[1] < object_roi[2] or click.coords[1] >= object_roi[3]: + return False + + return True diff --git a/isegm/inference/utils.py b/isegm/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4246910cb8e7e7eed247f1b41392e9be489f8f73 --- /dev/null +++ b/isegm/inference/utils.py @@ -0,0 +1,149 @@ +from datetime import timedelta +from pathlib import Path +import torch +import numpy as np +from isegm.utils.serialization import load_model + + +def get_time_metrics(all_ious, elapsed_time): + n_images = len(all_ious) + n_clicks = sum(map(len, all_ious)) + + mean_spc = elapsed_time / n_clicks + mean_spi = elapsed_time / n_images + + return mean_spc, mean_spi + + +def load_is_model(checkpoint, device, eval_ritm, lora_checkpoint=None, **kwargs): + if isinstance(checkpoint, (str, Path)): + state_dict = torch.load(checkpoint, map_location='cpu') + else: + state_dict = checkpoint + if isinstance(state_dict, list): + model = load_single_is_model(state_dict[0], device, eval_ritm, **kwargs) + models = [load_single_is_model(x, device, eval_ritm, **kwargs) for x in state_dict] + + return model, models + else: + return load_single_is_model(state_dict, device, eval_ritm, lora_checkpoint=lora_checkpoint, **kwargs) + + +def load_single_is_model(state_dict, device, eval_ritm, lora_checkpoint=None, **kwargs): + if 'config' in state_dict.keys(): + _config = state_dict['config'] + if lora_checkpoint is not None: + lora_state_dict = torch.load(lora_checkpoint, map_location='cpu') + _config = lora_state_dict['config'] + + model = load_model(_config, eval_ritm, **kwargs) + print("Load predictor weights...") + if 'state_dict' in state_dict.keys(): + msg = model.load_state_dict(state_dict['state_dict'], strict=False) + else: + try: + msg = model.load_state_dict(state_dict, strict=False) + except: + current_state_dict = model.state_dict() + + new_state_dict = {} + for k, v in state_dict.items(): + if k in current_state_dict and v.shape == current_state_dict[k].shape: + new_state_dict[k] = v + + msg = model.load_state_dict(new_state_dict, strict=False) + print(msg) + + if lora_checkpoint is not None: + print("Load predictor LoRA weights...") + msg = model.load_state_dict(lora_state_dict['state_dict'], strict=False) + print(msg[1]) + + for param in model.parameters(): + param.requires_grad = False + model.to(device) + model.eval() + + return model + + +def get_iou(gt_mask, pred_mask, ignore_label=-1): + ignore_gt_mask_inv = gt_mask != ignore_label + obj_gt_mask = gt_mask == 1 + + intersection = np.logical_and(np.logical_and(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + union = np.logical_and(np.logical_or(pred_mask, obj_gt_mask), ignore_gt_mask_inv).sum() + + return intersection / union + + +def compute_noc_metric(all_ious, iou_thrs, max_clicks=20): + def _get_noc(iou_arr, iou_thr): + vals = iou_arr >= iou_thr + return np.argmax(vals) + 1 if np.any(vals) else max_clicks + + noc_list = [] + noc_list_std = [] + over_max_list = [] + for iou_thr in iou_thrs: + scores_arr = np.array([_get_noc(iou_arr, iou_thr) + for iou_arr in all_ious], dtype=np.int_) + + score = scores_arr.mean() + score_std = scores_arr.std() + over_max = (scores_arr == max_clicks).sum() + + noc_list.append(score) + noc_list_std.append(score_std) + over_max_list.append(over_max) + + return noc_list, noc_list_std, over_max_list + + +def find_checkpoint(weights_folder, checkpoint_name): + weights_folder = Path(weights_folder) + if ':' in checkpoint_name: + model_name, checkpoint_name = checkpoint_name.split(':') + models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()] + assert len(models_candidates) == 1 + model_folder = models_candidates[0] + else: + model_folder = weights_folder + + if checkpoint_name.endswith('.pth'): + if Path(checkpoint_name).exists(): + checkpoint_path = checkpoint_name + else: + checkpoint_path = weights_folder / checkpoint_name + else: + model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth')) + assert len(model_checkpoints) == 1 + checkpoint_path = model_checkpoints[0] + + return str(checkpoint_path) + + +def get_results_table(noc_list, over_max_list, brs_type, dataset_name, mean_spc, elapsed_time, iou_first, + n_clicks=20, model_name=None): + table_header = (f'|{"BRS Type":^13}|{"Dataset":^11}|' + f'{"NoC@80%":^9}|{"NoC@85%":^9}|{"NoC@90%":^9}|' + f'{"IoU@1":^9}|' + f'{">="+str(n_clicks)+"@85%":^9}|{">="+str(n_clicks)+"@90%":^9}|' + f'{"SPC,s":^7}|{"Time":^9}|') + row_width = len(table_header) + + header = f'Eval results for model: {model_name}\n' if model_name is not None else '' + header += '-' * row_width + '\n' + header += table_header + '\n' + '-' * row_width + + eval_time = str(timedelta(seconds=int(elapsed_time))) + table_row = f'|{brs_type:^13}|{dataset_name:^11}|' + table_row += f'{noc_list[0]:^9.2f}|' + table_row += f'{noc_list[1]:^9.2f}|' if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f'{noc_list[2]:^9.2f}|' if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f'{iou_first:^9.2f}|' + table_row += f'{over_max_list[1]:^9}|' if len(noc_list) > 1 else f'{"?":^9}|' + table_row += f'{over_max_list[2]:^9}|' if len(noc_list) > 2 else f'{"?":^9}|' + table_row += f'{mean_spc:^7.3f}|{eval_time:^9}|' + + return header, table_row diff --git a/isegm/model/__init__.py b/isegm/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/isegm/model/build_sam.py b/isegm/model/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..8427dab16981c6be920098cfac0f55156bc2884d --- /dev/null +++ b/isegm/model/build_sam.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .sam_modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, SAMISWrapper + + +def build_sam_vit_h(checkpoint=None, enable_lora=False, enable_gra=False, mode='eval', image_size=1024): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + enable_lora=enable_lora, + enable_gra=enable_gra, + mode=mode, + image_size=image_size, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None, enable_lora=False, enable_gra=False, mode='eval', image_size=1024): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + enable_lora=enable_lora, + enable_gra=enable_gra, + mode=mode, + image_size=image_size, + ) + + +def build_sam_vit_b(checkpoint=None, enable_lora=False, enable_gra=False, mode='eval', image_size=1024): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + enable_lora=enable_lora, + enable_gra=enable_gra, + mode=mode, + image_size=image_size, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + enable_lora=False, + enable_gra=False, + mode='eval', + image_size=1024, +): + prompt_embed_dim = 256 + image_size = image_size + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + + if mode == 'train': + sam = SAMISWrapper( + encoder_embed_dim=encoder_embed_dim, + encoder_depth=encoder_depth, + encoder_num_heads=encoder_num_heads, + encoder_global_attn_indexes=encoder_global_attn_indexes, + enable_lora=enable_lora, + enable_gra=enable_gra, + with_prev_mask=True, + image_size=image_size, + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + else: + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + pretrained_dict = torch.load(f) + + model_dict = sam.state_dict() + new_pretrained_dict = {} + for k, v in pretrained_dict.items(): + if k in model_dict and v.shape == model_dict[k].shape: + new_pretrained_dict[k] = v + msg = sam.load_state_dict(new_pretrained_dict, strict=False) + print("SAM load Info: ", msg) + return sam diff --git a/isegm/model/initializer.py b/isegm/model/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..470c7df4659bc1e80ceec80a170b3b2e0302fb84 --- /dev/null +++ b/isegm/model/initializer.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import numpy as np + + +class Initializer(object): + def __init__(self, local_init=True, gamma=None): + self.local_init = local_init + self.gamma = gamma + + def __call__(self, m): + if getattr(m, '__initialized', False): + return + + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: + if m.weight is not None: + self._init_gamma(m.weight.data) + if m.bias is not None: + self._init_beta(m.bias.data) + else: + if getattr(m, 'weight', None) is not None: + self._init_weight(m.weight.data) + if getattr(m, 'bias', None) is not None: + self._init_bias(m.bias.data) + + if self.local_init: + object.__setattr__(m, '__initialized', True) + + def _init_weight(self, data): + nn.init.uniform_(data, -0.07, 0.07) + + def _init_bias(self, data): + nn.init.constant_(data, 0) + + def _init_gamma(self, data): + if self.gamma is None: + nn.init.constant_(data, 1.0) + else: + nn.init.normal_(data, 1.0, self.gamma) + + def _init_beta(self, data): + nn.init.constant_(data, 0) + + +class Bilinear(Initializer): + def __init__(self, scale, groups, in_channels, **kwargs): + super().__init__(**kwargs) + self.scale = scale + self.groups = groups + self.in_channels = in_channels + + def _init_weight(self, data): + """Reset the weight and bias.""" + bilinear_kernel = self.get_bilinear_kernel(self.scale) + weight = torch.zeros_like(data) + for i in range(self.in_channels): + if self.groups == 1: + j = i + else: + j = 0 + weight[i, j] = bilinear_kernel + data[:] = weight + + @staticmethod + def get_bilinear_kernel(scale): + """Generate a bilinear upsampling kernel.""" + kernel_size = 2 * scale - scale % 2 + scale = (kernel_size + 1) // 2 + center = scale - 0.5 * (1 + kernel_size % 2) + + og = np.ogrid[:kernel_size, :kernel_size] + kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) + + return torch.tensor(kernel, dtype=torch.float32) + + +class XavierGluon(Initializer): + def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): + super().__init__(**kwargs) + + self.rnd_type = rnd_type + self.factor_type = factor_type + self.magnitude = float(magnitude) + + def _init_weight(self, arr): + fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) + + if self.factor_type == 'avg': + factor = (fan_in + fan_out) / 2.0 + elif self.factor_type == 'in': + factor = fan_in + elif self.factor_type == 'out': + factor = fan_out + else: + raise ValueError('Incorrect factor type') + scale = np.sqrt(self.magnitude / factor) + + if self.rnd_type == 'uniform': + nn.init.uniform_(arr, -scale, scale) + elif self.rnd_type == 'gaussian': + nn.init.normal_(arr, 0, scale) + else: + raise ValueError('Unknown random type') diff --git a/isegm/model/is_deeplab_model.py b/isegm/model/is_deeplab_model.py new file mode 100644 index 0000000000000000000000000000000000000000..45fa55364d14d129889fce083a791be1e48a35c9 --- /dev/null +++ b/isegm/model/is_deeplab_model.py @@ -0,0 +1,25 @@ +import torch.nn as nn + +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.deeplab_v3 import DeepLabV3Plus +from .modeling.basic_blocks import SepConvHead +from isegm.model.modifiers import LRMult + + +class DeeplabModel(ISModel): + @serialize + def __init__(self, backbone='resnet50', deeplab_ch=256, aspp_dropout=0.5, + backbone_norm_layer=None, backbone_lr_mult=0.1, norm_layer=nn.BatchNorm2d, **kwargs): + super().__init__(norm_layer=norm_layer, **kwargs) + + self.feature_extractor = DeepLabV3Plus(backbone=backbone, ch=deeplab_ch, project_dropout=aspp_dropout, + norm_layer=norm_layer, backbone_norm_layer=backbone_norm_layer) + self.feature_extractor.backbone.apply(LRMult(backbone_lr_mult)) + self.head = SepConvHead(1, in_channels=deeplab_ch, mid_channels=deeplab_ch // 2, + num_layers=2, norm_layer=norm_layer) + + def backbone_forward(self, image, coord_features=None): + backbone_features = self.feature_extractor(image, coord_features) + + return {'instances': self.head(backbone_features[0])} diff --git a/isegm/model/is_hrformer_model.py b/isegm/model/is_hrformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..947c3984e388ce9488af5ad882615162b3ddd8ab --- /dev/null +++ b/isegm/model/is_hrformer_model.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn + +from collections import OrderedDict + +from isegm.utils.serialization import serialize +from .is_model import ISModel +from isegm.model.modifiers import LRMult +from .modeling.hrformer import HRT_B_OCR_V3 + +class HRFormerModel(ISModel): + @serialize + def __init__( + self, + num_classes=1, + in_ch=6, + backbone_lr_mult=0.1, + **kwargs + ): + + super().__init__(**kwargs) + + self.feature_extractor = HRT_B_OCR_V3(num_classes, in_ch) + self.feature_extractor.apply(LRMult(backbone_lr_mult)) + + def backbone_forward(self, image, coord_features=None): + backbone_features = self.feature_extractor(image) + return {'instances': backbone_features[0], 'instances_aux': backbone_features[1]} + + def init_weight(self, pretrained=None): + if pretrained is not None: + state_dict = torch.load(pretrained)['model'] + state_dict_rename = OrderedDict() + for k, v in state_dict.items(): + state_dict_rename['backbone.' + k] = v + + ori_proj_weight = state_dict_rename['backbone.conv1.weight'] + state_dict_rename['backbone.conv1.weight'] = torch.cat([ori_proj_weight, ori_proj_weight], dim=1) + + self.feature_extractor.load_state_dict(state_dict_rename, False) + print('Successfully loaded pretrained model.') diff --git a/isegm/model/is_hrnet_model.py b/isegm/model/is_hrnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..df2081eeb0bb11dd8b043816171118206cb6a830 --- /dev/null +++ b/isegm/model/is_hrnet_model.py @@ -0,0 +1,26 @@ +import torch.nn as nn + +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.hrnet_ocr import HighResolutionNet +from isegm.model.modifiers import LRMult + + +class HRNetModel(ISModel): + @serialize + def __init__(self, width=48, ocr_width=256, small=False, backbone_lr_mult=0.1, + norm_layer=nn.BatchNorm2d, **kwargs): + super().__init__(**kwargs) + + self.feature_extractor = HighResolutionNet(width=width, ocr_width=ocr_width, small=small, + num_classes=1, norm_layer=norm_layer) + self.feature_extractor.apply(LRMult(backbone_lr_mult)) + if ocr_width > 0: + self.feature_extractor.ocr_distri_head.apply(LRMult(1.0)) + self.feature_extractor.ocr_gather_head.apply(LRMult(1.0)) + self.feature_extractor.conv3x3_ocr.apply(LRMult(1.0)) + + def backbone_forward(self, image, coord_features=None): + net_outputs = self.feature_extractor(image, coord_features) + + return {'instances': net_outputs[0], 'instances_aux': net_outputs[1]} diff --git a/isegm/model/is_model.py b/isegm/model/is_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b4af09b719f79ffbcd821c2294ec6e30c9f282ab --- /dev/null +++ b/isegm/model/is_model.py @@ -0,0 +1,114 @@ +import torch +import torch.nn as nn +import numpy as np + +from isegm.model.ops import DistMaps, BatchImageNormalize, ScaleLayer + + +class ISModel(nn.Module): + def __init__(self, with_aux_output=False, norm_radius=5, use_disks=False, cpu_dist_maps=False, + use_rgb_conv=False, use_leaky_relu=False, # the two arguments only used for RITM + with_prev_mask=False, norm_mean_std=([.485, .456, .406], [.229, .224, .225])): + super().__init__() + + self.with_aux_output = with_aux_output + self.with_prev_mask = with_prev_mask + self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1]) + + self.coord_feature_ch = 2 + if self.with_prev_mask: + self.coord_feature_ch += 1 + + if use_rgb_conv: + # Only RITM models need to transform the coordinate features, though they don't use + # exact 'rgb_conv'. We keep 'use_rgb_conv' only for compatible issues. + # The simpleclick models use a patch embedding layer instead + mt_layers = [ + nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1), + nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), + nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1), + ScaleLayer(init_value=0.05, lr_mult=1) + ] + self.maps_transform = nn.Sequential(*mt_layers) + else: + self.maps_transform=nn.Identity() + + self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0, + cpu_mode=cpu_dist_maps, use_disks=use_disks) + + def forward(self, image, points, text=None, gra=None): + image, prev_mask = self.prepare_input(image) + coord_features = self.get_coord_features(image, prev_mask, points) + coord_features = self.maps_transform(coord_features) + + if gra is not None and text is not None: + outputs = self.backbone_forward(image, coord_features, text=text, gra=gra) + elif gra is not None: + outputs = self.backbone_forward(image, coord_features, gra=gra) + elif text is not None: + outputs = self.backbone_forward(image, coord_features, text=text) + else: + outputs = self.backbone_forward(image, coord_features) + + outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:], + mode='bilinear', align_corners=True) + if self.with_aux_output: + outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:], + mode='bilinear', align_corners=True) + + return outputs + + def prepare_input(self, image): + prev_mask = None + if self.with_prev_mask: + prev_mask = image[:, 3:, :, :] + image = image[:, :3, :, :] + + image = self.normalization(image) + return image, prev_mask + + def backbone_forward(self, image, coord_features=None): + raise NotImplementedError + + def get_coord_features(self, image, prev_mask, points): + coord_features = self.dist_maps(image, points) + if prev_mask is not None: + coord_features = torch.cat((prev_mask, coord_features), dim=1) + + return coord_features + + +def split_points_by_order(tpoints: torch.Tensor, groups): + points = tpoints.cpu().numpy() + num_groups = len(groups) + bs = points.shape[0] + num_points = points.shape[1] // 2 + + groups = [x if x > 0 else num_points for x in groups] + group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) + for x in groups] + + last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int_) + for group_indx, group_size in enumerate(groups): + last_point_indx_group[:, group_indx, 1] = group_size + + for bindx in range(bs): + for pindx in range(2 * num_points): + point = points[bindx, pindx, :] + group_id = int(point[2]) + if group_id < 0: + continue + + is_negative = int(pindx >= num_points) + if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click + group_id = num_groups - 1 + + new_point_indx = last_point_indx_group[bindx, group_id, is_negative] + last_point_indx_group[bindx, group_id, is_negative] += 1 + + group_points[group_id][bindx, new_point_indx, :] = point + + group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) + for x in group_points] + + return group_points diff --git a/isegm/model/is_plainvit_model.py b/isegm/model/is_plainvit_model.py new file mode 100644 index 0000000000000000000000000000000000000000..572e12313f4a2db8533098624d0de29163247bf0 --- /dev/null +++ b/isegm/model/is_plainvit_model.py @@ -0,0 +1,95 @@ +import math +import torch.nn as nn +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.models_vit import VisionTransformer, PatchEmbed +from .modeling.swin_transformer import SwinTransfomerSegHead + + +class SimpleFPN(nn.Module): + def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]): + super().__init__() + self.down_4_chan = max(out_dims[0]*2, in_dim // 2) + self.down_4 = nn.Sequential( + nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2), + nn.GroupNorm(1, self.down_4_chan), + nn.GELU(), + nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2), + nn.GroupNorm(1, self.down_4_chan // 2), + nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1), + nn.GroupNorm(1, out_dims[0]), + nn.GELU() + ) + self.down_8_chan = max(out_dims[1], in_dim // 2) + self.down_8 = nn.Sequential( + nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2), + nn.GroupNorm(1, self.down_8_chan), + nn.Conv2d(self.down_8_chan, out_dims[1], 1), + nn.GroupNorm(1, out_dims[1]), + nn.GELU() + ) + self.down_16 = nn.Sequential( + nn.Conv2d(in_dim, out_dims[2], 1), + nn.GroupNorm(1, out_dims[2]), + nn.GELU() + ) + self.down_32_chan = max(out_dims[3], in_dim * 2) + self.down_32 = nn.Sequential( + nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2), + nn.GroupNorm(1, self.down_32_chan), + nn.Conv2d(self.down_32_chan, out_dims[3], 1), + nn.GroupNorm(1, out_dims[3]), + nn.GELU() + ) + + self.init_weights() + + def init_weights(self): + pass + + def forward(self, x): + x_down_4 = self.down_4(x) + x_down_8 = self.down_8(x) + x_down_16 = self.down_16(x) + x_down_32 = self.down_32(x) + + return [x_down_4, x_down_8, x_down_16, x_down_32] + + +class PlainVitModel(ISModel): + @serialize + def __init__( + self, + backbone_params={}, + neck_params={}, + head_params={}, + random_split=False, + **kwargs + ): + + super().__init__(**kwargs) + self.random_split = random_split + + self.patch_embed_coords = PatchEmbed( + img_size= backbone_params['img_size'], + patch_size=backbone_params['patch_size'], + in_chans=3 if self.with_prev_mask else 2, + embed_dim=backbone_params['embed_dim'], + ) + + self.backbone = VisionTransformer(**backbone_params) + self.neck = SimpleFPN(**neck_params) + self.head = SwinTransfomerSegHead(**head_params) + + def backbone_forward(self, image, coord_features=None, gra=None): + coord_features = self.patch_embed_coords(coord_features) + backbone_features = self.backbone.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split) + + # Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + B, N, C = backbone_features.shape + grid_size = self.backbone.patch_embed.grid_size + + backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) + multi_scale_features = self.neck(backbone_features) + + return {'instances': self.head(multi_scale_features), 'instances_aux': None} diff --git a/isegm/model/is_plainvit_model_lora.py b/isegm/model/is_plainvit_model_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..09e4d5025efe4b5f8d79070931fbc23cab186cac --- /dev/null +++ b/isegm/model/is_plainvit_model_lora.py @@ -0,0 +1,95 @@ +import math +import torch.nn as nn +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.models_vit_lora import VisionTransformer_lora, PatchEmbed +from .modeling.swin_transformer import SwinTransfomerSegHead + + +class SimpleFPN(nn.Module): + def __init__(self, in_dim=768, out_dims=[128, 256, 512, 1024]): + super().__init__() + self.down_4_chan = max(out_dims[0]*2, in_dim // 2) + self.down_4 = nn.Sequential( + nn.ConvTranspose2d(in_dim, self.down_4_chan, 2, stride=2), + nn.GroupNorm(1, self.down_4_chan), + nn.GELU(), + nn.ConvTranspose2d(self.down_4_chan, self.down_4_chan // 2, 2, stride=2), + nn.GroupNorm(1, self.down_4_chan // 2), + nn.Conv2d(self.down_4_chan // 2, out_dims[0], 1), + nn.GroupNorm(1, out_dims[0]), + nn.GELU() + ) + self.down_8_chan = max(out_dims[1], in_dim // 2) + self.down_8 = nn.Sequential( + nn.ConvTranspose2d(in_dim, self.down_8_chan, 2, stride=2), + nn.GroupNorm(1, self.down_8_chan), + nn.Conv2d(self.down_8_chan, out_dims[1], 1), + nn.GroupNorm(1, out_dims[1]), + nn.GELU() + ) + self.down_16 = nn.Sequential( + nn.Conv2d(in_dim, out_dims[2], 1), + nn.GroupNorm(1, out_dims[2]), + nn.GELU() + ) + self.down_32_chan = max(out_dims[3], in_dim * 2) + self.down_32 = nn.Sequential( + nn.Conv2d(in_dim, self.down_32_chan, 2, stride=2), + nn.GroupNorm(1, self.down_32_chan), + nn.Conv2d(self.down_32_chan, out_dims[3], 1), + nn.GroupNorm(1, out_dims[3]), + nn.GELU() + ) + + self.init_weights() + + def init_weights(self): + pass + + def forward(self, x): + x_down_4 = self.down_4(x) + x_down_8 = self.down_8(x) + x_down_16 = self.down_16(x) + x_down_32 = self.down_32(x) + + return [x_down_4, x_down_8, x_down_16, x_down_32] + + +class PlainVitModel_lora(ISModel): + @serialize + def __init__( + self, + backbone_params={}, + neck_params={}, + head_params={}, + random_split=False, + **kwargs + ): + + super().__init__(**kwargs) + self.random_split = random_split + + self.patch_embed_coords = PatchEmbed( + img_size= backbone_params['img_size'], + patch_size=backbone_params['patch_size'], + in_chans=3 if self.with_prev_mask else 2, + embed_dim=backbone_params['embed_dim'], + ) + + self.backbone = VisionTransformer_lora(**backbone_params) + self.neck = SimpleFPN(**neck_params) + self.head = SwinTransfomerSegHead(**head_params) + + def backbone_forward(self, image, coord_features=None, gra=None): + coord_features = self.patch_embed_coords(coord_features) + backbone_features = self.backbone.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split) + + # Extract 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + B, N, C = backbone_features.shape + grid_size = self.backbone.patch_embed.grid_size + + backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) + multi_scale_features = self.neck(backbone_features) + + return {'instances': self.head(multi_scale_features), 'instances_aux': None} diff --git a/isegm/model/is_segformer_model.py b/isegm/model/is_segformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..361be37cbfce628eab13a4543ccb9d681dbef0ed --- /dev/null +++ b/isegm/model/is_segformer_model.py @@ -0,0 +1,29 @@ +import torch.nn as nn + +from isegm.utils.serialization import serialize +from .is_model import ISModel +from isegm.model.modifiers import LRMult +from .modeling.segformer import MixVisionTransformer, SegformerHead + + +class SegformerModel(ISModel): + @serialize + def __init__( + self, + backbone_params=None, + decode_head_params=None, + backbone_lr_mult=0.1, + **kwargs + ): + + super().__init__(**kwargs) + + self.feature_extractor = MixVisionTransformer(**backbone_params) + self.feature_extractor.apply(LRMult(backbone_lr_mult)) + + self.head = SegformerHead(**decode_head_params) + + def backbone_forward(self, image, coord_features=None): + backbone_features = self.feature_extractor(image, coord_features) + return {'instances': self.head(backbone_features), 'instances_aux': None} + \ No newline at end of file diff --git a/isegm/model/is_swinformer_model.py b/isegm/model/is_swinformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..011736e1261cb444b105fac6fb139ca3d87b3ad8 --- /dev/null +++ b/isegm/model/is_swinformer_model.py @@ -0,0 +1,21 @@ +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .modeling.swin_transformer import SwinTransformer, SwinTransfomerSegHead + +class SwinformerModel(ISModel): + @serialize + def __init__( + self, + backbone_params={}, + head_params={}, + **kwargs + ): + + super().__init__(**kwargs) + + self.backbone = SwinTransformer(**backbone_params) + self.head = SwinTransfomerSegHead(**head_params) + + def backbone_forward(self, image, coord_features=None): + backbone_features = self.backbone(image, coord_features) + return {'instances': self.head(backbone_features), 'instances_aux': None} \ No newline at end of file diff --git a/isegm/model/is_text_graco_model.py b/isegm/model/is_text_graco_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6fac6e66ce495c0eadada120417cd7a3602f030f --- /dev/null +++ b/isegm/model/is_text_graco_model.py @@ -0,0 +1,63 @@ +import torch.nn as nn +from isegm.utils.serialization import serialize +from .is_model import ISModel +from .is_plainvit_model import SimpleFPN +from .modeling.models_vit import VisionTransformer, PatchEmbed +from .modeling.twoway_transformer import TwoWayTransformer, PositionEmbeddingRandom +from .modeling.swin_transformer import SwinTransfomerSegHead +from .modeling.clip_text_encoding import ClipTextEncoder + + +class TextGraCoModel(ISModel): + @serialize + def __init__( + self, + image_encoder_params={}, + text_encoder_params={}, + cross_encoder_params={}, + neck_params={}, + head_params={}, + random_split=False, + **kwargs + ): + + super().__init__(**kwargs) + self.random_split = random_split + + self.patch_embed_coords = PatchEmbed( + img_size=image_encoder_params['img_size'], + patch_size=image_encoder_params['patch_size'], + in_chans=3 if self.with_prev_mask else 2, + embed_dim=image_encoder_params['embed_dim'], + ) + + self.image_encoder = VisionTransformer(**image_encoder_params) + self.text_encoder = ClipTextEncoder(**text_encoder_params) + self.cross_encoder = TwoWayTransformer(**cross_encoder_params) + + self.pe_layer = PositionEmbeddingRandom(cross_encoder_params["embedding_dim"] // 2) + patch_size = image_encoder_params['patch_size'][0] + self.image_embedding_size = image_encoder_params['img_size'][0] // (patch_size if patch_size > 0 else 1) + + self.neck = SimpleFPN(**neck_params) + self.head = SwinTransfomerSegHead(**head_params) + + def backbone_forward(self, image, coord_features=None, text=None, gra=None): + coord_features = self.patch_embed_coords(coord_features) + backbone_features = self.image_encoder.forward_backbone(image, coord_features, gra=gra, shuffle=self.random_split) + text_features = self.text_encoder(text) + + text_features, backbone_features = self.cross_encoder( + backbone_features, + self.pe_layer((self.image_embedding_size, self.image_embedding_size)).unsqueeze(0), + text_features) + + # Extract 4 stage image_encoder feature map: 1/4, 1/8, 1/16, 1/32 + B, N, C = backbone_features.shape + grid_size = self.image_encoder.patch_embed.grid_size + + backbone_features = backbone_features.transpose(-1,-2).view(B, C, grid_size[0], grid_size[1]) + multi_scale_features = self.neck(backbone_features) + + return {'instances': self.head(multi_scale_features), 'instances_aux': None} + \ No newline at end of file diff --git a/isegm/model/losses.py b/isegm/model/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..38c4ee3126c9cbf4c714332acb8d45d094998bc4 --- /dev/null +++ b/isegm/model/losses.py @@ -0,0 +1,195 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from isegm.utils import misc + + +class NormalizedFocalLossSigmoid(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, max_mult=-1, eps=1e-12, + from_sigmoid=False, detach_delimeter=True, + batch_axis=0, weight=None, size_average=True, + ignore_label=-1): + super(NormalizedFocalLossSigmoid, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._from_logits = from_sigmoid + self._eps = eps + self._size_average = size_average + self._detach_delimeter = detach_delimeter + self._max_mult = max_mult + self._k_sum = 0 + self._m_max = 0 + + def forward(self, pred, label): + one_hot = label > 0.5 + sample_weight = label != self._ignore_label + + if not self._from_logits: + pred = torch.sigmoid(pred) + + alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) + pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) + + beta = (1 - pt) ** self._gamma + + sw_sum = torch.sum(sample_weight, dim=(-2, -1), keepdim=True) + beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) + mult = sw_sum / (beta_sum + self._eps) + if self._detach_delimeter: + mult = mult.detach() + beta = beta * mult + if self._max_mult > 0: + beta = torch.clamp_max(beta, self._max_mult) + + with torch.no_grad(): + ignore_area = torch.sum(label == self._ignore_label, dim=tuple(range(1, label.dim()))).cpu().numpy() + sample_mult = torch.mean(mult, dim=tuple(range(1, mult.dim()))).cpu().numpy() + if np.any(ignore_area == 0): + self._k_sum = 0.9 * self._k_sum + 0.1 * sample_mult[ignore_area == 0].mean() + + beta_pmax, _ = torch.flatten(beta, start_dim=1).max(dim=1) + beta_pmax = beta_pmax.mean().item() + self._m_max = 0.8 * self._m_max + 0.2 * beta_pmax + + loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = self._weight * (loss * sample_weight) + + if self._size_average: + bsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(sample_weight.dim(), self._batch_axis)) + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (bsum + self._eps) + else: + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + return loss + + def log_states(self, sw, name, global_step): + sw.add_scalar(tag=name + '_k', value=self._k_sum, global_step=global_step) + sw.add_scalar(tag=name + '_m', value=self._m_max, global_step=global_step) + + +class FocalLoss(nn.Module): + def __init__(self, axis=-1, alpha=0.25, gamma=2, + from_logits=False, batch_axis=0, + weight=None, num_class=None, + eps=1e-9, size_average=True, scale=1.0, + ignore_label=-1): + super(FocalLoss, self).__init__() + self._axis = axis + self._alpha = alpha + self._gamma = gamma + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + self._scale = scale + self._num_class = num_class + self._from_logits = from_logits + self._eps = eps + self._size_average = size_average + + def forward(self, pred, label, sample_weight=None): + one_hot = label > 0.5 + sample_weight = label != self._ignore_label + + if not self._from_logits: + pred = torch.sigmoid(pred) + + alpha = torch.where(one_hot, self._alpha * sample_weight, (1 - self._alpha) * sample_weight) + pt = torch.where(sample_weight, 1.0 - torch.abs(label - pred), torch.ones_like(pred)) + + beta = (1 - pt) ** self._gamma + + loss = -alpha * beta * torch.log(torch.min(pt + self._eps, torch.ones(1, dtype=torch.float).to(pt.device))) + loss = self._weight * (loss * sample_weight) + + if self._size_average: + tsum = torch.sum(sample_weight, dim=misc.get_dims_with_exclusion(label.dim(), self._batch_axis)) + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) / (tsum + self._eps) + else: + loss = torch.sum(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + return self._scale * loss + + +class SoftIoU(nn.Module): + def __init__(self, from_sigmoid=False, ignore_label=-1): + super().__init__() + self._from_sigmoid = from_sigmoid + self._ignore_label = ignore_label + + def forward(self, pred, label): + label = label.view(pred.size()) + sample_weight = label != self._ignore_label + + if not self._from_sigmoid: + pred = torch.sigmoid(pred) + + loss = 1.0 - torch.sum(pred * label * sample_weight, dim=(1, 2, 3)) \ + / (torch.sum(torch.max(pred, label) * sample_weight, dim=(1, 2, 3)) + 1e-8) + + return loss + + +class SigmoidBinaryCrossEntropyLoss(nn.Module): + def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, ignore_label=-1): + super(SigmoidBinaryCrossEntropyLoss, self).__init__() + self._from_sigmoid = from_sigmoid + self._ignore_label = ignore_label + self._weight = weight if weight is not None else 1.0 + self._batch_axis = batch_axis + + def forward(self, pred, label): + label = label.view(pred.size()) + sample_weight = label != self._ignore_label + label = torch.where(sample_weight, label, torch.zeros_like(label)) + + if not self._from_sigmoid: + loss = torch.relu(pred) - pred * label + F.softplus(-torch.abs(pred)) + else: + eps = 1e-12 + loss = -(torch.log(pred + eps) * label + + torch.log(1. - pred + eps) * (1. - label)) + + loss = self._weight * (loss * sample_weight) + return torch.mean(loss, dim=misc.get_dims_with_exclusion(loss.dim(), self._batch_axis)) + + +class BinaryDiceLoss(nn.Module): + """ Dice Loss for binary segmentation + """ + + def forward(self, pred, label): + batchsize = pred.size(0) + + # convert probability to binary label using maximum probability + input_pred, input_label = pred.max(1) + input_pred *= input_label.float() + + # convert to floats + input_pred = input_pred.float() + target_label = label.float() + + # convert to 1D + input_pred = input_pred.view(batchsize, -1) + target_label = target_label.view(batchsize, -1) + + # compute dice score + intersect = torch.sum(input_pred * target_label, 1) + input_area = torch.sum(input_pred * input_pred, 1) + target_area = torch.sum(target_label * target_label, 1) + + sum = input_area + target_area + epsilon = torch.tensor(1e-6) + + # batch dice loss and ignore dice loss where target area = 0 + batch_loss = torch.tensor(1.0) - (torch.tensor(2.0) * intersect + epsilon) / (sum + epsilon) + loss = batch_loss.mean() + + return loss \ No newline at end of file diff --git a/isegm/model/metrics.py b/isegm/model/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..a572dcd97ed2dac222fa51a33657aa5b403dbb2a --- /dev/null +++ b/isegm/model/metrics.py @@ -0,0 +1,101 @@ +import torch +import numpy as np + +from isegm.utils import misc + + +class TrainMetric(object): + def __init__(self, pred_outputs, gt_outputs): + self.pred_outputs = pred_outputs + self.gt_outputs = gt_outputs + + def update(self, *args, **kwargs): + raise NotImplementedError + + def get_epoch_value(self): + raise NotImplementedError + + def reset_epoch_stats(self): + raise NotImplementedError + + def log_states(self, sw, tag_prefix, global_step): + pass + + @property + def name(self): + return type(self).__name__ + + +class AdaptiveIoU(TrainMetric): + def __init__(self, init_thresh=0.4, thresh_step=0.025, thresh_beta=0.99, iou_beta=0.9, + ignore_label=-1, from_logits=True, + pred_output='instances', gt_output='instances'): + super().__init__(pred_outputs=(pred_output,), gt_outputs=(gt_output,)) + self._ignore_label = ignore_label + self._from_logits = from_logits + self._iou_thresh = init_thresh + self._thresh_step = thresh_step + self._thresh_beta = thresh_beta + self._iou_beta = iou_beta + self._ema_iou = 0.0 + self._epoch_iou_sum = 0.0 + self._epoch_batch_count = 0 + + def update(self, pred, gt): + gt_mask = gt > 0.5 + if self._from_logits: + pred = torch.sigmoid(pred) + + gt_mask_area = torch.sum(gt_mask, dim=(1, 2)).detach().cpu().numpy() + if np.all(gt_mask_area == 0): + return + + ignore_mask = gt == self._ignore_label + max_iou = _compute_iou(pred > self._iou_thresh, gt_mask, ignore_mask).mean() + best_thresh = self._iou_thresh + for t in [best_thresh - self._thresh_step, best_thresh + self._thresh_step]: + temp_iou = _compute_iou(pred > t, gt_mask, ignore_mask).mean() + if temp_iou > max_iou: + max_iou = temp_iou + best_thresh = t + + self._iou_thresh = self._thresh_beta * self._iou_thresh + (1 - self._thresh_beta) * best_thresh + self._ema_iou = self._iou_beta * self._ema_iou + (1 - self._iou_beta) * max_iou + self._epoch_iou_sum += max_iou + self._epoch_batch_count += 1 + + def get_epoch_value(self): + if self._epoch_batch_count > 0: + return self._epoch_iou_sum / self._epoch_batch_count + else: + return 0.0 + + def reset_epoch_stats(self): + self._epoch_iou_sum = 0.0 + self._epoch_batch_count = 0 + + def log_states(self, sw, tag_prefix, global_step): + sw.add_scalar(tag=tag_prefix + '_ema_iou', value=self._ema_iou, global_step=global_step) + sw.add_scalar(tag=tag_prefix + '_iou_thresh', value=self._iou_thresh, global_step=global_step) + + @property + def iou_thresh(self): + return self._iou_thresh + + +def _compute_iou(pred_mask, gt_mask, ignore_mask=None, keep_ignore=False): + if ignore_mask is not None: + pred_mask = torch.where(ignore_mask, torch.zeros_like(pred_mask), pred_mask) + + reduction_dims = misc.get_dims_with_exclusion(gt_mask.dim(), 0) + union = torch.mean((pred_mask | gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + intersection = torch.mean((pred_mask & gt_mask).float(), dim=reduction_dims).detach().cpu().numpy() + nonzero = union > 0 + + iou = intersection[nonzero] / union[nonzero] + if not keep_ignore: + return iou + else: + result = np.full_like(intersection, -1) + result[nonzero] = iou + return result diff --git a/isegm/model/modeling/__init__.py b/isegm/model/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/isegm/model/modeling/basic_blocks.py b/isegm/model/modeling/basic_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..13753e85353ed9250aa3888ab2e715350b1b2c50 --- /dev/null +++ b/isegm/model/modeling/basic_blocks.py @@ -0,0 +1,71 @@ +import torch.nn as nn + +from isegm.model import ops + + +class ConvHead(nn.Module): + def __init__(self, out_channels, in_channels=32, num_layers=1, + kernel_size=3, padding=1, + norm_layer=nn.BatchNorm2d): + super(ConvHead, self).__init__() + convhead = [] + + for i in range(num_layers): + convhead.extend([ + nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), + nn.ReLU(), + norm_layer(in_channels) if norm_layer is not None else nn.Identity() + ]) + convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) + + self.convhead = nn.Sequential(*convhead) + + def forward(self, *inputs): + return self.convhead(inputs[0]) + + +class SepConvHead(nn.Module): + def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, + kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, + norm_layer=nn.BatchNorm2d): + super(SepConvHead, self).__init__() + + sepconvhead = [] + + for i in range(num_layers): + sepconvhead.append( + SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, + out_channels=mid_channels, + dw_kernel=kernel_size, dw_padding=padding, + norm_layer=norm_layer, activation='relu') + ) + if dropout_ratio > 0 and dropout_indx == i: + sepconvhead.append(nn.Dropout(dropout_ratio)) + + sepconvhead.append( + nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) + ) + + self.layers = nn.Sequential(*sepconvhead) + + def forward(self, *inputs): + x = inputs[0] + + return self.layers(x) + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, + activation=None, use_bias=False, norm_layer=None): + super(SeparableConv2d, self).__init__() + _activation = ops.select_activation_function(activation) + self.body = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, + padding=dw_padding, bias=use_bias, groups=in_channels), + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), + norm_layer(out_channels) if norm_layer is not None else nn.Identity(), + _activation() + ) + + def forward(self, x): + return self.body(x) diff --git a/isegm/model/modeling/clip/__init__.py b/isegm/model/modeling/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9 --- /dev/null +++ b/isegm/model/modeling/clip/__init__.py @@ -0,0 +1 @@ +from .clip import * diff --git a/isegm/model/modeling/clip/clip.py b/isegm/model/modeling/clip/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a5da5e69e0a3b41383734711ccfff1923a9ef9 --- /dev/null +++ b/isegm/model/modeling/clip/clip.py @@ -0,0 +1,245 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if _node_get(inputs[i].node(), "value") == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/isegm/model/modeling/clip/model.py b/isegm/model/modeling/clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c565b386e3aa8f217eb9ae81c07e2869d0127ffd --- /dev/null +++ b/isegm/model/modeling/clip/model.py @@ -0,0 +1,436 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/isegm/model/modeling/clip/simple_tokenizer.py b/isegm/model/modeling/clip/simple_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91 --- /dev/null +++ b/isegm/model/modeling/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/isegm/model/modeling/clip_text_encoding.py b/isegm/model/modeling/clip_text_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..8a51cbc0e3110d1e7c318c3b7db323b08473f41d --- /dev/null +++ b/isegm/model/modeling/clip_text_encoding.py @@ -0,0 +1,29 @@ +import torch +from torch import nn +from .clip import clip + +class ClipTextEncoder(nn.Module): + def __init__(self, clip_enocder_name="ViT-B/32", embedding_dim=512, out_dim=768): + super().__init__() + assert clip_enocder_name in ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'] + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model, self.preprocess = clip.load(clip_enocder_name, device=self.device) + + # freeze model + for _, param in self.model.named_parameters(): + param.requires_grad = False + self.out_proj = nn.Linear(embedding_dim, out_dim) + nn.init.zeros_(self.out_proj.bias) + + @torch.no_grad() + def forward(self, prompt): + ''' + prompt: text tokens + ''' + text_features = self.model.encode_text(prompt).type(torch.float32) + # norm + # text_features /= text_features.norm(dim=-1, keepdim=True) # [bs, 1024] + # proj + text_features = self.out_proj(text_features) + return text_features + \ No newline at end of file diff --git a/isegm/model/modeling/deeplab_v3.py b/isegm/model/modeling/deeplab_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..8219a4ef18048a0fc79fdf3e5b603af7eac03892 --- /dev/null +++ b/isegm/model/modeling/deeplab_v3.py @@ -0,0 +1,176 @@ +from contextlib import ExitStack + +import torch +from torch import nn +import torch.nn.functional as F + +from .basic_blocks import SeparableConv2d +from .resnet import ResNetBackbone +from isegm.model import ops + + +class DeepLabV3Plus(nn.Module): + def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, + backbone_norm_layer=None, + ch=256, + project_dropout=0.5, + inference_mode=False, + **kwargs): + super(DeepLabV3Plus, self).__init__() + if backbone_norm_layer is None: + backbone_norm_layer = norm_layer + + self.backbone_name = backbone + self.norm_layer = norm_layer + self.backbone_norm_layer = backbone_norm_layer + self.inference_mode = False + self.ch = ch + self.aspp_in_channels = 2048 + self.skip_project_in_channels = 256 # layer 1 out_channels + + self._kwargs = kwargs + if backbone == 'resnet34': + self.aspp_in_channels = 512 + self.skip_project_in_channels = 64 + + self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, + norm_layer=self.backbone_norm_layer, **kwargs) + + self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch, + norm_layer=self.norm_layer) + self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) + self.aspp = _ASPP(in_channels=self.aspp_in_channels, + atrous_rates=[12, 24, 36], + out_channels=ch, + project_dropout=project_dropout, + norm_layer=self.norm_layer) + + if inference_mode: + self.set_prediction_mode() + + def load_pretrained_weights(self): + pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, + norm_layer=self.backbone_norm_layer, **self._kwargs) + backbone_state_dict = self.backbone.state_dict() + pretrained_state_dict = pretrained.state_dict() + + backbone_state_dict.update(pretrained_state_dict) + self.backbone.load_state_dict(backbone_state_dict) + + if self.inference_mode: + for param in self.backbone.parameters(): + param.requires_grad = False + + def set_prediction_mode(self): + self.inference_mode = True + self.eval() + + def forward(self, x, additional_features=None): + with ExitStack() as stack: + if self.inference_mode: + stack.enter_context(torch.no_grad()) + + c1, _, c3, c4 = self.backbone(x, additional_features) + c1 = self.skip_project(c1) + + x = self.aspp(c4) + x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) + x = torch.cat((x, c1), dim=1) + x = self.head(x) + + return x, + + +class _SkipProject(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): + super(_SkipProject, self).__init__() + _activation = ops.select_activation_function("relu") + + self.skip_project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + _activation() + ) + + def forward(self, x): + return self.skip_project(x) + + +class _DeepLabHead(nn.Module): + def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): + super(_DeepLabHead, self).__init__() + + self.block = nn.Sequential( + SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, + dw_padding=1, activation='relu', norm_layer=norm_layer), + SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, + dw_padding=1, activation='relu', norm_layer=norm_layer), + nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) + ) + + def forward(self, x): + return self.block(x) + + +class _ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates, out_channels=256, + project_dropout=0.5, norm_layer=nn.BatchNorm2d): + super(_ASPP, self).__init__() + + b0 = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + rate1, rate2, rate3 = tuple(atrous_rates) + b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) + b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) + b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) + b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) + + self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) + + project = [ + nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, + kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ] + if project_dropout > 0: + project.append(nn.Dropout(project_dropout)) + self.project = nn.Sequential(*project) + + def forward(self, x): + x = torch.cat([block(x) for block in self.concurent], dim=1) + + return self.project(x) + + +class _AsppPooling(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer): + super(_AsppPooling, self).__init__() + + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=1, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + def forward(self, x): + pool = self.gap(x) + return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) + + +def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): + block = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=3, padding=atrous_rate, + dilation=atrous_rate, bias=False), + norm_layer(out_channels), + nn.ReLU() + ) + + return block diff --git a/isegm/model/modeling/hrformer.py b/isegm/model/modeling/hrformer.py new file mode 100644 index 0000000000000000000000000000000000000000..49c046b1d40112d1dd2b70ed3fa49c69359c8e35 --- /dev/null +++ b/isegm/model/modeling/hrformer.py @@ -0,0 +1,487 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: RainbowSecret +## Microsoft Research +## yuyua@microsoft.com, furao17@mails.ucas.ac.cn +## Copyright (c) 2021 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import os +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +# from .hrformer_helper.backbone_selector import BackboneSelector +from .hrformer_helper.hrt.module_helper import ModuleHelper +from .hrformer_helper.hrt.modules.spatial_ocr_block import SpatialGather_Module, SpatialOCR_Module + +from .hrformer_helper.hrt.logger import Logger as Log +from .hrformer_helper.hrt.hrt_backbone import HRTBackbone, HRTBackbone_v2 + + +class BackboneSelector(object): + def __init__(self, configer): + self.configer = configer + + def get_backbone(self, **params): + backbone = self.configer.get("network", "backbone") + + model = None + # if ( + # "resnet" in backbone or "resnext" in backbone or "resnest" in backbone + # ) and "senet" not in backbone: + # model = ResNetBackbone(self.configer)(**params) + + if "hrt" in backbone: + model = HRTBackbone(self.configer)(**params) + pass + + # elif "hrnet" in backbone: + # model = HRNetBackbone(self.configer)(**params) + + # elif "swin" in backbone: + # model = SwinTransformerBackbone(self.configer)(**params) + + else: + Log.error("Backbone {} is invalid.".format(backbone)) + exit(1) + + return model + + +class HRT_B_OCR_V3(nn.Module): + def __init__(self, num_classes, in_ch=3, backbone='hrt_base', bn_type="torchbn", pretrained=None): + super(HRT_B_OCR_V3, self).__init__() + self.num_classes = num_classes + self.bn_type = bn_type + self.backbone = HRTBackbone_v2(backbone, pretrained, in_ch)() + + in_channels = 1170 + hidden_dim = 512 + group_channel = math.gcd(in_channels, hidden_dim) + self.conv3x3 = nn.Sequential( + nn.Conv2d( + in_channels, + hidden_dim, + kernel_size=7, + stride=1, + padding=3, + groups=group_channel, + ), + ModuleHelper.BNReLU( + hidden_dim, bn_type=self.bn_type + ), + ) + self.ocr_gather_head = SpatialGather_Module(self.num_classes) + self.ocr_distri_head = SpatialOCR_Module( + in_channels=hidden_dim, + key_channels=hidden_dim // 2, + out_channels=hidden_dim, + scale=1, + dropout=0.05, + bn_type=self.bn_type, + ) + self.cls_head = nn.Conv2d( + hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) + self.aux_head = nn.Sequential( + nn.Conv2d( + in_channels, + hidden_dim, + kernel_size=7, + stride=1, + padding=3, + groups=group_channel, + ), + ModuleHelper.BNReLU( + hidden_dim, bn_type=self.bn_type + ), + nn.Conv2d( + hidden_dim, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), + ) + + def forward(self, x_): + x = self.backbone(x_) + _, _, h, w = x[0].size() + + feat1 = x[0] + feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) + feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) + feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) + + feats = torch.cat([feat1, feat2, feat3, feat4], 1) + out_aux = self.aux_head(feats) + + feats = self.conv3x3(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + out_aux = F.interpolate( + out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + out = F.interpolate( + out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + return out_aux, out + + +class HRT_S_OCR_V2(nn.Module): + def __init__(self, num_classes, backbone='hrt_small', bn_type="torchbn", pretrained=None): + super(HRT_S_OCR_V2, self).__init__() + self.num_classes = num_classes + self.bn_type = bn_type + self.backbone = HRTBackbone_v2(backbone, pretrained)() + + in_channels = 480 + self.conv3x3 = nn.Sequential( + nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), + ModuleHelper.BNReLU(512, bn_type=self.bn_type), + ) + self.ocr_gather_head = SpatialGather_Module(self.num_classes) + self.ocr_distri_head = SpatialOCR_Module( + in_channels=512, + key_channels=256, + out_channels=512, + scale=1, + dropout=0.05, + bn_type=self.bn_type, + ) + self.cls_head = nn.Conv2d( + 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) + self.aux_head = nn.Sequential( + nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), + ModuleHelper.BNReLU(512, bn_type=self.bn_type), + nn.Conv2d( + 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ), + ) + + def forward(self, x_): + x = self.backbone(x_) + _, _, h, w = x[0].size() + + feat1 = x[0] + feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) + feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) + feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) + + feats = torch.cat([feat1, feat2, feat3, feat4], 1) + out_aux = self.aux_head(feats) + + feats = self.conv3x3(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + out_aux = F.interpolate( + out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + out = F.interpolate( + out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + return out_aux, out + + +class HRT_SMALL_OCR_V2(nn.Module): + def __init__(self, configer): + super(HRT_SMALL_OCR_V2, self).__init__() + self.configer = configer + self.num_classes = self.configer.get("data", "num_classes") + self.backbone = BackboneSelector(configer).get_backbone() + + in_channels = 480 + self.conv3x3 = nn.Sequential( + nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), + ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")), + ) + self.ocr_gather_head = SpatialGather_Module(self.num_classes) + self.ocr_distri_head = SpatialOCR_Module( + in_channels=512, + key_channels=256, + out_channels=512, + scale=1, + dropout=0.05, + bn_type=self.configer.get("network", "bn_type"), + ) + self.cls_head = nn.Conv2d( + 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) + self.aux_head = nn.Sequential( + nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), + ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")), + nn.Conv2d( + 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ), + ) + + def forward(self, x_): + x = self.backbone(x_) + _, _, h, w = x[0].size() + + feat1 = x[0] + feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) + feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) + feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) + + feats = torch.cat([feat1, feat2, feat3, feat4], 1) + out_aux = self.aux_head(feats) + + feats = self.conv3x3(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + out_aux = F.interpolate( + out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + out = F.interpolate( + out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + return out_aux, out + + +class HRT_BASE_OCR_V2(nn.Module): + def __init__(self, configer): + super(HRT_BASE_OCR_V2, self).__init__() + self.configer = configer + self.num_classes = self.configer.get("data", "num_classes") + self.backbone = BackboneSelector(configer).get_backbone() + + in_channels = 1170 + self.conv3x3 = nn.Sequential( + nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), + ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")), + ) + self.ocr_gather_head = SpatialGather_Module(self.num_classes) + self.ocr_distri_head = SpatialOCR_Module( + in_channels=512, + key_channels=256, + out_channels=512, + scale=1, + dropout=0.05, + bn_type=self.configer.get("network", "bn_type"), + ) + self.cls_head = nn.Conv2d( + 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) + self.aux_head = nn.Sequential( + nn.Conv2d(in_channels, 512, kernel_size=3, stride=1, padding=1), + ModuleHelper.BNReLU(512, bn_type=self.configer.get("network", "bn_type")), + nn.Conv2d( + 512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ), + ) + + def forward(self, x_): + x = self.backbone(x_) + _, _, h, w = x[0].size() + + feat1 = x[0] + feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) + feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) + feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) + + feats = torch.cat([feat1, feat2, feat3, feat4], 1) + out_aux = self.aux_head(feats) + + feats = self.conv3x3(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + out_aux = F.interpolate( + out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + out = F.interpolate( + out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + return out_aux, out + + +class HRT_SMALL_OCR_V3(nn.Module): + def __init__(self, configer): + super(HRT_SMALL_OCR_V3, self).__init__() + self.configer = configer + self.num_classes = self.configer.get("data", "num_classes") + self.backbone = BackboneSelector(configer).get_backbone() + + in_channels = 480 + hidden_dim = 512 + group_channel = math.gcd(in_channels, hidden_dim) + self.conv3x3 = nn.Sequential( + nn.Conv2d( + in_channels, + hidden_dim, + kernel_size=7, + stride=1, + padding=3, + groups=group_channel, + ), + ModuleHelper.BNReLU( + hidden_dim, bn_type=self.configer.get("network", "bn_type") + ), + ) + self.ocr_gather_head = SpatialGather_Module(self.num_classes) + self.ocr_distri_head = SpatialOCR_Module( + in_channels=hidden_dim, + key_channels=hidden_dim // 2, + out_channels=hidden_dim, + scale=1, + dropout=0.05, + bn_type=self.configer.get("network", "bn_type"), + ) + self.cls_head = nn.Conv2d( + hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) + self.aux_head = nn.Sequential( + nn.Conv2d( + in_channels, + hidden_dim, + kernel_size=7, + stride=1, + padding=3, + groups=group_channel, + ), + ModuleHelper.BNReLU( + hidden_dim, bn_type=self.configer.get("network", "bn_type") + ), + nn.Conv2d( + hidden_dim, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), + ) + + def forward(self, x_): + x = self.backbone(x_) + _, _, h, w = x[0].size() + + feat1 = x[0] + feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) + feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) + feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) + + feats = torch.cat([feat1, feat2, feat3, feat4], 1) + out_aux = self.aux_head(feats) + + feats = self.conv3x3(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + out_aux = F.interpolate( + out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + out = F.interpolate( + out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + return out_aux, out + + +class HRT_BASE_OCR_V3(nn.Module): + def __init__(self, configer): + super(HRT_BASE_OCR_V3, self).__init__() + self.configer = configer + self.num_classes = self.configer.get("data", "num_classes") + self.backbone = BackboneSelector(configer).get_backbone() + + in_channels = 1170 + hidden_dim = 512 + group_channel = math.gcd(in_channels, hidden_dim) + self.conv3x3 = nn.Sequential( + nn.Conv2d( + in_channels, + hidden_dim, + kernel_size=7, + stride=1, + padding=3, + groups=group_channel, + ), + ModuleHelper.BNReLU( + hidden_dim, bn_type=self.configer.get("network", "bn_type") + ), + ) + self.ocr_gather_head = SpatialGather_Module(self.num_classes) + self.ocr_distri_head = SpatialOCR_Module( + in_channels=hidden_dim, + key_channels=hidden_dim // 2, + out_channels=hidden_dim, + scale=1, + dropout=0.05, + bn_type=self.configer.get("network", "bn_type"), + ) + self.cls_head = nn.Conv2d( + hidden_dim, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True + ) + self.aux_head = nn.Sequential( + nn.Conv2d( + in_channels, + hidden_dim, + kernel_size=7, + stride=1, + padding=3, + groups=group_channel, + ), + ModuleHelper.BNReLU( + hidden_dim, bn_type=self.configer.get("network", "bn_type") + ), + nn.Conv2d( + hidden_dim, + self.num_classes, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ), + ) + + def forward(self, x_): + x = self.backbone(x_) + _, _, h, w = x[0].size() + + feat1 = x[0] + feat2 = F.interpolate(x[1], size=(h, w), mode="bilinear", align_corners=True) + feat3 = F.interpolate(x[2], size=(h, w), mode="bilinear", align_corners=True) + feat4 = F.interpolate(x[3], size=(h, w), mode="bilinear", align_corners=True) + + feats = torch.cat([feat1, feat2, feat3, feat4], 1) + out_aux = self.aux_head(feats) + + feats = self.conv3x3(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + out_aux = F.interpolate( + out_aux, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + out = F.interpolate( + out, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True + ) + return out_aux, out \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/__init__.py b/isegm/model/modeling/hrformer_helper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/isegm/model/modeling/hrformer_helper/backbone_selector.py b/isegm/model/modeling/hrformer_helper/backbone_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0c26f2d673037371d61bc39bd932e742b5ebb6 --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/backbone_selector.py @@ -0,0 +1,54 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Donny You, RainbowSecret +## Microsoft Research +## yuyua@microsoft.com +## Copyright (c) 2019 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# from lib.models.backbones.resnet.resnet_backbone import ResNetBackbone +# from lib.models.backbones.hrnet.hrnet_backbone import HRNetBackbone +from .hrt.hrt_backbone import HRTBackbone +# from lib.models.backbones.swin.swin_backbone import SwinTransformerBackbone +from .hrt.logger import Logger as Log + + +class BackboneSelector(object): + def __init__(self, configer): + self.configer = configer + + def get_backbone(self, **params): + backbone = self.configer.get("network", "backbone") + + model = None + # if ( + # "resnet" in backbone or "resnext" in backbone or "resnest" in backbone + # ) and "senet" not in backbone: + # model = ResNetBackbone(self.configer)(**params) + + if "hrt" in backbone: + # model = HRTBackbone(self.configer)(**params) + pass + + # elif "hrnet" in backbone: + # model = HRNetBackbone(self.configer)(**params) + + # elif "swin" in backbone: + # model = SwinTransformerBackbone(self.configer)(**params) + + else: + Log.error("Backbone {} is invalid.".format(backbone)) + exit(1) + + return model + +class Test(): + def __init__(): + pass \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/__init__.py b/isegm/model/modeling/hrformer_helper/hrt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/isegm/model/modeling/hrformer_helper/hrt/hrt_backbone.py b/isegm/model/modeling/hrformer_helper/hrt/hrt_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..b2174a9ffe8c150cc1cc0bf039e78a52a8c22f85 --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/hrt_backbone.py @@ -0,0 +1,661 @@ +import os +import pdb +import argparse +import torch +import logging +import torch.nn as nn +import torch.nn.functional as F + +from .modules.bottleneck_block import Bottleneck, BottleneckDWP +from .modules.transformer_block import GeneralTransformerBlock + +from .module_helper import ModuleHelper +from .logger import Logger as Log + +blocks_dict = { + "BOTTLENECK": Bottleneck, + "TRANSFORMER_BLOCK": GeneralTransformerBlock, +} + + +BN_MOMENTUM = 0.1 + + +class HighResolutionTransformerModule(nn.Module): + def __init__( + self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + multi_scale_output=True, + drop_path=0.0, + ): + """Based on Local-Attention & FFN-DW-BN + num_heads: the number of head witin each MHSA + num_window_sizes: the window size for the local self-attention + num_halo_sizes: the halo size around the local window + - reference: ``Scaling Local Self-Attention for Parameter Efficient Visual Backbones'' + num_sr_ratios: the spatial reduction ratios of PVT/SRA scheme. + - reference: ``Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions'' + """ + super(HighResolutionTransformerModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels + ) + + self.num_inchannels = num_inchannels + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + self.branches = self._make_branches( + num_branches, + blocks, + num_blocks, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + drop_path, + ) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + self.num_heads = num_heads + self.num_window_sizes = num_window_sizes + self.num_mlp_ratios = num_mlp_ratios + + def _check_branches( + self, num_branches, blocks, num_blocks, num_inchannels, num_channels + ): + if num_branches != len(num_blocks): + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( + num_branches, len(num_blocks) + ) + Log.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( + num_branches, len(num_channels) + ) + Log.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( + num_branches, len(num_inchannels) + ) + Log.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch( + self, + branch_index, + block, + num_blocks, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + drop_paths, + stride=1, + ): + downsample = None + if ( + stride != 1 + or self.num_inchannels[branch_index] + != num_channels[branch_index] * block.expansion + ): + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.SyncBatchNorm( + num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM + ), + ) + + layers = [] + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index], + num_heads=num_heads[branch_index], + window_size=num_window_sizes[branch_index], + mlp_ratio=num_mlp_ratios[branch_index], + drop_path=drop_paths[0], + ) + ) + + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index], + num_heads=num_heads[branch_index], + window_size=num_window_sizes[branch_index], + mlp_ratio=num_mlp_ratios[branch_index], + drop_path=drop_paths[i], + ) + ) + return nn.Sequential(*layers) + + def _make_branches( + self, + num_branches, + block, + num_blocks, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + drop_paths, + ): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch( + i, + block, + num_blocks, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + drop_paths=[_ * (2 ** i) for _ in drop_paths] + if os.environ.get("multi_res_drop_path", False) + else drop_paths, + ) + ) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + kernel_size=1, + stride=1, + bias=False, + ), + nn.SyncBatchNorm(num_inchannels[i], momentum=BN_MOMENTUM), + nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"), + ) + ) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[j], + kernel_size=3, + stride=2, + padding=1, + groups=num_inchannels[j], + bias=False, + ), + nn.SyncBatchNorm( + num_inchannels[j], momentum=BN_MOMENTUM + ), + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=1, + stride=1, + bias=False, + ), + nn.SyncBatchNorm( + num_outchannels_conv3x3, momentum=BN_MOMENTUM + ), + ) + ) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[j], + kernel_size=3, + stride=2, + padding=1, + groups=num_inchannels[j], + bias=False, + ), + nn.SyncBatchNorm( + num_inchannels[j], momentum=BN_MOMENTUM + ), + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=1, + stride=1, + bias=False, + ), + nn.SyncBatchNorm( + num_outchannels_conv3x3, momentum=BN_MOMENTUM + ), + nn.ReLU(False), + ) + ) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode="bilinear", + align_corners=True, + ) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HighResolutionTransformer(nn.Module): + def __init__(self, cfg, in_ch=3, **kwargs): + super(HighResolutionTransformer, self).__init__() + + self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.SyncBatchNorm(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.SyncBatchNorm(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + # stochastic depth + depth_s2 = cfg["STAGE2"]["NUM_BLOCKS"][0] * cfg["STAGE2"]["NUM_MODULES"] + depth_s3 = cfg["STAGE3"]["NUM_BLOCKS"][0] * cfg["STAGE3"]["NUM_MODULES"] + depth_s4 = cfg["STAGE4"]["NUM_BLOCKS"][0] * cfg["STAGE4"]["NUM_MODULES"] + depths = [depth_s2, depth_s3, depth_s4] + drop_path_rate = cfg["DROP_PATH_RATE"] + if os.environ.get("drop_path_rate") is not None: + drop_path_rate = float(os.environ.get("drop_path_rate")) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + self.stage1_cfg = cfg["STAGE1"] + num_channels = self.stage1_cfg["NUM_CHANNELS"][0] + block = blocks_dict[self.stage1_cfg["BLOCK"]] + num_blocks = self.stage1_cfg["NUM_BLOCKS"][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = cfg["STAGE2"] + num_channels = self.stage2_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage2_cfg["BLOCK"]] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels + ) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels, drop_path=dpr[0:depth_s2] + ) + + self.stage3_cfg = cfg["STAGE3"] + num_channels = self.stage3_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage3_cfg["BLOCK"]] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels, drop_path=dpr[depth_s2 : depth_s2 + depth_s3] + ) + + self.stage4_cfg = cfg["STAGE4"] + num_channels = self.stage4_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage4_cfg["BLOCK"]] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, + num_channels, + multi_scale_output=True, + drop_path=dpr[depth_s2 + depth_s3 :], + ) + + if os.environ.get("keep_imagenet_head"): + ( + self.incre_modules, + self.downsamp_modules, + self.final_layer, + ) = self._make_head(pre_stage_channels) + + def _make_head(self, pre_stage_channels): + head_block = BottleneckDWP + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_module = self._make_layer( + head_block, channels, head_channels[i], 1, stride=1 + ) + incre_modules.append(incre_module) + incre_modules = nn.ModuleList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = head_channels[i] * head_block.expansion + out_channels = head_channels[i + 1] * head_block.expansion + downsamp_module = nn.Sequential( + nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=1, + groups=in_channels, + ), + nn.SyncBatchNorm(in_channels, momentum=BN_MOMENTUM), + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1), + nn.SyncBatchNorm(out_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True), + ) + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=head_channels[3] * head_block.expansion, + out_channels=2048, + kernel_size=1, + stride=1, + padding=0, + ), + nn.SyncBatchNorm(2048, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True), + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False, + ), + nn.SyncBatchNorm( + num_channels_cur_layer[i], momentum=BN_MOMENTUM + ), + nn.ReLU(inplace=True), + ) + ) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = ( + num_channels_cur_layer[i] + if j == i - num_branches_pre + else inchannels + ) + conv3x3s.append( + nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.SyncBatchNorm(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True), + ) + ) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer( + self, + block, + inplanes, + planes, + blocks, + num_heads=1, + stride=1, + window_size=7, + mlp_ratio=4.0, + ): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.SyncBatchNorm(planes * block.expansion, momentum=BN_MOMENTUM), + ) + layers = [] + + if isinstance(block, GeneralTransformerBlock): + layers.append( + block( + inplanes, + planes, + num_heads, + window_size, + mlp_ratio, + ) + ) + else: + layers.append(block(inplanes, planes, stride, downsample)) + + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage( + self, layer_config, num_inchannels, multi_scale_output=True, drop_path=0.0 + ): + num_modules = layer_config["NUM_MODULES"] + num_branches = layer_config["NUM_BRANCHES"] + num_blocks = layer_config["NUM_BLOCKS"] + num_channels = layer_config["NUM_CHANNELS"] + block = blocks_dict[layer_config["BLOCK"]] + num_heads = layer_config["NUM_HEADS"] + num_window_sizes = layer_config["NUM_WINDOW_SIZES"] + num_mlp_ratios = layer_config["NUM_MLP_RATIOS"] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionTransformerModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + num_heads, + num_window_sizes, + num_mlp_ratios, + reset_multi_scale_output, + drop_path=drop_path[num_blocks[0] * i : num_blocks[0] * (i + 1)], + ) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg["NUM_BRANCHES"]): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg["NUM_BRANCHES"]): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg["NUM_BRANCHES"]): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + if os.environ.get("keep_imagenet_head"): + x_list = [] + y = self.incre_modules[0](y_list[0]) + x_list.append(y) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i + 1](y_list[i + 1]) + self.downsamp_modules[i]( + y + ) + x_list.append(y) + + y = self.final_layer(y) + del x_list[-1] + x_list.append(y) + return x_list + + else: + return y_list + + +class HRTBackbone(object): + def __init__(self, configer): + self.configer = configer + + def __call__(self): + arch = self.configer.get("network", "backbone") + from .hrt_config import MODEL_CONFIGS + + if arch in [ + "hrt_small", + "hrt_base", + "hrt_base_win13", + "hrt_base_win15", + ]: + arch_net = HighResolutionTransformer(MODEL_CONFIGS[arch]) + arch_net = ModuleHelper.load_model( + arch_net, + pretrained=self.configer.get("network", "pretrained"), + all_match=False, + network="hrt_window" if "win" in arch else "hrt", + ) + + else: + raise Exception("Architecture undefined!") + + return arch_net + + +class HRTBackbone_v2(object): + def __init__(self, backbone='hrt_small', pretrained=None, in_ch=3): + self.backbone = backbone + self.pretrained = pretrained + self.in_ch = in_ch + + def __call__(self): + from .hrt_config import MODEL_CONFIGS + if self.backbone in [ + "hrt_small", + "hrt_base", + "hrt_base_win13", + "hrt_base_win15", + ]: + arch_net = HighResolutionTransformer(MODEL_CONFIGS[self.backbone], in_ch=self.in_ch) + arch_net = ModuleHelper.load_model( + arch_net, + pretrained=self.pretrained, + all_match=False, + network="hrt_window" if "win" in self.backbone else "hrt", + ) + + else: + raise Exception("ARCHITECTURE UNDEFINED!") + + return arch_net \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/hrt_config.py b/isegm/model/modeling/hrformer_helper/hrt/hrt_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8300beaae80eaff73c430333ff3c85877f48b14f --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/hrt_config.py @@ -0,0 +1,123 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Rainbowsecret (yuyua@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from yacs.config import CfgNode as CN + +# configs for HRT_SMALL +HRT_SMALL = CN() +HRT_SMALL.DROP_PATH_RATE = 0.2 + +HRT_SMALL.STAGE1 = CN() +HRT_SMALL.STAGE1.NUM_MODULES = 1 +HRT_SMALL.STAGE1.NUM_BRANCHES = 1 +HRT_SMALL.STAGE1.NUM_BLOCKS = [2] +HRT_SMALL.STAGE1.NUM_CHANNELS = [64] +HRT_SMALL.STAGE1.NUM_HEADS = [2] +HRT_SMALL.STAGE1.NUM_MLP_RATIOS = [4] +HRT_SMALL.STAGE1.NUM_RESOLUTIONS = [[56, 56]] +HRT_SMALL.STAGE1.BLOCK = "BOTTLENECK" + +HRT_SMALL.STAGE2 = CN() +HRT_SMALL.STAGE2.NUM_MODULES = 1 +HRT_SMALL.STAGE2.NUM_BRANCHES = 2 +HRT_SMALL.STAGE2.NUM_BLOCKS = [2, 2] +HRT_SMALL.STAGE2.NUM_CHANNELS = [32, 64] +HRT_SMALL.STAGE2.NUM_HEADS = [1, 2] +HRT_SMALL.STAGE2.NUM_MLP_RATIOS = [4, 4] +HRT_SMALL.STAGE2.NUM_RESOLUTIONS = [[56, 56], [28, 28]] +HRT_SMALL.STAGE2.NUM_WINDOW_SIZES = [7, 7] +HRT_SMALL.STAGE2.BLOCK = "TRANSFORMER_BLOCK" + +HRT_SMALL.STAGE3 = CN() +HRT_SMALL.STAGE3.NUM_MODULES = 4 +HRT_SMALL.STAGE3.NUM_BRANCHES = 3 +HRT_SMALL.STAGE3.NUM_BLOCKS = [2, 2, 2] +HRT_SMALL.STAGE3.NUM_CHANNELS = [32, 64, 128] +HRT_SMALL.STAGE3.NUM_HEADS = [1, 2, 4] +HRT_SMALL.STAGE3.NUM_MLP_RATIOS = [4, 4, 4] +HRT_SMALL.STAGE3.NUM_RESOLUTIONS = [[56, 56], [28, 28], [14, 14]] +HRT_SMALL.STAGE3.NUM_WINDOW_SIZES = [7, 7, 7] +HRT_SMALL.STAGE3.BLOCK = "TRANSFORMER_BLOCK" + +HRT_SMALL.STAGE4 = CN() +HRT_SMALL.STAGE4.NUM_MODULES = 2 +HRT_SMALL.STAGE4.NUM_BRANCHES = 4 +HRT_SMALL.STAGE4.NUM_BLOCKS = [2, 2, 2, 2] +HRT_SMALL.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] +HRT_SMALL.STAGE4.NUM_HEADS = [1, 2, 4, 8] +HRT_SMALL.STAGE4.NUM_MLP_RATIOS = [4, 4, 4, 4] +HRT_SMALL.STAGE4.NUM_RESOLUTIONS = [[56, 56], [28, 28], [14, 14], [7, 7]] +HRT_SMALL.STAGE4.NUM_WINDOW_SIZES = [7, 7, 7, 7] +HRT_SMALL.STAGE4.BLOCK = "TRANSFORMER_BLOCK" + +# configs for HRT_BASE +HRT_BASE = CN() +HRT_BASE.DROP_PATH_RATE = 0.2 + +HRT_BASE.STAGE1 = CN() +HRT_BASE.STAGE1.NUM_MODULES = 1 +HRT_BASE.STAGE1.NUM_BRANCHES = 1 +HRT_BASE.STAGE1.NUM_BLOCKS = [2] +HRT_BASE.STAGE1.NUM_CHANNELS = [64] +HRT_BASE.STAGE1.NUM_HEADS = [2] +HRT_BASE.STAGE1.NUM_MLP_RATIOS = [4] +HRT_BASE.STAGE1.NUM_RESOLUTIONS = [[56, 56]] +HRT_BASE.STAGE1.BLOCK = "BOTTLENECK" + +HRT_BASE.STAGE2 = CN() +HRT_BASE.STAGE2.NUM_MODULES = 1 +HRT_BASE.STAGE2.NUM_BRANCHES = 2 +HRT_BASE.STAGE2.NUM_BLOCKS = [2, 2] +HRT_BASE.STAGE2.NUM_CHANNELS = [78, 156] +HRT_BASE.STAGE2.NUM_HEADS = [2, 4] +HRT_BASE.STAGE2.NUM_MLP_RATIOS = [4, 4] +HRT_BASE.STAGE2.NUM_RESOLUTIONS = [[56, 56], [28, 28]] +HRT_BASE.STAGE2.NUM_WINDOW_SIZES = [7, 7] +HRT_BASE.STAGE2.BLOCK = "TRANSFORMER_BLOCK" + +HRT_BASE.STAGE3 = CN() +HRT_BASE.STAGE3.NUM_MODULES = 4 +HRT_BASE.STAGE3.NUM_BRANCHES = 3 +HRT_BASE.STAGE3.NUM_BLOCKS = [2, 2, 2] +HRT_BASE.STAGE3.NUM_CHANNELS = [78, 156, 312] +HRT_BASE.STAGE3.NUM_HEADS = [2, 4, 8] +HRT_BASE.STAGE3.NUM_MLP_RATIOS = [4, 4, 4] +HRT_BASE.STAGE3.NUM_RESOLUTIONS = [[56, 56], [28, 28], [14, 14]] +HRT_BASE.STAGE3.NUM_WINDOW_SIZES = [7, 7, 7] +HRT_BASE.STAGE3.BLOCK = "TRANSFORMER_BLOCK" + +HRT_BASE.STAGE4 = CN() +HRT_BASE.STAGE4.NUM_MODULES = 2 +HRT_BASE.STAGE4.NUM_BRANCHES = 4 +HRT_BASE.STAGE4.NUM_BLOCKS = [2, 2, 2, 2] +HRT_BASE.STAGE4.NUM_CHANNELS = [78, 156, 312, 624] +HRT_BASE.STAGE4.NUM_HEADS = [2, 4, 8, 16] +HRT_BASE.STAGE4.NUM_MLP_RATIOS = [4, 4, 4, 4] +HRT_BASE.STAGE4.NUM_RESOLUTIONS = [[56, 56], [28, 28], [14, 14], [7, 7]] +HRT_BASE.STAGE4.NUM_WINDOW_SIZES = [7, 7, 7, 7] +HRT_BASE.STAGE4.BLOCK = "TRANSFORMER_BLOCK" + +HRT_BASE_WIN_13 = HRT_BASE.clone() +HRT_BASE_WIN_13.STAGE2.NUM_WINDOW_SIZES = [13, 13] +HRT_BASE_WIN_13.STAGE3.NUM_WINDOW_SIZES = [13, 13, 13] +HRT_BASE_WIN_13.STAGE4.NUM_WINDOW_SIZES = [13, 13, 13, 13] + + +HRT_BASE_WIN_15 = HRT_BASE.clone() +HRT_BASE_WIN_15.STAGE2.NUM_WINDOW_SIZES = [15, 15] +HRT_BASE_WIN_15.STAGE3.NUM_WINDOW_SIZES = [15, 15, 15] +HRT_BASE_WIN_15.STAGE4.NUM_WINDOW_SIZES = [15, 15, 15, 15] + +MODEL_CONFIGS = { + "hrt_small": HRT_SMALL, + "hrt_base": HRT_BASE, + "hrt_base_win13": HRT_BASE_WIN_13, + "hrt_base_win15": HRT_BASE_WIN_15, +} \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/logger.py b/isegm/model/modeling/hrformer_helper/hrt/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e5bdf642a64fe73f027898b0ebd6fd02729a617d --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/logger.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# Author: Donny You(youansheng@gmail.com) +# Logging tool implemented with the python Package logging. + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import logging +import os +import sys + + +DEFAULT_LOGFILE_LEVEL = 'debug' +DEFAULT_STDOUT_LEVEL = 'info' +DEFAULT_LOG_FILE = './default.log' +DEFAULT_LOG_FORMAT = '%(asctime)s %(levelname)-7s %(message)s' + +LOG_LEVEL_DICT = { + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + 'critical': logging.CRITICAL +} + + +class Logger(object): + """ + Args: + Log level: CRITICAL>ERROR>WARNING>INFO>DEBUG. + Log file: The file that stores the logging info. + rewrite: Clear the log file. + log format: The format of log messages. + stdout level: The log level to print on the screen. + """ + logfile_level = None + log_file = None + log_format = None + rewrite = None + stdout_level = None + logger = None + + _caches = {} + + @staticmethod + def init(logfile_level=DEFAULT_LOGFILE_LEVEL, + log_file=DEFAULT_LOG_FILE, + log_format=DEFAULT_LOG_FORMAT, + rewrite=False, + stdout_level=None): + Logger.logfile_level = logfile_level + Logger.log_file = log_file + Logger.log_format = log_format + Logger.rewrite = rewrite + Logger.stdout_level = stdout_level + + Logger.logger = logging.getLogger() + Logger.logger.handlers = [] + fmt = logging.Formatter(Logger.log_format) + + if Logger.logfile_level is not None: + filemode = 'w' + if not Logger.rewrite: + filemode = 'a' + + dir_name = os.path.dirname(os.path.abspath(Logger.log_file)) + if not os.path.exists(dir_name): + os.makedirs(dir_name) + + if Logger.logfile_level not in LOG_LEVEL_DICT: + print('Invalid logging level: {}'.format(Logger.logfile_level)) + Logger.logfile_level = DEFAULT_LOGFILE_LEVEL + + Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.logfile_level]) + + fh = logging.FileHandler(Logger.log_file, mode=filemode) + fh.setFormatter(fmt) + fh.setLevel(LOG_LEVEL_DICT[Logger.logfile_level]) + + Logger.logger.addHandler(fh) + + if stdout_level is not None: + if Logger.logfile_level is None: + Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.stdout_level]) + + console = logging.StreamHandler() + if Logger.stdout_level not in LOG_LEVEL_DICT: + print('Invalid logging level: {}'.format(Logger.stdout_level)) + return + + console.setLevel(LOG_LEVEL_DICT[Logger.stdout_level]) + console.setFormatter(fmt) + Logger.logger.addHandler(console) + + @staticmethod + def set_log_file(file_path): + Logger.log_file = file_path + Logger.init(log_file=file_path) + + @staticmethod + def set_logfile_level(log_level): + if log_level not in LOG_LEVEL_DICT: + print('Invalid logging level: {}'.format(log_level)) + return + + Logger.init(logfile_level=log_level) + + @staticmethod + def clear_log_file(): + Logger.rewrite = True + Logger.init(rewrite=True) + + @staticmethod + def check_logger(): + if Logger.logger is None: + Logger.init(logfile_level=None, stdout_level=DEFAULT_STDOUT_LEVEL) + + @staticmethod + def set_stdout_level(log_level): + if log_level not in LOG_LEVEL_DICT: + print('Invalid logging level: {}'.format(log_level)) + return + + Logger.init(stdout_level=log_level) + + @staticmethod + def debug(message): + Logger.check_logger() + filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) + lineno = sys._getframe().f_back.f_lineno + prefix = '[{}, {}]'.format(filename,lineno) + Logger.logger.debug('{} {}'.format(prefix, message)) + + @staticmethod + def info(message): + Logger.check_logger() + filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) + lineno = sys._getframe().f_back.f_lineno + prefix = '[{}, {}]'.format(filename,lineno) + Logger.logger.info('{} {}'.format(prefix, message)) + + @staticmethod + def info_once(message): + Logger.check_logger() + filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) + lineno = sys._getframe().f_back.f_lineno + prefix = '[{}, {}]'.format(filename, lineno) + + if Logger._caches.get((prefix, message)) is not None: + return + + Logger.logger.info('{} {}'.format(prefix, message)) + Logger._caches[(prefix, message)] = True + + @staticmethod + def warn(message): + Logger.check_logger() + filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) + lineno = sys._getframe().f_back.f_lineno + prefix = '[{}, {}]'.format(filename,lineno) + Logger.logger.warn('{} {}'.format(prefix, message)) + + @staticmethod + def error(message): + Logger.check_logger() + filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) + lineno = sys._getframe().f_back.f_lineno + prefix = '[{}, {}]'.format(filename,lineno) + Logger.logger.error('{} {}'.format(prefix, message)) + + @staticmethod + def critical(message): + Logger.check_logger() + filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) + lineno = sys._getframe().f_back.f_lineno + prefix = '[{}, {}]'.format(filename,lineno) + Logger.logger.critical('{} {}'.format(prefix, message)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--logfile_level', default="debug", type=str, + dest='logfile_level', help='To set the log level to files.') + parser.add_argument('--stdout_level', default=None, type=str, + dest='stdout_level', help='To set the level to print to screen.') + parser.add_argument('--log_file', default="./default.log", type=str, + dest='log_file', help='The path of log files.') + parser.add_argument('--log_format', default="%(asctime)s %(levelname)-7s %(message)s", + type=str, dest='log_format', help='The format of log messages.') + parser.add_argument('--rewrite', default=False, type=bool, + dest='rewrite', help='Clear the log files existed.') + + args = parser.parse_args() + Logger.init(logfile_level=args.logfile_level, stdout_level=args.stdout_level, + log_file=args.log_file, log_format=args.log_format, rewrite=args.rewrite) + + Logger.info("info test.") + Logger.debug("debug test.") + Logger.warn("warn test.") + Logger.error("error test.") + Logger.debug("debug test.") \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/module_helper.py b/isegm/model/modeling/hrformer_helper/hrt/module_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8460ab3ce2547d634f6dca295ec3865330a165 --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/module_helper.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# Author: Donny You (youansheng@gmail.com) + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os +import pdb +import math + +import torch +import torch.nn as nn + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + +from .logger import Logger as Log + + +class ModuleHelper(object): + @staticmethod + def BNReLU(num_features, bn_type=None, **kwargs): + if bn_type == "torchbn": + return nn.Sequential(nn.BatchNorm2d(num_features, **kwargs), nn.ReLU()) + elif bn_type == "torchsyncbn": + return nn.Sequential(nn.SyncBatchNorm(num_features, **kwargs), nn.ReLU()) + elif bn_type == "syncbn": + from lib.extensions.syncbn.module import BatchNorm2d + + return nn.Sequential(BatchNorm2d(num_features, **kwargs), nn.ReLU()) + elif bn_type == "sn": + from lib.extensions.switchablenorms.switchable_norm import SwitchNorm2d + + return nn.Sequential(SwitchNorm2d(num_features, **kwargs), nn.ReLU()) + elif bn_type == "gn": + return nn.Sequential( + nn.GroupNorm(num_groups=8, num_channels=num_features, **kwargs), + nn.ReLU(), + ) + elif bn_type == "fn": + Log.error("Not support Filter-Response-Normalization: {}.".format(bn_type)) + exit(1) + elif bn_type == "inplace_abn": + torch_ver = torch.__version__[:3] + # Log.info('Pytorch Version: {}'.format(torch_ver)) + if torch_ver == "0.4": + from lib.extensions.inplace_abn.bn import InPlaceABNSync + + return InPlaceABNSync(num_features, **kwargs) + elif torch_ver in ("1.0", "1.1"): + from lib.extensions.inplace_abn_1.bn import InPlaceABNSync + + return InPlaceABNSync(num_features, **kwargs) + elif torch_ver == "1.2": + from inplace_abn import InPlaceABNSync + + return InPlaceABNSync(num_features, **kwargs) + + else: + Log.error("Not support BN type: {}.".format(bn_type)) + exit(1) + + @staticmethod + def BatchNorm2d(bn_type="torch", ret_cls=False): + if bn_type == "torchbn": + return nn.BatchNorm2d + + elif bn_type == "torchsyncbn": + return nn.SyncBatchNorm + + elif bn_type == "syncbn": + from lib.extensions.syncbn.module import BatchNorm2d + + return BatchNorm2d + + elif bn_type == "sn": + from lib.extensions.switchablenorms.switchable_norm import SwitchNorm2d + + return SwitchNorm2d + + elif bn_type == "gn": + return functools.partial(nn.GroupNorm, num_groups=32) + + elif bn_type == "inplace_abn": + torch_ver = torch.__version__[:3] + if torch_ver == "0.4": + from lib.extensions.inplace_abn.bn import InPlaceABNSync + + if ret_cls: + return InPlaceABNSync + return functools.partial(InPlaceABNSync, activation="none") + + elif torch_ver in ("1.0", "1.1"): + from lib.extensions.inplace_abn_1.bn import InPlaceABNSync + + if ret_cls: + return InPlaceABNSync + return functools.partial(InPlaceABNSync, activation="none") + + elif torch_ver == "1.2": + from inplace_abn import InPlaceABNSync + + if ret_cls: + return InPlaceABNSync + return functools.partial(InPlaceABNSync, activation="identity") + + else: + Log.error("Not support BN type: {}.".format(bn_type)) + exit(1) + + @staticmethod + def load_model(model, pretrained=None, all_match=True, network="resnet101"): + if pretrained is None: + return model + + if all_match: + Log.info("Loading pretrained model:{}".format(pretrained)) + pretrained_dict = torch.load(pretrained) + model_dict = model.state_dict() + load_dict = dict() + for k, v in pretrained_dict.items(): + if "resinit.{}".format(k) in model_dict: + load_dict["resinit.{}".format(k)] = v + else: + load_dict[k] = v + model.load_state_dict(load_dict) + + else: + Log.info("Loading pretrained model:{}".format(pretrained)) + pretrained_dict = torch.load(pretrained) + + # settings for "wide_resnet38" or network == "resnet152" + if network == "wide_resnet": + pretrained_dict = pretrained_dict["state_dict"] + + model_dict = model.state_dict() + + if network == "hrnet_plus": + # pretrained_dict['conv1_full_res.weight'] = pretrained_dict['conv1.weight'] + # pretrained_dict['conv2_full_res.weight'] = pretrained_dict['conv2.weight'] + load_dict = { + k: v for k, v in pretrained_dict.items() if k in model_dict.keys() + } + + elif network == "hrt_window": + pretrained_dict = pretrained_dict["model"] + for name, m in model.named_parameters(): + if "relative_position_bias_table" in name and "embed" not in name: + target_size = int(math.sqrt(m.shape[0])) + head_num = m.shape[-1] + ckpt_size = int(math.sqrt(pretrained_dict[name].shape[0])) + if target_size != ckpt_size: + Log.info( + f"Interpolate from size {pretrained_dict[name ].shape} to {m.shape}." + ) + reshape_ckpt = ( + pretrained_dict[name] + .permute(1, 0) + .reshape(1, head_num, ckpt_size, ckpt_size) + ) + inter_ckpt = ( + torch.nn.functional.interpolate( + reshape_ckpt, + size=(target_size, target_size), + mode="bilinear", + ) + .reshape(head_num, -1) + .permute(1, 0) + ) + scale = 1 + inter_ckpt *= scale + pretrained_dict[name] = inter_ckpt + for name, m in list(pretrained_dict.items()): + if "relative_position_index" in name: + Log.info(f"Remove {name}.") + pretrained_dict.pop(name) + load_dict = { + k: v for k, v in pretrained_dict.items() if k in model_dict.keys() + } + Log.info( + "Missing keys: {}".format(list(set(model_dict) - set(load_dict))) + ) + + elif network == "hrt": + pretrained_dict = pretrained_dict["model"] + load_dict = { + k: v for k, v in pretrained_dict.items() if k in model_dict.keys() + } + Log.info( + "Missing keys: {}".format(list(set(model_dict) - set(load_dict))) + ) + + elif network == "swin": + pretrained_dict = pretrained_dict["model"] + # TODO fix the mis-match between the dict keys and the checkpoint keys. + pretrained_dict = { + k.replace(".attn.", ".attn.attn."): v + for k, v in pretrained_dict.items() + } + load_dict = { + k: v for k, v in pretrained_dict.items() if k in model_dict.keys() + } + Log.info( + "Missing keys: {}".format(list(set(model_dict) - set(load_dict))) + ) + + elif network == "hrnet" or network == "xception" or network == "resnest": + load_dict = { + k: v for k, v in pretrained_dict.items() if k in model_dict.keys() + } + Log.info( + "Missing keys: {}".format(list(set(model_dict) - set(load_dict))) + ) + + elif network == "dcnet" or network == "resnext": + load_dict = dict() + for k, v in pretrained_dict.items(): + if "resinit.{}".format(k) in model_dict: + load_dict["resinit.{}".format(k)] = v + else: + if k in model_dict: + load_dict[k] = v + else: + pass + + elif network == "wide_resnet": + load_dict = { + ".".join(k.split(".")[1:]): v + for k, v in pretrained_dict.items() + if ".".join(k.split(".")[1:]) in model_dict + } + else: + load_dict = { + ".".join(k.split(".")[1:]): v + for k, v in pretrained_dict.items() + if ".".join(k.split(".")[1:]) in model_dict + } + + # used to debug + if int(os.environ.get("debug_load_model", 0)): + Log.info("Matched Keys List:") + for key in load_dict.keys(): + Log.info("{}".format(key)) + model_dict.update(load_dict) + model.load_state_dict(model_dict) + + return model + + @staticmethod + def load_url(url, map_location=None): + model_dir = os.path.join("~", ".PyTorchCV", "models") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + filename = url.split("/")[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + Log.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + + Log.info("Loading pretrained model:{}".format(cached_file)) + return torch.load(cached_file, map_location=map_location) + + @staticmethod + def constant_init(module, val, bias=0): + nn.init.constant_(module.weight, val) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) + + @staticmethod + def xavier_init(module, gain=1, bias=0, distribution="normal"): + assert distribution in ["uniform", "normal"] + if distribution == "uniform": + nn.init.xavier_uniform_(module.weight, gain=gain) + else: + nn.init.xavier_normal_(module.weight, gain=gain) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) + + @staticmethod + def normal_init(module, mean=0, std=1, bias=0): + nn.init.normal_(module.weight, mean, std) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) + + @staticmethod + def uniform_init(module, a=0, b=1, bias=0): + nn.init.uniform_(module.weight, a, b) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) + + @staticmethod + def kaiming_init( + module, mode="fan_in", nonlinearity="leaky_relu", bias=0, distribution="normal" + ): + assert distribution in ["uniform", "normal"] + if distribution == "uniform": + nn.init.kaiming_uniform_( + module.weight, mode=mode, nonlinearity=nonlinearity + ) + else: + nn.init.kaiming_normal_(module.weight, mode=mode, nonlinearity=nonlinearity) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/modules/__init__.py b/isegm/model/modeling/hrformer_helper/hrt/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/isegm/model/modeling/hrformer_helper/hrt/modules/bottleneck_block.py b/isegm/model/modeling/hrformer_helper/hrt/modules/bottleneck_block.py new file mode 100644 index 0000000000000000000000000000000000000000..1a22ceafdaf7193bde537ae36edf47f77db868fa --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/modules/bottleneck_block.py @@ -0,0 +1,128 @@ +import os +import pdb +import logging +import torch.nn as nn +import torch.nn.functional as F +# from torchvision.models.utils import load_state_dict_from_url +# from timm.models.registry import register_model +from functools import partial + +BN_MOMENTUM = 0.1 + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + mhsa_flag=False, + num_heads=1, + num_halo_block=1, + num_mlp_ratio=4, + num_sr_ratio=1, + num_resolution=None, + with_rpe=False, + with_ffn=True, + ): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.SyncBatchNorm(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn2 = nn.SyncBatchNorm(planes) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.bn3 = nn.SyncBatchNorm(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BottleneckDWP(nn.Module): + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + mhsa_flag=False, + num_heads=1, + num_halo_block=1, + num_mlp_ratio=4, + num_sr_ratio=1, + num_resolution=None, + with_rpe=False, + with_ffn=True, + ): + super(BottleneckDWP, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.SyncBatchNorm(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=planes, + ) + self.bn2 = nn.SyncBatchNorm(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.bn3 = nn.SyncBatchNorm(planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/modules/ffn_block.py b/isegm/model/modeling/hrformer_helper/hrt/modules/ffn_block.py new file mode 100644 index 0000000000000000000000000000000000000000..b55a9ceb88de82f96d3d5c20781ab0bcf41e5b95 --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/modules/ffn_block.py @@ -0,0 +1,287 @@ +import pdb +import torch +import torch.nn as nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MlpLight(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + self.fc1 = nn.Linear(in_features, in_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + return x + + +class MlpDW(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + dw_act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1) + self.act1 = act_layer() + self.dw3x3 = nn.Conv2d( + hidden_features, + hidden_features, + kernel_size=3, + stride=1, + groups=hidden_features, + padding=1, + ) + self.act2 = dw_act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + B, N, C = x.shape + + if N == (H * W + 1): + cls_tokens = x[:, 0, :] + x_ = x[:, 1:, :].permute(0, 2, 1).reshape(B, C, H, W) + else: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + + x_ = self.fc1(x_) + x_ = self.act1(x_) + x_ = self.dw3x3(x_) + x_ = self.act2(x_) + x_ = self.drop(x_) + x_ = self.fc2(x_) + x_ = self.drop(x_) + x_ = x_.reshape(B, C, -1).permute(0, 2, 1) + + if N == (H * W + 1): + x = torch.cat((cls_tokens.unsqueeze(1), x_), dim=1) + else: + x = x_ + + return x + + +class MlpDWBN(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + dw_act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1) + self.act1 = act_layer() + self.norm1 = nn.SyncBatchNorm(hidden_features) + self.dw3x3 = nn.Conv2d( + hidden_features, + hidden_features, + kernel_size=3, + stride=1, + groups=hidden_features, + padding=1, + ) + self.act2 = dw_act_layer() + self.norm2 = nn.SyncBatchNorm(hidden_features) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1) + self.act3 = act_layer() + self.norm3 = nn.SyncBatchNorm(out_features) + # self.drop = nn.Dropout(drop, inplace=True) + + def forward(self, x, H, W): + if len(x.shape) == 3: + B, N, C = x.shape + if N == (H * W + 1): + cls_tokens = x[:, 0, :] + x_ = x[:, 1:, :].permute(0, 2, 1).reshape(B, C, H, W) + else: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + + x_ = self.fc1(x_) + x_ = self.norm1(x_) + x_ = self.act1(x_) + x_ = self.dw3x3(x_) + x_ = self.norm2(x_) + x_ = self.act2(x_) + # x_ = self.drop(x_) + x_ = self.fc2(x_) + x_ = self.norm3(x_) + x_ = self.act3(x_) + # x_ = self.drop(x_) + x_ = x_.reshape(B, C, -1).permute(0, 2, 1) + if N == (H * W + 1): + x = torch.cat((cls_tokens.unsqueeze(1), x_), dim=1) + else: + x = x_ + return x + + elif len(x.shape) == 4: + x = self.fc1(x) + x = self.norm1(x) + x = self.act1(x) + x = self.dw3x3(x) + x = self.norm2(x) + x = self.act2(x) + # x = self.drop(x) + x = self.fc2(x) + x = self.norm3(x) + x = self.act3(x) + # x = self.drop(x) + return x + + else: + raise RuntimeError("Unsupported input shape: {}".format(x.shape)) + + +class MlpConvBN(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Sequential( + nn.Conv1d( + in_channels=in_features, + out_channels=hidden_features, + kernel_size=1, + stride=1, + padding=0, + ), + nn.BatchNorm1d(hidden_features), + ) + self.act = act_layer() + self.fc2 = nn.Sequential( + nn.Conv1d( + in_channels=hidden_features, + out_channels=out_features, + kernel_size=1, + stride=1, + padding=0, + ), + nn.BatchNorm1d(out_features), + ) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = x.transpose(1, 2) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = x.transpose(1, 2) + x = self.drop(x) + return x + + +class MlpWODWBN(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + dw_act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1) + self.act1 = act_layer() + self.norm1 = nn.SyncBatchNorm(hidden_features) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1) + self.act3 = act_layer() + self.norm3 = nn.SyncBatchNorm(out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + if len(x.shape) == 3: + B, N, C = x.shape + if N == (H * W + 1): + cls_tokens = x[:, 0, :] + x_ = x[:, 1:, :].permute(0, 2, 1).reshape(B, C, H, W) + else: + x_ = x.permute(0, 2, 1).reshape(B, C, H, W) + + x_ = self.fc1(x_) + x_ = self.norm1(x_) + x_ = self.act1(x_) + x_ = self.fc2(x_) + x_ = self.norm3(x_) + x_ = self.act3(x_) + x_ = self.drop(x_) + x_ = x_.reshape(B, C, -1).permute(0, 2, 1) + if N == (H * W + 1): + x = torch.cat((cls_tokens.unsqueeze(1), x_), dim=1) + else: + x = x_ + return x + + elif len(x.shape) == 4: + x = self.fc1(x) + x = self.norm1(x) + x = self.act1(x) + x = self.dw3x3(x) + x = self.norm2(x) + x = self.act2(x) + x = self.drop(x) + x = self.fc2(x) + x = self.norm3(x) + x = self.act3(x) + x = self.drop(x) + return x + + else: + raise RuntimeError("Unsupported input shape: {}".format(x.shape)) \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_attention.py b/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..159c97d23e7685995c509d07f3987bcca65703b5 --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_attention.py @@ -0,0 +1,342 @@ +import copy +import warnings + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.nn.modules.module import Module +from torch._jit_internal import Optional, Tuple +from torch.nn.functional import linear, pad, softmax, dropout +from torch.overrides import has_torch_function, handle_torch_function + + + +class MultiheadAttention(Module): + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + ): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) + self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim) + + self.in_proj_bias = None + self.in_proj_weight = None + self.bias_k = self.bias_v = None + self.q_proj_weight = None + self.k_proj_weight = None + self.v_proj_weight = None + self.add_zero_attn = add_zero_attn + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + need_weights=False, + attn_mask=None, + residual_attn=None, + ): + if not self._qkv_same_embed_dim: + return self.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + out_dim=self.vdim, + residual_attn=residual_attn, + ) + else: + return self.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + out_dim=self.vdim, + residual_attn=residual_attn, + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + out_dim: Optional[Tensor] = None, + residual_attn: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + if not torch.jit.is_scripting(): + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function( + tens_ops + ): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + ) + tgt_len, bsz, embed_dim = query.size() + key = query if key is None else key + value = query if value is None else value + + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + v_head_dim = out_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + q = self.q_proj(query) * scaling + k = self.k_proj(key) + v = self.v_proj(value) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1) + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat( + [ + k, + torch.zeros( + (k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device + ), + ], + dim=1, + ) + v = torch.cat( + [ + v, + torch.zeros( + (v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device + ), + ], + dim=1, + ) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + """ + Attention weight for the invalid region is -inf + """ + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + if residual_attn is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights += residual_attn.unsqueeze(0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + """ + Reweight the attention map before softmax(). + attn_output_weights: (b*n_head, n, hw) + """ + attn_output_weights = softmax(attn_output_weights, dim=-1) + attn_output_weights = dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim] + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim) + ) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_isa_attention.py b/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_isa_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3cabff339bd16a3cc4722131fe1467e368c9bbe0 --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_isa_attention.py @@ -0,0 +1,426 @@ +import copy +import math +import warnings +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from torch.nn.functional import linear, pad, softmax, dropout +from torch._jit_internal import Optional, Tuple +from torch.overrides import has_torch_function, handle_torch_function + +from einops import rearrange +from timm.models.layers import to_2tuple, trunc_normal_ + +from .multihead_attention import MultiheadAttention + + +class MHA_(MultiheadAttention): + """ "Multihead Attention with extra flags on the q/k/v and out projections.""" + + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__(self, *args, rpe=False, window_size=7, **kwargs): + super(MHA_, self).__init__(*args, **kwargs) + + self.rpe = rpe + if rpe: + self.window_size = [window_size] * 2 + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), + self.num_heads, + ) + ) # 2*Wh-1 * 2*Ww-1, nH + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward( + self, + query, + key, + value, + key_padding_mask=None, + need_weights=False, + attn_mask=None, + do_qkv_proj=True, + do_out_proj=True, + rpe=True, + ): + if not self._qkv_same_embed_dim: + return self.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + out_dim=self.vdim, + do_qkv_proj=do_qkv_proj, + do_out_proj=do_out_proj, + rpe=rpe, + ) + else: + return self.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + out_dim=self.vdim, + do_qkv_proj=do_qkv_proj, + do_out_proj=do_out_proj, + rpe=rpe, + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + out_dim: Optional[Tensor] = None, + do_qkv_proj: bool = True, + do_out_proj: bool = True, + rpe=True, + ) -> Tuple[Tensor, Optional[Tensor]]: + if not torch.jit.is_scripting(): + tens_ops = ( + query, + key, + value, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + out_proj_weight, + out_proj_bias, + ) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function( + tens_ops + ): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + ) + tgt_len, bsz, embed_dim = query.size() + key = query if key is None else key + value = query if value is None else value + + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + v_head_dim = out_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + # whether or not use the original query/key/value + q = self.q_proj(query) * scaling if do_qkv_proj else query + k = self.k_proj(key) if do_qkv_proj else key + v = self.v_proj(value) if do_qkv_proj else value + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + ) + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1) + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat( + [ + k, + torch.zeros( + (k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device + ), + ], + dim=1, + ) + v = torch.cat( + [ + v, + torch.zeros( + (v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device + ), + ], + dim=1, + ) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + """ + Add relative position embedding + """ + if self.rpe and rpe: + # NOTE: for simplicity, we assume the src_len == tgt_len == window_size**2 here + assert ( + src_len == self.window_size[0] * self.window_size[1] + and tgt_len == self.window_size[0] * self.window_size[1] + ), f"src{src_len}, tgt{tgt_len}, window{self.window_size[0]}" + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + relative_position_bias.unsqueeze(0) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + """ + Attention weight for the invalid region is -inf + """ + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + """ + Reweight the attention map before softmax(). + attn_output_weights: (b*n_head, n, hw) + """ + attn_output_weights = softmax(attn_output_weights, dim=-1) + attn_output_weights = dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim] + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim) + ) + if do_out_proj: + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, q, k, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, q, k # additionaly return the query and key + + +class PadBlock(object): + """ "Make the size of feature map divisible by local group size.""" + + def __init__(self, local_group_size=7): + self.lgs = local_group_size + if not isinstance(self.lgs, (tuple, list)): + self.lgs = to_2tuple(self.lgs) + assert len(self.lgs) == 2 + + def pad_if_needed(self, x, size): + n, h, w, c = size + pad_h = math.ceil(h / self.lgs[0]) * self.lgs[0] - h + pad_w = math.ceil(w / self.lgs[1]) * self.lgs[1] - w + if pad_h > 0 or pad_w > 0: # center-pad the feature on H and W axes + return F.pad( + x, + (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), + ) + return x + + def depad_if_needed(self, x, size): + n, h, w, c = size + pad_h = math.ceil(h / self.lgs[0]) * self.lgs[0] - h + pad_w = math.ceil(w / self.lgs[1]) * self.lgs[1] - w + if pad_h > 0 or pad_w > 0: # remove the center-padding on feature + return x[:, pad_h // 2 : pad_h // 2 + h, pad_w // 2 : pad_w // 2 + w, :] + return x + + +class LocalPermuteModule(object): + """ "Permute the feature map to gather pixels in local groups, and the reverse permutation""" + + def __init__(self, local_group_size=7): + self.lgs = local_group_size + if not isinstance(self.lgs, (tuple, list)): + self.lgs = to_2tuple(self.lgs) + assert len(self.lgs) == 2 + + def permute(self, x, size): + n, h, w, c = size + return rearrange( + x, + "n (qh ph) (qw pw) c -> (ph pw) (n qh qw) c", + n=n, + qh=h // self.lgs[0], + ph=self.lgs[0], + qw=w // self.lgs[0], + pw=self.lgs[0], + c=c, + ) + + def rev_permute(self, x, size): + n, h, w, c = size + return rearrange( + x, + "(ph pw) (n qh qw) c -> n (qh ph) (qw pw) c", + n=n, + qh=h // self.lgs[0], + ph=self.lgs[0], + qw=w // self.lgs[0], + pw=self.lgs[0], + c=c, + ) \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_isa_pool_attention.py b/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_isa_pool_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..16269c1e73a240e6e5c7e5d2fe933fa347bce3f7 --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/modules/multihead_isa_pool_attention.py @@ -0,0 +1,45 @@ +import os +import pdb +import math +import torch +import torch.nn as nn + +from .multihead_isa_attention import MHA_, PadBlock, LocalPermuteModule + +class InterlacedPoolAttention(nn.Module): + r""" interlaced sparse multi-head self attention (ISA) module with relative position bias. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): Window size. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + def __init__(self, embed_dim, num_heads, window_size=7, + rpe=True, **kwargs): + super(InterlacedPoolAttention, self).__init__() + + self.dim = embed_dim + self.num_heads = num_heads + self.window_size = window_size + self.with_rpe = rpe + self.attn = MHA_(embed_dim, num_heads, rpe=rpe, window_size=window_size, **kwargs) + self.pad_helper = PadBlock(window_size) + self.permute_helper = LocalPermuteModule(window_size) + + def forward(self, x, H, W, **kwargs): + B, N, C = x.shape + x = x.view(B, H, W, C) + # attention + # pad + x_pad = self.pad_helper.pad_if_needed(x, x.size()) + # permute + x_permute = self.permute_helper.permute(x_pad, x_pad.size()) + # attention + out, _, _ = self.attn(x_permute, x_permute, x_permute, rpe=self.with_rpe, **kwargs) + # reverse permutation + out = self.permute_helper.rev_permute(out, x_pad.size()) + out = self.pad_helper.depad_if_needed(out, x.size()) + return out.reshape(B, N, C) \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/modules/spatial_ocr_block.py b/isegm/model/modeling/hrformer_helper/hrt/modules/spatial_ocr_block.py new file mode 100644 index 0000000000000000000000000000000000000000..c36c3f09db4dda65b001854b82491ccb10e59e50 --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/modules/spatial_ocr_block.py @@ -0,0 +1,819 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: RainbowSecret +## Microsoft Research +## yuyua@microsoft.com +## Copyright (c) 2019 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import os +import pdb +import math +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn import functional as F + +from ..module_helper import ModuleHelper + + +def label_to_onehot(gt, num_classes, ignore_index=-1): + """ + gt: ground truth with size (N, H, W) + num_classes: the number of classes of different label + """ + N, H, W = gt.size() + x = gt + x[x == ignore_index] = num_classes + # convert label into onehot format + onehot = torch.zeros(N, x.size(1), x.size(2), num_classes + 1).cuda() + onehot = onehot.scatter_(-1, x.unsqueeze(-1), 1) + + return onehot.permute(0, 3, 1, 2) + + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scale=1, use_gt=False): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + self.use_gt = use_gt + self.relu = nn.ReLU(inplace=True) + + def forward(self, feats, probs, gt_probs=None): + if self.use_gt and gt_probs is not None: + gt_probs = label_to_onehot( + gt_probs.squeeze(1).type(torch.cuda.LongTensor), probs.size(1) + ) + batch_size, c, h, w = ( + gt_probs.size(0), + gt_probs.size(1), + gt_probs.size(2), + gt_probs.size(3), + ) + gt_probs = gt_probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + gt_probs = F.normalize(gt_probs, p=1, dim=2) # batch x k x hw + ocr_context = ( + torch.matmul(gt_probs, feats).permute(0, 2, 1).unsqueeze(3) + ) # batch x k x c + return ocr_context + else: + batch_size, c, h, w = ( + probs.size(0), + probs.size(1), + probs.size(2), + probs.size(3), + ) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = ( + torch.matmul(probs, feats).permute(0, 2, 1).unsqueeze(3) + ) # batch x k x c + return ocr_context + + +class PyramidSpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scales=[1, 2, 4]): + super(PyramidSpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scales = scales + self.relu = nn.ReLU(inplace=True) + + def _compute_single_scale(self, feats, probs, dh, dw): + batch_size, k, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + c = feats.size(1) + + out_h, out_w = math.ceil(h / dh), math.ceil(w / dw) + pad_h, pad_w = out_h * dh - h, out_w * dw - w + if pad_h > 0 or pad_w > 0: # padding in both left&right sides + feats = F.pad( + feats, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + ) + probs = F.pad( + probs, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + ) + + feats = feats.view(batch_size, c, out_h, dh, out_w, dw).permute( + 0, 3, 5, 1, 2, 4 + ) + feats = feats.contiguous().view(batch_size, dh * dw, c, out_h, out_w) + + probs = probs.view(batch_size, k, out_h, dh, out_w, dw).permute( + 0, 3, 5, 1, 2, 4 + ) + probs = probs.contiguous().view(batch_size, dh * dw, k, out_h, out_w) + + feats = feats.view(batch_size, dh * dw, c, -1) + probs = probs.view(batch_size, dh * dw, k, -1) + feats = feats.permute(0, 1, 3, 2) + + probs = F.softmax(probs, dim=3) # batch x k x hw + cc = torch.matmul(probs, feats).view(batch_size, -1, c) # batch x k x c + + return cc.permute(0, 2, 1).unsqueeze(3) + + def forward(self, feats, probs): + ocr_list = [] + for scale in self.scales: + ocr_tmp = self._compute_single_scale(feats, probs, scale, scale) + ocr_list.append(ocr_tmp) + pyramid_ocr = torch.cat(ocr_list, 2) + return pyramid_ocr + + +class _ObjectAttentionBlock(nn.Module): + """ + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + use_gt : whether use the ground truth label map to compute the similarity map + fetch_attention : whether return the estimated similarity map + bn_type : specify the bn type + Return: + N X C X H X W + """ + + def __init__( + self, + in_channels, + key_channels, + scale=1, + use_gt=False, + use_bg=False, + fetch_attention=False, + bn_type=None, + ): + super(_ObjectAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.use_gt = use_gt + self.use_bg = use_bg + self.fetch_attention = fetch_attention + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_object = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_down = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_up = nn.Sequential( + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type), + ) + + def forward(self, x, proxy, gt_label=None): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + if self.use_gt and gt_label is not None: + gt_label = label_to_onehot( + gt_label.squeeze(1).type(torch.cuda.LongTensor), proxy.size(2) - 1 + ) + sim_map = ( + gt_label[:, :, :, :].permute(0, 2, 3, 1).view(batch_size, h * w, -1) + ) + if self.use_bg: + bg_sim_map = 1.0 - sim_map + bg_sim_map = F.normalize(bg_sim_map, p=1, dim=-1) + sim_map = F.normalize(sim_map, p=1, dim=-1) + else: + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels ** -0.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) # hw x k x k x c + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate( + input=context, size=(h, w), mode="bilinear", align_corners=True + ) + + if self.use_bg: + bg_context = torch.matmul(bg_sim_map, value) + bg_context = bg_context.permute(0, 2, 1).contiguous() + bg_context = bg_context.view(batch_size, self.key_channels, *x.size()[2:]) + bg_context = self.f_up(bg_context) + bg_context = F.interpolate( + input=bg_context, size=(h, w), mode="bilinear", align_corners=True + ) + return context, bg_context + else: + if self.fetch_attention: + return context, sim_map + else: + return context + + +class ObjectAttentionBlock2D(_ObjectAttentionBlock): + def __init__( + self, + in_channels, + key_channels, + scale=1, + use_gt=False, + use_bg=False, + fetch_attention=False, + bn_type=None, + ): + super(ObjectAttentionBlock2D, self).__init__( + in_channels, + key_channels, + scale, + use_gt, + use_bg, + fetch_attention, + bn_type=bn_type, + ) + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + + use_gt=True: whether use the ground-truth label to compute the ideal object contextual representations. + use_bg=True: use the ground-truth label to compute the ideal background context to augment the representations. + use_oc=True: use object context or not. + """ + + def __init__( + self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + use_gt=False, + use_bg=False, + use_oc=True, + fetch_attention=False, + bn_type=None, + ): + super(SpatialOCR_Module, self).__init__() + self.use_gt = use_gt + self.use_bg = use_bg + self.use_oc = use_oc + self.fetch_attention = fetch_attention + self.object_context_block = ObjectAttentionBlock2D( + in_channels, key_channels, scale, use_gt, use_bg, fetch_attention, bn_type + ) + if self.use_bg: + if self.use_oc: + _in_channels = 3 * in_channels + else: + _in_channels = 2 * in_channels + else: + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0), + ModuleHelper.BNReLU(out_channels, bn_type=bn_type), + nn.Dropout2d(dropout), + ) + + def forward(self, feats, proxy_feats, gt_label=None): + if self.use_gt and gt_label is not None: + if self.use_bg: + context, bg_context = self.object_context_block( + feats, proxy_feats, gt_label + ) + else: + context = self.object_context_block(feats, proxy_feats, gt_label) + else: + if self.fetch_attention: + context, sim_map = self.object_context_block(feats, proxy_feats) + else: + context = self.object_context_block(feats, proxy_feats) + + if self.use_bg: + if self.use_oc: + output = self.conv_bn_dropout( + torch.cat([context, bg_context, feats], 1) + ) + else: + output = self.conv_bn_dropout(torch.cat([bg_context, feats], 1)) + else: + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + if self.fetch_attention: + return output, sim_map + else: + return output + + +class SpatialOCR_Context(nn.Module): + """ + Implementation of the FastOC module: + We aggregate the global object representation to update the representation for each pixel. + """ + + def __init__( + self, + in_channels, + key_channels, + scale=1, + dropout=0, + bn_type=None, + ): + super(SpatialOCR_Context, self).__init__() + self.object_context_block = ObjectAttentionBlock2D( + in_channels, key_channels, scale, bn_type=bn_type + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + return context + + +class SpatialOCR_ASP_Module(nn.Module): + def __init__( + self, + features, + hidden_features=256, + out_features=512, + dilations=(12, 24, 36), + num_classes=19, + bn_type=None, + dropout=0.1, + ): + super(SpatialOCR_ASP_Module, self).__init__() + from lib.models.modules.spatial_ocr_block import SpatialOCR_Context + + self.context = nn.Sequential( + nn.Conv2d( + features, + hidden_features, + kernel_size=3, + padding=1, + dilation=1, + bias=True, + ), + ModuleHelper.BNReLU(hidden_features, bn_type=bn_type), + SpatialOCR_Context( + in_channels=hidden_features, + key_channels=hidden_features // 2, + scale=1, + bn_type=bn_type, + ), + ) + self.conv2 = nn.Sequential( + nn.Conv2d( + features, + hidden_features, + kernel_size=1, + padding=0, + dilation=1, + bias=True, + ), + ModuleHelper.BNReLU(hidden_features, bn_type=bn_type), + ) + self.conv3 = nn.Sequential( + nn.Conv2d( + features, + hidden_features, + kernel_size=3, + padding=dilations[0], + dilation=dilations[0], + bias=True, + ), + ModuleHelper.BNReLU(hidden_features, bn_type=bn_type), + ) + self.conv4 = nn.Sequential( + nn.Conv2d( + features, + hidden_features, + kernel_size=3, + padding=dilations[1], + dilation=dilations[1], + bias=True, + ), + ModuleHelper.BNReLU(hidden_features, bn_type=bn_type), + ) + self.conv5 = nn.Sequential( + nn.Conv2d( + features, + hidden_features, + kernel_size=3, + padding=dilations[2], + dilation=dilations[2], + bias=True, + ), + ModuleHelper.BNReLU(hidden_features, bn_type=bn_type), + ) + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d( + hidden_features * 5, + out_features, + kernel_size=1, + padding=0, + dilation=1, + bias=True, + ), + ModuleHelper.BNReLU(out_features, bn_type=bn_type), + nn.Dropout2d(dropout), + ) + self.object_head = SpatialGather_Module(num_classes) + + def _cat_each(self, feat1, feat2, feat3, feat4, feat5): + assert len(feat1) == len(feat2) + z = [] + for i in range(len(feat1)): + z.append(torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]), 1)) + return z + + def forward(self, x, probs): + if isinstance(x, Variable): + _, _, h, w = x.size() + elif isinstance(x, tuple) or isinstance(x, list): + _, _, h, w = x[0].size() + else: + raise RuntimeError("unknown input type") + + feat1 = self.context[0](x) + feat1 = self.context[1](feat1) + proxy_feats = self.object_head(feat1, probs) + feat1 = self.context[2](feat1, proxy_feats) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + + if isinstance(x, Variable): + out = torch.cat((feat1, feat2, feat3, feat4, feat5), 1) + elif isinstance(x, tuple) or isinstance(x, list): + out = self._cat_each(feat1, feat2, feat3, feat4, feat5) + else: + raise RuntimeError("unknown input type") + + output = self.conv_bn_dropout(out) + return output + + +class _MultiheadObjectAttentionBlock(nn.Module): + """ + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + use_gt : whether use the ground truth label map to compute the similarity map + fetch_attention : whether return the estimated similarity map + bn_type : specify the bn type + Return: + N X C X H X W + """ + + def __init__( + self, + in_channels, + key_channels, + num_heads=1, + scale=1, + use_gt=False, + use_bg=False, + fetch_attention=False, + bn_type=None, + ): + super(_MultiheadObjectAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + # for simplicity + assert key_channels & num_heads == 0 + assert not use_gt + assert not use_bg + self.num_heads = num_heads + self.use_gt = use_gt + self.use_bg = use_bg + self.fetch_attention = fetch_attention + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_object = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_down = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_up = nn.Sequential( + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0, + ), + ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type), + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view( + batch_size, self.num_heads, self.key_channels // self.num_heads, -1 + ) + query = query.permute(0, 1, 3, 2) # (b, nH, N, C//nH) + key = self.f_object(proxy).view( + batch_size, self.num_heads, self.key_channels // self.num_heads, -1 + ) + value = self.f_down(proxy).view( + batch_size, self.num_heads, self.key_channels // self.num_heads, -1 + ) + value = value.permute(0, 1, 3, 2) # (b, nH, N, C//nH) + + # attention map + query = query * (self.key_channels ** -0.5) + sim_map = query @ key # (b, nH, N, N) + sim_map = F.softmax(sim_map, dim=-1) + context = (sim_map @ value).permute(0, 1, 3, 2) + context = context.reshape(batch_size, -1, h, w) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate( + input=context, size=(h, w), mode="bilinear", align_corners=True + ) + + if self.fetch_attention: + return context, sim_map + else: + return context + + +class MultiheadObjectAttentionBlock2D(_MultiheadObjectAttentionBlock): + def __init__( + self, + in_channels, + key_channels, + num_heads=1, + scale=1, + use_gt=False, + use_bg=False, + fetch_attention=False, + bn_type=None, + ): + super(MultiheadObjectAttentionBlock2D, self).__init__( + in_channels, + key_channels, + num_heads, + scale, + use_gt, + use_bg, + fetch_attention, + bn_type=bn_type, + ) + + +class MultiheadSpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + use_gt=True: whether use the ground-truth label to compute the ideal object contextual representations. + use_bg=True: use the ground-truth label to compute the ideal background context to augment the representations. + use_oc=True: use object context or not. + """ + + def __init__( + self, + in_channels, + key_channels, + out_channels, + num_heads=1, + scale=1, + dropout=0.1, + use_gt=False, + use_bg=False, + use_oc=True, + fetch_attention=False, + bn_type=None, + ): + super(MultiheadSpatialOCR_Module, self).__init__() + self.use_gt = use_gt + self.use_bg = use_bg + self.use_oc = use_oc + self.fetch_attention = fetch_attention + self.object_context_block = MultiheadObjectAttentionBlock2D( + in_channels, + key_channels, + num_heads, + scale, + use_gt, + use_bg, + fetch_attention, + bn_type, + ) + if self.use_bg: + if self.use_oc: + _in_channels = 3 * in_channels + else: + _in_channels = 2 * in_channels + else: + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0), + ModuleHelper.BNReLU(out_channels, bn_type=bn_type), + nn.Dropout2d(dropout), + ) + + def forward(self, feats, proxy_feats, gt_label=None): + if self.use_gt and gt_label is not None: + if self.use_bg: + context, bg_context = self.object_context_block( + feats, proxy_feats, gt_label + ) + else: + context = self.object_context_block(feats, proxy_feats, gt_label) + else: + if self.fetch_attention: + context, sim_map = self.object_context_block(feats, proxy_feats) + else: + context = self.object_context_block(feats, proxy_feats) + + if self.use_bg: + if self.use_oc: + output = self.conv_bn_dropout( + torch.cat([context, bg_context, feats], 1) + ) + else: + output = self.conv_bn_dropout(torch.cat([bg_context, feats], 1)) + else: + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + if self.fetch_attention: + return output, sim_map + else: + return output + + +if __name__ == "__main__": + os.environ["CUDA_VISIBLE_DEVICES"] = "0" + probs = torch.randn((1, 19, 128, 128)).cuda() + feats = torch.randn((1, 2048, 128, 128)).cuda() + + conv_3x3 = nn.Sequential( + nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1), + ModuleHelper.BNReLU(512, bn_type="torchsyncbn"), + ) + + ocp_gather_infer = SpatialGather_Module(19) + ocp_distr_infer = SpatialOCR_Module( + in_channels=512, + key_channels=256, + out_channels=512, + scale=1, + dropout=0, + bn_type="torchsyncbn", + ) + ocp_gather_infer.eval() + ocp_gather_infer.cuda() + ocp_distr_infer.eval() + ocp_distr_infer.cuda() + conv_3x3.eval() + conv_3x3.cuda() + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + avg_time = 0 + avg_mem = 0 + import time + + with torch.no_grad(): + for i in range(100): + start_time = time.time() + feats_ = conv_3x3(feats) + ocp_feats = ocp_gather_infer(feats_, probs) + outputs = ocp_distr_infer(feats_, ocp_feats) + torch.cuda.synchronize() + avg_time += time.time() - start_time + avg_mem += ( + torch.cuda.max_memory_allocated() + - feats.element_size() * feats.nelement() + ) + + print( + "Average Parameters : {}".format( + count_parameters(ocp_distr_infer) + count_parameters(conv_3x3) + ) + ) + print("Average Running Time: {}".format(avg_time / 100)) + print("Average GPU Memory: {:.2f} MB".format(avg_mem / 100 / 2 ** 20)) \ No newline at end of file diff --git a/isegm/model/modeling/hrformer_helper/hrt/modules/transformer_block.py b/isegm/model/modeling/hrformer_helper/hrt/modules/transformer_block.py new file mode 100644 index 0000000000000000000000000000000000000000..92f258db6a0f8d42a0bcd2b4b9658abddcd44a6a --- /dev/null +++ b/isegm/model/modeling/hrformer_helper/hrt/modules/transformer_block.py @@ -0,0 +1,115 @@ +import os +import pdb +import math +import logging +import torch +import torch.nn as nn +from functools import partial + +from .multihead_isa_pool_attention import InterlacedPoolAttention +from .ffn_block import MlpDWBN + + +BN_MOMENTUM = 0.1 + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + # (Optional)Set the extra information about this module. You can test + # it by printing an object of this class. + return "drop_prob={}".format(self.drop_prob) + + +class GeneralTransformerBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ): + super(GeneralTransformerBlock, self).__init__() + self.dim = inplanes + self.out_dim = planes + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.attn = InterlacedPoolAttention( + self.dim, + num_heads=num_heads, + window_size=window_size, + rpe=True, + dropout=attn_drop, + ) + + self.norm1 = norm_layer(self.dim) + self.norm2 = norm_layer(self.out_dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + mlp_hidden_dim = int(self.dim * mlp_ratio) + + self.mlp = MlpDWBN( + in_features=self.dim, + hidden_features=mlp_hidden_dim, + out_features=self.out_dim, + act_layer=act_layer, + dw_act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, mask=None): + B, C, H, W = x.size() + # reshape + x = x.view(B, C, -1).permute(0, 2, 1) + # Attention + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + # FFN + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + # reshape + x = x.permute(0, 2, 1).view(B, C, H, W) + return x + + def extra_repr(self): + # (Optional)Set the extra information about this module. You can test + # it by printing an object of this class. + return "num_heads={}, window_size={}, mlp_ratio={}".format( + self.num_heads, self.window_size, self.mlp_ratio + ) \ No newline at end of file diff --git a/isegm/model/modeling/hrnet_ocr.py b/isegm/model/modeling/hrnet_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..d386ee0d376df2d498ef3c05f743caaf83374273 --- /dev/null +++ b/isegm/model/modeling/hrnet_ocr.py @@ -0,0 +1,416 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +from .ocr import SpatialOCR_Module, SpatialGather_Module +from .resnetv1b import BasicBlockV1b, BottleneckV1b + +relu_inplace = True + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method,multi_scale_output=True, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + self.norm_layer = norm_layer + self.align_corners = align_corners + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=relu_inplace) + + def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, + downsample=downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], + norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(in_channels=num_inchannels[j], + out_channels=num_inchannels[i], + kernel_size=1, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=relu_inplace))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', align_corners=self.align_corners) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HighResolutionNet(nn.Module): + def __init__(self, width, num_classes, ocr_width=256, small=False, + norm_layer=nn.BatchNorm2d, align_corners=True): + super(HighResolutionNet, self).__init__() + self.norm_layer = norm_layer + self.width = width + self.ocr_width = ocr_width + self.align_corners = align_corners + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = norm_layer(64) + self.relu = nn.ReLU(inplace=relu_inplace) + + num_blocks = 2 if small else 4 + + stage1_num_channels = 64 + self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks) + stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels + + self.stage2_num_branches = 2 + num_channels = [width, 2 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_inchannels) + self.stage2, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches, + num_blocks=2 * [num_blocks], num_channels=num_channels) + + self.stage3_num_branches = 3 + num_channels = [width, 2 * width, 4 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage3, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, + num_modules=3 if small else 4, num_branches=self.stage3_num_branches, + num_blocks=3 * [num_blocks], num_channels=num_channels) + + self.stage4_num_branches = 4 + num_channels = [width, 2 * width, 4 * width, 8 * width] + num_inchannels = [ + num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_inchannels) + self.stage4, pre_stage_channels = self._make_stage( + BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3, + num_branches=self.stage4_num_branches, + num_blocks=4 * [num_blocks], num_channels=num_channels) + + last_inp_channels = np.int(np.sum(pre_stage_channels)) + if self.ocr_width > 0: + ocr_mid_channels = 2 * self.ocr_width + ocr_key_channels = self.ocr_width + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=relu_inplace), + ) + self.ocr_gather_head = SpatialGather_Module(num_classes) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + norm_layer=norm_layer, + align_corners=align_corners) + self.cls_head = nn.Conv2d( + ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True) + + self.aux_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=1, stride=1, padding=0), + norm_layer(last_inp_channels), + nn.ReLU(inplace=relu_inplace), + nn.Conv2d(last_inp_channels, num_classes, + kernel_size=1, stride=1, padding=0, bias=True) + ) + else: + self.cls_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(last_inp_channels), + nn.ReLU(inplace=relu_inplace), + nn.Conv2d(last_inp_channels, num_classes, + kernel_size=1, stride=1, padding=0, bias=True) + ) + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + kernel_size=3, + stride=1, + padding=1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=relu_inplace))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, + kernel_size=3, stride=2, padding=1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=relu_inplace))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, + downsample=downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, block, num_inchannels, + num_modules, num_branches, num_blocks, num_channels, + fuse_method='SUM', + multi_scale_output=True): + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer, + align_corners=self.align_corners) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, additional_features=None): + feats = self.compute_hrnet_feats(x, additional_features) + if self.ocr_width > 0: + out_aux = self.aux_head(feats) + feats = self.conv3x3_ocr(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + out = self.cls_head(feats) + return [out, out_aux] + else: + return [self.cls_head(feats), None] + + def compute_hrnet_feats(self, x, additional_features): + x = self.compute_pre_stage_features(x, additional_features) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_num_branches): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_num_branches): + if self.transition2[i] is not None: + if i < self.stage2_num_branches: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_num_branches): + if self.transition3[i] is not None: + if i < self.stage3_num_branches: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + return self.aggregate_hrnet_features(x) + + def compute_pre_stage_features(self, x, additional_features): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + additional_features + x = self.conv2(x) + x = self.bn2(x) + return self.relu(x) + + def aggregate_hrnet_features(self, x): + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate(x[1], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), + mode='bilinear', align_corners=self.align_corners) + + return torch.cat([x[0], x1, x2, x3], 1) + + def load_pretrained_weights(self, pretrained_path=''): + model_dict = self.state_dict() + + if not os.path.exists(pretrained_path): + print(f'\nFile "{pretrained_path}" does not exist.') + print('You need to specify the correct path to the pre-trained weights.\n' + 'You can download the weights for HRNet from the repository:\n' + 'https://github.com/HRNet/HRNet-Image-Classification') + exit(1) + pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'}) + pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in + pretrained_dict.items()} + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) diff --git a/isegm/model/modeling/models_vit.py b/isegm/model/modeling/models_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..1c1a9723ea7412927d8bd95a40334d5dcddfec25 --- /dev/null +++ b/isegm/model/modeling/models_vit.py @@ -0,0 +1,322 @@ +import torch +import torch.nn as nn + +from functools import partial +from collections import OrderedDict +from .pos_embed import interpolate_pos_embed + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() if act_layer else nn.GELU() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + ''' Multi-head self-attention ''' + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1,2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., mlp_drop=0., qkv_bias=False, attn_drop=0., + proj_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, + proj_drop=proj_drop) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=mlp_drop) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=(224,224), patch_size=(16,16), in_chans=3, embed_dim=768, + norm_layer=None, flatten=True): + super().__init__() + self.in_chans = in_chans + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + # B, C, H, W = x.shape + # assert H % self.img_size[0] == 0 and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # assert C == self.in_chans, \ + # f"Input image chanel ({C}) doesn't match model ({self.in_chans})" + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for global average pooling + """ + def __init__(self, img_size=(224,224), patch_size=(16, 16), in_chans=3, num_classes=1000, embed_dim=768, + depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, pos_drop_rate=0., attn_drop_rate=0., + proj_drop_rate=0., norm_layer=None, act_layer=None, cls_feature_dim=None, global_pool=False, enable_gra=False): + super().__init__() + self.global_pool = global_pool + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + + norm_layer = norm_layer if norm_layer else partial(nn.LayerNorm, eps=1e-6) + self.blocks = nn.Sequential(*[ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, norm_layer=norm_layer, + act_layer=act_layer) + for _ in range(depth)]) + + self.fc_norm = norm_layer(embed_dim) + + self.enable_gra = enable_gra + if self.enable_gra: + self.gra_embed = nn.Embedding(10, embed_dim) + + # feature representation for classification + if cls_feature_dim: + self.num_features = cls_feature_dim + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, cls_feature_dim)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # classification head(s) + self.head = nn.Linear(self.num_features, num_classes) + + self.init_weights() + + def init_weights_from_pretrained(self, pretrained_path): + if pretrained_path: + checkpoint = torch.load(pretrained_path, map_location='cpu') + print("Load pre-trained checkpoint from: %s" % pretrained_path) + checkpoint_model = checkpoint['model'] + + # interpolate position embedding + interpolate_pos_embed(self, checkpoint_model) + + # load pre-trained model + msg = self.load_state_dict(checkpoint_model, strict=False) + print(msg) + + def init_weights(self): + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively similar to normal_(std=0.02) + # as the default cutoff in trunc_normal_(std=.02) is too big (-2., 2.) + nn.init.normal_(self.cls_token, std=.02) + nn.init.normal_(self.pos_embed, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def shuffle(self, x): + """ + in: x (B, N, C) + out: x_shuffle (B, N, C), ids_restore (B, N) + """ + B, N, C = x.shape + noise = torch.rand(B, N, device=x.device) + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + x_shuffle = torch.gather(x, 1, index=ids_shuffle.unsqueeze(-1).repeat(1, 1, C)) + + return x_shuffle, ids_restore + + def unshuffle(self, x, ids_restore): + B, N, C = x.shape + x_unshuffle = torch.gather(x, 1, index=ids_restore.unsqueeze(-1).repeat(1, 1, C)) + + return x_unshuffle + + def split(self, x): + B, N, C = x.shape + num_tokens_per_split = 224 * 224 + num_splits = max(1, N // num_tokens_per_split) + out = [] + for i in range(num_splits): + if i == num_splits - 1: + out.append(x[:, i*num_tokens_per_split:]) + return out + out.append(x[:, i*num_tokens_per_split:(i+1)*num_tokens_per_split]) + + # window split for finetuning on larger size (the pretraining size should be 224 x 224) + def patchify(self, x): + """ + in: (B, N, C) + out: (B*win_w*win_h, N//(win_w*win_h), C) + """ + B, N, C = x.shape + grid_h, grid_w = self.patch_embed.grid_size + win_h_grid = 224 // self.patch_embed.patch_size[0] + win_w_grid = 224 // self.patch_embed.patch_size[1] + win_h, win_w = grid_h // win_h_grid, grid_w // win_w_grid + x = x.view(B, win_h, grid_h // win_h, win_w, grid_w // win_w, C) + x_patchified = x.permute((0, 1, 3, 2, 4, 5)).contiguous() + x_patchified = x_patchified.view(B * win_h * win_w, grid_h * grid_w // (win_h * win_w), C) + + return x_patchified + + # recover the window split + def unpatchify(self, x): + """ + in: (B*win_h*win_w, N//(win_h*win_w), C) + out: (B, N, C) + """ + B, N, C = x.shape + grid_h, grid_w = self.patch_embed.grid_size + win_h_grid = 224 // self.patch_embed.patch_size[0] + win_w_grid = 224 // self.patch_embed.patch_size[1] + win_h, win_w = grid_h // win_h_grid, grid_w // win_w_grid + x = x.view(B // (win_h * win_w), win_h, win_w, grid_h // win_h, grid_w // win_w, C) + x = x.permute((0, 1, 3, 2, 4, 5)).contiguous().view(B // (win_h * win_w), win_h * win_w * N, C) + + return x + + def forward_backbone(self, x, additional_features=None, gra=None, shuffle=False): + x = self.patch_embed(x) + if additional_features is not None: + x += additional_features + + if self.enable_gra and gra is not None: + gra_idx = torch.clamp(gra * 10 - 1, 0, 9).long() + x += self.gra_embed(gra_idx).repeat(1, x.shape[1], 1) + + x = self.pos_drop(x + self.pos_embed[:, 1:]) + num_blocks = len(self.blocks) + assert num_blocks % 4 == 0 + + if shuffle: + for i in range(1, num_blocks + 1): + x, ids_restore = self.shuffle(x) + x_split = self.split(x) + x_split = [self.blocks[i-1](x_split[j]) for j in range(len(x_split))] + x = torch.cat(x_split, dim=1) + x = self.unshuffle(x, ids_restore) + else: + num_blocks_per_group = 6 if num_blocks == 12 else num_blocks // 4 + is_patchified = False + for i in range(1, num_blocks + 1): + if i % num_blocks_per_group: + if not is_patchified: + x = self.patchify(x) + is_patchified = True + else: + pass # do nothing + else: + x = self.unpatchify(x) + is_patchified = False + x = self.blocks[i-1](x) + return x + + def forward(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + + if self.global_pool: + x = x[:, 1:].mean(dim=1) # global pool without cls token + x = self.fc_norm(x) + else: + x = self.fc_norm(x) + x = x[:, 0] + x = self.pre_logits(x) + x = self.head(x) + return x + +def vit_base_patch16(**kwargs): + model = VisionTransformer( + patch_size=(16, 16), embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs) + return model + +def vit_large_patch16(**kwargs): + model = VisionTransformer( + patch_size=(16, 16), embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs) + return model + +def vit_huge_patch14(**kwargs): + model = VisionTransformer( + patch_size=(14,14), embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs) + return model \ No newline at end of file diff --git a/isegm/model/modeling/models_vit_lora.py b/isegm/model/modeling/models_vit_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..89a464013027922657cf4792f48952cec2d9388c --- /dev/null +++ b/isegm/model/modeling/models_vit_lora.py @@ -0,0 +1,333 @@ +import torch +import torch.nn as nn + +from functools import partial +from collections import OrderedDict +from .pos_embed import interpolate_pos_embed +import loralib as lora + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() if act_layer else nn.GELU() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + ''' Multi-head self-attention ''' + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = lora.MergedLinear(dim, dim * 3, r=8, enable_lora=[True, False, True], bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1,2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., mlp_drop=0., qkv_bias=False, attn_drop=0., + proj_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, + proj_drop=proj_drop) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=mlp_drop) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=(224,224), patch_size=(16,16), in_chans=3, embed_dim=768, + norm_layer=None, flatten=True): + super().__init__() + self.in_chans = in_chans + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + # B, C, H, W = x.shape + # assert H % self.img_size[0] == 0 and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # assert C == self.in_chans, \ + # f"Input image chanel ({C}) doesn't match model ({self.in_chans})" + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class VisionTransformer_lora(nn.Module): + """ Vision Transformer with support for global average pooling + """ + def __init__(self, img_size=(224,224), patch_size=(16, 16), in_chans=3, num_classes=1000, embed_dim=768, + depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, pos_drop_rate=0., attn_drop_rate=0., + proj_drop_rate=0., norm_layer=None, act_layer=None, cls_feature_dim=None, global_pool=False, enable_gra=False): + super().__init__() + self.global_pool = global_pool + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # learnable positional embedding + self.pos_drop = nn.Dropout(p=pos_drop_rate) + + norm_layer = norm_layer if norm_layer else partial(nn.LayerNorm, eps=1e-6) + self.blocks = nn.Sequential(*[ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + attn_drop=attn_drop_rate, proj_drop=proj_drop_rate, norm_layer=norm_layer, + act_layer=act_layer) + for _ in range(depth)]) + + self.fc_norm = norm_layer(embed_dim) + + self.enable_gra = enable_gra + if self.enable_gra: + self.gra_embed = nn.Embedding(10, embed_dim) + + # feature representation for classification + if cls_feature_dim: + self.num_features = cls_feature_dim + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, cls_feature_dim)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # classification head(s) + self.head = nn.Linear(self.num_features, num_classes) + + self.init_weights() + + def init_weights_from_pretrained(self, pretrained_path): + if pretrained_path: + checkpoint = torch.load(pretrained_path, map_location='cpu') + print("Load pre-trained checkpoint from: %s" % pretrained_path) + checkpoint_model = checkpoint['model'] + + # interpolate position embedding + interpolate_pos_embed(self, checkpoint_model) + + # load pre-trained model + msg = self.load_state_dict(checkpoint_model, strict=False) + print(msg) + + + def init_weights(self): + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively similar to normal_(std=0.02) + # as the default cutoff in trunc_normal_(std=.02) is too big (-2., 2.) + nn.init.normal_(self.cls_token, std=.02) + nn.init.normal_(self.pos_embed, std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def shuffle(self, x): + """ + in: x (B, N, C) + out: x_shuffle (B, N, C), ids_restore (B, N) + """ + B, N, C = x.shape + noise = torch.rand(B, N, device=x.device) + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + x_shuffle = torch.gather(x, 1, index=ids_shuffle.unsqueeze(-1).repeat(1, 1, C)) + + return x_shuffle, ids_restore + + def unshuffle(self, x, ids_restore): + B, N, C = x.shape + x_unshuffle = torch.gather(x, 1, index=ids_restore.unsqueeze(-1).repeat(1, 1, C)) + + return x_unshuffle + + def split(self, x): + B, N, C = x.shape + num_tokens_per_split = 224 * 224 + num_splits = max(1, N // num_tokens_per_split) + out = [] + for i in range(num_splits): + if i == num_splits - 1: + out.append(x[:, i*num_tokens_per_split:]) + return out + out.append(x[:, i*num_tokens_per_split:(i+1)*num_tokens_per_split]) + + # window split for finetuning on larger size (the pretraining size should be 224 x 224) + def patchify(self, x): + """ + in: (B, N, C) + out: (B*win_w*win_h, N//(win_w*win_h), C) + """ + B, N, C = x.shape + grid_h, grid_w = self.patch_embed.grid_size + win_h_grid = 224 // self.patch_embed.patch_size[0] + win_w_grid = 224 // self.patch_embed.patch_size[1] + win_h, win_w = grid_h // win_h_grid, grid_w // win_w_grid + x = x.view(B, win_h, grid_h // win_h, win_w, grid_w // win_w, C) + x_patchified = x.permute((0, 1, 3, 2, 4, 5)).contiguous() + x_patchified = x_patchified.view(B * win_h * win_w, grid_h * grid_w // (win_h * win_w), C) + + return x_patchified + + # recover the window split + def unpatchify(self, x): + """ + in: (B*win_h*win_w, N//(win_h*win_w), C) + out: (B, N, C) + """ + B, N, C = x.shape + grid_h, grid_w = self.patch_embed.grid_size + win_h_grid = 224 // self.patch_embed.patch_size[0] + win_w_grid = 224 // self.patch_embed.patch_size[1] + win_h, win_w = grid_h // win_h_grid, grid_w // win_w_grid + x = x.view(B // (win_h * win_w), win_h, win_w, grid_h // win_h, grid_w // win_w, C) + x = x.permute((0, 1, 3, 2, 4, 5)).contiguous().view(B // (win_h * win_w), win_h * win_w * N, C) + + return x + + def forward_backbone(self, x, additional_features=None, gra=None, shuffle=False): + x = self.patch_embed(x) + if additional_features is not None: + x += additional_features + + if self.enable_gra and gra is not None: + gra_idx = torch.clamp(gra * 10 - 1, 0, 9).long() + x += self.gra_embed(gra_idx).repeat(1, x.shape[1], 1) + + x = self.pos_drop(x + self.pos_embed[:, 1:]) + num_blocks = len(self.blocks) + assert num_blocks % 4 == 0 + + if shuffle: + for i in range(1, num_blocks + 1): + x, ids_restore = self.shuffle(x) + x_split = self.split(x) + x_split = [self.blocks[i-1](x_split[j]) for j in range(len(x_split))] + x = torch.cat(x_split, dim=1) + x = self.unshuffle(x, ids_restore) + else: + num_blocks_per_group = 6 if num_blocks == 12 else num_blocks // 4 + is_patchified = False + for i in range(1, num_blocks + 1): + if i % num_blocks_per_group: + if not is_patchified: + x = self.patchify(x) + is_patchified = True + else: + pass # do nothing + else: + x = self.unpatchify(x) + is_patchified = False + x = self.blocks[i-1](x) + return x + + def forward(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + + if self.global_pool: + x = x[:, 1:].mean(dim=1) # global pool without cls token + x = self.fc_norm(x) + else: + x = self.fc_norm(x) + x = x[:, 0] + x = self.pre_logits(x) + x = self.head(x) + return x + + +def vit_tiny_patch16(**kwargs): + model = VisionTransformer( + patch_size=(16, 16), embed_dim=160, depth=8, num_heads=4, mlp_ratio=4, qkv_bias=True, **kwargs) + return model + + +def vit_base_patch16(**kwargs): + model = VisionTransformer( + patch_size=(16, 16), embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs) + return model + +def vit_large_patch16(**kwargs): + model = VisionTransformer( + patch_size=(16, 16), embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs) + return model + +def vit_huge_patch14(**kwargs): + model = VisionTransformer( + patch_size=(14,14), embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs) + return model \ No newline at end of file diff --git a/isegm/model/modeling/ocr.py b/isegm/model/modeling/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..df3b4f67959fc6a088b93ee7a34b15c1e07402df --- /dev/null +++ b/isegm/model/modeling/ocr.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F + + +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scale=1): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) \ + .permute(0, 2, 1).unsqueeze(3) # batch x k x c + return ocr_context + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, + norm_layer, align_corners) + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), + nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + +class ObjectAttentionBlock2D(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + + def __init__(self, + in_channels, + key_channels, + scale=1, + norm_layer=nn.BatchNorm2d, + align_corners=True): + super(ObjectAttentionBlock2D, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.align_corners = align_corners + + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels ** -.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), + mode='bilinear', align_corners=self.align_corners) + + return context diff --git a/isegm/model/modeling/pos_embed.py b/isegm/model/modeling/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..9059901686ae1588490cc4070b65b79a58203597 --- /dev/null +++ b/isegm/model/modeling/pos_embed.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def interpolate_pos_embed_inference(model, infer_img_size, device): + pos_embed = model.pos_embed + embedding_size = pos_embed.shape[-1] + + patch_embed = model.patch_embed + + num_patches = patch_embed.num_patches + num_extra_tokens = pos_embed.shape[-2] - num_patches + grid_size = patch_embed.grid_size + + patch_size = patch_embed.patch_size + infer_grid_size = (infer_img_size[0] // patch_size[0], \ + infer_img_size[1] // patch_size[1]) + + orig_size, new_size = grid_size, infer_grid_size + if orig_size != new_size: + # print("Position interpolate from %dx%d to %dx%d" % (orig_size[0], orig_size[1], + # new_size[0], new_size[1])) + extra_tokens = pos_embed[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size[0], orig_size[1], embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=new_size, mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + new_pos_embed = torch.nn.Parameter(new_pos_embed).to(device) + + model.pos_embed = new_pos_embed + model.patch_embed.grid_size = infer_grid_size + + diff --git a/isegm/model/modeling/resnet.py b/isegm/model/modeling/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..65fe949cef0035ba691ee319b25a0132d8ad37fe --- /dev/null +++ b/isegm/model/modeling/resnet.py @@ -0,0 +1,43 @@ +import torch +from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s + + +class ResNetBackbone(torch.nn.Module): + def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): + super(ResNetBackbone, self).__init__() + + if backbone == 'resnet34': + pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet50': + pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet101': + pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + elif backbone == 'resnet152': + pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) + else: + raise RuntimeError(f'unknown backbone: {backbone}') + + self.conv1 = pretrained.conv1 + self.bn1 = pretrained.bn1 + self.relu = pretrained.relu + self.maxpool = pretrained.maxpool + self.layer1 = pretrained.layer1 + self.layer2 = pretrained.layer2 + self.layer3 = pretrained.layer3 + self.layer4 = pretrained.layer4 + + def forward(self, x, additional_features=None): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + if additional_features is not None: + x = x + torch.nn.functional.pad(additional_features, + [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)], + mode='constant', value=0) + x = self.maxpool(x) + c1 = self.layer1(x) + c2 = self.layer2(c1) + c3 = self.layer3(c2) + c4 = self.layer4(c3) + + return c1, c2, c3, c4 diff --git a/isegm/model/modeling/resnetv1b.py b/isegm/model/modeling/resnetv1b.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad24cef5bde19f2627cfd3f755636f37cfb39ac --- /dev/null +++ b/isegm/model/modeling/resnetv1b.py @@ -0,0 +1,276 @@ +import torch +import torch.nn as nn +GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet' + + +class BasicBlockV1b(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BasicBlockV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn1 = norm_layer(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, + padding=previous_dilation, dilation=previous_dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class BottleneckV1b(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, + previous_dilation=1, norm_layer=nn.BatchNorm2d): + super(BottleneckV1b, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(planes) + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(planes) + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = out + residual + out = self.relu(out) + + return out + + +class ResNetV1b(nn.Module): + """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5. + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockV1, BottleneckV1. + layers : list of int + Numbers of layers in each block + classes : int, default 1000 + Number of classification classes. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + norm_layer : object + Normalization layer used (default: :class:`nn.BatchNorm2d`) + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + final_drop : float, default 0.0 + Dropout ratio before the final classification layer. + + Reference: + - He, Kaiming, et al. "Deep residual learning for image recognition." + Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. + + - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." + """ + def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32, + avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d): + self.inplanes = stem_width*2 if deep_stem else 64 + super(ResNetV1b, self).__init__() + if not deep_stem: + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + else: + self.conv1 = nn.Sequential( + nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(True), + nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False) + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(True) + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, + norm_layer=norm_layer) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, + norm_layer=norm_layer) + if dilated: + self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, + avg_down=avg_down, norm_layer=norm_layer) + else: + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + avg_down=avg_down, norm_layer=norm_layer) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.drop = None + if final_drop > 0.0: + self.drop = nn.Dropout(final_drop) + self.fc = nn.Linear(512 * block.expansion, classes) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, + avg_down=False, norm_layer=nn.BatchNorm2d): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = [] + if avg_down: + if dilation == 1: + downsample.append( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False) + ) + else: + downsample.append( + nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False) + ) + downsample.extend([ + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=1, bias=False), + norm_layer(planes * block.expansion) + ]) + downsample = nn.Sequential(*downsample) + else: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, out_channels=planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + norm_layer(planes * block.expansion) + ) + + layers = [] + if dilation in (1, 2): + layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + elif dilation == 4: + layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, + previous_dilation=dilation, norm_layer=norm_layer)) + else: + raise RuntimeError("=> unknown dilation size: {}".format(dilation)) + + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation, + previous_dilation=dilation, norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + if self.drop is not None: + x = self.drop(x) + x = self.fc(x) + + return x + + +def _safe_state_dict_filtering(orig_dict, model_dict_keys): + filtered_orig_dict = {} + for k, v in orig_dict.items(): + if k in model_dict_keys: + filtered_orig_dict[k] = v + else: + print(f"[ERROR] Failed to load <{k}> in backbone") + return filtered_orig_dict + + +def resnet34_v1b(pretrained=False, **kwargs): + model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet50_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet101_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model + + +def resnet152_v1s(pretrained=False, **kwargs): + model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs) + if pretrained: + model_dict = model.state_dict() + filtered_orig_dict = _safe_state_dict_filtering( + torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(), + model_dict.keys() + ) + model_dict.update(filtered_orig_dict) + model.load_state_dict(model_dict) + return model diff --git a/isegm/model/modeling/segformer.py b/isegm/model/modeling/segformer.py new file mode 100644 index 0000000000000000000000000000000000000000..eeab95c8ae2cd45af94174b15988c9ab62e18f55 --- /dev/null +++ b/isegm/model/modeling/segformer.py @@ -0,0 +1,478 @@ +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer, + constant_init, normal_init, trunc_normal_init) +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.transformer import MultiheadAttention +from mmcv.runner import BaseModule, ModuleList, Sequential, _load_checkpoint + +from .transformer_helper import PatchEmbed, nchw_to_nlc, nlc_to_nchw, resize, \ + get_root_logger, BaseDecodeHead, HEADS, BACKBONES + + +class MixFFN(BaseModule): + """An implementation of MixFFN of Segformer. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Conv to encode positional information. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + dropout_layer=None, + init_cfg=None): + super(MixFFN, self).__init__(init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + in_channels = embed_dims + fc1 = Conv2d( + in_channels=in_channels, + out_channels=feedforward_channels, + kernel_size=1, + stride=1, + bias=True) + # 3x3 depth wise conv to provide positional encode information + pe_conv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=(3 - 1) // 2, + bias=True, + groups=feedforward_channels) + fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=in_channels, + kernel_size=1, + stride=1, + bias=True) + drop = nn.Dropout(ffn_drop) + layers = [fc1, pe_conv, self.activate, drop, fc2, drop] + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else torch.nn.Identity() + + def forward(self, x, hw_shape, identity=None): + out = nlc_to_nchw(x, hw_shape) + out = self.layers(out) + out = nchw_to_nlc(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +class EfficientMultiheadAttention(MultiheadAttention): + """An implementation of Efficient Multi-head Attention of Segformer. + + This module is modified from MultiheadAttention which is a module from + mmcv.cnn.bricks.transformer. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + qkv_bias (bool): enable bias for qkv if True. Default True. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=None, + init_cfg=None, + batch_first=True, + qkv_bias=False, + norm_cfg=dict(type='LN'), + sr_ratio=1): + super().__init__( + embed_dims, + num_heads, + attn_drop, + proj_drop, + dropout_layer=dropout_layer, + init_cfg=init_cfg, + batch_first=batch_first, + bias=qkv_bias) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=sr_ratio, + stride=sr_ratio) + # The ret[0] of build_norm_layer is norm name. + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + + def forward(self, x, hw_shape, identity=None): + + x_q = x + if self.sr_ratio > 1: + x_kv = nlc_to_nchw(x, hw_shape) + x_kv = self.sr(x_kv) + x_kv = nchw_to_nlc(x_kv) + x_kv = self.norm(x_kv) + else: + x_kv = x + + if identity is None: + identity = x_q + + # `need_weights=True` will let nn.MultiHeadAttention + # `return attn_output, attn_output_weights.sum(dim=1) / num_heads` + # The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set + # `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`. + # This issue - `https://github.com/pytorch/pytorch/issues/37583` report + # the error that large scale tensor sum operation may cause cuda error. + out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0] + + return identity + self.dropout_layer(self.proj_drop(out)) + + +class TransformerEncoderLayer(BaseModule): + """Implements one encoder layer in Segformer. + + Args: + embed_dims (int): The feature dimension. + num_heads (int): Parallel attention heads. + feedforward_channels (int): The hidden dimension for FFNs. + drop_rate (float): Probability of an element to be zeroed. + after the feed forward layer. Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0.0. + qkv_bias (bool): enable bias for qkv if True. + Default: True. + act_cfg (dict): The activation config for FFNs. + Defalut: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default: False. + init_cfg (dict, optional): Initialization config dict. + Default:None. + sr_ratio (int): The ratio of spatial reduction of Efficient Multi-head + Attention of Segformer. Default: 1. + """ + + def __init__(self, + embed_dims, + num_heads, + feedforward_channels, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + qkv_bias=True, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + batch_first=True, + sr_ratio=1): + super(TransformerEncoderLayer, self).__init__() + + # The ret[0] of build_norm_layer is norm name. + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.attn = EfficientMultiheadAttention( + embed_dims=embed_dims, + num_heads=num_heads, + attn_drop=attn_drop_rate, + proj_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + batch_first=batch_first, + qkv_bias=qkv_bias, + norm_cfg=norm_cfg, + sr_ratio=sr_ratio) + + # The ret[0] of build_norm_layer is norm name. + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + + self.ffn = MixFFN( + embed_dims=embed_dims, + feedforward_channels=feedforward_channels, + ffn_drop=drop_rate, + dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), + act_cfg=act_cfg) + + def forward(self, x, hw_shape): + x = self.attn(self.norm1(x), hw_shape, identity=x) + x = self.ffn(self.norm2(x), hw_shape, identity=x) + return x + + +@BACKBONES.register_module() +class MixVisionTransformer(BaseModule): + """The backbone of Segformer. + + A PyTorch implement of : `SegFormer: Simple and Efficient Design for + Semantic Segmentation with Transformers` - + https://arxiv.org/pdf/2105.15203.pdf + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): Embedding dimension. Default: 768. + num_stags (int): The num of stages. Default: 4. + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 4, 8]. + patch_sizes (Sequence[int]): The patch_size of each overlapped patch + embedding. Default: [7, 3, 3, 3]. + strides (Sequence[int]): The stride of each overlapped patch embedding. + Default: [4, 2, 2, 2]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer encode layer. Default: [8, 4, 2, 1]. + out_indices (Sequence[int] | int): Output from which stages. + Default: (0, 1, 2, 3). + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0 + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0 + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Defalut: dict(type='GELU'). + pretrain_style (str): Choose to use official or mmcls pretrain weights. + Default: official. + pretrained (str, optional): model pretrained path. Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=64, + embed_dims=64, + num_stages=4, + num_layers=[3, 4, 6, 3], + num_heads=[1, 2, 4, 8], + patch_sizes=[7, 3, 3, 3], + strides=[2, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN', eps=1e-6), + pretrain_style='official', + pretrained=None, + init_cfg=None): + super().__init__() + + assert pretrain_style in [ + 'official', 'mmcls' + ], 'we only support official weights or mmcls weights.' + + if isinstance(pretrained, str) or pretrained is None: + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + else: + raise TypeError('pretrained must be a str or None') + + self.embed_dims = embed_dims + + self.num_stages = num_stages + self.num_layers = num_layers + self.num_heads = num_heads + self.patch_sizes = patch_sizes + self.strides = strides + self.sr_ratios = sr_ratios + assert num_stages == len(num_layers) == len(num_heads) \ + == len(patch_sizes) == len(strides) == len(sr_ratios) + + self.out_indices = out_indices + assert max(out_indices) < self.num_stages + self.pretrain_style = pretrain_style + self.pretrained = pretrained + self.init_cfg = init_cfg + + # transformer encoder + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] # stochastic num_layer decay rule + + cur = 0 + self.layers = ModuleList() + for i, num_layer in enumerate(num_layers): + embed_dims_i = embed_dims * num_heads[i] + patch_embed = PatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims_i, + kernel_size=patch_sizes[i], + stride=strides[i], + padding=patch_sizes[i] // 2, + pad_to_patch_size=False, + norm_cfg=norm_cfg) + layer = ModuleList([ + TransformerEncoderLayer( + embed_dims=embed_dims_i, + num_heads=num_heads[i], + feedforward_channels=mlp_ratio * embed_dims_i, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dpr[cur + idx], + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + sr_ratio=sr_ratios[i]) for idx in range(num_layer) + ]) + in_channels = embed_dims_i + # The ret[0] of build_norm_layer is norm name. + norm = build_norm_layer(norm_cfg, embed_dims_i)[1] + self.layers.append(ModuleList([patch_embed, layer, norm])) + cur += num_layer + + def init_weights(self): + if self.pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m.weight, std=.02) + if m.bias is not None: + constant_init(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + constant_init(m.bias, 0) + constant_init(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + constant_init(m.bias, 0) + elif isinstance(self.pretrained, str): + logger = get_root_logger() + checkpoint = _load_checkpoint( + self.pretrained, logger=logger, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + + # only use this code if when adopt v3 + ori_proj_weight = state_dict['layers.0.0.projection.weight'] + state_dict['layers.0.0.projection.weight'] = torch.cat([ori_proj_weight, ori_proj_weight], dim=1) + + self.load_state_dict(state_dict, True) + + + def forward(self, x, additional_features=None): + + outs = [] + for i, layer in enumerate(self.layers): + x, H, W = layer[0](x), layer[0].DH, layer[0].DW + hw_shape = (H, W) + for block in layer[1]: + x = block(x, hw_shape) + x = layer[2](x) + x = nlc_to_nchw(x, hw_shape) + if i in self.out_indices: + outs.append(x) + + return outs + + +@HEADS.register_module() +class SegformerHead(BaseDecodeHead): + """The all mlp Head of segformer. + + This head is the implementation of + `Segformer ` _. + + Args: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__(self, interpolate_mode='bilinear', **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + + assert num_inputs == len(self.in_index) + + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.channels * num_inputs, + out_channels=self.channels, + kernel_size=1, + norm_cfg=self.norm_cfg) + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=inputs[0].shape[2:], + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + + out = self.cls_seg(out) + + return out diff --git a/isegm/model/modeling/swin_transformer.py b/isegm/model/modeling/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c9100e5ed1587db61e5d7959253ec997a55f34b2 --- /dev/null +++ b/isegm/model/modeling/swin_transformer.py @@ -0,0 +1,724 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from mmcv.cnn import ConvModule + +# from mmcv_custom import load_checkpoint +# from mmseg.utils import get_root_logger +# from ..builder import BACKBONES +from .swin_transformer_helper.checkpoint import load_checkpoint +from .swin_transformer_helper.logger import get_root_logger +from .swin_transformer_helper.builder import BACKBONES +from .transformer_helper import resize, BaseDecodeHead + + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ Forward function. + + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer + + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +@BACKBONES.register_module() +class SwinTransformer(nn.Module): + """ Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + in_coord_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_coords = PatchEmbed( + patch_size=patch_size, in_chans=in_coord_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] + + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x, coords=None): + """Forward function.""" + x = self.patch_embed(x) + coords = self.patch_embed(coords) + x = x + coords + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +@BACKBONES.register_module() +class SwinTransfomerSegHead(BaseDecodeHead): + """The all mlp Head of segformer. + + This head is the implementation of + `Segformer ` _. + + Args: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__(self, upsample='x1', interpolate_mode='bilinear', **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.unsample = upsample + self.out_channels = {'x1': self.channels, 'x2': self.channels * 2, + 'x4': self.channels * 4}[upsample] + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + + assert num_inputs == len(self.in_index) + + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.out_channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.out_channels * num_inputs, + out_channels=self.out_channels, + kernel_size=1, + norm_cfg=self.norm_cfg) + + self.up_conv1 = nn.Sequential( + nn.ConvTranspose2d(self.out_channels, self.out_channels // 2, 2, stride=2), + nn.GroupNorm(1, self.out_channels // 2), + nn.Conv2d(self.out_channels // 2, self.out_channels // 2, 1), + nn.GroupNorm(1, self.out_channels // 2), + nn.GELU() + ) + + self.up_conv2 = nn.Sequential( + nn.ConvTranspose2d(self.out_channels // 2, self.out_channels // 4, 2, stride=2), + nn.GroupNorm(1, self.out_channels // 4), + nn.Conv2d(self.out_channels // 4, self.out_channels // 4, 1), + nn.GroupNorm(1, self.out_channels // 4), + nn.GELU() + ) + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=inputs[0].shape[2:], + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + if self.unsample == 'x2': + out = self.up_conv1(out) + + if self.unsample == 'x4': + out = self.up_conv2(self.up_conv1(out)) + + out = self.cls_seg(out) + + return out diff --git a/isegm/model/modeling/swin_transformer_helper/__init__.py b/isegm/model/modeling/swin_transformer_helper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/isegm/model/modeling/swin_transformer_helper/builder.py b/isegm/model/modeling/swin_transformer_helper/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd09279b0f0d8efd7a08fe9a13109341801eada --- /dev/null +++ b/isegm/model/modeling/swin_transformer_helper/builder.py @@ -0,0 +1,66 @@ +import warnings + +from mmcv.utils import Registry, build_from_cfg +from torch import nn + +BACKBONES = Registry('backbone') +NECKS = Registry('neck') +HEADS = Registry('head') +LOSSES = Registry('loss') +SEGMENTORS = Registry('segmentor') + + +def build(cfg, registry, default_args=None): + """Build a module. + + Args: + cfg (dict, list[dict]): The config of modules, is is either a dict + or a list of configs. + registry (:obj:`Registry`): A registry the module belongs to. + default_args (dict, optional): Default arguments to build the module. + Defaults to None. + + Returns: + nn.Module: A built nn module. + """ + + if isinstance(cfg, list): + modules = [ + build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + ] + return nn.Sequential(*modules) + else: + return build_from_cfg(cfg, registry, default_args) + + +def build_backbone(cfg): + """Build backbone.""" + return build(cfg, BACKBONES) + + +def build_neck(cfg): + """Build neck.""" + return build(cfg, NECKS) + + +def build_head(cfg): + """Build head.""" + return build(cfg, HEADS) + + +def build_loss(cfg): + """Build loss.""" + return build(cfg, LOSSES) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) \ No newline at end of file diff --git a/isegm/model/modeling/swin_transformer_helper/checkpoint.py b/isegm/model/modeling/swin_transformer_helper/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f3878beebfa36921f4e34d1b96840c2bfd5df1 --- /dev/null +++ b/isegm/model/modeling/swin_transformer_helper/checkpoint.py @@ -0,0 +1,500 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import torch +import torchvision +from torch.optim import Optimizer +from torch.utils import model_zoo +from torch.nn import functional as F + +import mmcv +from mmcv.fileio import FileClient +from mmcv.fileio import load as load_file +from mmcv.parallel import is_module_wrapper +from mmcv.utils import mkdir_or_exist +from mmcv.runner import get_dist_info + +ENV_MMCV_HOME = 'MMCV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load( + downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + allowed_backends = ['ceph'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('torchvision://'): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('open-mmlab://'): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith('mmcls://'): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(('http://', 'https://')): + checkpoint = load_url_dist(filename) + elif filename.startswith('pavi://'): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith('s3://'): + checkpoint = load_fileclient_dist( + filename, backend='ceph', map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H*W: + logger.warning("Error in loading absolute_pos_embed, pass") + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) + + # interpolate position bias table if needed + relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f"Error in loading {table_key}, pass") + else: + if L1 != L2: + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() \ No newline at end of file diff --git a/isegm/model/modeling/swin_transformer_helper/logger.py b/isegm/model/modeling/swin_transformer_helper/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..25477bab309b0f42592f413de40e37a4137fe3cb --- /dev/null +++ b/isegm/model/modeling/swin_transformer_helper/logger.py @@ -0,0 +1,27 @@ +import logging + +from mmcv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmseg". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + + logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) + + return logger \ No newline at end of file diff --git a/isegm/model/modeling/swin_unet.py b/isegm/model/modeling/swin_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..55045da38da473b4d961c647c43182d72194317e --- /dev/null +++ b/isegm/model/modeling/swin_unet.py @@ -0,0 +1,751 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + +class PatchExpand(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() + self.norm = norm_layer(dim // dim_scale) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x= self.norm(x) + + return x + +class FinalPatchExpand_X4(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(dim, 16*dim, bias=False) + self.output_dim = dim + self.norm = norm_layer(self.output_dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) + x = x.view(B,-1,self.output_dim) + x= self.norm(x) + + return x + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class BasicLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if upsample is not None: + self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) + else: + self.upsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.upsample is not None: + x = self.upsample(x) + return x + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformerSys(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, final_upsample="expand_first", **kwargs): + super().__init__() + + print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, + depths_decoder,drop_path_rate,num_classes)) + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features_up = int(embed_dim * 2) + self.mlp_ratio = mlp_ratio + self.final_upsample = final_upsample + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build encoder and bottleneck layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + # build decoder layers + self.layers_up = nn.ModuleList() + self.concat_back_dim = nn.ModuleList() + for i_layer in range(self.num_layers): + concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), + int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() + if i_layer ==0 : + layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) + else: + layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), + input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], + norm_layer=norm_layer, + upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers_up.append(layer_up) + self.concat_back_dim.append(concat_linear) + + self.norm = norm_layer(self.num_features) + self.norm_up= norm_layer(self.embed_dim) + + if self.final_upsample == "expand_first": + print("---final upsample expand_first---") + self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) + self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + #Encoder and Bottleneck + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + x_downsample = [] + + for layer in self.layers: + x_downsample.append(x) + x = layer(x) + + x = self.norm(x) # B L C + + return x, x_downsample + + #Dencoder and Skip connection + def forward_up_features(self, x, x_downsample): + for inx, layer_up in enumerate(self.layers_up): + if inx == 0: + x = layer_up(x) + else: + x = torch.cat([x,x_downsample[3-inx]],-1) + x = self.concat_back_dim[inx](x) + x = layer_up(x) + + x = self.norm_up(x) # B L C + + return x + + def up_x4(self, x): + H, W = self.patches_resolution + B, L, C = x.shape + assert L == H*W, "input features has wrong size" + + if self.final_upsample=="expand_first": + x = self.up(x) + x = x.view(B,4*H,4*W,-1) + x = x.permute(0,3,1,2) #B,C,H,W + x = self.output(x) + + return x + + def forward(self, x): + x, x_downsample = self.forward_features(x) + x = self.forward_up_features(x,x_downsample) + x = self.up_x4(x) + + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return flops \ No newline at end of file diff --git a/isegm/model/modeling/transformer_helper/__init__.py b/isegm/model/modeling/transformer_helper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..068c849953b934389fd986ca7e18abfee13fc63d --- /dev/null +++ b/isegm/model/modeling/transformer_helper/__init__.py @@ -0,0 +1,13 @@ +from .embed import PatchEmbed +from .shape_convert import nchw_to_nlc, nlc_to_nchw +from .wrappers import resize, Upsample +from .logger import get_root_logger +from .decode_head import BaseDecodeHead +from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone, + build_head, build_loss, build_segmentor) + +__all__ = [ + 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'resize', 'Upsample', + 'get_root_logger', 'BaseDecodeHead', 'BACKBONES', 'HEADS', 'LOSSES', + 'SEGMENTORS', 'build_backbone', 'build_head', 'build_loss', 'build_segmentor' +] diff --git a/isegm/model/modeling/transformer_helper/accuracy.py b/isegm/model/modeling/transformer_helper/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..f2cd16b7f9be41d96ebc7ae3bffd027ea1353460 --- /dev/null +++ b/isegm/model/modeling/transformer_helper/accuracy.py @@ -0,0 +1,79 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class, ...) + target (torch.Tensor): The target of each prediction, shape (N, , ...) + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == target.ndim + 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), \ + f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + pred_value, pred_label = pred.topk(maxk, dim=1) + # transpose to shape (maxk, N, ...) + pred_label = pred_label.transpose(0, 1) + correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / target.numel())) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + """Accuracy calculation module.""" + + def __init__(self, topk=(1, ), thresh=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh) diff --git a/isegm/model/modeling/transformer_helper/base_pixel_sampler.py b/isegm/model/modeling/transformer_helper/base_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..03672cd478a2e464cc734ae92686c86f219da0a9 --- /dev/null +++ b/isegm/model/modeling/transformer_helper/base_pixel_sampler.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + + +class BasePixelSampler(metaclass=ABCMeta): + """Base class of pixel sampler.""" + + def __init__(self, **kwargs): + pass + + @abstractmethod + def sample(self, seg_logit, seg_label): + """Placeholder for sample function.""" diff --git a/isegm/model/modeling/transformer_helper/builder.py b/isegm/model/modeling/transformer_helper/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..829807d630e01de7bab4cc40d088823d2b2be5d7 --- /dev/null +++ b/isegm/model/modeling/transformer_helper/builder.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION +from mmcv.utils import Registry, build_from_cfg + + +PIXEL_SAMPLERS = Registry('pixel sampler') +MODELS = Registry('models', parent=MMCV_MODELS) +ATTENTION = Registry('attention', parent=MMCV_ATTENTION) + +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +SEGMENTORS = MODELS + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) + + +def build_backbone(cfg): + """Build backbone.""" + return BACKBONES.build(cfg) + + +def build_neck(cfg): + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return SEGMENTORS.build( + cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/isegm/model/modeling/transformer_helper/cross_entropy_loss.py b/isegm/model/modeling/transformer_helper/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d1c27eadf57ee2c32dd1fd206205c17fc5430e --- /dev/null +++ b/isegm/model/modeling/transformer_helper/cross_entropy_loss.py @@ -0,0 +1,199 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .builder import LOSSES +from .utils import get_class_weight, weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=-100): + """The wrapper function for :func:`F.cross_entropy`""" + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights *= valid_mask + + return bin_labels, bin_label_weights + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=255): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int | None): The label index to be ignored. Default: 255 + + Returns: + torch.Tensor: The calculated loss + """ + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + label, weight = _expand_onehot_labels(label, weight, pred.shape, + ignore_index) + + # weighted element-wise losses + if weight is not None: + weight = weight.float() + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@LOSSES.register_module() +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + loss_weight=1.0): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/isegm/model/modeling/transformer_helper/decode_head.py b/isegm/model/modeling/transformer_helper/decode_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3da7c0ba94eeed93fabe9c0350fb6519b91e2232 --- /dev/null +++ b/isegm/model/modeling/transformer_helper/decode_head.py @@ -0,0 +1,231 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +import torch +import torch.nn as nn +from mmcv.runner import BaseModule, auto_fp16, force_fp32 + +from .builder import build_pixel_sampler, build_loss +from .wrappers import resize +from .accuracy import accuracy +from .cross_entropy_loss import CrossEntropyLoss + + +class BaseDecodeHead(BaseModule, metaclass=ABCMeta): + """Base class for BaseDecodeHead. + + Args: + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict): Config of decode loss. + Default: dict(type='CrossEntropyLoss'). + ignore_index (int | None): The label index to be ignored. When using + masked BCE loss, ignore_index should be set to None. Default: 255 + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + in_channels, + channels, + *, + num_classes, + loss_decode, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + in_index=-1, + input_transform=None, + ignore_index=255, + sampler=None, + align_corners=False, + init_cfg=dict( + type='Normal', std=0.01, override=dict(name='conv_seg'))): + super(BaseDecodeHead, self).__init__(init_cfg) + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.num_classes = num_classes + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + # self.loss_decode = build_loss(loss_decode) + self.loss_decode = loss_decode + self.ignore_index = ignore_index + self.align_corners = align_corners + if sampler is not None: + self.sampler = build_pixel_sampler(sampler, context=self) + else: + self.sampler = None + + self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.fp16_enabled = False + + def extra_repr(self): + """Extra repr.""" + s = f'input_transform={self.input_transform}, ' \ + f'ignore_index={self.ignore_index}, ' \ + f'align_corners={self.align_corners}' + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ['resize_concat', 'multiple_select'] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == 'resize_concat': + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def _transform_inputs(self, inputs): + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if self.input_transform == 'resize_concat': + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == 'multiple_select': + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + @auto_fp16() + @abstractmethod + def forward(self, inputs): + """Placeholder of forward function.""" + pass + + def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): + """Forward function for training. + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + train_cfg (dict): The training config. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + seg_logits = self.forward(inputs) + losses = self.losses(seg_logits, gt_semantic_seg) + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Forward function for testing. + + Args: + inputs (list[Tensor]): List of multi-level img features. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + test_cfg (dict): The testing config. + + Returns: + Tensor: Output segmentation map. + """ + return self.forward(inputs) + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + @force_fp32(apply_to=('seg_logit', )) + def losses(self, seg_logit, seg_label): + """Compute segmentation loss.""" + loss = dict() + seg_logit = resize( + input=seg_logit, + size=seg_label.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + loss['loss_seg'] = self.loss_decode( + seg_logit, + seg_label, + weight=seg_weight, + ignore_index=self.ignore_index) + loss['acc_seg'] = accuracy(seg_logit, seg_label) + return loss diff --git a/isegm/model/modeling/transformer_helper/embed.py b/isegm/model/modeling/transformer_helper/embed.py new file mode 100644 index 0000000000000000000000000000000000000000..c0cf143488eafcd1dcd9e80f824807aa47de34fd --- /dev/null +++ b/isegm/model/modeling/transformer_helper/embed.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn.functional as F +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.runner.base_module import BaseModule +from torch.nn.modules.utils import _pair as to_2tuple + + +# Modified from Pytorch-Image-Models +class PatchEmbed(BaseModule): + """Image to Patch Embedding V2. + + We use a conv layer to implement PatchEmbed. + Args: + in_channels (int): The num of input channels. Default: 3 + embed_dims (int): The dimensions of embedding. Default: 768 + conv_type (dict, optional): The config dict for conv layers type + selection. Default: None. + kernel_size (int): The kernel_size of embedding conv. Default: 16. + stride (int): The slide stride of embedding conv. + Default: None (Default to be equal with kernel_size). + padding (int): The padding length of embedding conv. Default: 0. + dilation (int): The dilation rate of embedding conv. Default: 1. + pad_to_patch_size (bool, optional): Whether to pad feature map shape + to multiple patch size. Default: True. + norm_cfg (dict, optional): Config dict for normalization layer. + init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=768, + conv_type=None, + kernel_size=16, + stride=16, + padding=0, + dilation=1, + pad_to_patch_size=True, + norm_cfg=None, + init_cfg=None): + super(PatchEmbed, self).__init__() + + self.embed_dims = embed_dims + self.init_cfg = init_cfg + + if stride is None: + stride = kernel_size + + self.pad_to_patch_size = pad_to_patch_size + + # The default setting of patch size is equal to kernel size. + patch_size = kernel_size + if isinstance(patch_size, int): + patch_size = to_2tuple(patch_size) + elif isinstance(patch_size, tuple): + if len(patch_size) == 1: + patch_size = to_2tuple(patch_size[0]) + assert len(patch_size) == 2, \ + f'The size of patch should have length 1 or 2, ' \ + f'but got {len(patch_size)}' + + self.patch_size = patch_size + + # Use conv layer to embed + conv_type = conv_type or 'Conv2d' + self.projection = build_conv_layer( + dict(type=conv_type), + in_channels=in_channels, + out_channels=embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation) + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + else: + self.norm = None + + def forward(self, x): + H, W = x.shape[2], x.shape[3] + + # TODO: Process overlapping op + if self.pad_to_patch_size: + # Modify H, W to multiple of patch size. + if H % self.patch_size[0] != 0: + x = F.pad( + x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + if W % self.patch_size[1] != 0: + x = F.pad( + x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0)) + + x = self.projection(x) + self.DH, self.DW = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) + + if self.norm is not None: + x = self.norm(x) + + return x diff --git a/isegm/model/modeling/transformer_helper/logger.py b/isegm/model/modeling/transformer_helper/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..05d2f13439ff501aa51b248ce7396ea5d41a38fb --- /dev/null +++ b/isegm/model/modeling/transformer_helper/logger.py @@ -0,0 +1,27 @@ +import logging + +from mmcv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmseg". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + + logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) + + return logger diff --git a/isegm/model/modeling/transformer_helper/shape_convert.py b/isegm/model/modeling/transformer_helper/shape_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..34c8648c4ac299cdc16d42e4c88c843df268772d --- /dev/null +++ b/isegm/model/modeling/transformer_helper/shape_convert.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def nlc_to_nchw(x, hw_shape): + """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, L, C] before convertion. + hw_shape (Sequence[int]): The height and width of output feature map. + + Returns: + Tensor: The output tensor of shape [N, C, H, W] after convertion. + """ + H, W = hw_shape + assert len(x.shape) == 3 + B, L, C = x.shape + assert L == H * W, 'The seq_len doesn\'t match H, W' + return x.transpose(1, 2).reshape(B, C, H, W) + + +def nchw_to_nlc(x): + """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. + + Args: + x (Tensor): The input tensor of shape [N, C, H, W] before convertion. + + Returns: + Tensor: The output tensor of shape [N, L, C] after convertion. + """ + assert len(x.shape) == 4 + return x.flatten(2).transpose(1, 2).contiguous() diff --git a/isegm/model/modeling/transformer_helper/utils.py b/isegm/model/modeling/transformer_helper/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c57e4b18a8a9445271b9165ca8d0bd26b4659a7b --- /dev/null +++ b/isegm/model/modeling/transformer_helper/utils.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools + +import mmcv +import numpy as np +import torch.nn.functional as F + + +def get_class_weight(class_weight): + """Get class weight for loss function. + + Args: + class_weight (list[float] | str | None): If class_weight is a str, + take it as a file name and read from it. + """ + if isinstance(class_weight, str): + # take it as a file path + if class_weight.endswith('.npy'): + class_weight = np.load(class_weight) + else: + # pkl, json or yaml + class_weight = mmcv.load(class_weight) + + return class_weight + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Avarage factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/isegm/model/modeling/transformer_helper/wrappers.py b/isegm/model/modeling/transformer_helper/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..ce67e4bebe1ed463072858f97dd950e596ca6a28 --- /dev/null +++ b/isegm/model/modeling/transformer_helper/wrappers.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/isegm/model/modeling/twoway_transformer.py b/isegm/model/modeling/twoway_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bc28c5aa7b365f07a8274ea4bf68f44b2e4d7edb --- /dev/null +++ b/isegm/model/modeling/twoway_transformer.py @@ -0,0 +1,333 @@ +import math +from typing import Tuple, Type +import numpy as np +import torch +from torch import nn, Tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLPBlock(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + act: Type[nn.Module], + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Sequential(nn.Linear(n, k), act()) + for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) + ) + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.fc(x) + +# From https://github.com/yformer/EfficientSAM/blob/main/efficient_sam/efficient_sam_decoder.py +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int) -> None: + super().__init__() + self.register_buffer( + "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats)) + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device = self.positional_encoding_gaussian_matrix.device + grid = torch.ones([h, w], device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# From https://github.com/yformer/EfficientSAM/blob/main/efficient_sam/build_efficient_sam.py +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module], + normalize_before_activation: bool, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + curr_layer = TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + normalize_before_activation=normalize_before_activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + self.layers.append(curr_layer) + + self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + # bs, c, h, w = image_embedding.shape + if len(image_embedding.shape) == 4: + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for idx, layer in enumerate(self.layers): + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module], + normalize_before_activation: bool, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock( + embedding_dim, + mlp_dim, + embedding_dim, + 1, + activation, + ) + + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if not self.skip_first_layer_pe: + queries = queries + query_pe + attn_out = self.self_attn(q=queries, k=queries, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class AttentionForTwoWayAttentionBlock(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + self._reset_parameters() + + def _reset_parameters(self) -> None: + # The fan_out is incorrect, but matches pytorch's initialization + # for which qkv is a single 3*embedding_dim x embedding_dim matrix + fan_in = self.embedding_dim + fan_out = 3 * self.internal_dim + # Xavier uniform with our custom fan_out + bnd = math.sqrt(6 / (fan_in + fan_out)) + nn.init.uniform_(self.q_proj.weight, -bnd, bnd) + nn.init.uniform_(self.k_proj.weight, -bnd, bnd) + nn.init.uniform_(self.v_proj.weight, -bnd, bnd) + # out_proj.weight is left with default initialization, like pytorch attention + nn.init.zeros_(self.q_proj.bias) + nn.init.zeros_(self.k_proj.bias) + nn.init.zeros_(self.v_proj.bias) + nn.init.zeros_(self.out_proj.bias) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + return out diff --git a/isegm/model/modifiers.py b/isegm/model/modifiers.py new file mode 100644 index 0000000000000000000000000000000000000000..046221838069e90ae201b9169db159cc69c13244 --- /dev/null +++ b/isegm/model/modifiers.py @@ -0,0 +1,11 @@ + + +class LRMult(object): + def __init__(self, lr_mult=1.): + self.lr_mult = lr_mult + + def __call__(self, m): + if getattr(m, 'weight', None) is not None: + m.weight.lr_mult = self.lr_mult + if getattr(m, 'bias', None) is not None: + m.bias.lr_mult = self.lr_mult diff --git a/isegm/model/ops.py b/isegm/model/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9be9c73cbef7b83645af93e1fa7338fa6513a92b --- /dev/null +++ b/isegm/model/ops.py @@ -0,0 +1,116 @@ +import torch +from torch import nn as nn +import numpy as np +import isegm.model.initializer as initializer + + +def select_activation_function(activation): + if isinstance(activation, str): + if activation.lower() == 'relu': + return nn.ReLU + elif activation.lower() == 'softplus': + return nn.Softplus + else: + raise ValueError(f"Unknown activation type {activation}") + elif isinstance(activation, nn.Module): + return activation + else: + raise ValueError(f"Unknown activation type {activation}") + + +class BilinearConvTranspose2d(nn.ConvTranspose2d): + def __init__(self, in_channels, out_channels, scale, groups=1): + kernel_size = 2 * scale - scale % 2 + self.scale = scale + + super().__init__( + in_channels, out_channels, + kernel_size=kernel_size, + stride=scale, + padding=1, + groups=groups, + bias=False) + + self.apply(initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups)) + + +class DistMaps(nn.Module): + def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): + super(DistMaps, self).__init__() + self.spatial_scale = spatial_scale + self.norm_radius = norm_radius + self.cpu_mode = cpu_mode + self.use_disks = use_disks + if self.cpu_mode: + from isegm.utils.cython import get_dist_maps + self._get_dist_maps = get_dist_maps + + def get_coord_features(self, points, batchsize, rows, cols): + if self.cpu_mode: + coords = [] + for i in range(batchsize): + norm_delimeter = 1.0 if self.use_disks else self.spatial_scale * self.norm_radius + coords.append(self._get_dist_maps(points[i].cpu().float().numpy(), rows, cols, + norm_delimeter)) + coords = torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() + else: + num_points = points.shape[1] // 2 + points = points.view(-1, points.size(2)) + points, points_order = torch.split(points, [2, 1], dim=1) + + invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 + row_array = torch.arange(start=0, end=rows, step=1, dtype=torch.float32, device=points.device) + col_array = torch.arange(start=0, end=cols, step=1, dtype=torch.float32, device=points.device) + + coord_rows, coord_cols = torch.meshgrid(row_array, col_array) + coords = torch.stack((coord_rows, coord_cols), dim=0).unsqueeze(0).repeat(points.size(0), 1, 1, 1) + + add_xy = (points * self.spatial_scale).view(points.size(0), points.size(1), 1, 1) + coords.add_(-add_xy) + if not self.use_disks: + coords.div_(self.norm_radius * self.spatial_scale) + coords.mul_(coords) + + coords[:, 0] += coords[:, 1] + coords = coords[:, :1] + + coords[invalid_points, :, :, :] = 1e6 + + coords = coords.view(-1, num_points, 1, rows, cols) + coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w + coords = coords.view(-1, 2, rows, cols) + + if self.use_disks: + coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float() + else: + coords.sqrt_().mul_(2).tanh_() + + return coords + + def forward(self, x, coords): + return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) + + +class ScaleLayer(nn.Module): + def __init__(self, init_value=1.0, lr_mult=1): + super().__init__() + self.lr_mult = lr_mult + self.scale = nn.Parameter( + torch.full((1,), init_value / lr_mult, dtype=torch.float32) + ) + + def forward(self, x): + scale = torch.abs(self.scale * self.lr_mult) + return x * scale + + +class BatchImageNormalize: + def __init__(self, mean, std, dtype=torch.float): + self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None] + self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None] + + def __call__(self, tensor): + tensor = tensor.clone() + + tensor.sub_(self.mean.to(tensor.device)).div_(self.std.to(tensor.device)) + return tensor diff --git a/isegm/model/sam_modeling/__init__.py b/isegm/model/sam_modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d31ff27b438e4e387dd0663f07031cd46a8f094d --- /dev/null +++ b/isegm/model/sam_modeling/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .is_model_sam import SAMISWrapper +from .image_encoder import ImageEncoderViT +from .image_encoder_lora import ImageEncoderViT_lora +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/isegm/model/sam_modeling/common.py b/isegm/model/sam_modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96 --- /dev/null +++ b/isegm/model/sam_modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/isegm/model/sam_modeling/image_encoder.py b/isegm/model/sam_modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..940a79cf9da30dfee81ee0bb318fba1dfa45fb61 --- /dev/null +++ b/isegm/model/sam_modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/isegm/model/sam_modeling/image_encoder_lora.py b/isegm/model/sam_modeling/image_encoder_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..69e9a2d3f56fe343e430ca20abba6eee984f7f80 --- /dev/null +++ b/isegm/model/sam_modeling/image_encoder_lora.py @@ -0,0 +1,397 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock +import loralib as lora + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT_lora(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.qkv = lora.MergedLinear(dim, dim * 3, r=8, lora_dropout=1e-9, enable_lora=[True, False, True], bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/isegm/model/sam_modeling/is_model_sam.py b/isegm/model/sam_modeling/is_model_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1e5f63bea7d5ef3ba55ac6999061ce2fe33e4a --- /dev/null +++ b/isegm/model/sam_modeling/is_model_sam.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +import numpy as np + +from typing import Any, Dict, List, Tuple +from torch.nn import functional as F +from functools import partial + +from isegm.model.ops import DistMaps, BatchImageNormalize, ScaleLayer +from isegm.utils.serialization import serialize +from .image_encoder import ImageEncoderViT +from .image_encoder_lora import ImageEncoderViT_lora +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer + +class SAMISWrapper(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + @serialize + def __init__( + self, + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7,15,23,31], + enable_lora=True, + enable_gra=True, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + with_aux_output=False, + with_prev_mask=False, + norm_mean_std=([.485, .456, .406], [.229, .224, .225]), + image_size=1024, + ): + super().__init__() + self.with_aux_output = with_aux_output + self.with_prev_mask = with_prev_mask + self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1]) + + prompt_embed_dim = 256 + image_size = image_size + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + + if enable_lora: + self.image_encoder = ImageEncoderViT_lora( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + else: + self.image_encoder = ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + self.prompt_encoder = PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ) + self.mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + enable_gra=enable_gra, + iou_head_depth=3, + iou_head_hidden_dim=256, + ) + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + def forward(self, image, points, gra=None, multimask_output=False, return_logits=True): + image, prev_mask = self.prepare_input(image) + point_coords, point_labels = self.get_model_input(points) + batched_input = [] + for bindx in range(image.shape[0]): + batched_input.append( + { + "image": image[bindx], + "point_coords": point_coords[bindx:bindx+1], + "point_labels": point_labels[bindx:bindx+1], + "mask_inputs": prev_mask[bindx:bindx+1], + "gra": gra[bindx] if gra is not None else None, + } + ) + + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + output_masks = [] + output_iou_predictions = [] + output_low_res_masks = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + gra=image_record["gra"], + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["image"].shape[-2:], + ) + if not return_logits: + masks = masks > self.mask_threshold + + output_masks.append(masks) + output_iou_predictions.append(iou_predictions) + output_low_res_masks.append(low_res_masks) + + return { + "instances": torch.cat(output_masks, dim=0), + "iou_predictions": torch.cat(output_iou_predictions, dim=0), + "low_res_logits": torch.cat(output_low_res_masks, dim=0), + } + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def prepare_input(self, image): + prev_mask = None + if self.with_prev_mask: + prev_mask = image[:, 3:, :, :] + image = image[:, :3, :, :] + + image = self.normalization(image) + return image, prev_mask + + def backbone_forward(self, image, coord_features=None): + raise NotImplementedError + + def get_model_input(self, points_nd): + device = points_nd.device + points_nd = points_nd.cpu().numpy() + points_coords = [] + points_labels = [] + for bindx in range(points_nd.shape[0]): + points = points_nd[bindx] + point_length = len(points) // 2 + point_coords = [] + point_labels = [] + for i, point in enumerate(points): + if point[0] == -1: + point_labels.append(-1) + else: + if i < point_length: + point_labels.append(1) + else: + point_labels.append(0) + point_coords.append([point[1], point[0]]) + points_coords.append(point_coords) + points_labels.append(point_labels) + coords_torch = torch.as_tensor(np.array(points_coords), dtype=torch.float, device=device) + labels_torch = torch.as_tensor(np.array(points_labels), dtype=torch.int, device=device) + + return coords_torch, labels_torch + +def split_points_by_order(tpoints: torch.Tensor, groups): + points = tpoints.cpu().numpy() + num_groups = len(groups) + bs = points.shape[0] + num_points = points.shape[1] // 2 + + groups = [x if x > 0 else num_points for x in groups] + group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) + for x in groups] + + last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int_) + for group_indx, group_size in enumerate(groups): + last_point_indx_group[:, group_indx, 1] = group_size + + for bindx in range(bs): + for pindx in range(2 * num_points): + point = points[bindx, pindx, :] + group_id = int(point[2]) + if group_id < 0: + continue + + is_negative = int(pindx >= num_points) + if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click + group_id = num_groups - 1 + + new_point_indx = last_point_indx_group[bindx, group_id, is_negative] + last_point_indx_group[bindx, group_id, is_negative] += 1 + + group_points[group_id][bindx, new_point_indx, :] = point + + group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) + for x in group_points] + + return group_points diff --git a/isegm/model/sam_modeling/mask_decoder.py b/isegm/model/sam_modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb87cec1b37d8745d832cbadcfe5b69e3371d3d --- /dev/null +++ b/isegm/model/sam_modeling/mask_decoder.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + enable_gra=False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + self.enable_gra = enable_gra + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + if self.enable_gra: + self.gra_embed = nn.Embedding(10, transformer_dim) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + gra=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + + if self.enable_gra and gra is not None: + gra_idx = torch.clamp(gra * 10 - 1, 0, 9).long() + gra_embeddings = self.gra_embed(gra_idx) + else: + gra_embeddings = None + + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + gra_embeddings=gra_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + gra_embeddings, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) # [1+4, 256] + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) # [1, 5, 256] + + # attach the granularity embedding + if gra_embeddings is not None: # [1, 256] + # tokens = torch.cat((output_tokens, sparse_prompt_embeddings, gra_embeddings.unsqueeze(1)), dim=1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings + gra_embeddings.unsqueeze(1)), dim=1) # selection 2 + else: + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # [1, 49 + 5, 256] + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) # image_embeddings: [1, 256, 64, 64], dense_prompt_embeddings: [1, 256, 64, 64] + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) # [1, 256, 64, 64] + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) # hs: [1, 54, 256], src: [1, 4096, 256] + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # [1, 4, 256] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) # [1, 256, 64, 64] + upscaled_embedding = self.output_upscaling(src) # [1, 32, 256, 256] + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) # [1, 4, 32] + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # [1, 4, 256, 256] + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/isegm/model/sam_modeling/prompt_encoder.py b/isegm/model/sam_modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd2024f0707fd7be26c0b29f6a5d27f7df376eb --- /dev/null +++ b/isegm/model/sam_modeling/prompt_encoder.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn +from torchvision.transforms import Resize + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + masks = Resize(self.mask_input_size)(masks) + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/isegm/model/sam_modeling/sam.py b/isegm/model/sam_modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..226d169500ba3e8066db53ba5f8e03d8aaba857d --- /dev/null +++ b/isegm/model/sam_modeling/sam.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + return_logits: bool = False + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + if not return_logits: + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/isegm/model/sam_modeling/transformer.py b/isegm/model/sam_modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a --- /dev/null +++ b/isegm/model/sam_modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/isegm/utils/cython/__init__.py b/isegm/utils/cython/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb66bdbba883b9477bbc1a52d8355131d32a04cb --- /dev/null +++ b/isegm/utils/cython/__init__.py @@ -0,0 +1,2 @@ +# noinspection PyUnresolvedReferences +from .dist_maps import get_dist_maps \ No newline at end of file diff --git a/isegm/utils/cython/dist_maps.py b/isegm/utils/cython/dist_maps.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffa1e3f25231cd7c48b66ef8ef5167235c3ea4e --- /dev/null +++ b/isegm/utils/cython/dist_maps.py @@ -0,0 +1,3 @@ +import pyximport; pyximport.install(pyximport=True, language_level=3) +# noinspection PyUnresolvedReferences +from ._get_dist_maps import get_dist_maps \ No newline at end of file diff --git a/isegm/utils/distributed.py b/isegm/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e48f50500ee7440be035b17107573e86bb5d24 --- /dev/null +++ b/isegm/utils/distributed.py @@ -0,0 +1,67 @@ +import torch +from torch import distributed as dist +from torch.utils import data + + +def get_rank(): + if not dist.is_available() or not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def synchronize(): + if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1: + return + dist.barrier() + + +def get_world_size(): + if not dist.is_available() or not dist.is_initialized(): + return 1 + + return dist.get_world_size() + + +def reduce_loss_dict(loss_dict): + world_size = get_world_size() + + if world_size < 2: + return loss_dict + + with torch.no_grad(): + keys = [] + losses = [] + + for k in loss_dict.keys(): + keys.append(k) + losses.append(loss_dict[k]) + + losses = torch.stack(losses, 0) + dist.reduce(losses, dst=0) + + if dist.get_rank() == 0: + losses /= world_size + + reduced_losses = {k: v for k, v in zip(keys, losses)} + + return reduced_losses + + +def get_sampler(dataset, shuffle, distributed): + if distributed: + return data.distributed.DistributedSampler(dataset, shuffle=shuffle) + + if shuffle: + return data.RandomSampler(dataset) + else: + return data.SequentialSampler(dataset) + + +def get_dp_wrapper(distributed): + class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + return DPWrapper diff --git a/isegm/utils/exp.py b/isegm/utils/exp.py new file mode 100644 index 0000000000000000000000000000000000000000..171ca8bd18862840ba2ac2dee5ca1a106b6e7a86 --- /dev/null +++ b/isegm/utils/exp.py @@ -0,0 +1,187 @@ +import os +import sys +import shutil +import pprint +from pathlib import Path +from datetime import datetime + +import yaml +import torch +from easydict import EasyDict as edict + +from .log import logger, add_logging +from .distributed import synchronize, get_world_size + + +def init_experiment(args, model_name): + model_path = Path(args.model_path) + ftree = get_model_family_tree(model_path, model_name=model_name) + + if ftree is None: + print('Models can only be located in the "models" directory in the root of the repository') + sys.exit(1) + + cfg = load_config(model_path) + update_config(cfg, args) + + cfg.distributed = args.distributed + cfg.local_rank = args.local_rank + if cfg.distributed: + torch.distributed.init_process_group(backend='nccl', init_method='env://') + if args.workers > 0: + torch.multiprocessing.set_start_method('forkserver', force=True) + + experiments_path = Path(cfg.EXPS_PATH) + exp_parent_path = experiments_path / '/'.join(ftree) + exp_parent_path.mkdir(parents=True, exist_ok=True) + + if cfg.resume_exp: + exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp) + else: + last_exp_indx = find_last_exp_indx(exp_parent_path) + exp_name = f'{last_exp_indx:03d}' + if cfg.exp_name: + exp_name += '_' + cfg.exp_name + exp_path = exp_parent_path / exp_name + synchronize() + if cfg.local_rank == 0: + exp_path.mkdir(parents=True) + + cfg.EXP_PATH = exp_path + cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' + cfg.VIS_PATH = exp_path / 'vis' + cfg.LOGS_PATH = exp_path / 'logs' + + if cfg.local_rank == 0: + cfg.LOGS_PATH.mkdir(exist_ok=True) + cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True) + cfg.VIS_PATH.mkdir(exist_ok=True) + + dst_script_path = exp_path / (model_path.stem + datetime.strftime(datetime.today(), '_%Y-%m-%d-%H-%M-%S.py')) + if args.temp_model_path: + shutil.copy(args.temp_model_path, dst_script_path) + os.remove(args.temp_model_path) + else: + shutil.copy(model_path, dst_script_path) + + synchronize() + + if cfg.gpus != '': + gpu_ids = [int(id) for id in cfg.gpus.split(',')] + else: + gpu_ids = list(range(max(cfg.ngpus, get_world_size()))) + cfg.gpus = ','.join([str(id) for id in gpu_ids]) + + cfg.gpu_ids = gpu_ids + cfg.ngpus = len(gpu_ids) + cfg.multi_gpu = cfg.ngpus > 1 + + if cfg.distributed: + cfg.device = torch.device('cuda') + cfg.gpu_ids = [cfg.gpu_ids[cfg.local_rank]] + torch.cuda.set_device(cfg.gpu_ids[0]) + else: + if cfg.multi_gpu: + os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus + ngpus = torch.cuda.device_count() + assert ngpus >= cfg.ngpus + cfg.device = torch.device(f'cuda:{cfg.gpu_ids[0]}') + + if cfg.local_rank == 0: + add_logging(cfg.LOGS_PATH, prefix='train_') + logger.info(f'Number of GPUs: {cfg.ngpus}') + if cfg.distributed: + logger.info(f'Multi-Process Multi-GPU Distributed Training') + + logger.info('Run experiment with config:') + logger.info(pprint.pformat(cfg, indent=4)) + + return cfg + + +def get_model_family_tree(model_path, terminate_name='models', model_name=None): + if model_name is None: + model_name = model_path.stem + family_tree = [model_name] + for x in model_path.parents: + if x.stem == terminate_name: + break + family_tree.append(x.stem) + else: + return None + + return family_tree[::-1] + + +def find_last_exp_indx(exp_parent_path): + indx = 0 + for x in exp_parent_path.iterdir(): + if not x.is_dir(): + continue + + exp_name = x.stem + if exp_name[:3].isnumeric(): + indx = max(indx, int(exp_name[:3]) + 1) + + return indx + + +def find_resume_exp(exp_parent_path, exp_pattern): + candidates = sorted(exp_parent_path.glob(f'{exp_pattern}*')) + if len(candidates) == 0: + print(f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"') + sys.exit(1) + elif len(candidates) > 1: + print('More than one experiment found:') + for x in candidates: + print(x) + sys.exit(1) + else: + exp_path = candidates[0] + print(f'Continue with experiment "{exp_path}"') + + return exp_path + + +def update_config(cfg, args): + for param_name, value in vars(args).items(): + if param_name.lower() in cfg or param_name.upper() in cfg: + continue + cfg[param_name] = value + + +def load_config(model_path): + model_name = model_path.stem + config_path = model_path.parent / (model_name + '.yml') + + if config_path.exists(): + cfg = load_config_file(config_path) + else: + cfg = dict() + + cwd = Path.cwd() + config_parent = config_path.parent.absolute() + while len(config_parent.parents) > 0: + config_path = config_parent / 'config.yml' + + if config_path.exists(): + local_config = load_config_file(config_path, model_name=model_name) + cfg.update({k: v for k, v in local_config.items() if k not in cfg}) + + if config_parent.absolute() == cwd: + break + config_parent = config_parent.parent + + return edict(cfg) + + +def load_config_file(config_path, model_name=None, return_edict=False): + with open(config_path, 'r') as f: + cfg = yaml.safe_load(f) + + if 'SUBCONFIGS' in cfg: + if model_name is not None and model_name in cfg['SUBCONFIGS']: + cfg.update(cfg['SUBCONFIGS'][model_name]) + del cfg['SUBCONFIGS'] + + return edict(cfg) if return_edict else cfg diff --git a/isegm/utils/exp_imports/default.py b/isegm/utils/exp_imports/default.py new file mode 100644 index 0000000000000000000000000000000000000000..b5ca593a31ad65abc6960abd16bad998b6319a98 --- /dev/null +++ b/isegm/utils/exp_imports/default.py @@ -0,0 +1,22 @@ +import torch +from functools import partial +from easydict import EasyDict as edict +from albumentations import * + +from isegm.data.datasets import * +from isegm.model.losses import * +from isegm.data.transforms import * +from isegm.engine.trainer import ISTrainer +from isegm.model.metrics import AdaptiveIoU +from isegm.data.points_sampler import MultiPointSampler +from isegm.utils.log import logger +from isegm.model import initializer + +from isegm.model.is_hrnet_model import HRNetModel +from isegm.model.is_deeplab_model import DeeplabModel +from isegm.model.is_segformer_model import SegformerModel +from isegm.model.is_hrformer_model import HRFormerModel +from isegm.model.is_swinformer_model import SwinformerModel +from isegm.model.is_plainvit_model import PlainVitModel +from isegm.model.is_plainvit_model_lora import PlainVitModel_lora +from isegm.model.is_text_graco_model import TextGraCoModel \ No newline at end of file diff --git a/isegm/utils/log.py b/isegm/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9f8bdb4bdd74d72514db8cf9cecef51001a588 --- /dev/null +++ b/isegm/utils/log.py @@ -0,0 +1,97 @@ +import io +import time +import logging +from datetime import datetime + +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +LOGGER_NAME = 'root' +LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S' + +handler = logging.StreamHandler() + +logger = logging.getLogger(LOGGER_NAME) +logger.setLevel(logging.INFO) +logger.addHandler(handler) + + +def add_logging(logs_path, prefix): + log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log' + stdout_log_path = logs_path / log_name + + fh = logging.FileHandler(str(stdout_log_path)) + formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s', + datefmt=LOGGER_DATEFMT) + fh.setFormatter(formatter) + logger.addHandler(fh) + + +class TqdmToLogger(io.StringIO): + logger = None + level = None + buf = '' + + def __init__(self, logger, level=None, mininterval=5): + super(TqdmToLogger, self).__init__() + self.logger = logger + self.level = level or logging.INFO + self.mininterval = mininterval + self.last_time = 0 + + def write(self, buf): + self.buf = buf.strip('\r\n\t ') + + def flush(self): + if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval: + self.logger.log(self.level, self.buf) + self.last_time = time.time() + + +class SummaryWriterAvg(SummaryWriter): + def __init__(self, *args, dump_period=20, **kwargs): + super().__init__(*args, **kwargs) + self._dump_period = dump_period + self._avg_scalars = dict() + + def add_scalar(self, tag, value, global_step=None, disable_avg=False): + if disable_avg or isinstance(value, (tuple, list, dict)): + super().add_scalar(tag, np.array(value), global_step=global_step) + else: + if tag not in self._avg_scalars: + self._avg_scalars[tag] = ScalarAccumulator(self._dump_period) + avg_scalar = self._avg_scalars[tag] + avg_scalar.add(value) + + if avg_scalar.is_full(): + super().add_scalar(tag, avg_scalar.value, + global_step=global_step) + avg_scalar.reset() + + +class ScalarAccumulator(object): + def __init__(self, period): + self.sum = 0 + self.cnt = 0 + self.period = period + + def add(self, value): + self.sum += value + self.cnt += 1 + + @property + def value(self): + if self.cnt > 0: + return self.sum / self.cnt + else: + return 0 + + def reset(self): + self.cnt = 0 + self.sum = 0 + + def is_full(self): + return self.cnt >= self.period + + def __len__(self): + return self.cnt diff --git a/isegm/utils/lr_decay.py b/isegm/utils/lr_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..eb17bd186e7fd239d109fb98931027173aa90b50 --- /dev/null +++ b/isegm/utils/lr_decay.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ELECTRA https://github.com/google-research/electra +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import json + + +def param_groups_lrd(model, lr, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): + """ + Parameter groups for layer-wise lr decay + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + param_group_names = {} + param_groups = {} + num_layers = len(model.backbone.blocks) + 1 + layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) + for n, p in model.backbone.named_parameters(): + if not p.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if p.ndim == 1 or n in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0. + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = get_layer_id_for_vit(n, num_layers) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_group_names: + this_scale = layer_scales[layer_id] + + param_group_names[group_name] = { + "lr_scale": this_scale, + "lr": lr * this_scale, + "weight_decay": this_decay, + "params": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "lr": lr * this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["params"].append(n) + param_groups[group_name]["params"].append(p) + + params = list(param_groups.values()) + + for n, p in model.neck.named_parameters(): + if not p.requires_grad: + continue + params.append({"params": p, "weight_decay": weight_decay}) + + for n, p in model.head.named_parameters(): + if not p.requires_grad: + continue + params.append({"params": p, "weight_decay": weight_decay}) + + return params + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ['cls_token', 'pos_embed']: + return 0 + elif name.startswith('patch_embed'): + return 0 + elif name.startswith('blocks'): + return int(name.split('.')[1]) + 1 + else: + return num_layers \ No newline at end of file diff --git a/isegm/utils/misc.py b/isegm/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..fb746eec4ea005fd701ae870218379133b63da34 --- /dev/null +++ b/isegm/utils/misc.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + +from .log import logger + + +def get_dims_with_exclusion(dim, exclude=None): + dims = list(range(dim)) + if exclude is not None: + dims.remove(exclude) + + return dims + +def part_state_dict(state_dict): + return {k: state_dict[k] for k in state_dict if ('lora_' in k) or ('gra_embed' in k)} + +def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False, save_lora=False): + if epoch is None: + checkpoint_name = 'last_checkpoint.pth' + else: + checkpoint_name = f'{epoch:03d}.pth' + + if prefix: + checkpoint_name = f'{prefix}_{checkpoint_name}' + + if not checkpoints_path.exists(): + checkpoints_path.mkdir(parents=True) + + checkpoint_path = checkpoints_path / checkpoint_name + if verbose: + logger.info(f'Save checkpoint to {str(checkpoint_path)}') + + net = net.module if multi_gpu else net + + if save_lora: + torch.save({'state_dict': part_state_dict(net.state_dict()), + 'config': net._config}, str(checkpoint_path)) + else: + torch.save({'state_dict': net.state_dict(), + 'config': net._config}, str(checkpoint_path)) + +def get_bbox_from_mask(mask): + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + return rmin, rmax, cmin, cmax + + +def expand_bbox(bbox, expand_ratio, min_crop_size=None): + rmin, rmax, cmin, cmax = bbox + rcenter = 0.5 * (rmin + rmax) + ccenter = 0.5 * (cmin + cmax) + height = expand_ratio * (rmax - rmin + 1) + width = expand_ratio * (cmax - cmin + 1) + if min_crop_size is not None: + height = max(height, min_crop_size) + width = max(width, min_crop_size) + + rmin = int(round(rcenter - 0.5 * height)) + rmax = int(round(rcenter + 0.5 * height)) + cmin = int(round(ccenter - 0.5 * width)) + cmax = int(round(ccenter + 0.5 * width)) + + return rmin, rmax, cmin, cmax + + +def clamp_bbox(bbox, rmin, rmax, cmin, cmax): + return (max(rmin, bbox[0]), min(rmax, bbox[1]), + max(cmin, bbox[2]), min(cmax, bbox[3])) + + +def get_bbox_iou(b1, b2): + h_iou = get_segments_iou(b1[:2], b2[:2]) + w_iou = get_segments_iou(b1[2:4], b2[2:4]) + return h_iou * w_iou + + +def get_segments_iou(s1, s2): + a, b = s1 + c, d = s2 + intersection = max(0, min(b, d) - max(a, c) + 1) + union = max(1e-6, max(b, d) - min(a, c) + 1) + return intersection / union + + +def get_labels_with_sizes(x): + obj_sizes = np.bincount(x.flatten()) + labels = np.nonzero(obj_sizes)[0].tolist() + labels = [x for x in labels if x != 0] + return labels, obj_sizes[labels].tolist() diff --git a/isegm/utils/serialization.py b/isegm/utils/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe7b6e6614412f2518ab1eff07123bf26505081 --- /dev/null +++ b/isegm/utils/serialization.py @@ -0,0 +1,110 @@ +from functools import wraps +from copy import deepcopy +import inspect +import torch.nn as nn + + +def serialize(init): + parameters = list(inspect.signature(init).parameters) + + @wraps(init) + def new_init(self, *args, **kwargs): + params = deepcopy(kwargs) + for pname, value in zip(parameters[1:], args): + params[pname] = value + + config = { + 'class': get_classname(self.__class__), + 'params': dict() + } + specified_params = set(params.keys()) + + for pname, param in get_default_params(self.__class__).items(): + if pname not in params: + params[pname] = param.default + + for name, value in list(params.items()): + param_type = 'builtin' + if inspect.isclass(value): + param_type = 'class' + value = get_classname(value) + + config['params'][name] = { + 'type': param_type, + 'value': value, + 'specified': name in specified_params + } + + setattr(self, '_config', config) + init(self, *args, **kwargs) + + return new_init + + +def load_model(config, eval_ritm, **kwargs): + model_class = get_class_from_str(config['class']) + model_default_params = get_default_params(model_class) + + model_args = dict() + for pname, param in config['params'].items(): + value = param['value'] + if param['type'] == 'class': + value = get_class_from_str(value) + + if pname not in model_default_params and not param['specified']: + continue + + assert pname in model_default_params + if not param['specified'] and model_default_params[pname].default == value: + continue + model_args[pname] = value + model_args.update(kwargs) + + # This ugly hardcode is only to support evalution for RITM models + # Ignore it if you are evaluting SimpleClick models. + if eval_ritm: + model_args['use_rgb_conv'] = True + + return model_class(**model_args) + + +def get_config_repr(config): + config_str = f'Model: {config["class"]}\n' + for pname, param in config['params'].items(): + value = param["value"] + if param['type'] == 'class': + value = value.split('.')[-1] + param_str = f'{pname:<22} = {str(value):<12}' + if not param['specified']: + param_str += ' (default)' + config_str += param_str + '\n' + return config_str + + +def get_default_params(some_class): + params = dict() + for mclass in some_class.mro(): + if mclass is nn.Module or mclass is object: + continue + mclass_params = inspect.signature(mclass.__init__).parameters + for pname, param in mclass_params.items(): + if param.default != param.empty and pname not in params: + params[pname] = param + + return params + + +def get_classname(cls): + module = cls.__module__ + name = cls.__qualname__ + if module is not None and module != "__builtin__": + name = module + "." + name + return name + + +def get_class_from_str(class_str): + components = class_str.split('.') + mod = __import__('.'.join(components[:-1])) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod diff --git a/isegm/utils/vis.py b/isegm/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..914ebcc4c997daf4bebe4d8c65d588ea85e9d929 --- /dev/null +++ b/isegm/utils/vis.py @@ -0,0 +1,154 @@ +from functools import lru_cache + +import cv2 +import numpy as np + + +def visualize_instances(imask, bg_color=255, + boundaries_color=None, boundaries_width=1, boundaries_alpha=0.8): + num_objects = imask.max() + 1 + palette = get_palette(num_objects) + if bg_color is not None: + palette[0] = bg_color + + result = palette[imask].astype(np.uint8) + if boundaries_color is not None: + boundaries_mask = get_boundaries(imask, boundaries_width=boundaries_width) + tresult = result.astype(np.float32) + tresult[boundaries_mask] = boundaries_color + tresult = tresult * boundaries_alpha + (1 - boundaries_alpha) * result + result = tresult.astype(np.uint8) + + return result + + +@lru_cache(maxsize=16) +def get_palette(num_cls): + palette = np.zeros(3 * num_cls, dtype=np.int32) + + for j in range(0, num_cls): + lab = j + i = 0 + + while lab > 0: + palette[j*3 + 0] |= (((lab >> 0) & 1) << (7-i)) + palette[j*3 + 1] |= (((lab >> 1) & 1) << (7-i)) + palette[j*3 + 2] |= (((lab >> 2) & 1) << (7-i)) + i = i + 1 + lab >>= 3 + + return palette.reshape((-1, 3)) + + +def visualize_mask(mask, num_cls): + palette = get_palette(num_cls) + mask[mask == -1] = 0 + + return palette[mask].astype(np.uint8) + + +def visualize_proposals(proposals_info, point_color=(255, 0, 0), point_radius=1): + proposal_map, colors, candidates = proposals_info + + proposal_map = draw_probmap(proposal_map) + for x, y in candidates: + proposal_map = cv2.circle(proposal_map, (y, x), point_radius, point_color, -1) + + return proposal_map + + +def draw_probmap(x): + return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_HOT) + + +def draw_points(image, points, color, radius=3): + image = image.copy() + for p in points: + if p[0] < 0: + continue + if len(p) == 3: + pradius = {0: 8, 1: 6, 2: 4}[p[2]] if p[2] < 3 else 2 + else: + pradius = radius + image = cv2.circle(image, (int(p[1]), int(p[0])), pradius, color, -1) + + return image + + +def draw_instance_map(x, palette=None): + num_colors = x.max() + 1 + if palette is None: + palette = get_palette(num_colors) + + return palette[x].astype(np.uint8) + + +def blend_mask(image, mask, alpha=0.6): + if mask.min() == -1: + mask = mask.copy() + 1 + + imap = draw_instance_map(mask) + result = (image * (1 - alpha) + alpha * imap).astype(np.uint8) + return result + + +def get_boundaries(instances_masks, boundaries_width=1): + boundaries = np.zeros((instances_masks.shape[0], instances_masks.shape[1]), dtype=bool) + + for obj_id in np.unique(instances_masks.flatten()): + if obj_id == 0: + continue + + obj_mask = instances_masks == obj_id + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + inner_mask = cv2.erode(obj_mask.astype(np.uint8), kernel, iterations=boundaries_width).astype(bool) + + obj_boundary = np.logical_xor(obj_mask, np.logical_and(inner_mask, obj_mask)) + boundaries = np.logical_or(boundaries, obj_boundary) + return boundaries + + +def draw_with_blend_and_clicks(img, mask=None, alpha=0.6, clicks_list=None, pos_color=(0, 255, 0), + neg_color=(255, 0, 0), radius=4): + result = img.copy() + + if mask is not None: + palette = get_palette(np.max(mask) + 1) + rgb_mask = palette[mask.astype(np.uint8)] + + mask_region = (mask > 0).astype(np.uint8) + result = result * (1 - mask_region[:, :, np.newaxis]) + \ + (1 - alpha) * mask_region[:, :, np.newaxis] * result + \ + alpha * rgb_mask + result = result.astype(np.uint8) + + # result = (result * (1 - alpha) + alpha * rgb_mask).astype(np.uint8) + + if clicks_list is not None and len(clicks_list) > 0: + pos_points = [click.coords for click in clicks_list if click.is_positive] + neg_points = [click.coords for click in clicks_list if not click.is_positive] + + result = draw_points(result, pos_points, pos_color, radius=radius) + result = draw_points(result, neg_points, neg_color, radius=radius) + + return result + +def draw_contour(img, mask, color=(253, 211, 106), thickness=2): + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) + img = cv2.drawContours(img, contours, -1, color=color, thickness=thickness) + return img + + +def draw_mask(img, mask, opacity=0.6): + mask = (mask > 0)[..., None] + img = img * mask + img * ~mask * (1 - opacity) + return img.astype(np.uint8) + + +def draw_click(img, clicks, radius=5): + for click in clicks: + color = (146, 208, 80) if click.is_positive else (192, 0, 0) + coords = (click.coords[1], click.coords[0]) + img = cv2.circle(img.copy(), coords, int(radius * 1.5), (0, 0, 0), -1) + img = cv2.circle(img, coords, radius, color, -1) + return img \ No newline at end of file diff --git a/isegm/utils/visualization.py b/isegm/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..4460f13d361389b9f62cfc368a93d7e9db7bbfbb --- /dev/null +++ b/isegm/utils/visualization.py @@ -0,0 +1,89 @@ +import os +import os.path as osp +import cv2 +import torch +import numpy as np + +from isegm.inference.clicker import Clicker +from isegm.inference import utils + + +def inference(image, gt_mask, predictor, threshold=0.5, min_clicks=1, max_clicks=20): + clicker = Clicker(gt_mask=gt_mask) + pred_mask = np.zeros_like(gt_mask) + ious_list = [] + probs_list = [] + masks_list = [] + + with torch.no_grad(): + predictor.set_input_image(image) + + for click_indx in range(max_clicks): + clicker.make_next_click(pred_mask) + pred_probs = predictor.get_prediction(clicker) + pred_mask = pred_probs > threshold + + iou = utils.get_iou(gt_mask, pred_mask) + ious_list.append(iou) + + probs_list.append(pred_probs.copy()) + masks_list.append(pred_mask.copy()) + + return clicker.clicks_list, np.array(ious_list, dtype=np.float32), probs_list, masks_list + + +def visualization(sample, predictor, mask=True, score=False, contour=True, click=True, threshold=0.5, min_clicks=1, max_clicks=20, out_dir=None): + mask = False if score else mask + if out_dir is not None: + out_dir = osp.join(out_dir, str(sample.sample_id)) + os.makedirs(out_dir, exist_ok=True) + + clicks, ious, probs, masks = inference(sample.image, sample.gt_mask(sample.objects_ids[0]), + predictor, threshold=threshold, + min_clicks=min_clicks, max_clicks=max_clicks) + + outputs = [] + + show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR) + gt_mask = sample.gt_mask(sample.objects_ids[0]).astype(np.bool8) + if mask: + show[~gt_mask] = (show[~gt_mask] * 0.4).astype(np.uint8) + if contour: + contours, _ = cv2.findContours(gt_mask.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) + show = cv2.drawContours(show, contours, -1, (106, 211, 253), 2) + if out_dir is not None: + cv2.imwrite(osp.join(out_dir, f'gt_mask.jpg'), show) + outputs.append(show) + + for i in range(len(clicks)): + show = cv2.cvtColor(sample.image.copy(), cv2.COLOR_RGB2BGR) + + if score: + score_map = cv2.applyColorMap((probs[i] * 255).astype(np.uint8), cv2.COLORMAP_JET) + show = cv2.addWeighted(show, 0.5, score_map, 0.5, 0) + + if mask: + show[~masks[i]] = (show[~masks[i]] * 0.4).astype(np.uint8) + + if contour: + contours, _ = cv2.findContours(masks[i].astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) + min_area_threshold = 10 + for cur_contour in contours: + area = cv2.contourArea(cur_contour) + if area > min_area_threshold: + show = cv2.drawContours(show, [cur_contour], -1, (106, 211, 253), 2) + # show = cv2.drawContours(show, contours, -1, (106, 211, 253), 4) + + if click: + for j in range(i + 1): + color = (80, 208, 146) if clicks[j].is_positive else (0, 0, 192) + coords = (clicks[j].coords[1], clicks[j].coords[0]) + show = cv2.circle(show, coords, 7, (0, 0, 0), -1) + show = cv2.circle(show, coords, 5, color, -1) + + outputs.append(show) + + if out_dir is not None: + cv2.imwrite(osp.join(out_dir, f'{i+1}_{ious[i]:.2f}.jpg'), show) + + return outputs, ious diff --git a/web_app/__init__.py b/web_app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae436fb84e2114f8f71292a0dedff49f5179cfbb --- /dev/null +++ b/web_app/__init__.py @@ -0,0 +1 @@ +from .app import GraCoWebApplication \ No newline at end of file diff --git a/web_app/app.py b/web_app/app.py new file mode 100644 index 0000000000000000000000000000000000000000..65223c1b42ad3baa9cf8ffdce857da352237064a --- /dev/null +++ b/web_app/app.py @@ -0,0 +1,34 @@ +import gradio as gr +import torch + +from .segmentation import InteractiveSegmentationInterface + + +_HEADER = """ +
+

GraCo: + Granularity-Cotrollable Interactive Segmentation

+
+""" +#

+# +# +# +# +# +#

+ + +class GraCoWebApplication(object): + + def __init__(self, device: torch.device = None): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + with gr.Blocks() as self._blocks: + gr.Markdown(_HEADER) + with gr.Tab('Granularity-Cotrollable Interactive Segmentation'): + InteractiveSegmentationInterface(device=device) + + def launch(self): + self._blocks.launch() diff --git a/web_app/segmentation.py b/web_app/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8975aa4cadfb4f743002836df9277e671ac191 --- /dev/null +++ b/web_app/segmentation.py @@ -0,0 +1,207 @@ +import gc + +import gradio as gr +import numpy as np +import torch + +from isegm.inference.clicker import Click, Clicker +from isegm.inference.predictors import BasePredictor +from isegm.inference.transforms import ZoomIn +from isegm.inference.utils import load_single_is_model +from isegm.utils.vis import draw_click, draw_contour, draw_mask + + +class InteractiveSegmentationInterface(object): + + def __init__(self, device: torch.device): + self.device = device + + self._clicker = Clicker() + + self._pretrained_models = { + 'GraCo_SimpleClick_ViT-B': {"weights": './weights/simpleclick/sbd_vit_base.pth', "lora": './weights/GraCo/sbd_vit_base_lora.pth'} + } + self._predictor = None + + self._pred_prob = None + self._masked_img = None + + self._build_interface() + self._add_functions() + + def _build_interface(self): + with gr.Row(): + with gr.Column(): + with gr.Row(): + choices = list(self._pretrained_models.keys()) + self.model_name = gr.Dropdown(choices=choices, value=choices[0], label='Model') + self.loaded_model = gr.Textbox(label='Loaded Model', interactive=False) + self.load_button = gr.Button(value='Load Model') + with gr.Row(): + self.input_img = gr.Image(label='Input Image') + self.click_map = gr.Image( + label='Click Map', show_download_button=False, interactive=False) + + with gr.Row(): + self.add_button = gr.Button(value='Add Click', interactive=False) + self.undo_button = gr.Button(value='Undo', interactive=False) + self.submit_button = gr.Button(value='Segment', interactive=False) + + self.drawing_board = gr.Image( + label='Add Click', + tool='sketch', + interactive=False, + visible=False, + brush_radius=15) + with gr.Row(): + self.pos_button = gr.Button(value='Add Positive', visible=False) + self.neg_button = gr.Button(value='Add Negative', visible=False) + self.cancel_button = gr.Button(value='Cancel', visible=False) + + with gr.Column(): + self.threshold = gr.Slider( + label='Threshold', + minimum=0.0, + maximum=1.0, + value=0.5, + step=0.01, + interactive=False) + self.granularity = gr.Slider( + label='Granularity', + minimum=0.0, + maximum=1.0, + value=1.0, + step=0.01, + interactive=False) + self.seg_mask = gr.Image( + label='Segmentation', show_download_button=False, interactive=False) + + def _add_functions(self): + self.input_img.upload( + fn=self._load_image, + inputs=self.input_img, + outputs=[ + self.click_map, self.seg_mask, self.add_button, self.undo_button, + self.submit_button, self.threshold, self.granularity, self.drawing_board, self.pos_button, + self.neg_button, self.cancel_button + ]) + + self.load_button.click( + fn=self._load_model, + inputs=[self.model_name, self.input_img], + outputs=[self.loaded_model, self.submit_button]) + + self.add_button.click( + fn=self._create_click, + outputs=[self.drawing_board, self.pos_button, self.neg_button, self.cancel_button]) + self.undo_button.click( + fn=self._undo_click, + outputs=[self.click_map, self.drawing_board, self.undo_button, self.submit_button]) + + self.pos_button.click( + fn=self._add_pos_click, + inputs=self.drawing_board, + outputs=[ + self.click_map, self.undo_button, self.submit_button, self.drawing_board, + self.pos_button, self.neg_button, self.cancel_button + ]) + self.neg_button.click( + fn=self._add_neg_click, + inputs=self.drawing_board, + outputs=[ + self.click_map, self.undo_button, self.submit_button, self.drawing_board, + self.pos_button, self.neg_button, self.cancel_button + ]) + self.cancel_button.click( + fn=self._cancel, + outputs=[self.drawing_board, self.pos_button, self.neg_button, self.cancel_button]) + + self.submit_button.click( + fn=self._segment, + inputs=[self.input_img, self.threshold, self.granularity], + outputs=[self.seg_mask, self.click_map, self.drawing_board, self.threshold, self.granularity]) + self.threshold.release( + fn=self._show_mask, + inputs=self.threshold, + outputs=[self.seg_mask, self.click_map, self.drawing_board]) + + @property + def _click_map(self): + if self._img is None: + return None + img = self._img if self._masked_img is None else self._masked_img + return draw_click(img, self._clicker.get_clicks()) + + def _load_image(self, img): + self._img = img + self._img_size = img.shape[:2] + self._clicker.reset_clicks() + self._pred_prob = None + self._masked_img = None + return (self._click_map, None, gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=False), + gr.update(interactive=False), gr.update(interactive=True), *self._cancel()) + + def _load_model(self, model_name, img): + if self._predictor is not None: + del self._predictor + self._predictor = None + gc.collect() + torch.cuda.empty_cache() + state_dict = torch.load(self._pretrained_models[model_name]["weights"], map_location='cpu') + model = load_single_is_model(state_dict, device=self.device, lora_checkpoint=self._pretrained_models[model_name]["lora"], eval_ritm=False) + zoom_in = ZoomIn(skip_clicks=-1, target_size=(448, 448)) + self._predictor = BasePredictor(model, device=self.device, zoom_in=zoom_in, with_flip=True) + enable_submit = img is not None and len(self._clicker) > 0 + return model_name, gr.update(interactive=enable_submit) + + def _create_click(self): + return gr.update( + value=self._click_map, interactive=True, + visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) + + def _cancel(self): + return gr.update( + interactive=False, visible=False), gr.update(visible=False), gr.update( + visible=False), gr.update(visible=False) + + def _add_click(self, inp, is_positive): + coords = np.nonzero(inp['mask'].sum(axis=-1)) + if len(coords[0]) == 0: + return (self._click_map, gr.update(interactive=False), gr.update(interactive=False), + *self._cancel()) + coords = (round(coords[0].mean()), round(coords[1].mean())) + click = Click(is_positive=is_positive, coords=coords) + self._clicker.add_click(click) + return (self._click_map, gr.update(interactive=True), + gr.update(interactive=self._predictor is not None), *self._cancel()) + + def _add_pos_click(self, inp): + return self._add_click(inp, is_positive=True) + + def _add_neg_click(self, inp): + return self._add_click(inp, is_positive=False) + + def _undo_click(self): + self._clicker._remove_last_click() + has_clicks = len(self._clicker) > 0 + click_map = self._click_map + return ( + click_map, + click_map, + gr.update(interactive=has_clicks), + gr.update(interactive=has_clicks), + ) + + @torch.no_grad() + def _segment(self, img, threshold, granularity): + self._predictor.set_input_image(img) + self._pred_prob = self._predictor.get_prediction(self._clicker, gra=granularity) + return (*self._show_mask(threshold), gr.update(value=0.5, interactive=True), gr.update(interactive=True)) + + def _show_mask(self, threshold): + mask = self._pred_prob > threshold + img = draw_mask(self._img, mask) + img = draw_contour(img, mask) + self._masked_img = img + click_map = self._click_map + return img, click_map, click_map