Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import math | |
from pathlib import Path | |
import torch | |
import torchvision | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
import webdataset as wds | |
from torch.utils.data.distributed import DistributedSampler | |
import matplotlib.pyplot as plt | |
import sys | |
class ObjaverseDataLoader(): | |
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4): | |
self.root_dir = root_dir | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.total_view = total_view | |
image_transforms = [torchvision.transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5])] | |
self.image_transforms = torchvision.transforms.Compose(image_transforms) | |
def train_dataloader(self): | |
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, | |
image_transforms=self.image_transforms) | |
# sampler = DistributedSampler(dataset) | |
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
# sampler=sampler) | |
def val_dataloader(self): | |
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, | |
image_transforms=self.image_transforms) | |
sampler = DistributedSampler(dataset) | |
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
def cartesian_to_spherical(xyz): | |
ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) | |
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2 | |
z = np.sqrt(xy + xyz[:, 2] ** 2) | |
theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down | |
# ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up | |
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0]) | |
return np.array([theta, azimuth, z]) | |
def get_pose(target_RT): | |
target_RT = target_RT[:3, :] | |
R, T = target_RT[:3, :3], target_RT[:, -1] | |
T_target = -R.T @ T | |
theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) | |
# assert if z_target is out of range | |
if z_target.item() < 1.5 or z_target.item() > 2.2: | |
# print('z_target out of range 1.5-2.2', z_target.item()) | |
z_target = np.clip(z_target.item(), 1.5, 2.2) | |
# with log scale for radius | |
target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)]) | |
assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi) | |
return target_T.numpy() | |
class ObjaverseData(Dataset): | |
def __init__(self, | |
root_dir='.objaverse/hf-objaverse-v1/views', | |
image_transforms=None, | |
total_view=12, | |
validation=False, | |
T_in=1, | |
T_out=1, | |
fix_sample=False, | |
) -> None: | |
"""Create a dataset from a folder of images. | |
If you pass in a root directory it will be searched for images | |
ending in ext (ext can be a list) | |
""" | |
self.root_dir = Path(root_dir) | |
self.total_view = total_view | |
self.T_in = T_in | |
self.T_out = T_out | |
self.fix_sample = fix_sample | |
self.paths = [] | |
# # include all folders | |
# for folder in os.listdir(self.root_dir): | |
# if os.path.isdir(os.path.join(self.root_dir, folder)): | |
# self.paths.append(folder) | |
# load ids from .npy so we have exactly the same ids/order | |
self.paths = np.load("../scripts/obj_ids.npy") | |
# # only use 100K objects for ablation study | |
# self.paths = self.paths[:100000] | |
total_objects = len(self.paths) | |
assert total_objects == 790152, 'total objects %d' % total_objects | |
if validation: | |
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation | |
else: | |
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training | |
print('============= length of dataset %d =============' % len(self.paths)) | |
self.tform = image_transforms | |
downscale = 512 / 256. | |
self.fx = 560. / downscale | |
self.fy = 560. / downscale | |
self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3) | |
def __len__(self): | |
return len(self.paths) | |
def cartesian_to_spherical(self, xyz): | |
ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) | |
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2 | |
z = np.sqrt(xy + xyz[:, 2] ** 2) | |
theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down | |
# ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up | |
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0]) | |
return np.array([theta, azimuth, z]) | |
def get_T(self, target_RT, cond_RT): | |
R, T = target_RT[:3, :3], target_RT[:, -1] | |
T_target = -R.T @ T | |
R, T = cond_RT[:3, :3], cond_RT[:, -1] | |
T_cond = -R.T @ T | |
theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) | |
theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) | |
d_theta = theta_target - theta_cond | |
d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) | |
d_z = z_target - z_cond | |
d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) | |
return d_T | |
def get_pose(self, target_RT): | |
R, T = target_RT[:3, :3], target_RT[:, -1] | |
T_target = -R.T @ T | |
theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) | |
# assert if z_target is out of range | |
if z_target.item() < 1.5 or z_target.item() > 2.2: | |
# print('z_target out of range 1.5-2.2', z_target.item()) | |
z_target = np.clip(z_target.item(), 1.5, 2.2) | |
# with log scale for radius | |
target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)]) | |
assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi) | |
return target_T | |
def load_im(self, path, color): | |
''' | |
replace background pixel with random color in rendering | |
''' | |
try: | |
img = plt.imread(path) | |
except: | |
print(path) | |
sys.exit() | |
img[img[:, :, -1] == 0.] = color | |
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) | |
return img | |
def __getitem__(self, index): | |
data = {} | |
total_view = 12 | |
if self.fix_sample: | |
if self.T_out > 1: | |
indexes = range(total_view) | |
index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):]) | |
index_inputs = indexes[1:self.T_in+1] # one overlap identity | |
else: | |
indexes = range(total_view) | |
index_targets = indexes[:self.T_out] | |
index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity | |
else: | |
assert self.T_in + self.T_out <= total_view | |
# training with replace, including identity | |
indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True) | |
index_inputs = indexes[:self.T_in] | |
index_targets = indexes[self.T_in:] | |
filename = os.path.join(self.root_dir, self.paths[index]) | |
color = [1., 1., 1., 1.] | |
try: | |
input_ims = [] | |
target_ims = [] | |
target_Ts = [] | |
cond_Ts = [] | |
for i, index_input in enumerate(index_inputs): | |
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color)) | |
input_ims.append(input_im) | |
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input)) | |
cond_Ts.append(self.get_pose(input_RT)) | |
for i, index_target in enumerate(index_targets): | |
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) | |
target_ims.append(target_im) | |
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) | |
target_Ts.append(self.get_pose(target_RT)) | |
except: | |
print('error loading data ', filename) | |
filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid | |
input_ims = [] | |
target_ims = [] | |
target_Ts = [] | |
cond_Ts = [] | |
# very hacky solution, sorry about this | |
for i, index_input in enumerate(index_inputs): | |
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color)) | |
input_ims.append(input_im) | |
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input)) | |
cond_Ts.append(self.get_pose(input_RT)) | |
for i, index_target in enumerate(index_targets): | |
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) | |
target_ims.append(target_im) | |
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) | |
target_Ts.append(self.get_pose(target_RT)) | |
# stack to batch | |
data['image_input'] = torch.stack(input_ims, dim=0) | |
data['image_target'] = torch.stack(target_ims, dim=0) | |
data['pose_out'] = torch.stack(target_Ts, dim=0) | |
data['pose_in'] = torch.stack(cond_Ts, dim=0) | |
return data | |
def process_im(self, im): | |
im = im.convert("RGB") | |
return self.tform(im) |