Spaces:
Runtime error
Runtime error
from torch.utils.data import Dataset, DataLoader | |
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ | |
CenterCrop | |
from transformers import BertTokenizer | |
import pytorch_lightning as pl | |
from PIL import Image | |
import os | |
class flickr30k_CNA(Dataset): | |
def __init__(self, img_root_path, | |
annot_path, | |
transform=None): | |
self.images = [] | |
self.captions = [] | |
self.labels = [] | |
self.root = img_root_path | |
with open(annot_path, 'r') as f: | |
for line in f: | |
line = line.strip().split('\t') | |
key, caption = line[0].split('#')[0], line[1] | |
img_path = key + '.jpg' | |
self.images.append(img_path) | |
self.captions.append(caption) | |
self.labels.append(key) | |
self.transforms = transform | |
self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") | |
# NOTE large 模型 | |
self.context_length = 77 | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
img_path = str(self.images[idx]) | |
image = self.transforms(Image.open(os.path.join(self.root, img_path))) | |
text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length, | |
padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0] | |
label = self.labels[idx] | |
return image, text, label | |
def _convert_to_rgb(image): | |
return image.convert('RGB') | |
def image_transform( | |
image_size: int, | |
is_train: bool, | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711) | |
): | |
normalize = Normalize(mean=mean, std=std) | |
if is_train: | |
return Compose([ | |
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), | |
_convert_to_rgb, | |
ToTensor(), | |
normalize, | |
]) | |
else: | |
return Compose([ | |
Resize(image_size, interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(image_size), | |
_convert_to_rgb, | |
ToTensor(), | |
normalize, | |
]) | |
class FlickrDataModule(pl.LightningDataModule): | |
def __init__(self, args): | |
self.batch_size = args.batch_size | |
self.train_filename = args.train_filename # NOTE 标注的文件夹 | |
self.train_root = args.train_root # NOTE 图片地址 | |
self.val_filename = args.val_filename | |
self.val_root = args.val_root | |
self.test_filename = args.test_filename | |
self.test_root = args.test_root | |
self.pretrain_model = args.pretrain_model | |
self.image_size = 224 | |
self.prepare_data_per_node = True | |
self._log_hyperparams = False | |
self.num_workers = args.num_workers | |
def setup(self, stage=None): | |
# dataset | |
train_transform = image_transform(224, True) | |
val_transform = image_transform(224, False) | |
test_transform = image_transform(224, False) | |
self.train_dataset = flickr30k_CNA(self.train_root, self.train_filename, transform=train_transform) | |
self.val_dataset = flickr30k_CNA(self.val_root, self.val_filename, transform=val_transform) | |
self.test_dataset = flickr30k_CNA(self.test_root, self.test_filename, transform=test_transform) | |
def train_dataloader(self): | |
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers) | |
def val_dataloader(self): | |
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers) | |
def test_dataloader(self): | |
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers) | |