kxic's picture
Upload folder using huggingface_hub
e371ddd verified
raw
history blame
28.8 kB
import abc
import os
from pathlib import Path
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from skimage.io import imread, imsave
from PIL import Image
from torch.optim.lr_scheduler import LambdaLR
from renderer.neus_networks import SDFNetwork, RenderingNetwork, SingleVarianceNetwork, SDFHashGridNetwork, RenderingFFNetwork
from renderer.ngp_renderer import NGPNetwork
from util import instantiate_from_config, read_pickle, concat_images_list
DEFAULT_RADIUS = np.sqrt(3)/2
DEFAULT_SIDE_LENGTH = 0.6
def sample_pdf(bins, weights, n_samples, det=True):
device = bins.device
dtype = bins.dtype
# This implementation is from NeRF
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples, dtype=dtype, device=device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples], dtype=dtype, device=device)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def near_far_from_sphere(rays_o, rays_d, radius=DEFAULT_RADIUS):
a = torch.sum(rays_d ** 2, dim=-1, keepdim=True)
b = torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
mid = -b / a
near = mid - radius
far = mid + radius
return near, far
class BackgroundRemoval:
def __init__(self, device='cuda'):
from carvekit.api.high import HiInterface
self.interface = HiInterface(
object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device=device,
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=True,
)
@torch.no_grad()
def __call__(self, image):
# image: [H, W, 3] array in [0, 255].
image = Image.fromarray(image)
image = self.interface([image])[0]
image = np.array(image)
return image
class BaseRenderer(nn.Module):
def __init__(self, train_batch_num, test_batch_num):
super().__init__()
self.train_batch_num = train_batch_num
self.test_batch_num = test_batch_num
@abc.abstractmethod
def render_impl(self, ray_batch, is_train, step):
pass
@abc.abstractmethod
def render_with_loss(self, ray_batch, is_train, step):
pass
def render(self, ray_batch, is_train, step):
batch_num = self.train_batch_num if is_train else self.test_batch_num
ray_num = ray_batch['rays_o'].shape[0]
outputs = {}
for ri in range(0, ray_num, batch_num):
cur_ray_batch = {}
for k, v in ray_batch.items():
cur_ray_batch[k] = v[ri:ri + batch_num]
cur_outputs = self.render_impl(cur_ray_batch, is_train, step)
for k, v in cur_outputs.items():
if k not in outputs: outputs[k] = []
outputs[k].append(v)
for k, v in outputs.items():
outputs[k] = torch.cat(v, 0)
return outputs
class NeuSRenderer(BaseRenderer):
def __init__(self, train_batch_num, test_batch_num, lambda_eikonal_loss=0.1, use_mask=True,
lambda_rgb_loss=1.0, lambda_mask_loss=0.0, rgb_loss='soft_l1', coarse_sn=64, fine_sn=64):
super().__init__(train_batch_num, test_batch_num)
self.n_samples = coarse_sn
self.n_importance = fine_sn
self.up_sample_steps = 4
self.anneal_end = 200
self.use_mask = use_mask
self.lambda_eikonal_loss = lambda_eikonal_loss
self.lambda_rgb_loss = lambda_rgb_loss
self.lambda_mask_loss = lambda_mask_loss
self.rgb_loss = rgb_loss
self.sdf_network = SDFNetwork(d_out=257, d_in=3, d_hidden=256, n_layers=8, skip_in=[4], multires=6, bias=0.5, scale=1.0, geometric_init=True, weight_norm=True)
self.color_network = RenderingNetwork(d_feature=256, d_in=9, d_out=3, d_hidden=256, n_layers=4, weight_norm=True, multires_view=4, squeeze_out=True)
self.default_dtype = torch.float32
self.deviation_network = SingleVarianceNetwork(0.3)
@torch.no_grad()
def get_vertex_colors(self, vertices):
"""
@param vertices: n,3
@return:
"""
V = vertices.shape[0]
bn = 20480
verts_colors = []
with torch.no_grad():
for vi in range(0, V, bn):
verts = torch.from_numpy(vertices[vi:vi+bn].astype(np.float32)).cuda()
feats = self.sdf_network(verts)[..., 1:]
gradients = self.sdf_network.gradient(verts) # ...,3
gradients = F.normalize(gradients, dim=-1)
colors = self.color_network(verts, gradients, gradients, feats)
colors = torch.clamp(colors,min=0,max=1).cpu().numpy()
verts_colors.append(colors)
verts_colors = (np.concatenate(verts_colors, 0)*255).astype(np.uint8)
return verts_colors
def upsample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
"""
Up sampling give a fixed inv_s
"""
device = rays_o.device
batch_size, n_samples = z_vals.shape
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
inner_mask = self.get_inner_mask(pts)
# radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
inside_sphere = inner_mask[:, :-1] | inner_mask[:, 1:]
sdf = sdf.reshape(batch_size, n_samples)
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
mid_sdf = (prev_sdf + next_sdf) * 0.5
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
prev_cos_val = torch.cat([torch.zeros([batch_size, 1], dtype=self.default_dtype, device=device), cos_val[:, :-1]], dim=-1)
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
dist = (next_z_vals - prev_z_vals)
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
weights = alpha * torch.cumprod(
torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[:, :-1]
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
return z_samples
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
batch_size, n_samples = z_vals.shape
_, n_importance = new_z_vals.shape
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
z_vals, index = torch.sort(z_vals, dim=-1)
if not last:
device = pts.device
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
sdf = torch.cat([sdf, new_sdf], dim=-1)
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1).to(device)
index = index.reshape(-1)
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
return z_vals, sdf
def sample_depth(self, rays_o, rays_d, near, far, perturb):
n_samples = self.n_samples
n_importance = self.n_importance
up_sample_steps = self.up_sample_steps
device = rays_o.device
# sample points
batch_size = len(rays_o)
z_vals = torch.linspace(0.0, 1.0, n_samples, dtype=self.default_dtype, device=device) # sn
z_vals = near + (far - near) * z_vals[None, :] # rn,sn
if perturb > 0:
t_rand = (torch.rand([batch_size, 1]).to(device) - 0.5)
z_vals = z_vals + t_rand * 2.0 / n_samples
# Up sample
with torch.no_grad():
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
sdf = self.sdf_network.sdf(pts).reshape(batch_size, n_samples)
for i in range(up_sample_steps):
rn, sn = z_vals.shape
inv_s = torch.ones(rn, sn - 1, dtype=self.default_dtype, device=device) * 64 * 2 ** i
new_z_vals = self.upsample(rays_o, rays_d, z_vals, sdf, n_importance // up_sample_steps, inv_s)
z_vals, sdf = self.cat_z_vals(rays_o, rays_d, z_vals, new_z_vals, sdf, last=(i + 1 == up_sample_steps))
return z_vals
def compute_sdf_alpha(self, points, dists, dirs, cos_anneal_ratio, step):
# points [...,3] dists [...] dirs[...,3]
sdf_nn_output = self.sdf_network(points)
sdf = sdf_nn_output[..., 0]
feature_vector = sdf_nn_output[..., 1:]
gradients = self.sdf_network.gradient(points) # ...,3
inv_s = self.deviation_network(points).clip(1e-6, 1e6) # ...,1
inv_s = inv_s[..., 0]
true_cos = (dirs * gradients).sum(-1) # [...]
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
# Estimate signed distances at section points
estimated_next_sdf = sdf + iter_cos * dists * 0.5
estimated_prev_sdf = sdf - iter_cos * dists * 0.5
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
p = prev_cdf - next_cdf
c = prev_cdf
alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) # [...]
return alpha, gradients, feature_vector, inv_s, sdf
def get_anneal_val(self, step):
if self.anneal_end < 0:
return 1.0
else:
return np.min([1.0, step / self.anneal_end])
def get_inner_mask(self, points):
return torch.sum(torch.abs(points)<=DEFAULT_SIDE_LENGTH,-1)==3
def render_impl(self, ray_batch, is_train, step):
near, far = near_far_from_sphere(ray_batch['rays_o'], ray_batch['rays_d'])
rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
z_vals = self.sample_depth(rays_o, rays_d, near, far, is_train)
batch_size, n_samples = z_vals.shape
# section length in original space
dists = z_vals[..., 1:] - z_vals[..., :-1] # rn,sn-1
dists = torch.cat([dists, dists[..., -1:]], -1) # rn,sn
mid_z_vals = z_vals + dists * 0.5
points = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * mid_z_vals.unsqueeze(-1) # rn, sn, 3
inner_mask = self.get_inner_mask(points)
dirs = rays_d.unsqueeze(-2).expand(batch_size, n_samples, 3)
dirs = F.normalize(dirs, dim=-1)
device = rays_o.device
alpha, sampled_color, gradient_error, normal = torch.zeros(batch_size, n_samples, dtype=self.default_dtype, device=device), \
torch.zeros(batch_size, n_samples, 3, dtype=self.default_dtype, device=device), \
torch.zeros([batch_size, n_samples], dtype=self.default_dtype, device=device), \
torch.zeros([batch_size, n_samples, 3], dtype=self.default_dtype, device=device)
if torch.sum(inner_mask) > 0:
cos_anneal_ratio = self.get_anneal_val(step) if is_train else 1.0
alpha[inner_mask], gradients, feature_vector, inv_s, sdf = self.compute_sdf_alpha(points[inner_mask], dists[inner_mask], dirs[inner_mask], cos_anneal_ratio, step)
sampled_color[inner_mask] = self.color_network(points[inner_mask], gradients, -dirs[inner_mask], feature_vector)
# Eikonal loss
gradient_error[inner_mask] = (torch.linalg.norm(gradients, ord=2, dim=-1) - 1.0) ** 2 # rn,sn
normal[inner_mask] = F.normalize(gradients, dim=-1)
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[..., :-1] # rn,sn
mask = torch.sum(weights,dim=1).unsqueeze(-1) # rn,1
color = (sampled_color * weights[..., None]).sum(dim=1) + (1 - mask) # add white background
normal = (normal * weights[..., None]).sum(dim=1)
outputs = {
'rgb': color, # rn,3
'gradient_error': gradient_error, # rn,sn
'inner_mask': inner_mask, # rn,sn
'normal': normal, # rn,3
'mask': mask, # rn,1
}
return outputs
def render_with_loss(self, ray_batch, is_train, step):
render_outputs = self.render(ray_batch, is_train, step)
rgb_gt = ray_batch['rgb']
rgb_pr = render_outputs['rgb']
if self.rgb_loss == 'soft_l1':
epsilon = 0.001
rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
elif self.rgb_loss =='mse':
rgb_loss = F.mse_loss(rgb_pr, rgb_gt, reduction='none')
else:
raise NotImplementedError
rgb_loss = torch.mean(rgb_loss)
eikonal_loss = torch.sum(render_outputs['gradient_error'] * render_outputs['inner_mask']) / torch.sum(render_outputs['inner_mask'] + 1e-5)
loss = rgb_loss * self.lambda_rgb_loss + eikonal_loss * self.lambda_eikonal_loss
loss_batch = {
'eikonal': eikonal_loss,
'rendering': rgb_loss,
# 'mask': mask_loss,
}
if self.lambda_mask_loss>0 and self.use_mask:
mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none').mean()
loss += mask_loss * self.lambda_mask_loss
loss_batch['mask'] = mask_loss
return loss, loss_batch
class NeRFRenderer(BaseRenderer):
def __init__(self, train_batch_num, test_batch_num, bound=0.5, use_mask=False, lambda_rgb_loss=1.0, lambda_mask_loss=0.0):
super().__init__(train_batch_num, test_batch_num)
self.train_batch_num = train_batch_num
self.test_batch_num = test_batch_num
self.use_mask = use_mask
self.field = NGPNetwork(bound=bound)
self.update_interval = 16
self.fp16 = True
self.lambda_rgb_loss = lambda_rgb_loss
self.lambda_mask_loss = lambda_mask_loss
def render_impl(self, ray_batch, is_train, step):
rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d']
with torch.cuda.amp.autocast(enabled=self.fp16):
if step % self.update_interval==0:
self.field.update_extra_state()
outputs = self.field.render(rays_o, rays_d,)
renderings={
'rgb': outputs['image'],
'depth': outputs['depth'],
'mask': outputs['weights_sum'].unsqueeze(-1),
}
return renderings
def render_with_loss(self, ray_batch, is_train, step):
render_outputs = self.render(ray_batch, is_train, step)
rgb_gt = ray_batch['rgb']
rgb_pr = render_outputs['rgb']
epsilon = 0.001
rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon)
rgb_loss = torch.mean(rgb_loss)
loss = rgb_loss * self.lambda_rgb_loss
loss_batch = {'rendering': rgb_loss}
if self.use_mask:
mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none')
mask_loss = torch.mean(mask_loss)
loss = loss + mask_loss * self.lambda_mask_loss
loss_batch['mask'] = mask_loss
return loss, loss_batch
def cartesian_to_spherical(xyz):
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2
z = np.sqrt(xy + xyz[:, 2] ** 2)
theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down
# ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
azimuth = np.arctan2(xyz[:, 1], xyz[:, 0])
return np.array([theta, azimuth, z])
def get_pose(target_RT):
R, T = target_RT[:3, :3], target_RT[:, -1]
T_target = -R.T @ T
theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
return theta_target, azimuth_target, z_target
class RendererTrainer(pl.LightningModule):
def __init__(self, image_path, data_path, total_steps, warm_up_steps, log_dir, train_batch_fg_num=0,
use_cube_feats=False, cube_ckpt=None, cube_cfg=None, cube_bound=0.5,
train_batch_num=4096, test_batch_num=8192, use_warm_up=True, use_mask=True,
lambda_rgb_loss=1.0, lambda_mask_loss=0.0, renderer='neus',
# used in neus
lambda_eikonal_loss=0.1,
coarse_sn=64, fine_sn=64):
super().__init__()
self.num_images = 36 # todo ours 36, syncdreamer 16
self.image_size = 256
self.log_dir = log_dir
(Path(log_dir)/'images').mkdir(exist_ok=True, parents=True)
self.train_batch_num = train_batch_num
self.train_batch_fg_num = train_batch_fg_num
self.test_batch_num = test_batch_num
self.image_path = image_path
self.data_path = data_path
self.total_steps = total_steps
self.warm_up_steps = warm_up_steps
self.use_mask = use_mask
self.lambda_eikonal_loss = lambda_eikonal_loss
self.lambda_rgb_loss = lambda_rgb_loss
self.lambda_mask_loss = lambda_mask_loss
self.use_warm_up = use_warm_up
self.use_cube_feats, self.cube_cfg, self.cube_ckpt = use_cube_feats, cube_cfg, cube_ckpt
self._init_dataset()
if renderer=='neus':
self.renderer = NeuSRenderer(train_batch_num, test_batch_num,
lambda_rgb_loss=lambda_rgb_loss,
lambda_eikonal_loss=lambda_eikonal_loss,
lambda_mask_loss=lambda_mask_loss,
coarse_sn=coarse_sn, fine_sn=fine_sn)
elif renderer=='ngp':
self.renderer = NeRFRenderer(train_batch_num, test_batch_num, bound=cube_bound, use_mask=use_mask, lambda_mask_loss=lambda_mask_loss, lambda_rgb_loss=lambda_rgb_loss,)
else:
raise NotImplementedError
self.validation_index = 0
def _construct_ray_batch(self, images_info):
image_num = images_info['images'].shape[0]
_, h, w, _ = images_info['images'].shape
coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2
coords = coords.float()[None, :, :, :].repeat(image_num, 1, 1, 1) # imn,h,w,2
coords = coords.reshape(image_num, h * w, 2)
coords = torch.cat([coords, torch.ones(image_num, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3
# imn,h*w,3 @ imn,3,3 => imn,h*w,3
rays_d = coords @ torch.inverse(images_info['Ks']).permute(0, 2, 1)
poses = images_info['poses'] # imn,3,4
R, t = poses[:, :, :3], poses[:, :, 3:]
rays_d = rays_d @ R
rays_d = F.normalize(rays_d, dim=-1)
rays_o = -R.permute(0,2,1) @ t # imn,3,3 @ imn,3,1
rays_o = rays_o.permute(0, 2, 1).repeat(1, h*w, 1) # imn,h*w,3
ray_batch = {
'rgb': images_info['images'].reshape(image_num*h*w,3),
'mask': images_info['masks'].reshape(image_num*h*w,1),
'rays_o': rays_o.reshape(image_num*h*w,3).float(),
'rays_d': rays_d.reshape(image_num*h*w,3).float(),
}
return ray_batch
@staticmethod
def load_model(cfg, ckpt):
config = OmegaConf.load(cfg)
model = instantiate_from_config(config.model)
print(f'loading model from {ckpt} ...')
ckpt = torch.load(ckpt)
model.load_state_dict(ckpt['state_dict'])
model = model.cuda().eval()
return model
def _init_dataset(self):
mask_predictor = BackgroundRemoval()
# syncdreamer fixed 16 views
# self.K, self.azs, self.els, self.dists, self.poses = read_pickle(f'meta_info/camera-{self.num_images}.pkl')
# for ours+NeuS, we pre fix 36 views
self.K = np.array([[280.,0.,128.],[0.,280.,128.],[0.,0.,1.]], dtype=np.float32)
data_dir = os.path.join(self.data_path, "mario/render_sync_36_single/model/") # fixed 36 views
# get all files .npy
self.azs = []
self.els = []
self.dists = []
self.poses = []
for index in range(self.num_images):
pose = np.load(os.path.join(data_dir, "%03d.npy"%index))[:3, :] # in blender
self.poses.append(pose)
theta, azimuth, radius = get_pose(pose)
self.azs.append(azimuth)
self.els.append(theta)
self.dists.append(radius)
# stack to numpy along axis 0
self.azs = np.stack(self.azs, axis=0) # [25,]
self.els = np.stack(self.els, axis=0) # [25,]
self.dists = np.stack(self.dists, axis=0) # [25,]
self.poses = np.stack(self.poses, axis=0) # [25, 3, 4]
self.images_info = {'images': [] ,'masks': [], 'Ks': [], 'poses':[]}
img = imread(self.image_path)
for index in range(self.num_images):
rgb = np.copy(img[:,index*self.image_size:(index+1)*self.image_size,:])
# predict mask
if self.use_mask:
imsave(f'{self.log_dir}/input-{index}.png', rgb)
masked_image = mask_predictor(rgb)
imsave(f'{self.log_dir}/masked-{index}.png', masked_image)
mask = masked_image[:,:,3].astype(np.float32)/255
else:
h, w, _ = rgb.shape
mask = np.zeros([h,w], np.float32)
rgb = rgb.astype(np.float32)/255
K, pose = np.copy(self.K), self.poses[index]
self.images_info['images'].append(torch.from_numpy(rgb.astype(np.float32))) # h,w,3
self.images_info['masks'].append(torch.from_numpy(mask.astype(np.float32))) # h,w
self.images_info['Ks'].append(torch.from_numpy(K.astype(np.float32)))
self.images_info['poses'].append(torch.from_numpy(pose.astype(np.float32)))
for k, v in self.images_info.items(): self.images_info[k] = torch.stack(v, 0) # stack all values
self.train_batch = self._construct_ray_batch(self.images_info)
self.train_batch_pseudo_fg = {}
pseudo_fg_mask = torch.sum(self.train_batch['rgb']>0.99,1)!=3
for k, v in self.train_batch.items():
self.train_batch_pseudo_fg[k] = v[pseudo_fg_mask]
self.train_ray_fg_num = int(torch.sum(pseudo_fg_mask).cpu().numpy())
self.train_ray_num = self.num_images * self.image_size ** 2
self._shuffle_train_batch()
self._shuffle_train_fg_batch()
def _shuffle_train_batch(self):
self.train_batch_i = 0
shuffle_idxs = torch.randperm(self.train_ray_num, device='cpu') # shuffle
for k, v in self.train_batch.items():
self.train_batch[k] = v[shuffle_idxs]
def _shuffle_train_fg_batch(self):
self.train_batch_fg_i = 0
shuffle_idxs = torch.randperm(self.train_ray_fg_num, device='cpu') # shuffle
for k, v in self.train_batch_pseudo_fg.items():
self.train_batch_pseudo_fg[k] = v[shuffle_idxs]
def training_step(self, batch, batch_idx):
train_ray_batch = {k: v[self.train_batch_i:self.train_batch_i + self.train_batch_num].cuda() for k, v in self.train_batch.items()}
self.train_batch_i += self.train_batch_num
if self.train_batch_i + self.train_batch_num >= self.train_ray_num: self._shuffle_train_batch()
if self.train_batch_fg_num>0:
train_ray_batch_fg = {k: v[self.train_batch_fg_i:self.train_batch_fg_i+self.train_batch_fg_num].cuda() for k, v in self.train_batch_pseudo_fg.items()}
self.train_batch_fg_i += self.train_batch_fg_num
if self.train_batch_fg_i + self.train_batch_fg_num >= self.train_ray_fg_num: self._shuffle_train_fg_batch()
for k, v in train_ray_batch_fg.items():
train_ray_batch[k] = torch.cat([train_ray_batch[k], v], 0)
loss, loss_batch = self.renderer.render_with_loss(train_ray_batch, is_train=True, step=self.global_step)
self.log_dict(loss_batch, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
self.log('step', self.global_step, prog_bar=True, on_step=True, on_epoch=False, logger=False, rank_zero_only=True)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
return loss
def _slice_images_info(self, index):
return {k:v[index:index+1] for k, v in self.images_info.items()}
@torch.no_grad()
def validation_step(self, batch, batch_idx):
with torch.no_grad():
if self.global_rank==0:
# we output an rendering image
images_info = self._slice_images_info(self.validation_index)
self.validation_index += 1
self.validation_index %= self.num_images
test_ray_batch = self._construct_ray_batch(images_info)
test_ray_batch = {k: v.cuda() for k,v in test_ray_batch.items()}
test_ray_batch['near'], test_ray_batch['far'] = near_far_from_sphere(test_ray_batch['rays_o'], test_ray_batch['rays_d'])
render_outputs = self.renderer.render(test_ray_batch, False, self.global_step)
process = lambda x: (x.cpu().numpy() * 255).astype(np.uint8)
h, w = self.image_size, self.image_size
rgb = torch.clamp(render_outputs['rgb'].reshape(h, w, 3), max=1.0, min=0.0)
mask = torch.clamp(render_outputs['mask'].reshape(h, w, 1), max=1.0, min=0.0)
mask_ = torch.repeat_interleave(mask, 3, dim=-1)
output_image = concat_images_list(process(rgb), process(mask_))
if 'normal' in render_outputs:
normal = torch.clamp((render_outputs['normal'].reshape(h, w, 3) + 1) / 2, max=1.0, min=0.0)
normal = normal * mask # we only show foregound normal
output_image = concat_images_list(output_image, process(normal))
# save images
imsave(f'{self.log_dir}/images/{self.global_step}.jpg', output_image)
def configure_optimizers(self):
lr = self.learning_rate
opt = torch.optim.AdamW([{"params": self.renderer.parameters(), "lr": lr},], lr=lr)
def schedule_fn(step):
total_step = self.total_steps
warm_up_step = self.warm_up_steps
warm_up_init = 0.02
warm_up_end = 1.0
final_lr = 0.02
interval = 1000
times = total_step // interval
ratio = np.power(final_lr, 1/times)
if step<warm_up_step:
learning_rate = (step / warm_up_step) * (warm_up_end - warm_up_init) + warm_up_init
else:
learning_rate = ratio ** (step // interval) * warm_up_end
return learning_rate
if self.use_warm_up:
scheduler = [{
'scheduler': LambdaLR(opt, lr_lambda=schedule_fn),
'interval': 'step',
'frequency': 1
}]
else:
scheduler = []
return [opt], scheduler