Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import torch | |
import albumentations as A | |
import config as CFG | |
class PoemTextDataset(torch.utils.data.Dataset): | |
""" | |
torch Dataset for PoemTextModel. | |
... | |
Attributes: | |
----------- | |
dataset_dict : list of dict | |
dataset containing poem-text pair with ids | |
encoded_poems : dict | |
output of tokenizer for beyts found in dataset_dict. max_length spedified in configs. | |
padding and truncation set to True to be truncated or padded to max length. | |
encoded_texts : dict | |
output of tokenizer for texts found in dataset_dict. max_length spedified in configs. | |
padding and truncation set to True to be truncated or padded to max length. | |
Methods: | |
-------- | |
__get_item__(idx) | |
returns item with index idx. | |
__len__() | |
represents length of dataset | |
""" | |
def __init__(self, dataset_dict): | |
""" | |
Init class, save dataset_dict and calculate output of tokenizers for each text and poem using their corresponding tokenizers. | |
The tokenizers are chosen based on configs. | |
Parameters: | |
----------- | |
dataset_dict: list of dict | |
a list containing dictionaries which have "beyt", "text" and "id" keys. | |
""" | |
self.dataset_dict = dataset_dict | |
poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer) | |
text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) | |
self.encoded_poems = poem_tokenizer( | |
[item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length | |
) | |
self.encoded_texts = text_tokenizer( | |
[item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length | |
) | |
def __getitem__(self, idx): | |
""" | |
returns a dict having data with index idx. the dict is used as an input to the PoemTextModel. | |
Parameters: | |
----------- | |
idx: int | |
index of the data to get | |
Returns: | |
-------- | |
item: dict | |
a dict having tokenizers' output for poem and text, and id of the data with index idx | |
""" | |
item = {} | |
item["beyt"] = { | |
key: torch.tensor(values[idx]) | |
for key, values in self.encoded_poems.items() | |
} | |
item["text"] = { | |
key: torch.tensor(values[idx]) | |
for key, values in self.encoded_texts.items() | |
} | |
item['id'] = self.dataset_dict[idx]['id'] | |
return item | |
def __len__(self): | |
""" | |
returns the length of the dataset | |
Returns: | |
-------- | |
length: int | |
length using the length of dataset_dict we saved in class | |
""" | |
return len(self.dataset_dict) | |
class CLIPDataset(torch.utils.data.Dataset): | |
""" | |
torch Dataset for CLIPModel. | |
... | |
Attributes: | |
----------- | |
dataset_dict : list of dict | |
dataset containing poem-image or text-image pair with ids | |
encoded : dict | |
output of tokenizer for beyts/texts found in dataset_dict. max_length spedified in configs. | |
padding and truncation set to True to be truncated or padded to max length. | |
transforms: albumentations.BasicTransform | |
transforms to apply to the images | |
Methods: | |
-------- | |
__get_item__(idx) | |
returns item with index idx. | |
__len__() | |
represents length of dataset | |
""" | |
def __init__(self, dataset_dict, transforms, is_image_poem_pair=True): | |
""" | |
Init class, save dataset_dict and transforms and calculate output of tokenizers for each text and poem using their corresponding tokenizers. | |
The tokenizers are chosen based on configs. | |
Parameters: | |
----------- | |
dataset_dict: list of dict | |
a list containing dictionaries which have "beyt", "text" and "id" keys. | |
transforms: albumentations.BasicTransform | |
transforms to apply to the images | |
is_image_poem_pair: Bool, optional | |
if set to False, dataset has text-image pairs and must use the corresponding text tokenizer. | |
else has poem-images pairs and uses the poem tokenizer. | |
""" | |
self.dataset_dict = dataset_dict | |
# using the poem tokenizer to encode poems or text tokenizer to encode text (based on configs). | |
if is_image_poem_pair: | |
poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer) | |
self.encoded = poem_tokenizer( | |
[item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length | |
) | |
else: | |
text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer) | |
self.encoded = text_tokenizer( | |
[item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length | |
) | |
self.transforms = transforms | |
def __getitem__(self, idx): | |
""" | |
returns a dict having data with index idx. the dict is used as an input to the CLIPModel. | |
Parameters: | |
----------- | |
idx: int | |
index of the data to get | |
Returns: | |
-------- | |
item: dict | |
a dict having tokenizers' output for poem and text, and id of the data with index idx | |
""" | |
item = {} | |
# getting text from encoded texts | |
item["text"] = { | |
key: torch.tensor(values[idx]) | |
for key, values in self.encoded.items() | |
} | |
# opening the image | |
image = cv2.imread(f"{CFG.image_path}{self.dataset_dict[idx]['image']}") | |
# converting BGR to RGB for transforms | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# apply transforms | |
image = self.transforms(image=image)['image'] | |
# permute dims of image | |
item['image'] = torch.tensor(image).permute(2, 0, 1).float() | |
return item | |
def __len__(self): | |
""" | |
returns the length of the dataset | |
Returns: | |
-------- | |
length: int | |
length using the length of dataset_dict we saved in class | |
""" | |
return len(self.dataset_dict) | |
def get_transforms(mode="train"): | |
""" | |
returns transforms to use on image based on mode | |
Parameters: | |
----------- | |
mode: str, optional | |
to distinguish between train and val/test transforms (here they are the same!) | |
Returns: | |
-------- | |
item: dict | |
a dict having tokenizers' output for poem and text, and id of the data with index idx | |
""" | |
if mode == "train": | |
return A.Compose( | |
[ | |
A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size | |
A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values | |
] | |
) | |
else: | |
return A.Compose( | |
[ | |
A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size | |
A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values | |
] | |
) |