trumans /
jnnan's picture
Upload 68 files
aeba71c verified
history blame
4.18 kB
import torch
import hydra
import numpy as np
from einops import rearrange
import random
import os
def seed_everything(seed: int):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def transform_points(x, mat):
shape = x.shape
x = rearrange(x, 'b t (j c) -> b (t j) c', c=3) # B x N x 3
x = torch.einsum('bpc,bck->bpk', mat[:, :3, :3], x.permute(0, 2, 1)) # B x 3 x N N x B x 3
x = x.permute(2, 0, 1) + mat[:, :3, 3]
x = x.permute(1, 0, 2)
x = x.reshape(shape)
return x
def create_meshgrid(bbox, size, batch_size=1):
x = torch.linspace(bbox[0], bbox[1], size[0])
y = torch.linspace(bbox[2], bbox[3], size[1])
z = torch.linspace(bbox[4], bbox[5], size[2])
xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
grid = torch.stack([xx, yy, zz], dim=-1).reshape(-1, 3)
grid = grid.repeat(batch_size, 1, 1)
# aug_z = 0.75 + torch.rand(batch_size, 1) * 0.35
# grid[:, :, 2] = grid[:, :, 2] * aug_z
return grid
def zup_to_yup(coord):
# change the coordinate from yup to zup
if len(coord.shape) > 1:
coord = coord[..., [0, 2, 1]]
coord[..., 2] *= -1
coord = coord[[0, 2, 1]]
coord[2] *= -1
return coord
def rigid_transform_3D(A, B, scale=False):
assert len(A) == len(B)
N = A.shape[0] # total points
centroid_A = np.mean(A, axis=0)
centroid_B = np.mean(B, axis=0)
# center the points
AA = A - np.tile(centroid_A, (N, 1))
BB = B - np.tile(centroid_B, (N, 1))
# dot is matrix multiplication for array
if scale:
H = np.transpose(BB) * AA / N
H = np.transpose(BB) * AA
U, S, Vt = np.linalg.svd(H)
R = Vt.T * U.T
# special reflection case
if np.linalg.det(R) < 0:
print("Reflection detected")
# return None, None, None
Vt[2, :] *= -1
R = Vt.T * U.T
if scale:
varA = np.var(A, axis=0).sum()
c = 1 / (1 / varA * np.sum(S)) # scale factor
t = -R * (centroid_B.T * c) + centroid_A.T
c = 1
t = -R * centroid_B.T + centroid_A.T
return c, R, t
def find_free_port():
from contextlib import closing
import socket
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return str(s.getsockname()[1])
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def init_model(model_cfg, device, eval, load_state_dict=False):
model = hydra.utils.instantiate(model_cfg)
if eval:
load_state_dict_eval(model, model_cfg.ckpt, device=device)
model =
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], broadcast_buffers=False,
if load_state_dict:
return model
def load_state_dict_eval(model, state_dict_path, map_location='cuda:0', device='cuda'):
state_dict = torch.load(state_dict_path, map_location=map_location)
key_list = [key for key in state_dict.keys()]
for old_key in key_list:
new_key = old_key.replace('module.', '')
state_dict[new_key] = state_dict.pop(old_key)
class dotDict(dict):
"""dot.notation access to dictionary attributes"""
def __getattr__(*args):
val = dict.get(*args)
return dotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__