File size: 2,777 Bytes
bbde80b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
import numpy as np
import glob
import os
import io
import random
import pickle
from torch.utils.data import Dataset, DataLoader
from lib.data.augmentation import Augmenter3D
from lib.utils.tools import read_pkl
from lib.utils.utils_data import flip_data
class MotionDataset(Dataset):
def __init__(self, args, subset_list, data_split): # data_split: train/test
np.random.seed(0)
self.data_root = args.data_root
self.subset_list = subset_list
self.data_split = data_split
file_list_all = []
for subset in self.subset_list:
data_path = os.path.join(self.data_root, subset, self.data_split)
motion_list = sorted(os.listdir(data_path))
for i in motion_list:
file_list_all.append(os.path.join(data_path, i))
self.file_list = file_list_all
def __len__(self):
'Denotes the total number of samples'
return len(self.file_list)
def __getitem__(self, index):
raise NotImplementedError
class MotionDataset3D(MotionDataset):
def __init__(self, args, subset_list, data_split):
super(MotionDataset3D, self).__init__(args, subset_list, data_split)
self.flip = args.flip
self.synthetic = args.synthetic
self.aug = Augmenter3D(args)
self.gt_2d = args.gt_2d
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
file_path = self.file_list[index]
motion_file = read_pkl(file_path)
motion_3d = motion_file["data_label"]
if self.data_split=="train":
if self.synthetic or self.gt_2d:
motion_3d = self.aug.augment3D(motion_3d)
motion_2d = np.zeros(motion_3d.shape, dtype=np.float32)
motion_2d[:,:,:2] = motion_3d[:,:,:2]
motion_2d[:,:,2] = 1 # No 2D detection, use GT xy and c=1.
elif motion_file["data_input"] is not None: # Have 2D detection
motion_2d = motion_file["data_input"]
if self.flip and random.random() > 0.5: # Training augmentation - random flipping
motion_2d = flip_data(motion_2d)
motion_3d = flip_data(motion_3d)
else:
raise ValueError('Training illegal.')
elif self.data_split=="test":
motion_2d = motion_file["data_input"]
if self.gt_2d:
motion_2d[:,:,:2] = motion_3d[:,:,:2]
motion_2d[:,:,2] = 1
else:
raise ValueError('Data split unknown.')
return torch.FloatTensor(motion_2d), torch.FloatTensor(motion_3d) |