EscherNet / 4DoF /dataset.py
kxic's picture
Upload folder using huggingface_hub
e371ddd verified
raw
history blame
10.2 kB
import os
import math
from pathlib import Path
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import webdataset as wds
from torch.utils.data.distributed import DistributedSampler
import matplotlib.pyplot as plt
import sys
class ObjaverseDataLoader():
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.total_view = total_view
image_transforms = [torchvision.transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]
self.image_transforms = torchvision.transforms.Compose(image_transforms)
def train_dataloader(self):
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
image_transforms=self.image_transforms)
# sampler = DistributedSampler(dataset)
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
# sampler=sampler)
def val_dataloader(self):
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
image_transforms=self.image_transforms)
sampler = DistributedSampler(dataset)
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
def cartesian_to_spherical(xyz):
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
z = np.sqrt(xy + xyz[:, 2] ** 2)
theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
# ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
return np.array([theta, azimuth, z])
def get_pose(target_RT):
target_RT = target_RT[:3, :]
R, T = target_RT[:3, :3], target_RT[:, -1]
T_target = -R.T @ T
theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
# assert if z_target is out of range
if z_target.item() < 1.5 or z_target.item() > 2.2:
# print('z_target out of range 1.5-2.2', z_target.item())
z_target = np.clip(z_target.item(), 1.5, 2.2)
# with log scale for radius
target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)])
assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi)
return target_T.numpy()
class ObjaverseData(Dataset):
def __init__(self,
root_dir='.objaverse/hf-objaverse-v1/views',
image_transforms=None,
total_view=12,
validation=False,
T_in=1,
T_out=1,
fix_sample=False,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = Path(root_dir)
self.total_view = total_view
self.T_in = T_in
self.T_out = T_out
self.fix_sample = fix_sample
self.paths = []
# # include all folders
# for folder in os.listdir(self.root_dir):
# if os.path.isdir(os.path.join(self.root_dir, folder)):
# self.paths.append(folder)
# load ids from .npy so we have exactly the same ids/order
self.paths = np.load("../scripts/obj_ids.npy")
# # only use 100K objects for ablation study
# self.paths = self.paths[:100000]
total_objects = len(self.paths)
assert total_objects == 790152, 'total objects %d' % total_objects
if validation:
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
else:
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
print('============= length of dataset %d =============' % len(self.paths))
self.tform = image_transforms
downscale = 512 / 256.
self.fx = 560. / downscale
self.fy = 560. / downscale
self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
def __len__(self):
return len(self.paths)
def cartesian_to_spherical(self, xyz):
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
z = np.sqrt(xy + xyz[:, 2] ** 2)
theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
# ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
return np.array([theta, azimuth, z])
def get_T(self, target_RT, cond_RT):
R, T = target_RT[:3, :3], target_RT[:, -1]
T_target = -R.T @ T
R, T = cond_RT[:3, :3], cond_RT[:, -1]
T_cond = -R.T @ T
theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
d_theta = theta_target - theta_cond
d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
d_z = z_target - z_cond
d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
return d_T
def get_pose(self, target_RT):
R, T = target_RT[:3, :3], target_RT[:, -1]
T_target = -R.T @ T
theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
# assert if z_target is out of range
if z_target.item() < 1.5 or z_target.item() > 2.2:
# print('z_target out of range 1.5-2.2', z_target.item())
z_target = np.clip(z_target.item(), 1.5, 2.2)
# with log scale for radius
target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)])
assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi)
return target_T
def load_im(self, path, color):
'''
replace background pixel with random color in rendering
'''
try:
img = plt.imread(path)
except:
print(path)
sys.exit()
img[img[:, :, -1] == 0.] = color
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
return img
def __getitem__(self, index):
data = {}
total_view = 12
if self.fix_sample:
if self.T_out > 1:
indexes = range(total_view)
index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
index_inputs = indexes[1:self.T_in+1] # one overlap identity
else:
indexes = range(total_view)
index_targets = indexes[:self.T_out]
index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
else:
assert self.T_in + self.T_out <= total_view
# training with replace, including identity
indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
index_inputs = indexes[:self.T_in]
index_targets = indexes[self.T_in:]
filename = os.path.join(self.root_dir, self.paths[index])
color = [1., 1., 1., 1.]
try:
input_ims = []
target_ims = []
target_Ts = []
cond_Ts = []
for i, index_input in enumerate(index_inputs):
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
input_ims.append(input_im)
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
cond_Ts.append(self.get_pose(input_RT))
for i, index_target in enumerate(index_targets):
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
target_ims.append(target_im)
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
target_Ts.append(self.get_pose(target_RT))
except:
print('error loading data ', filename)
filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
input_ims = []
target_ims = []
target_Ts = []
cond_Ts = []
# very hacky solution, sorry about this
for i, index_input in enumerate(index_inputs):
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
input_ims.append(input_im)
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
cond_Ts.append(self.get_pose(input_RT))
for i, index_target in enumerate(index_targets):
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
target_ims.append(target_im)
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
target_Ts.append(self.get_pose(target_RT))
# stack to batch
data['image_input'] = torch.stack(input_ims, dim=0)
data['image_target'] = torch.stack(target_ims, dim=0)
data['pose_out'] = torch.stack(target_Ts, dim=0)
data['pose_in'] = torch.stack(cond_Ts, dim=0)
return data
def process_im(self, im):
im = im.convert("RGB")
return self.tform(im)