import torch import numpy as np from .log import logger def get_dims_with_exclusion(dim, exclude=None): dims = list(range(dim)) if exclude is not None: dims.remove(exclude) return dims def part_state_dict(state_dict): return {k: state_dict[k] for k in state_dict if ('lora_' in k) or ('gra_embed' in k)} def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False, save_lora=False): if epoch is None: checkpoint_name = 'last_checkpoint.pth' else: checkpoint_name = f'{epoch:03d}.pth' if prefix: checkpoint_name = f'{prefix}_{checkpoint_name}' if not checkpoints_path.exists(): checkpoints_path.mkdir(parents=True) checkpoint_path = checkpoints_path / checkpoint_name if verbose: logger.info(f'Save checkpoint to {str(checkpoint_path)}') net = net.module if multi_gpu else net if save_lora: torch.save({'state_dict': part_state_dict(net.state_dict()), 'config': net._config}, str(checkpoint_path)) else: torch.save({'state_dict': net.state_dict(), 'config': net._config}, str(checkpoint_path)) def get_bbox_from_mask(mask): rows = np.any(mask, axis=1) cols = np.any(mask, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] return rmin, rmax, cmin, cmax def expand_bbox(bbox, expand_ratio, min_crop_size=None): rmin, rmax, cmin, cmax = bbox rcenter = 0.5 * (rmin + rmax) ccenter = 0.5 * (cmin + cmax) height = expand_ratio * (rmax - rmin + 1) width = expand_ratio * (cmax - cmin + 1) if min_crop_size is not None: height = max(height, min_crop_size) width = max(width, min_crop_size) rmin = int(round(rcenter - 0.5 * height)) rmax = int(round(rcenter + 0.5 * height)) cmin = int(round(ccenter - 0.5 * width)) cmax = int(round(ccenter + 0.5 * width)) return rmin, rmax, cmin, cmax def clamp_bbox(bbox, rmin, rmax, cmin, cmax): return (max(rmin, bbox[0]), min(rmax, bbox[1]), max(cmin, bbox[2]), min(cmax, bbox[3])) def get_bbox_iou(b1, b2): h_iou = get_segments_iou(b1[:2], b2[:2]) w_iou = get_segments_iou(b1[2:4], b2[2:4]) return h_iou * w_iou def get_segments_iou(s1, s2): a, b = s1 c, d = s2 intersection = max(0, min(b, d) - max(a, c) + 1) union = max(1e-6, max(b, d) - min(a, c) + 1) return intersection / union def get_labels_with_sizes(x): obj_sizes = np.bincount(x.flatten()) labels = np.nonzero(obj_sizes)[0].tolist() labels = [x for x in labels if x != 0] return labels, obj_sizes[labels].tolist()