Spaces:
Runtime error
Runtime error
File size: 7,689 Bytes
2fa2727 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
import cv2
import torch
import albumentations as A
import config as CFG
class PoemTextDataset(torch.utils.data.Dataset):
"""
torch Dataset for PoemTextModel.
...
Attributes:
-----------
dataset_dict : list of dict
dataset containing poem-text pair with ids
encoded_poems : dict
output of tokenizer for beyts found in dataset_dict. max_length spedified in configs.
padding and truncation set to True to be truncated or padded to max length.
encoded_texts : dict
output of tokenizer for texts found in dataset_dict. max_length spedified in configs.
padding and truncation set to True to be truncated or padded to max length.
Methods:
--------
__get_item__(idx)
returns item with index idx.
__len__()
represents length of dataset
"""
def __init__(self, dataset_dict):
"""
Init class, save dataset_dict and calculate output of tokenizers for each text and poem using their corresponding tokenizers.
The tokenizers are chosen based on configs.
Parameters:
-----------
dataset_dict: list of dict
a list containing dictionaries which have "beyt", "text" and "id" keys.
"""
self.dataset_dict = dataset_dict
poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer)
text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
self.encoded_poems = poem_tokenizer(
[item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length
)
self.encoded_texts = text_tokenizer(
[item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length
)
def __getitem__(self, idx):
"""
returns a dict having data with index idx. the dict is used as an input to the PoemTextModel.
Parameters:
-----------
idx: int
index of the data to get
Returns:
--------
item: dict
a dict having tokenizers' output for poem and text, and id of the data with index idx
"""
item = {}
item["beyt"] = {
key: torch.tensor(values[idx])
for key, values in self.encoded_poems.items()
}
item["text"] = {
key: torch.tensor(values[idx])
for key, values in self.encoded_texts.items()
}
item['id'] = self.dataset_dict[idx]['id']
return item
def __len__(self):
"""
returns the length of the dataset
Returns:
--------
length: int
length using the length of dataset_dict we saved in class
"""
return len(self.dataset_dict)
class CLIPDataset(torch.utils.data.Dataset):
"""
torch Dataset for CLIPModel.
...
Attributes:
-----------
dataset_dict : list of dict
dataset containing poem-image or text-image pair with ids
encoded : dict
output of tokenizer for beyts/texts found in dataset_dict. max_length spedified in configs.
padding and truncation set to True to be truncated or padded to max length.
transforms: albumentations.BasicTransform
transforms to apply to the images
Methods:
--------
__get_item__(idx)
returns item with index idx.
__len__()
represents length of dataset
"""
def __init__(self, dataset_dict, transforms, is_image_poem_pair=True):
"""
Init class, save dataset_dict and transforms and calculate output of tokenizers for each text and poem using their corresponding tokenizers.
The tokenizers are chosen based on configs.
Parameters:
-----------
dataset_dict: list of dict
a list containing dictionaries which have "beyt", "text" and "id" keys.
transforms: albumentations.BasicTransform
transforms to apply to the images
is_image_poem_pair: Bool, optional
if set to False, dataset has text-image pairs and must use the corresponding text tokenizer.
else has poem-images pairs and uses the poem tokenizer.
"""
self.dataset_dict = dataset_dict
# using the poem tokenizer to encode poems or text tokenizer to encode text (based on configs).
if is_image_poem_pair:
poem_tokenizer = CFG.tokenizers[CFG.poem_encoder_model].from_pretrained(CFG.poem_tokenizer)
self.encoded = poem_tokenizer(
[item['beyt'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.poems_max_length
)
else:
text_tokenizer = CFG.tokenizers[CFG.text_encoder_model].from_pretrained(CFG.text_tokenizer)
self.encoded = text_tokenizer(
[item['text'] for item in dataset_dict], padding=True, truncation=True, max_length=CFG.text_max_length
)
self.transforms = transforms
def __getitem__(self, idx):
"""
returns a dict having data with index idx. the dict is used as an input to the CLIPModel.
Parameters:
-----------
idx: int
index of the data to get
Returns:
--------
item: dict
a dict having tokenizers' output for poem and text, and id of the data with index idx
"""
item = {}
# getting text from encoded texts
item["text"] = {
key: torch.tensor(values[idx])
for key, values in self.encoded.items()
}
# opening the image
image = cv2.imread(f"{CFG.image_path}{self.dataset_dict[idx]['image']}")
# converting BGR to RGB for transforms
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# apply transforms
image = self.transforms(image=image)['image']
# permute dims of image
item['image'] = torch.tensor(image).permute(2, 0, 1).float()
return item
def __len__(self):
"""
returns the length of the dataset
Returns:
--------
length: int
length using the length of dataset_dict we saved in class
"""
return len(self.dataset_dict)
def get_transforms(mode="train"):
"""
returns transforms to use on image based on mode
Parameters:
-----------
mode: str, optional
to distinguish between train and val/test transforms (here they are the same!)
Returns:
--------
item: dict
a dict having tokenizers' output for poem and text, and id of the data with index idx
"""
if mode == "train":
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size
A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values
]
)
else:
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True), # resizing image to CFG.size
A.Normalize(max_pixel_value=255.0, always_apply=True), # normalizing image values
]
) |