cuhksz-text2image / preprocessing.py
bryandts's picture
Update preprocessing.py
423fc3f verified
import os
import io
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import torch
from torch.autograd import Variable
import pdb
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
def process_caption(text):
# Encode text and get model output
embeddings = model.encode(text)
ouput = np.expand_dims(embeddings, axis=0)
return ouput
class Text2ImageDataset(Dataset):
def __init__(self, dataset_dir):
self.dataset_dir = dataset_dir
with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file:
self.dataset = json.load(file)
self.images_path = os.path.join(dataset_dir)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if self.dataset is None:
with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file:
self.dataset = json.load(file)
item = self.dataset[idx]
examples_class = self.dataset[idx]['text']
examples_text = self.dataset[idx]['text']
image_path = os.path.join(self.images_path, item['file_name'])
right_image = Image.open(image_path).resize((128,128))
right_embed = np.array(process_caption(examples_text), dtype=float)
wrong_image = self.find_wrong_image(examples_class)
right_image = self.validate_image(right_image)
wrong_image = self.validate_image(wrong_image)
sample = {
'right_images': torch.FloatTensor(right_image),
'right_embed': torch.FloatTensor(right_embed),
'wrong_images': torch.FloatTensor(wrong_image)
}
sample['right_images'] = sample['right_images'].sub_(127.5).div_(127.5)
sample['wrong_images'] =sample['wrong_images'].sub_(127.5).div_(127.5)
return sample
def find_wrong_image(self, category):
idx = np.random.randint(len(self.dataset))
examples_class = self.dataset[idx]['class']
_category = examples_class
if _category != category:
item = self.dataset[idx]
image_path = os.path.join(self.images_path, item['file_name'])
return Image.open(image_path).resize((128,128))
return self.find_wrong_image(category)
def validate_image(self, img):
img = img.convert('RGB')
img = np.array(img, dtype=float)
if img.shape[2] == 4:
img = img[:, :, :3]
if len(img.shape) < 3:
rgb = np.empty((64, 64, 3), dtype=np.float32)
rgb[:, :, 0] = img
rgb[:, :, 1] = img
rgb[:, :, 2] = img
img = rgb
return img.transpose(2, 0, 1)