Spaces:
Sleeping
Sleeping
from torch.utils.data import Dataset | |
from PIL import Image | |
import torch | |
import json | |
import h5py | |
import bisect | |
CAPTION_LENGTH = 25 | |
SIMPLE_PREFIX = "This image shows " | |
def prep_strings(text, tokenizer, template=None, retrieved_caps=None, k=None, is_test=False, max_length=None): | |
if is_test: | |
padding = False | |
truncation = False | |
else: | |
padding = True | |
truncation = True | |
if retrieved_caps is not None: | |
infix = '\n\n'.join(retrieved_caps[:k]) + '.' | |
prefix = template.replace('||', infix) | |
else: | |
prefix = SIMPLE_PREFIX | |
prefix_ids = tokenizer.encode(prefix) | |
len_prefix = len(prefix_ids) | |
text_ids = tokenizer.encode(text, add_special_tokens=False) | |
if truncation: | |
text_ids = text_ids[:CAPTION_LENGTH] | |
input_ids = prefix_ids + text_ids if not is_test else prefix_ids | |
# we ignore the prefix (minus one as the first subtoken in the prefix is not predicted) | |
label_ids = [-100] * (len_prefix - 1) + text_ids + [tokenizer.eos_token_id] | |
if padding: | |
input_ids += [tokenizer.pad_token_id] * (max_length - len(input_ids)) | |
label_ids += [-100] * (max_length - len(label_ids)) | |
if is_test: | |
return input_ids | |
else: | |
return input_ids, label_ids | |
def postprocess_preds(pred, tokenizer): | |
pred = pred.split(SIMPLE_PREFIX)[-1] | |
pred = pred.replace(tokenizer.pad_token, '') | |
if pred.startswith(tokenizer.bos_token): | |
pred = pred[len(tokenizer.bos_token):] | |
if pred.endswith(tokenizer.eos_token): | |
pred = pred[:-len(tokenizer.eos_token)] | |
return pred | |
class TrainDataset(Dataset): | |
def __init__(self, df, features_path, tokenizer, rag=False, template_path=None, k=None, max_caption_length=25): | |
self.df = df | |
self.tokenizer = tokenizer | |
self.features = h5py.File(features_path, 'r') | |
if rag: | |
self.template = open(template_path).read().strip() + ' ' | |
self.max_target_length = (max_caption_length # target caption | |
+ max_caption_length * k # retrieved captions | |
+ len(tokenizer.encode(self.template)) # template | |
+ len(tokenizer.encode('\n\n')) * (k-1) # separator between captions | |
) | |
assert k is not None | |
self.k = k | |
self.rag = rag | |
def __len__(self): | |
return len(self.df) | |
def __getitem__(self, idx): | |
text = self.df['text'][idx] | |
if self.rag: | |
caps = self.df['caps'][idx] | |
decoder_input_ids, labels = prep_strings(text, self.tokenizer, template=self.template, | |
retrieved_caps=caps, k=self.k, max_length=self.max_target_length) | |
else: | |
decoder_input_ids, labels = prep_strings(text, self.tokenizer, max_length=self.max_target_length) | |
# load precomputed features | |
encoder_outputs = self.features[self.df['cocoid'][idx]][()] | |
encoding = {"encoder_outputs": torch.tensor(encoder_outputs), | |
"decoder_input_ids": torch.tensor(decoder_input_ids), | |
"labels": torch.tensor(labels)} | |
return encoding | |
def load_data_for_training(annot_path, caps_path=None): | |
annotations = json.load(open(annot_path))['images'] | |
if caps_path is not None: | |
retrieved_caps = json.load(open(caps_path)) | |
data = {'train': [], 'val': []} | |
for item in annotations: | |
file_name = item['filename'].split('_')[-1] | |
caps = retrieved_caps[str(item['cocoid'])] | |
samples = [] | |
for sentence in item['sentences']: | |
print("how are the retrieved caps", caps + ' '.join(sentence['tokens'])) | |
samples.append({'file_name': file_name, 'cocoid': str(item['cocoid']), 'caps': None, 'text': " ".join(caps) + ' '.join(sentence['tokens'])}) | |
if item['split'] == 'train' or item['split'] == 'restval': | |
data['train'] += samples | |
elif item['split'] == 'val': | |
data['val'] += samples | |
return data | |
def load_data_for_inference(annot_path, caps_path=None): | |
annotations = json.load(open(annot_path))['images'] | |
if caps_path is not None: | |
retrieved_caps = json.load(open(caps_path)) | |
data = {'test': [], 'val': []} | |
for item in annotations: | |
file_name = item['filename'].split('_')[-1] | |
if caps_path is not None: | |
caps = retrieved_caps[str(item['cocoid'])] | |
else: | |
caps = None | |
image = {'file_name': file_name, 'caps': caps, 'image_id': str(item['cocoid'])} | |
if item['split'] == 'test': | |
data['test'].append(image) | |
elif item['split'] == 'val': | |
data['val'].append(image) | |
return data | |