HaloMaster's picture
add fengshen
50f0fbb
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop
from transformers import BertTokenizer
import pytorch_lightning as pl
from PIL import Image
import os
class flickr30k_CNA(Dataset):
def __init__(self, img_root_path,
annot_path,
transform=None):
self.images = []
self.captions = []
self.labels = []
self.root = img_root_path
with open(annot_path, 'r') as f:
for line in f:
line = line.strip().split('\t')
key, caption = line[0].split('#')[0], line[1]
img_path = key + '.jpg'
self.images.append(img_path)
self.captions.append(caption)
self.labels.append(key)
self.transforms = transform
self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
# NOTE large 模型
self.context_length = 77
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = str(self.images[idx])
image = self.transforms(Image.open(os.path.join(self.root, img_path)))
text = self.tokenizer(str(self.captions[idx]), max_length=self.context_length,
padding='max_length', truncation=True, return_tensors='pt')['input_ids'][0]
label = self.labels[idx]
return image, text, label
def _convert_to_rgb(image):
return image.convert('RGB')
def image_transform(
image_size: int,
is_train: bool,
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)
):
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose([
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
])
else:
return Compose([
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
_convert_to_rgb,
ToTensor(),
normalize,
])
class FlickrDataModule(pl.LightningDataModule):
def __init__(self, args):
self.batch_size = args.batch_size
self.train_filename = args.train_filename # NOTE 标注的文件夹
self.train_root = args.train_root # NOTE 图片地址
self.val_filename = args.val_filename
self.val_root = args.val_root
self.test_filename = args.test_filename
self.test_root = args.test_root
self.pretrain_model = args.pretrain_model
self.image_size = 224
self.prepare_data_per_node = True
self._log_hyperparams = False
self.num_workers = args.num_workers
def setup(self, stage=None):
# dataset
train_transform = image_transform(224, True)
val_transform = image_transform(224, False)
test_transform = image_transform(224, False)
self.train_dataset = flickr30k_CNA(self.train_root, self.train_filename, transform=train_transform)
self.val_dataset = flickr30k_CNA(self.val_root, self.val_filename, transform=val_transform)
self.test_dataset = flickr30k_CNA(self.test_root, self.test_filename, transform=test_transform)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)