|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import torch.utils.data.distributed |
|
from torchvision import transforms |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import os |
|
import random |
|
import copy |
|
import cv2 |
|
|
|
from utils import DistributedSamplerNoEvenlyDivisible |
|
|
|
|
|
def _is_pil_image(img): |
|
return isinstance(img, Image.Image) |
|
|
|
|
|
def _is_numpy_image(img): |
|
return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) |
|
|
|
|
|
def preprocessing_transforms(mode): |
|
return transforms.Compose([ |
|
ToTensor(mode=mode) |
|
]) |
|
|
|
|
|
class NewDataLoader(object): |
|
def __init__(self, args, mode): |
|
if mode == 'train': |
|
self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) |
|
if args.distributed: |
|
self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples) |
|
else: |
|
self.train_sampler = None |
|
|
|
self.data = DataLoader(self.training_samples, args.batch_size, |
|
shuffle=(self.train_sampler is None), |
|
num_workers=args.num_threads, |
|
pin_memory=True, |
|
sampler=self.train_sampler) |
|
|
|
elif mode == 'online_eval': |
|
self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) |
|
if args.distributed: |
|
|
|
self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False) |
|
else: |
|
self.eval_sampler = None |
|
self.data = DataLoader(self.testing_samples, 1, |
|
shuffle=False, |
|
num_workers=1, |
|
pin_memory=True, |
|
sampler=self.eval_sampler) |
|
|
|
elif mode == 'test': |
|
self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode)) |
|
self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1) |
|
|
|
else: |
|
print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) |
|
|
|
|
|
class DataLoadPreprocess(Dataset): |
|
def __init__(self, args, mode, transform=None, is_for_online_eval=False): |
|
self.args = args |
|
if mode == 'online_eval': |
|
with open(args.filenames_file_eval, 'r') as f: |
|
self.filenames = f.readlines() |
|
else: |
|
with open(args.filenames_file, 'r') as f: |
|
self.filenames = f.readlines() |
|
|
|
self.mode = mode |
|
self.transform = transform |
|
self.to_tensor = ToTensor |
|
self.is_for_online_eval = is_for_online_eval |
|
|
|
def __getitem__(self, idx): |
|
sample_path = self.filenames[idx] |
|
|
|
focal = 518.8579 |
|
|
|
if self.mode == 'train': |
|
if self.args.dataset == 'kitti': |
|
rgb_file = sample_path.split()[0] |
|
depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1]) |
|
if self.args.use_right is True and random.random() > 0.5: |
|
rgb_file = rgb_file.replace('image_02', 'image_03') |
|
depth_file = depth_file.replace('image_02', 'image_03') |
|
else: |
|
rgb_file = sample_path.split()[0] |
|
depth_file = sample_path.split()[1] |
|
|
|
image_path = os.path.join(self.args.data_path, rgb_file) |
|
depth_path = os.path.join(self.args.gt_path, depth_file) |
|
|
|
image = Image.open(image_path) |
|
depth_gt = Image.open(depth_path) |
|
|
|
if self.args.do_kb_crop is True: |
|
height = image.height |
|
width = image.width |
|
top_margin = int(height - 352) |
|
left_margin = int((width - 1216) / 2) |
|
depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352)) |
|
image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352)) |
|
|
|
|
|
if self.args.dataset == 'nyu': |
|
if self.args.input_height == 480: |
|
depth_gt = np.array(depth_gt) |
|
valid_mask = np.zeros_like(depth_gt) |
|
valid_mask[45:472, 43:608] = 1 |
|
depth_gt[valid_mask==0] = 0 |
|
depth_gt = Image.fromarray(depth_gt) |
|
else: |
|
depth_gt = depth_gt.crop((43, 45, 608, 472)) |
|
image = image.crop((43, 45, 608, 472)) |
|
|
|
if self.args.do_random_rotate is True: |
|
random_angle = (random.random() - 0.5) * 2 * self.args.degree |
|
image = self.rotate_image(image, random_angle) |
|
depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST) |
|
|
|
image = np.asarray(image, dtype=np.float32) / 255.0 |
|
depth_gt = np.asarray(depth_gt, dtype=np.float32) |
|
depth_gt = np.expand_dims(depth_gt, axis=2) |
|
|
|
if self.args.dataset == 'nyu': |
|
depth_gt = depth_gt / 1000.0 |
|
else: |
|
depth_gt = depth_gt / 256.0 |
|
|
|
if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width: |
|
image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width) |
|
image, depth_gt = self.train_preprocess(image, depth_gt) |
|
image, depth_gt = self.Cut_Flip(image, depth_gt) |
|
sample = {'image': image, 'depth': depth_gt, 'focal': focal} |
|
|
|
else: |
|
if self.mode == 'online_eval': |
|
data_path = self.args.data_path_eval |
|
else: |
|
data_path = self.args.data_path |
|
|
|
image_path = os.path.join(data_path, "./" + sample_path.split()[0]) |
|
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0 |
|
image = cv2.resize(image, (640, 480)) |
|
|
|
if self.mode == 'online_eval': |
|
gt_path = self.args.gt_path_eval |
|
depth_path = os.path.join(gt_path, "./" + sample_path.split()[1]) |
|
if self.args.dataset == 'kitti': |
|
depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1]) |
|
has_valid_depth = False |
|
try: |
|
depth_gt = Image.open(depth_path) |
|
has_valid_depth = True |
|
except IOError: |
|
depth_gt = False |
|
|
|
|
|
if has_valid_depth: |
|
depth_gt = np.asarray(depth_gt, dtype=np.uint16) |
|
depth_gt = np.bitwise_or(np.right_shift(depth_gt, 3), np.left_shift(depth_gt, 16 - 3)) |
|
depth_gt = np.expand_dims(depth_gt, axis=2) |
|
if self.args.dataset == 'nyu': |
|
depth_gt = depth_gt.astype(np.single) / 1000 |
|
depth_gt = depth_gt.astype(np.float32) |
|
else: |
|
depth_gt = depth_gt / 256.0 |
|
|
|
if self.args.do_kb_crop is True: |
|
height = image.shape[0] |
|
width = image.shape[1] |
|
top_margin = int(height - 352) |
|
left_margin = int((width - 1216) / 2) |
|
image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :] |
|
if self.mode == 'online_eval' and has_valid_depth: |
|
depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :] |
|
|
|
if self.mode == 'online_eval': |
|
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth} |
|
else: |
|
sample = {'image': image, 'focal': focal} |
|
|
|
if self.transform: |
|
sample = self.transform(sample) |
|
|
|
return sample |
|
|
|
def rotate_image(self, image, angle, flag=Image.BILINEAR): |
|
result = image.rotate(angle, resample=flag) |
|
return result |
|
|
|
def random_crop(self, img, depth, height, width): |
|
assert img.shape[0] >= height |
|
assert img.shape[1] >= width |
|
assert img.shape[0] == depth.shape[0] |
|
assert img.shape[1] == depth.shape[1] |
|
x = random.randint(0, img.shape[1] - width) |
|
y = random.randint(0, img.shape[0] - height) |
|
img = img[y:y + height, x:x + width, :] |
|
depth = depth[y:y + height, x:x + width, :] |
|
return img, depth |
|
|
|
def train_preprocess(self, image, depth_gt): |
|
|
|
do_flip = random.random() |
|
if do_flip > 0.5: |
|
image = (image[:, ::-1, :]).copy() |
|
depth_gt = (depth_gt[:, ::-1, :]).copy() |
|
|
|
|
|
do_augment = random.random() |
|
if do_augment > 0.5: |
|
image = self.augment_image(image) |
|
|
|
return image, depth_gt |
|
|
|
def augment_image(self, image): |
|
|
|
gamma = random.uniform(0.9, 1.1) |
|
image_aug = image ** gamma |
|
|
|
|
|
if self.args.dataset == 'nyu': |
|
brightness = random.uniform(0.75, 1.25) |
|
else: |
|
brightness = random.uniform(0.9, 1.1) |
|
image_aug = image_aug * brightness |
|
|
|
|
|
colors = np.random.uniform(0.9, 1.1, size=3) |
|
white = np.ones((image.shape[0], image.shape[1])) |
|
color_image = np.stack([white * colors[i] for i in range(3)], axis=2) |
|
image_aug *= color_image |
|
image_aug = np.clip(image_aug, 0, 1) |
|
|
|
return image_aug |
|
|
|
def Cut_Flip(self, image, depth): |
|
|
|
p = random.random() |
|
if p < 0.5: |
|
return image, depth |
|
image_copy = copy.deepcopy(image) |
|
depth_copy = copy.deepcopy(depth) |
|
h, w, c = image.shape |
|
|
|
N = 2 |
|
h_list = [] |
|
h_interval_list = [] |
|
for i in range(N-1): |
|
h_list.append(random.randint(int(0.2*h), int(0.8*h))) |
|
h_list.append(h) |
|
h_list.append(0) |
|
h_list.sort() |
|
h_list_inv = np.array([h]*(N+1))-np.array(h_list) |
|
for i in range(len(h_list)-1): |
|
h_interval_list.append(h_list[i+1]-h_list[i]) |
|
for i in range(N): |
|
image[h_list[i]:h_list[i+1], :, :] = image_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :] |
|
depth[h_list[i]:h_list[i+1], :, :] = depth_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :] |
|
|
|
return image, depth |
|
|
|
|
|
def __len__(self): |
|
return len(self.filenames) |
|
|
|
|
|
class ToTensor(object): |
|
def __init__(self, mode): |
|
self.mode = mode |
|
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
def __call__(self, sample): |
|
image, focal = sample['image'], sample['focal'] |
|
image = self.to_tensor(image) |
|
image = self.normalize(image) |
|
|
|
if self.mode == 'test': |
|
return {'image': image, 'focal': focal} |
|
|
|
depth = sample['depth'] |
|
if self.mode == 'train': |
|
depth = self.to_tensor(depth) |
|
return {'image': image, 'depth': depth, 'focal': focal} |
|
else: |
|
has_valid_depth = sample['has_valid_depth'] |
|
return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth} |
|
|
|
def to_tensor(self, pic): |
|
if not (_is_pil_image(pic) or _is_numpy_image(pic)): |
|
raise TypeError( |
|
'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) |
|
|
|
if isinstance(pic, np.ndarray): |
|
img = torch.from_numpy(pic.transpose((2, 0, 1))) |
|
return img |
|
|
|
|
|
if pic.mode == 'I': |
|
img = torch.from_numpy(np.array(pic, np.int32, copy=False)) |
|
elif pic.mode == 'I;16': |
|
img = torch.from_numpy(np.array(pic, np.int16, copy=False)) |
|
else: |
|
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) |
|
|
|
if pic.mode == 'YCbCr': |
|
nchannel = 3 |
|
elif pic.mode == 'I;16': |
|
nchannel = 1 |
|
else: |
|
nchannel = len(pic.mode) |
|
img = img.view(pic.size[1], pic.size[0], nchannel) |
|
|
|
img = img.transpose(0, 1).transpose(0, 2).contiguous() |
|
if isinstance(img, torch.ByteTensor): |
|
return img.float() |
|
else: |
|
return img |
|
|