MCC_slim / app.py
chongjie's picture
Add pcd2grid
823f35c
raw
history blame
10.3 kB
import gradio as gr
import numpy as np
import cv2
from tqdm import tqdm
import torch
from pytorch3d.io.obj_io import load_obj
import tempfile
import main_mcc
import mcc_model
import util.misc as misc
from engine_mcc import prepare_data
from plyfile import PlyData, PlyElement
import trimesh
def run_inference(model, samples, device, temperature, args):
model.eval()
seen_xyz, valid_seen_xyz, unseen_xyz, unseen_rgb, labels, seen_images = prepare_data(
samples, device, is_train=False, args=args, is_viz=True
)
pred_occupy = []
pred_colors = []
max_n_unseen_fwd = 2000
model.cached_enc_feat = None
num_passes = int(np.ceil(unseen_xyz.shape[1] / max_n_unseen_fwd))
for p_idx in range(num_passes):
p_start = p_idx * max_n_unseen_fwd
p_end = (p_idx + 1) * max_n_unseen_fwd
cur_unseen_xyz = unseen_xyz[:, p_start:p_end]
cur_unseen_rgb = unseen_rgb[:, p_start:p_end].zero_()
cur_labels = labels[:, p_start:p_end].zero_()
with torch.no_grad():
_, pred = model(
seen_images=seen_images,
seen_xyz=seen_xyz,
unseen_xyz=cur_unseen_xyz,
unseen_rgb=cur_unseen_rgb,
unseen_occupy=cur_labels,
cache_enc=True,
valid_seen_xyz=valid_seen_xyz,
)
if device == "cuda":
pred_occupy.append(pred[..., 0].cuda())
else:
pred_occupy.append(pred[..., 0].cpu())
if args.regress_color:
pred_colors.append(pred[..., 1:].reshape((-1, 3)))
else:
pred_colors.append(
(
torch.nn.Softmax(dim=2)(
pred[..., 1:].reshape((-1, 3, 256)) / temperature
) * torch.linspace(0, 1, 256, device=pred.device)
).sum(axis=2)
)
pred_occupy = torch.cat(pred_occupy, dim=1)
pred_occupy = torch.nn.Sigmoid()(pred_occupy)
return torch.cat(pred_colors, dim=0).cpu().numpy(), pred_occupy.cpu().numpy(), unseen_xyz.cpu().numpy()
def pad_image(im, value):
if im.shape[0] > im.shape[1]:
diff = im.shape[0] - im.shape[1]
return torch.cat([im, (torch.zeros((im.shape[0], diff, im.shape[2])) + value)], dim=1)
else:
diff = im.shape[1] - im.shape[0]
return torch.cat([im, (torch.zeros((diff, im.shape[1], im.shape[2])) + value)], dim=0)
def backproject_depth_to_pointcloud(depth, rotation=np.eye(3), translation=np.zeros(3)):
# Calculate the principal point as the center of the image
principal_point = [depth.shape[1] / 2, depth.shape[0] / 2]
intrinsics = get_intrinsics(depth.shape[0], depth.shape[1], principal_point)
intrinsics = get_intrinsics(depth.shape[0], depth.shape[1], principal_point)
# Get the depth map shape
height, width = depth.shape
# Create a matrix of pixel coordinates
u, v = np.meshgrid(np.arange(width), np.arange(height))
uv_homogeneous = np.stack((u, v, np.ones_like(u)), axis=-1).reshape(-1, 3)
# Invert the intrinsic matrix
inv_intrinsics = np.linalg.inv(intrinsics)
# Convert depth to the camera coordinate system
points_cam_homogeneous = np.dot(uv_homogeneous, inv_intrinsics.T) * depth.flatten()[:, np.newaxis]
# Convert to 3D homogeneous coordinates
points_cam_homogeneous = np.concatenate((points_cam_homogeneous, np.ones((len(points_cam_homogeneous), 1))), axis=1)
# Apply the rotation and translation to get the 3D point cloud in the world coordinate system
extrinsics = np.hstack((rotation, translation[:, np.newaxis]))
pointcloud = np.dot(points_cam_homogeneous, extrinsics.T)
pointcloud[:, 1:] *= -1
# Reshape the point cloud back to the original depth map shape
pointcloud = pointcloud[:, :3].reshape(height, width, 3)
return pointcloud
# estimate camera intrinsics
def get_intrinsics(H,W, principal_point):
"""
Intrinsics for a pinhole camera model.
Assume fov of 55 degrees and central principal point
of bounding box.
"""
f = 0.5 * W / np.tan(0.5 * 55 * np.pi / 180.0)
cx, cy = principal_point
return np.array([[f, 0, cx],
[0, f, cy],
[0, 0, 1]])
def normalize(seen_xyz):
seen_xyz = seen_xyz / (seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].var(dim=0) ** 0.5).mean()
seen_xyz = seen_xyz - seen_xyz[torch.isfinite(seen_xyz.sum(dim=-1))].mean(axis=0)
return seen_xyz
def voxel_grid_downsample(points, colors, voxel_size):
# Compute voxel indices
voxel_indices = np.floor(points / voxel_size).astype(int)
# Remove duplicate voxel indices
unique_voxel_indices, inverse_indices = np.unique(voxel_indices, axis=0, return_inverse=True)
# Compute the centroid of the points and the average color in each voxel
centroids = np.empty_like(unique_voxel_indices, dtype=float)
avg_colors = np.empty((len(unique_voxel_indices), colors.shape[1]), dtype=colors.dtype)
for i in range(len(unique_voxel_indices)):
centroids[i] = points[inverse_indices == i].mean(axis=0)
avg_colors[i] = colors[inverse_indices == i].mean(axis=0)
# Convert colors from RGB to BGR
avg_colors = avg_colors[:, ::-1]
return centroids, avg_colors
def infer(
image,
depth_image,
seg,
granularity,
temperature,
):
args.viz_granularity = granularity
rgb = image
depth_image = cv2.imread(depth_image.name, -1)
depth_image = depth_image.astype(np.float32) / 256
seen_xyz = backproject_depth_to_pointcloud(depth_image)
seen_rgb = (torch.tensor(rgb).float() / 255)[..., [2, 1, 0]]
H, W = seen_rgb.shape[:2]
seen_rgb = torch.nn.functional.interpolate(
seen_rgb.permute(2, 0, 1)[None],
size=[H, W],
mode="bilinear",
align_corners=False,
)[0].permute(1, 2, 0)
seg = cv2.imread(seg.name, cv2.IMREAD_UNCHANGED)
mask = torch.tensor(cv2.resize(seg, (W, H))).bool()
seen_xyz[~mask] = float('inf')
seen_xyz = torch.tensor(seen_xyz).float()
seen_xyz = normalize(seen_xyz)
bottom, right = mask.nonzero().max(dim=0)[0]
top, left = mask.nonzero().min(dim=0)[0]
bottom = bottom + 40
right = right + 40
top = max(top - 40, 0)
left = max(left - 40, 0)
seen_xyz = seen_xyz[top:bottom+1, left:right+1]
seen_rgb = seen_rgb[top:bottom+1, left:right+1]
seen_xyz = pad_image(seen_xyz, float('inf'))
seen_rgb = pad_image(seen_rgb, 0)
seen_rgb = torch.nn.functional.interpolate(
seen_rgb.permute(2, 0, 1)[None],
size=[800, 800],
mode="bilinear",
align_corners=False,
)
seen_xyz = torch.nn.functional.interpolate(
seen_xyz.permute(2, 0, 1)[None],
size=[112, 112],
mode="bilinear",
align_corners=False,
).permute(0, 2, 3, 1)
samples = [
[seen_xyz, seen_rgb],
[torch.zeros((20000, 3)), torch.zeros((20000, 3))],
]
pred_colors, pred_occupy, unseen_xyz = run_inference(model, samples, device, temperature, args)
_masks = pred_occupy > 0.1
unseen_xyz = unseen_xyz[_masks]
pred_colors = pred_colors[None, ...][_masks] * 255
# Prepare data for PlyElement
vertex = np.core.records.fromarrays(np.hstack((unseen_xyz, pred_colors)).transpose(),
names='x, y, z, red, green, blue',
formats='f8, f8, f8, u1, u1, u1')
# Create PlyElement
element = PlyElement.describe(vertex, 'vertex')
# Save point cloud data to a temporary file
with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as f:
PlyData([element], text=True).write(f)
temp_file_name = f.name
# Perform voxel grid downsampling
voxel_size = 0.2 # Change this to the size of your cubes
downsampled_xyz, downsampled_colors = voxel_grid_downsample(unseen_xyz, pred_colors, voxel_size)
meshes = []
for point, color in zip(downsampled_xyz, downsampled_colors):
# Create a cube mesh at the given point
cube = trimesh.creation.box(extents=[voxel_size]*3)
cube.apply_translation(point)
# Assign the average color to the vertices
cube.visual.vertex_colors = np.hstack([color, 255]) # Set alpha to 255
meshes.append(cube)
# Save point cloud data to a temporary file
with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as f:
temp_obj_file = f.name
print(temp_obj_file)
# Combine all the cubes into a single mesh
combined = trimesh.util.concatenate(meshes)
# Save the combined mesh to a file
combined.export(temp_obj_file)
return temp_file_name, temp_obj_file
if __name__ == '__main__':
device = "cpu"
# device = "cuda" if torch.cuda.is_available() else "cpu"
parser = main_mcc.get_args_parser()
parser.set_defaults(eval=True)
args = parser.parse_args()
model = mcc_model.get_mcc_model(
occupancy_weight=1.0,
rgb_weight=0.01,
args=args,
)
if device == "cuda":
model = model.cuda()
misc.load_model(args=args, model_without_ddp=model, optimizer=None, loss_scaler=None)
demo = gr.Interface(fn=infer,
inputs=[gr.Image(label="Input Image"),
gr.File(label="Depth Image"),
gr.File(label="Segmentation File"),
gr.Slider(minimum=0.05, maximum=0.5, step=0.05, value=0.2, label="Grain Size"),
gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.1, label="Color Temperature")
],
outputs=[gr.outputs.File(label="Point Cloud"),
gr.Model3D( clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model")],
examples=[["demo/quest2.jpg", "demo/quest2_depth.png", "demo/quest2_seg.png", 0.2, 0.1]],
cache_examples=True)
demo.launch(server_name="0.0.0.0", server_port=7860)