import os import os.path as osp import torch import torch.utils.data as data import data.util as util import random import numpy as np from PIL import Image class imageTestDataset(data.Dataset): def __init__(self, opt): super(imageTestDataset, self).__init__() self.opt = opt self.half_N_frames = opt['N_frames'] // 2 self.data_path = opt['data_path'] self.bit_path = opt['bit_path'] self.txt_path = self.opt['txt_path'] self.num_image = self.opt['num_image'] with open(self.txt_path) as f: self.list_image = f.readlines() self.list_image = [line.strip('\n') for line in self.list_image] self.list_image.sort() self.list_image = self.list_image l = len(self.list_image) // (self.num_image + 1) self.image_list_gt = self.list_image def __getitem__(self, index): path_GT = self.image_list_gt[index] img_GT = util.read_img(None, osp.join(self.data_path, path_GT)) img_GT = img_GT[:, :, [2, 1, 0]] img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float().unsqueeze(0) img_GT = torch.nn.functional.interpolate(img_GT, size=(512, 512), mode='nearest', align_corners=None) T, C, W, H = img_GT.shape list_h = [] R = 0 G = 0 B = 255 image = Image.new('RGB', (W, H), (R, G, B)) result = np.array(image) / 255. expanded_matrix = np.expand_dims(result, axis=0) expanded_matrix = np.repeat(expanded_matrix, T, axis=0) imgs_LQ = torch.from_numpy(np.ascontiguousarray(expanded_matrix)).float() imgs_LQ = imgs_LQ.permute(0, 3, 1, 2) imgs_LQ = torch.nn.functional.interpolate(imgs_LQ, size=(W, H), mode='nearest', align_corners=None) list_h.append(imgs_LQ) list_h = torch.stack(list_h, dim=0) return { 'LQ': list_h, 'GT': img_GT } def __len__(self): return len(self.image_list_gt)