Spaces:
Sleeping
Sleeping
import os | |
import io | |
import json | |
from torch.utils.data import Dataset, DataLoader | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torch.autograd import Variable | |
import pdb | |
import torch.nn.functional as F | |
class Text2ImageDataset(Dataset): | |
def __init__(self, dataset_dir): | |
self.dataset_dir = dataset_dir | |
with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file: | |
self.dataset = json.load(file) | |
self.images_path = os.path.join(dataset_dir) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
if self.dataset is None: | |
with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file: | |
self.dataset = json.load(file) | |
item = self.dataset[idx] | |
examples_class = self.dataset[idx]['text'] | |
examples_text = self.dataset[idx]['text'] | |
image_path = os.path.join(self.images_path, item['file_name']) | |
right_image = Image.open(image_path).resize((128,128)) | |
right_embed = np.array(process_caption(examples_text), dtype=float) | |
wrong_image = self.find_wrong_image(examples_class) | |
right_image = self.validate_image(right_image) | |
wrong_image = self.validate_image(wrong_image) | |
sample = { | |
'right_images': torch.FloatTensor(right_image), | |
'right_embed': torch.FloatTensor(right_embed), | |
'wrong_images': torch.FloatTensor(wrong_image) | |
} | |
sample['right_images'] = sample['right_images'].sub_(127.5).div_(127.5) | |
sample['wrong_images'] =sample['wrong_images'].sub_(127.5).div_(127.5) | |
return sample | |
def find_wrong_image(self, category): | |
idx = np.random.randint(len(self.dataset)) | |
examples_class = self.dataset[idx]['class'] | |
_category = examples_class | |
if _category != category: | |
item = self.dataset[idx] | |
image_path = os.path.join(self.images_path, item['file_name']) | |
return Image.open(image_path).resize((128,128)) | |
return self.find_wrong_image(category) | |
def validate_image(self, img): | |
img = img.convert('RGB') | |
img = np.array(img, dtype=float) | |
if img.shape[2] == 4: | |
img = img[:, :, :3] | |
if len(img.shape) < 3: | |
rgb = np.empty((64, 64, 3), dtype=np.float32) | |
rgb[:, :, 0] = img | |
rgb[:, :, 1] = img | |
rgb[:, :, 2] = img | |
img = rgb | |
return img.transpose(2, 0, 1) | |