cuhksz-text2image / preprocessing.py
bryandts's picture
Create preprocessing.py
a305b79 verified
raw
history blame
2.55 kB
import os
import io
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import torch
from torch.autograd import Variable
import pdb
import torch.nn.functional as F
class Text2ImageDataset(Dataset):
def __init__(self, dataset_dir):
self.dataset_dir = dataset_dir
with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file:
self.dataset = json.load(file)
self.images_path = os.path.join(dataset_dir, 'CUHKSZ_Photos')
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if self.dataset is None:
with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file:
self.dataset = json.load(file)
item = self.dataset[idx]
examples_class = self.dataset[idx]['text']
examples_text = self.dataset[idx]['text']
image_path = os.path.join(self.images_path, item['file_name'])
right_image = Image.open(image_path).resize((128,128))
right_embed = np.array(process_caption(examples_text), dtype=float)
wrong_image = self.find_wrong_image(examples_class)
right_image = self.validate_image(right_image)
wrong_image = self.validate_image(wrong_image)
sample = {
'right_images': torch.FloatTensor(right_image),
'right_embed': torch.FloatTensor(right_embed),
'wrong_images': torch.FloatTensor(wrong_image)
}
sample['right_images'] = sample['right_images'].sub_(127.5).div_(127.5)
sample['wrong_images'] =sample['wrong_images'].sub_(127.5).div_(127.5)
return sample
def find_wrong_image(self, category):
idx = np.random.randint(len(self.dataset))
examples_class = self.dataset[idx]['class']
_category = examples_class
if _category != category:
item = self.dataset[idx]
image_path = os.path.join(self.images_path, item['file_name'])
return Image.open(image_path).resize((128,128))
return self.find_wrong_image(category)
def validate_image(self, img):
img = img.convert('RGB')
img = np.array(img, dtype=float)
if img.shape[2] == 4:
img = img[:, :, :3]
if len(img.shape) < 3:
rgb = np.empty((64, 64, 3), dtype=np.float32)
rgb[:, :, 0] = img
rgb[:, :, 1] = img
rgb[:, :, 2] = img
img = rgb
return img.transpose(2, 0, 1)