|
import torch |
|
import numpy as np |
|
import glob |
|
import os |
|
import io |
|
import random |
|
import pickle |
|
from torch.utils.data import Dataset, DataLoader |
|
from lib.data.augmentation import Augmenter3D |
|
from lib.utils.tools import read_pkl |
|
from lib.utils.utils_data import flip_data |
|
|
|
class MotionDataset(Dataset): |
|
def __init__(self, args, subset_list, data_split): |
|
np.random.seed(0) |
|
self.data_root = args.data_root |
|
self.subset_list = subset_list |
|
self.data_split = data_split |
|
file_list_all = [] |
|
for subset in self.subset_list: |
|
data_path = os.path.join(self.data_root, subset, self.data_split) |
|
motion_list = sorted(os.listdir(data_path)) |
|
for i in motion_list: |
|
file_list_all.append(os.path.join(data_path, i)) |
|
self.file_list = file_list_all |
|
|
|
def __len__(self): |
|
'Denotes the total number of samples' |
|
return len(self.file_list) |
|
|
|
def __getitem__(self, index): |
|
raise NotImplementedError |
|
|
|
class MotionDataset3D(MotionDataset): |
|
def __init__(self, args, subset_list, data_split): |
|
super(MotionDataset3D, self).__init__(args, subset_list, data_split) |
|
self.flip = args.flip |
|
self.synthetic = args.synthetic |
|
self.aug = Augmenter3D(args) |
|
self.gt_2d = args.gt_2d |
|
|
|
def __getitem__(self, index): |
|
'Generates one sample of data' |
|
|
|
file_path = self.file_list[index] |
|
motion_file = read_pkl(file_path) |
|
motion_3d = motion_file["data_label"] |
|
if self.data_split=="train": |
|
if self.synthetic or self.gt_2d: |
|
motion_3d = self.aug.augment3D(motion_3d) |
|
motion_2d = np.zeros(motion_3d.shape, dtype=np.float32) |
|
motion_2d[:,:,:2] = motion_3d[:,:,:2] |
|
motion_2d[:,:,2] = 1 |
|
elif motion_file["data_input"] is not None: |
|
motion_2d = motion_file["data_input"] |
|
if self.flip and random.random() > 0.5: |
|
motion_2d = flip_data(motion_2d) |
|
motion_3d = flip_data(motion_3d) |
|
else: |
|
raise ValueError('Training illegal.') |
|
elif self.data_split=="test": |
|
motion_2d = motion_file["data_input"] |
|
if self.gt_2d: |
|
motion_2d[:,:,:2] = motion_3d[:,:,:2] |
|
motion_2d[:,:,2] = 1 |
|
else: |
|
raise ValueError('Data split unknown.') |
|
return torch.FloatTensor(motion_2d), torch.FloatTensor(motion_3d) |