PolyFormer / data /refcoco_dataset.py
jiang
init commit
650c5f6
raw
history blame
12.1 kB
# ------------------------------------------------------------------------
# Modified from OFA (https://github.com/OFA-Sys/OFA)
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
# ------------------------------------------------------------------------
# Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from io import BytesIO
import logging
import warnings
import numpy as np
import torch
import base64
import utils.transforms as T
import math
from PIL import Image, ImageFile
from data import data_utils
from data.base_dataset import BaseDataset
from bert.tokenization_bert import BertTokenizer
from data.poly_utils import string_to_polygons, downsample_polygons, polygons_to_string, points_to_token_string
import cv2
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
class RefcocoDataset(BaseDataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=80,
max_tgt_length=30,
patch_image_size=512,
imagenet_default_mean_and_std=False,
num_bins=1000,
max_image_size=512
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.patch_image_size = patch_image_size
self.num_bins = num_bins
if imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
# for positioning
self.positioning_transform = T.Compose([
T.RandomResize([patch_image_size], max_size=patch_image_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
])
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def __getitem__(self, index):
data = self.dataset[index]
if len(data) == 7:
uniq_id, base64_str, seg64_str, text, poly_original, region_coord, poly_interpolated = data
train = True
else:
uniq_id, base64_str, seg64_str, text, poly, region_coord = data
train = False
# load image and segmentation labels
image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
label = Image.open(BytesIO(base64.urlsafe_b64decode(seg64_str)))
label = np.asarray(label)
label = cv2.resize(label, [self.patch_image_size, self.patch_image_size], interpolation=cv2.INTER_NEAREST)
w, h = image.size
patch_image = self.positioning_transform(image, target=None)
resize_h = self.patch_image_size
resize_w = self.patch_image_size
patch_mask = torch.tensor([True])
if train:
prob = np.random.uniform()
if prob < 0.5:
polygons_interpolated = string_to_polygons(poly_interpolated)
ds_rate = np.random.randint(25, 41)
polygons_augmented = downsample_polygons(polygons_interpolated, ds_rate)
poly = polygons_to_string(polygons_augmented)
else:
poly = poly_original
polygons = string_to_polygons(poly)
polygons_scaled = []
for polygon in polygons:
n_point = len(polygon) // 2
scale = np.concatenate([np.array([w, h]) for _ in range(n_point)], 0)
polygon = polygon / scale
polygon = polygon.reshape(n_point, 2)
polygons_scaled.append(polygon)
x0, y0, x1, y1 = region_coord.strip().split(',')
region_points = [float(x0), float(y0), float(x1), float(y1)]
region = np.array(region_points)
region_points = region_points / np.array([w, h, w, h]) # scaled to [0,1]
region_points = torch.tensor(region_points.reshape(2, 2))
quant_box = region_points * (self.num_bins - 1)
quant_box11 = [[math.floor(p[0]), math.floor(p[1])] for p in quant_box]
quant_box21 = [[math.ceil(p[0]), math.floor(p[1])] for p in quant_box]
quant_box12 = [[math.floor(p[0]), math.ceil(p[1])] for p in quant_box]
quant_box22 = [[math.ceil(p[0]), math.ceil(p[1])] for p in quant_box]
quant_poly = [poly * (self.num_bins - 1) for poly in polygons_scaled]
quant_poly11 = [[[math.floor(p[0]), math.floor(p[1])] for p in poly] for poly in quant_poly]
quant_poly21 = [[[math.ceil(p[0]), math.floor(p[1])] for p in poly] for poly in quant_poly]
quant_poly12 = [[[math.floor(p[0]), math.ceil(p[1])] for p in poly] for poly in quant_poly]
quant_poly22 = [[[math.ceil(p[0]), math.ceil(p[1])] for p in poly] for poly in quant_poly]
region_coord11, _ = points_to_token_string(quant_box11, quant_poly11)
region_coord21, _ = points_to_token_string(quant_box21, quant_poly21)
region_coord12, _ = points_to_token_string(quant_box12, quant_poly12)
region_coord22, token_type = points_to_token_string(quant_box22, quant_poly22)
# compute bilinear interpolation coefficient
delta_x1 = [0] + [p[0] - math.floor(p[0]) for p in quant_box] # [0] for bos token
for polygon in quant_poly:
delta = [poly_point[0] - math.floor(poly_point[0]) for poly_point in polygon]
delta_x1.extend(delta)
delta_x1.extend([0]) # for separator token
delta_x1 = delta_x1[:-1] # there is no separator token in the end
delta_x1 = torch.tensor(delta_x1)
delta_x2 = 1 - delta_x1
delta_y1 = [0] + [p[1] - math.floor(p[1]) for p in quant_box] # [0] for bos token
for polygon in quant_poly:
delta = [poly_point[1] - math.floor(poly_point[1]) for poly_point in polygon]
delta_y1.extend(delta)
delta_y1.extend([0]) # for separator token
delta_y1 = delta_y1[:-1] # there is no separator token in the end
delta_y1 = torch.tensor(delta_y1)
delta_y2 = 1 - delta_y1
token_type.append(2) # 2 for eos token
src_caption = self.pre_caption(text, self.max_src_length)
prompt = ' which region does the text " {} " describe?'.format(src_caption)
# tgt for input
tgt_item11 = self.encode_text(region_coord11, use_bpe=False)
tgt_item12 = self.encode_text(region_coord12, use_bpe=False)
tgt_item21 = self.encode_text(region_coord21, use_bpe=False)
tgt_item22 = self.encode_text(region_coord22, use_bpe=False)
# tgt for output
target_item = region_points
for poly in polygons_scaled:
target_item = torch.cat([target_item, torch.tensor(poly), torch.tensor([[0, 0]])], dim=0) # [0, 0] is padding token for separator and eos
#target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item11 = torch.cat([self.bos_item, tgt_item11])
prev_output_item12 = torch.cat([self.bos_item, tgt_item12])
prev_output_item21 = torch.cat([self.bos_item, tgt_item21])
prev_output_item22 = torch.cat([self.bos_item, tgt_item22])
example = {
"id": uniq_id,
"source": prompt,
"patch_image": patch_image,
"patch_mask": patch_mask,
"target": target_item,
"prev_output_tokens_11": prev_output_item11,
"prev_output_tokens_12": prev_output_item12,
"prev_output_tokens_21": prev_output_item21,
"prev_output_tokens_22": prev_output_item22,
"delta_x1": delta_x1,
"delta_y1": delta_y1,
"delta_x2": delta_x2,
"delta_y2": delta_y2,
"w_resize_ratio": torch.tensor(resize_w / w),
"h_resize_ratio": torch.tensor(resize_h / h),
"region_coord": torch.tensor(region),
"token_type": torch.tensor(token_type),
"w": torch.tensor(w),
"h": torch.tensor(h),
"label": label,
"n_poly": len(polygons),
"text": src_caption
}
return example
def collate(self, samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key, padding_item):
return data_utils.collate_tokens(
[s[key] for s in samples],
padding_item,
eos_idx=eos_idx,
)
id = np.array([s["id"] for s in samples])
captions = [s["source"] for s in samples]
tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt")
src_tokens = tokenized["input_ids"]
att_masks = tokenized["attention_mask"]
src_lengths = torch.LongTensor(att_masks.ne(0).long().sum())
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
delta_x1 = merge("delta_x1", 0)
delta_y1 = merge("delta_y1", 0)
delta_x2 = merge("delta_x2", 1)
delta_y2 = merge("delta_y2", 1)
region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
target = merge("target", pad_idx)
tgt_lengths = torch.LongTensor([s["target"].shape[0] for s in samples])
ntokens = tgt_lengths.sum().item()
prev_output_tokens_11 = merge("prev_output_tokens_11", pad_idx)
prev_output_tokens_12 = merge("prev_output_tokens_12", pad_idx)
prev_output_tokens_21 = merge("prev_output_tokens_21", pad_idx)
prev_output_tokens_22 = merge("prev_output_tokens_22", pad_idx)
token_type = merge("token_type", -1)
w = torch.stack([s["w"] for s in samples], dim=0)
h = torch.stack([s["h"] for s in samples], dim=0)
n_poly = [s['n_poly'] for s in samples]
labels = np.stack([sample['label'] for sample in samples], 0)
text = [s["text"] for s in samples]
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"att_masks": att_masks,
"patch_images": patch_images,
"patch_masks": patch_masks,
"prev_output_tokens_11": prev_output_tokens_11,
"prev_output_tokens_12": prev_output_tokens_12,
"prev_output_tokens_21": prev_output_tokens_21,
"prev_output_tokens_22": prev_output_tokens_22,
"delta_x1": delta_x1,
"delta_y1": delta_y1,
"delta_x2": delta_x2,
"delta_y2": delta_y2
},
"target": target,
"w_resize_ratios": w_resize_ratios,
"h_resize_ratios": h_resize_ratios,
"region_coords": region_coords,
"label": labels,
"token_type": token_type,
"w": w,
"h": h,
"n_poly": n_poly,
"text": text
}
return batch
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return self.collate(samples, pad_idx=self.pad, eos_idx=self.eos)