EditGuard / data /coco_test_dataset.py
Ricoooo's picture
'folder'
5d21dd2
raw
history blame
No virus
2.06 kB
import os
import os.path as osp
import torch
import torch.utils.data as data
import data.util as util
import random
import numpy as np
from PIL import Image
class imageTestDataset(data.Dataset):
def __init__(self, opt):
super(imageTestDataset, self).__init__()
self.opt = opt
self.half_N_frames = opt['N_frames'] // 2
self.data_path = opt['data_path']
self.bit_path = opt['bit_path']
self.txt_path = self.opt['txt_path']
self.num_image = self.opt['num_image']
with open(self.txt_path) as f:
self.list_image = f.readlines()
self.list_image = [line.strip('\n') for line in self.list_image]
self.list_image.sort()
self.list_image = self.list_image
l = len(self.list_image) // (self.num_image + 1)
self.image_list_gt = self.list_image
def __getitem__(self, index):
path_GT = self.image_list_gt[index]
img_GT = util.read_img(None, osp.join(self.data_path, path_GT))
img_GT = img_GT[:, :, [2, 1, 0]]
img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0)
img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None)
T, C, W, H = img_GT.shape
list_h = []
R = 0
G = 0
B = 255
image = Image.new('RGB', (W, H), (R, G, B))
result = np.array(image) / 255.
expanded_matrix = np.expand_dims(result, axis=0)
expanded_matrix = np.repeat(expanded_matrix, T, axis=0)
imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float()
imgs_LQ = imgs_LQ.permute(0, 3, 1, 2)
imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None)
list_h.append(imgs_LQ)
list_h = torch.stack(list_h, dim=0)
return {
'LQ': list_h,
'GT': img_GT
}
def __len__(self):
return len(self.image_list_gt)