Spaces:
Sleeping
Sleeping
from tkinter.messagebox import NO | |
import torch | |
import json | |
from collections import defaultdict | |
from PIL import Image, ImageDraw | |
from copy import deepcopy | |
import os | |
import torchvision.transforms as transforms | |
import torchvision | |
from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid | |
from io import BytesIO | |
import random | |
def check_unique(images, fields): | |
for field in fields: | |
temp_list = [] | |
for img_info in images: | |
temp_list.append(img_info[field]) | |
assert len(set(temp_list)) == len(temp_list), field | |
def clean_data(data): | |
for data_info in data: | |
data_info.pop("original_img_id", None) | |
data_info.pop("original_id", None) | |
data_info.pop("sentence_id", None) # sentence id for each image (multiple sentences for one image) | |
data_info.pop("dataset_name", None) | |
data_info.pop("data_source", None) | |
data_info["data_id"] = data_info.pop("id") | |
def clean_annotations(annotations): | |
for anno_info in annotations: | |
anno_info.pop("iscrowd", None) # I have checked that all 0 for flickr, vg, coco | |
anno_info.pop("category_id", None) # I have checked that all 1 for flickr vg. This is not always 1 for coco, but I do not think we need this annotation | |
anno_info.pop("area", None) | |
# anno_info.pop("id", None) | |
anno_info["data_id"] = anno_info.pop("image_id") | |
def draw_box(img, boxes): | |
draw = ImageDraw.Draw(img) | |
for box in boxes: | |
draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1 | |
return img | |
def xyhw2xyxy(box): | |
x0, y0, w, h = box | |
return [ x0, y0, x0+w, y0+h ] | |
class GroundingDataset(BaseDataset): | |
def __init__(self, | |
image_root, | |
json_path, | |
annotation_embedding_path, | |
prob_real_caption=1, | |
image_size=256, | |
min_box_size=0.01, | |
max_boxes_per_data=8, | |
max_images=None, # set as 30K used to eval | |
random_crop = False, | |
random_flip = True, | |
): | |
super().__init__(image_root, random_crop, random_flip, image_size) | |
self.image_root = image_root | |
self.json_path = json_path | |
self.annotation_embedding_path = annotation_embedding_path | |
self.prob_real_caption = prob_real_caption | |
self.min_box_size = min_box_size | |
self.max_boxes_per_data = max_boxes_per_data | |
self.max_images = max_images | |
# Load raw data | |
with open(json_path, 'r') as f: | |
json_raw = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations' | |
self.data = json_raw["images"] # donot name it images, which is misleading | |
self.annotations = json_raw["annotations"] | |
# Load preprocessed name embedding | |
if 'bert' in annotation_embedding_path: | |
self.embedding_len = 1280 | |
elif 'clip' in annotation_embedding_path: | |
self.embedding_len = 768 | |
else: | |
assert False | |
# clean data and annotation | |
check_unique( self.data, ['id'] ) | |
check_unique( self.annotations, ['id'] ) | |
clean_data(self.data) | |
clean_annotations(self.annotations) | |
self.data_id_list = [ datum['data_id'] for datum in self.data ] | |
self.data = { datum['data_id']:datum for datum in self.data } # map self.data from a list into a dict | |
# data point to its annotation mapping | |
self.data_id_to_annos = defaultdict(list) | |
for anno in self.annotations: | |
self.data_id_to_annos[ anno["data_id"] ].append(anno) | |
# These are not used that offen, but are useful in some cases | |
self.file_names = [] # all training images | |
self.file_name_to_data_ids = defaultdict(list) # for each image, there are multiple data points (captions) | |
for data_id in self.data_id_list: | |
fine_name = self.data[data_id]["file_name"] | |
self.file_names.append(fine_name) | |
self.file_name_to_data_ids[fine_name].append(data_id) | |
self.file_names = list(set(self.file_names)) | |
if self.max_images is not None: | |
"This is only used as COCO2017P evulation, when we set max_images as 30k" | |
assert False, 'I have commented out the following code to save cpu memory' | |
# new_data_id_list = [] | |
# new_file_name_to_data_ids = defaultdict(list) | |
# self.file_names = self.file_names[0:self.max_images] | |
# for file_name in self.file_names: | |
# data_id = self.file_name_to_data_ids[file_name][0] | |
# new_data_id_list.append(data_id) | |
# new_file_name_to_data_ids[file_name].append(data_id) | |
# self.data_id_list = new_data_id_list | |
# self.file_name_to_data_ids = new_file_name_to_data_ids | |
# Check if all filenames can be found in the zip file | |
# all_filenames = [self.data[idx]['file_name'] for idx in self.data_id_list ] | |
# check_filenames_in_zipdata(all_filenames, image_root) | |
def total_images(self): | |
return len(self.file_names) | |
def __getitem__(self, index): | |
if self.max_boxes_per_data > 99: | |
assert False, "Are you sure setting such large number of boxes?" | |
out = {} | |
data_id = self.data_id_list[index] | |
out['id'] = data_id | |
# Image and caption | |
file_name = self.data[data_id]['file_name'] | |
image = self.fetch_image(file_name) | |
image_tensor, trans_info = self.transform_image(image) | |
out["image"] = image_tensor | |
if random.uniform(0, 1) < self.prob_real_caption: | |
out["caption"] = self.data[data_id]["caption"] | |
else: | |
out["caption"] = "" | |
annos = deepcopy(self.data_id_to_annos[data_id]) | |
areas = [] | |
all_boxes = [] | |
all_masks = [] | |
all_positive_embeddings = [] | |
for anno in annos: | |
x, y, w, h = anno['bbox'] | |
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size) | |
if valid: | |
areas.append( (x1-x0)*(y1-y0) ) | |
all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) # scale to 0-1 | |
all_masks.append(1) | |
all_positive_embeddings.append( torch.load(os.path.join(self.annotation_embedding_path,str(anno["id"])), map_location='cpu' ) ) | |
wanted_idxs = torch.tensor(areas).sort(descending=True)[1] | |
wanted_idxs = wanted_idxs[0:self.max_boxes_per_data] | |
boxes = torch.zeros(self.max_boxes_per_data, 4) | |
masks = torch.zeros(self.max_boxes_per_data) | |
positive_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len) | |
for i, idx in enumerate(wanted_idxs): | |
boxes[i] = all_boxes[idx] | |
masks[i] = all_masks[idx] | |
positive_embeddings[i] = all_positive_embeddings[idx] | |
out["boxes"] = boxes | |
out["masks"] = masks | |
out["positive_embeddings"] = positive_embeddings | |
return out | |
def __len__(self): | |
return len(self.data_id_list) | |