ECON / lib /net /voxelize.py
Yuliang's picture
Support TEXTure
487ee6d
raw
history blame
7.56 kB
from __future__ import division, print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import voxelize_cuda
from torch.autograd import Function
class VoxelizationFunction(Function):
"""
Definition of differentiable voxelization function
Currently implemented only for cuda Tensors
"""
@staticmethod
def forward(
ctx,
smpl_vertices,
smpl_face_center,
smpl_face_normal,
smpl_vertex_code,
smpl_face_code,
smpl_tetrahedrons,
volume_res,
sigma,
smooth_kernel_size,
):
"""
forward pass
Output format: (batch_size, z_dims, y_dims, x_dims, channel_num)
"""
assert smpl_vertices.size()[1] == smpl_vertex_code.size()[1]
assert smpl_face_center.size()[1] == smpl_face_normal.size()[1]
assert smpl_face_center.size()[1] == smpl_face_code.size()[1]
ctx.batch_size = smpl_vertices.size()[0]
ctx.volume_res = volume_res
ctx.sigma = sigma
ctx.smooth_kernel_size = smooth_kernel_size
ctx.smpl_vertex_num = smpl_vertices.size()[1]
ctx.device = smpl_vertices.device
smpl_vertices = smpl_vertices.contiguous()
smpl_face_center = smpl_face_center.contiguous()
smpl_face_normal = smpl_face_normal.contiguous()
smpl_vertex_code = smpl_vertex_code.contiguous()
smpl_face_code = smpl_face_code.contiguous()
smpl_tetrahedrons = smpl_tetrahedrons.contiguous()
occ_volume = torch.cuda.FloatTensor(
ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res
).fill_(0.0)
semantic_volume = torch.cuda.FloatTensor(
ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res, 3
).fill_(0.0)
weight_sum_volume = torch.cuda.FloatTensor(
ctx.batch_size, ctx.volume_res, ctx.volume_res, ctx.volume_res
).fill_(1e-3)
# occ_volume [B, volume_res, volume_res, volume_res]
# semantic_volume [B, volume_res, volume_res, volume_res, 3]
# weight_sum_volume [B, volume_res, volume_res, volume_res]
(
occ_volume,
semantic_volume,
weight_sum_volume,
) = voxelize_cuda.forward_semantic_voxelization(
smpl_vertices,
smpl_vertex_code,
smpl_tetrahedrons,
occ_volume,
semantic_volume,
weight_sum_volume,
sigma,
)
return semantic_volume
class Voxelization(nn.Module):
"""
Wrapper around the autograd function VoxelizationFunction
"""
def __init__(
self,
smpl_vertex_code,
smpl_face_code,
smpl_face_indices,
smpl_tetraderon_indices,
volume_res,
sigma,
smooth_kernel_size,
batch_size,
):
super(Voxelization, self).__init__()
assert len(smpl_face_indices.shape) == 2
assert len(smpl_tetraderon_indices.shape) == 2
assert smpl_face_indices.shape[1] == 3
assert smpl_tetraderon_indices.shape[1] == 4
self.volume_res = volume_res
self.sigma = sigma
self.smooth_kernel_size = smooth_kernel_size
self.batch_size = batch_size
self.device = None
self.smpl_vertex_code = smpl_vertex_code
self.smpl_face_code = smpl_face_code
self.smpl_face_indices = smpl_face_indices
self.smpl_tetraderon_indices = smpl_tetraderon_indices
def update_param(self, voxel_faces):
self.device = voxel_faces.device
self.smpl_tetraderon_indices = voxel_faces
smpl_vertex_code_batch = torch.tile(self.smpl_vertex_code, (self.batch_size, 1, 1))
smpl_face_code_batch = torch.tile(self.smpl_face_code, (self.batch_size, 1, 1))
smpl_face_indices_batch = torch.tile(self.smpl_face_indices, (self.batch_size, 1, 1))
smpl_vertex_code_batch = (smpl_vertex_code_batch.contiguous().to(self.device))
smpl_face_code_batch = (smpl_face_code_batch.contiguous().to(self.device))
smpl_face_indices_batch = (smpl_face_indices_batch.contiguous().to(self.device))
smpl_tetraderon_indices_batch = (self.smpl_tetraderon_indices.contiguous().to(self.device))
self.register_buffer("smpl_vertex_code_batch", smpl_vertex_code_batch)
self.register_buffer("smpl_face_code_batch", smpl_face_code_batch)
self.register_buffer("smpl_face_indices_batch", smpl_face_indices_batch)
self.register_buffer("smpl_tetraderon_indices_batch", smpl_tetraderon_indices_batch)
def forward(self, smpl_vertices):
"""
Generate semantic volumes from SMPL vertices
"""
self.check_input(smpl_vertices)
smpl_faces = self.vertices_to_faces(smpl_vertices)
smpl_tetrahedrons = self.vertices_to_tetrahedrons(smpl_vertices)
smpl_face_center = self.calc_face_centers(smpl_faces)
smpl_face_normal = self.calc_face_normals(smpl_faces)
smpl_surface_vertex_num = self.smpl_vertex_code_batch.size()[1]
smpl_vertices_surface = smpl_vertices[:, :smpl_surface_vertex_num, :]
vol = VoxelizationFunction.apply(
smpl_vertices_surface,
smpl_face_center,
smpl_face_normal,
self.smpl_vertex_code_batch,
self.smpl_face_code_batch,
smpl_tetrahedrons,
self.volume_res,
self.sigma,
self.smooth_kernel_size,
)
return vol.permute((0, 4, 1, 2, 3)) # (bzyxc --> bcdhw)
def vertices_to_faces(self, vertices):
assert vertices.ndimension() == 3
bs, nv = vertices.shape[:2]
face = (
self.smpl_face_indices_batch +
(torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]
)
vertices_ = vertices.reshape((bs * nv, 3))
return vertices_[face.long()]
def vertices_to_tetrahedrons(self, vertices):
assert vertices.ndimension() == 3
bs, nv = vertices.shape[:2]
tets = (
self.smpl_tetraderon_indices_batch +
(torch.arange(bs, dtype=torch.int32).to(self.device) * nv)[:, None, None]
)
vertices_ = vertices.reshape((bs * nv, 3))
return vertices_[tets.long()]
def calc_face_centers(self, face_verts):
assert len(face_verts.shape) == 4
assert face_verts.shape[2] == 3
assert face_verts.shape[3] == 3
bs, nf = face_verts.shape[:2]
face_centers = (
face_verts[:, :, 0, :] + face_verts[:, :, 1, :] + face_verts[:, :, 2, :]
) / 3.0
face_centers = face_centers.reshape((bs, nf, 3))
return face_centers
def calc_face_normals(self, face_verts):
assert len(face_verts.shape) == 4
assert face_verts.shape[2] == 3
assert face_verts.shape[3] == 3
bs, nf = face_verts.shape[:2]
face_verts = face_verts.reshape((bs * nf, 3, 3))
v10 = face_verts[:, 0] - face_verts[:, 1]
v12 = face_verts[:, 2] - face_verts[:, 1]
normals = F.normalize(torch.cross(v10, v12), eps=1e-5)
normals = normals.reshape((bs, nf, 3))
return normals
def check_input(self, x):
if x.device == "cpu":
raise TypeError("Voxelization module supports only cuda tensors")
if x.type() != "torch.cuda.FloatTensor":
raise TypeError("Voxelization module supports only float32 tensors")