bryandts commited on
Commit
a305b79
1 Parent(s): e8c19db

Create preprocessing.py

Browse files
Files changed (1) hide show
  1. preprocessing.py +75 -0
preprocessing.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ from torch.autograd import Variable
9
+ import pdb
10
+ import torch.nn.functional as F
11
+
12
+ class Text2ImageDataset(Dataset):
13
+
14
+ def __init__(self, dataset_dir):
15
+ self.dataset_dir = dataset_dir
16
+ with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file:
17
+ self.dataset = json.load(file)
18
+ self.images_path = os.path.join(dataset_dir, 'CUHKSZ_Photos')
19
+
20
+ def __len__(self):
21
+ return len(self.dataset)
22
+
23
+ def __getitem__(self, idx):
24
+ if self.dataset is None:
25
+ with open(os.path.join(self.dataset_dir, 'descriptions.json'), 'r') as file:
26
+ self.dataset = json.load(file)
27
+
28
+ item = self.dataset[idx]
29
+ examples_class = self.dataset[idx]['text']
30
+ examples_text = self.dataset[idx]['text']
31
+
32
+ image_path = os.path.join(self.images_path, item['file_name'])
33
+ right_image = Image.open(image_path).resize((128,128))
34
+ right_embed = np.array(process_caption(examples_text), dtype=float)
35
+ wrong_image = self.find_wrong_image(examples_class)
36
+
37
+ right_image = self.validate_image(right_image)
38
+ wrong_image = self.validate_image(wrong_image)
39
+
40
+ sample = {
41
+ 'right_images': torch.FloatTensor(right_image),
42
+ 'right_embed': torch.FloatTensor(right_embed),
43
+ 'wrong_images': torch.FloatTensor(wrong_image)
44
+ }
45
+
46
+ sample['right_images'] = sample['right_images'].sub_(127.5).div_(127.5)
47
+ sample['wrong_images'] =sample['wrong_images'].sub_(127.5).div_(127.5)
48
+
49
+ return sample
50
+
51
+ def find_wrong_image(self, category):
52
+ idx = np.random.randint(len(self.dataset))
53
+ examples_class = self.dataset[idx]['class']
54
+ _category = examples_class
55
+
56
+ if _category != category:
57
+ item = self.dataset[idx]
58
+ image_path = os.path.join(self.images_path, item['file_name'])
59
+ return Image.open(image_path).resize((128,128))
60
+
61
+ return self.find_wrong_image(category)
62
+
63
+ def validate_image(self, img):
64
+ img = img.convert('RGB')
65
+ img = np.array(img, dtype=float)
66
+ if img.shape[2] == 4:
67
+ img = img[:, :, :3]
68
+ if len(img.shape) < 3:
69
+ rgb = np.empty((64, 64, 3), dtype=np.float32)
70
+ rgb[:, :, 0] = img
71
+ rgb[:, :, 1] = img
72
+ rgb[:, :, 2] = img
73
+ img = rgb
74
+
75
+ return img.transpose(2, 0, 1)