import tempfile import os import spaces import numpy as np import torch import torch.nn.functional as F from evo.tools.file_interface import read_kitti_poses_file from pathlib import Path import rerun as rr from typing import Optional, Dict from visualization.logger import SimulationLogger from scipy.spatial.transform import Rotation def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict: trajectory = read_kitti_poses_file(traj_file) matrix_trajectory = torch.from_numpy( np.array(trajectory.poses_se3)).to(torch.float32) raw_trans = torch.clone(matrix_trajectory[:, :3, 3]) raw_rot = matrix_trajectory[:, :3, :3] rot6d = raw_rot[:, :, :2].permute(0, 2, 1).reshape(-1, 6) trajectory_feature = torch.hstack([rot6d, raw_trans]).permute(1, 0) padded_trajectory_feature = F.pad( trajectory_feature, (0, num_cams - trajectory_feature.shape[1]) ) padding_mask = torch.ones((num_cams)) padding_mask[trajectory_feature.shape[1]:] = 0 char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32) padding_size = num_cams - char_feature.shape[0] padded_char_feature = F.pad( char_feature, (0, 0, 0, padding_size)).permute(1, 0) return { "traj_filename": Path(traj_file).name, "char_filename": Path(char_file).name, "traj_feat": padded_trajectory_feature, "char_feat": padded_char_feature, "padding_mask": padding_mask, "raw_matrix_trajectory": matrix_trajectory } class ETLogger(SimulationLogger): def __init__(self): super().__init__() rr.init("et_visualization") rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True) self.K = np.array([ [500, 0, 320], [0, 500, 240], [0, 0, 1] ]) def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray): valid_frames = int(padding_mask.sum()) valid_trajectory = trajectory[:valid_frames] positions = valid_trajectory[:, :3, 3] rr.log( "world/trajectory/points", rr.Points3D( positions, colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0]) ), timeless=True ) if len(positions) > 1: lines = np.stack([positions[:-1], positions[1:]], axis=1) rr.log( "world/trajectory/line", rr.LineStrips3D( lines, colors=[(0.0, 0.8, 0.8, 1.0)] ), timeless=True ) for k in range(valid_frames): rr.set_time_sequence("frame_idx", k) translation = valid_trajectory[k, :3, 3] rotation_q = Rotation.from_matrix( valid_trajectory[k, :3, :3]).as_quat() rr.log( f"world/camera", rr.Transform3D( translation=translation, rotation=rr.Quaternion(xyzw=rotation_q), ), ) rr.log( f"world/camera/image", rr.Pinhole( image_from_camera=self.K, width=640, height=480, ), ) def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray): valid_frames = int(padding_mask.sum()) valid_char = char_feature[:, :valid_frames] if valid_char.shape[0] > 0: rr.log( "world/character", rr.Points3D( valid_char.reshape(-1, 3), colors=np.full( (valid_char.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0]) ), timeless=True ) @spaces.GPU def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]: try: data = load_trajectory_data(traj_file, char_file) temp_dir = tempfile.mkdtemp() rrd_path = os.path.join(temp_dir, "et_visualization.rrd") logger = ETLogger() logger.log_trajectory( data["raw_matrix_trajectory"].numpy(), data["padding_mask"].numpy() ) logger.log_character( data["char_feat"].numpy(), data["padding_mask"].numpy() ) rr.save(rrd_path) return rrd_path except Exception as e: print(f"Error visualizing E.T. data: {str(e)}") return None