Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,252 Bytes
5f093a6 5833474 5f093a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import math
from pathlib import Path
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import sys
def get_pose(transformation):
# transformation: 4x4
return transformation
class ObjaverseData(Dataset):
def __init__(self,
root_dir='.objaverse/hf-objaverse-v1/views',
image_transforms=None,
total_view=12,
validation=False,
T_in=1,
T_out=1,
fix_sample=False,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = Path(root_dir)
self.total_view = total_view
self.T_in = T_in
self.T_out = T_out
self.fix_sample = fix_sample
self.paths = []
# # include all folders
# for folder in os.listdir(self.root_dir):
# if os.path.isdir(os.path.join(self.root_dir, folder)):
# self.paths.append(folder)
# load ids from .npy so we have exactly the same ids/order
self.paths = np.load("../scripts/obj_ids.npy")
# # only use 100K objects for ablation study
# self.paths = self.paths[:100000]
total_objects = len(self.paths)
assert total_objects == 790152, 'total objects %d' % total_objects
if validation:
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
else:
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
print('============= length of dataset %d =============' % len(self.paths))
self.tform = image_transforms
downscale = 512 / 256.
self.fx = 560. / downscale
self.fy = 560. / downscale
self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
def __len__(self):
return len(self.paths)
def get_pose(self, transformation):
# transformation: 4x4
return transformation
def load_im(self, path, color):
'''
replace background pixel with random color in rendering
'''
try:
img = plt.imread(path)
except:
print(path)
sys.exit()
img[img[:, :, -1] == 0.] = color
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
return img
def __getitem__(self, index):
data = {}
total_view = 12
if self.fix_sample:
if self.T_out > 1:
indexes = range(total_view)
index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
index_inputs = indexes[1:self.T_in+1] # one overlap identity
else:
indexes = range(total_view)
index_targets = indexes[:self.T_out]
index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
else:
assert self.T_in + self.T_out <= total_view
# training with replace, including identity
indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
index_inputs = indexes[:self.T_in]
index_targets = indexes[self.T_in:]
filename = os.path.join(self.root_dir, self.paths[index])
color = [1., 1., 1., 1.]
try:
input_ims = []
target_ims = []
target_Ts = []
cond_Ts = []
for i, index_input in enumerate(index_inputs):
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
input_ims.append(input_im)
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
for i, index_target in enumerate(index_targets):
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
target_ims.append(target_im)
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
except:
print('error loading data ', filename)
filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
input_ims = []
target_ims = []
target_Ts = []
cond_Ts = []
# very hacky solution, sorry about this
for i, index_input in enumerate(index_inputs):
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
input_ims.append(input_im)
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
for i, index_target in enumerate(index_targets):
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
target_ims.append(target_im)
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
# stack to batch
data['image_input'] = torch.stack(input_ims, dim=0)
data['image_target'] = torch.stack(target_ims, dim=0)
data['pose_out'] = np.stack(target_Ts)
data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1])
data['pose_in'] = np.stack(cond_Ts)
data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1])
return data
def process_im(self, im):
im = im.convert("RGB")
return self.tform(im) |