import os import json import random import torch import ijson import numpy as np from PIL import Image from torchvision.transforms import ToTensor from torchvision.ops import box_convert, clip_boxes_to_image from re_classifier import REClassifier from utils import progressbar def collate_fn(batch): image = torch.stack([s['image'] for s in batch], dim=0) image_size = torch.FloatTensor([s['image_size'] for s in batch]) # bbox = torch.stack([s['bbox'] for s in batch], dim=0) bbox = torch.cat([s['bbox'] for s in batch], dim=0) # bbox_raw = torch.stack([s['bbox_raw'] for s in batch], dim=0) bbox_raw = torch.cat([s['bbox_raw'] for s in batch], dim=0) expr = [s['expr'] for s in batch] tok = None if batch[0]['tok'] is not None: tok = { 'input_ids': torch.cat([s['tok']['input_ids'] for s in batch], dim=0), 'attention_mask': torch.cat([s['tok']['attention_mask'] for s in batch], dim=0) } # dynamic batching max_length = max([s['tok']['length'] for s in batch]) tok = { 'input_ids': tok['input_ids'][:, :max_length], 'attention_mask': tok['attention_mask'][:, :max_length], } mask = None if batch[0]['mask'] is not None: mask = torch.stack([s['mask'] for s in batch], dim=0) mask_bbox = None if batch[0]['mask_bbox'] is not None: mask_bbox = torch.stack([s['mask_bbox'] for s in batch], dim=0) tr_param = [s['tr_param'] for s in batch] return { 'image': image, 'image_size': image_size, 'bbox': bbox, 'bbox_raw': bbox_raw, 'expr': expr, 'tok': tok, 'tr_param': tr_param, 'mask': mask, 'mask_bbox': mask_bbox, } class RECDataset(torch.utils.data.Dataset): def __init__(self, transform=None, tokenizer=None, max_length=32, with_mask_bbox=False): super().__init__() self.samples = [] # list of samples: [(file_name, expresion, bbox)] self.transform = transform self.tokenizer = tokenizer self.max_length = int(max_length) self.with_mask_bbox = bool(with_mask_bbox) def tokenize(self, inp, max_length): return self.tokenizer( inp, return_tensors='pt', padding='max_length', return_token_type_ids=False, return_attention_mask=True, add_special_tokens=True, truncation=True, max_length=max_length ) def print_stats(self): print(f'{len(self.samples)} samples') lens = [len(expr.split()) for _, expr, _ in self.samples] print('expression lengths stats: ' f'min={np.min(lens):.1f}, ' f'mean={np.mean(lens):.1f}, ' f'median={np.median(lens):.1f}, ' f'max={np.max(lens):.1f}, ' f'99.9P={np.percentile(lens, 99.9):.1f}' ) def __len__(self): return len(self.samples) def __getitem__(self, idx): file_name, expr, bbox = self.samples[idx] if not os.path.exists(file_name): raise IOError(f'{file_name} not found') img = Image.open(file_name).convert('RGB') # if isinstance(expr, (list, tuple)): # expr = random.choice(expr) # image size as read from disk (PIL) W0, H0 = img.size # # ensure box coordinates fall inside the image # bbox = clip_boxes_to_image(bbox, (H0, W0)) # assert torch.all(bbox[:, (0, 1)] <= bbox[:, (2, 3)]) # xyxy format sample = { 'image': img, 'image_size': (H0, W0), # image original size 'bbox': bbox.clone(), # box transformations are inplace ops 'bbox_raw': bbox.clone(), # raw boxes w/o any transformation (in pixels) 'expr': expr, 'tok': None, 'mask': torch.ones((1, H0, W0), dtype=torch.float32), # visibiity mask 'mask_bbox': None, # target bbox mask } # apply transforms if self.transform is None: sample['image'] = ToTensor()(sample['image']) else: sample = self.transform(sample) # tokenize after the transformations (just in case there where a left<>right substitution) if self.tokenizer is not None: sample['tok'] = self.tokenize(sample['expr'], self.max_length) sample['tok']['length'] = sample['tok']['attention_mask'].sum(1).item() # bbox segmentation mask if self.with_mask_bbox: # image size after transforms _, H, W = sample['image'].size() # transformed bbox in pixels bbox = sample['bbox'].clone() bbox[:, (0, 2)] *= W bbox[:, (1, 3)] *= H bbox = clip_boxes_to_image((bbox + 0.5).long(), (H, W)) # output mask sample['mask_bbox'] = torch.zeros((1, H, W), dtype=torch.float32) for x1, y1, x2, y2 in bbox.tolist(): sample['mask_bbox'][:, y1:y2+1, x1:x2+1] = 1.0 return sample class RegionDescriptionsVisualGnome(RECDataset): def __init__(self, data_root, transform=None, tokenizer=None, max_length=32, with_mask_bbox=False): super().__init__(transform=transform, tokenizer=tokenizer, max_length=max_length, with_mask_bbox=with_mask_bbox) # if available, read COCO IDs from the val, testA and testB splits from # the RefCOCO dataset try: with open('./refcoco_valtest_ids.txt', 'r') as fh: refcoco_ids = [int(lin.strip()) for lin in fh.readlines()] except: refcoco_ids = [] def path_from_url(fname): return os.path.join(data_root, fname[fname.index('VG_100K'):]) with open(os.path.join(data_root, 'image_data.json'), 'r') as f: image_data = { data['image_id']: path_from_url(data['url']) for data in json.load(f) if data['coco_id'] is None or data['coco_id'] not in refcoco_ids } print(f'{len(image_data)} images') self.samples = [] with open(os.path.join(data_root, 'region_descriptions.json'), 'r') as f: for record in progressbar(ijson.items(f, 'item.regions.item'), desc='loading data'): if record['image_id'] not in image_data: continue file_name = image_data[record['image_id']] expr = record['phrase'] bbox = [record['x'], record['y'], record['width'], record['height']] bbox = torch.atleast_2d(torch.FloatTensor(bbox)) bbox = box_convert(bbox, 'xywh', 'xyxy') # xyxy self.samples.append((file_name, expr, bbox)) self.print_stats() class ReferDataset(RECDataset): def __init__(self, data_root, dataset, split_by, split, transform=None, tokenizer=None, max_length=32, with_mask_bbox=False): super().__init__(transform=transform, tokenizer=tokenizer, max_length=max_length, with_mask_bbox=with_mask_bbox) # https://github.com/lichengunc/refer try: import sys sys.path.append('refer') from refer import REFER except: raise RuntimeError('create a symlink to valid refer compilation ' '(see https://github.com/lichengunc/refer)') refer = REFER(data_root, dataset, split_by) ref_ids = sorted(refer.getRefIds(split=split)) self.samples = [] for rid in progressbar(ref_ids, desc='loading data'): ref = refer.Refs[rid] ann = refer.refToAnn[rid] file_name = refer.Imgs[ref['image_id']]['file_name'] if dataset == 'refclef': file_name = os.path.join( 'refer', 'data', 'images', 'saiapr_tc-12', file_name ) else: coco_set = file_name.split('_')[1] file_name = os.path.join( 'refer', 'data', 'images', 'mscoco', coco_set, file_name ) bbox = ann['bbox'] bbox = torch.atleast_2d(torch.FloatTensor(bbox)) bbox = box_convert(bbox, 'xywh', 'xyxy') # xyxy sentences = [s['sent'] for s in ref['sentences']] if 'train' in split: # remove repeated expresions sentences = list(set(sentences)) sentences = sorted(sentences) self.samples += [(file_name, expr, bbox) for expr in sentences] self.print_stats() class RefCLEF(ReferDataset): def __init__(self, *args, **kwargs): assert args[0] in ('train', 'val', 'test') super().__init__('refer/data', 'refclef', 'berkeley', *args, **kwargs) class RefCOCO(ReferDataset): def __init__(self, *args, **kwargs): assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB') super().__init__('refer/data', 'refcoco', 'unc', *args, **kwargs) class RefCOCOp(ReferDataset): def __init__(self, *args, **kwargs): assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB') super().__init__('refer/data', 'refcoco+', 'unc', *args, **kwargs) class RefCOCOg(ReferDataset): def __init__(self, *args, **kwargs): assert args[0] in ('train', 'val', 'test') super().__init__('refer/data', 'refcocog', 'umd', *args, **kwargs)