import trimesh import numpy as np from .data_utils import discretize, undiscretize def patchified_mesh(mesh: trimesh.Trimesh, special_token = -2, fix_orient=True): sequence = [] unvisited = np.full(len(mesh.faces), True) degrees = mesh.vertex_degree.copy() # with fix_orient=True, the normal would be correct. # but this may increase the difficulty for learning. if fix_orient: face_orient = {} for ind, face in enumerate(mesh.faces): v0, v1, v2 = face[0], face[1], face[2] face_orient['{}-{}-{}'.format(v0, v1, v2)] = True face_orient['{}-{}-{}'.format(v1, v2, v0)] = True face_orient['{}-{}-{}'.format(v2, v0, v1)] = True face_orient['{}-{}-{}'.format(v2, v1, v0)] = False face_orient['{}-{}-{}'.format(v1, v0, v2)] = False face_orient['{}-{}-{}'.format(v0, v2, v1)] = False while sum(unvisited): unvisited_faces = mesh.faces[unvisited] # select the patch center cur_face = unvisited_faces[0] max_deg_vertex_id = np.argmax(degrees[cur_face]) max_deg_vertex = cur_face[max_deg_vertex_id] # find all connected faces selected_faces = [] for face_idx in mesh.vertex_faces[max_deg_vertex]: if face_idx != -1 and unvisited[face_idx]: face = mesh.faces[face_idx] u, v = sorted([vertex for vertex in face if vertex != max_deg_vertex]) selected_faces.append([u, v, face_idx]) face_patch = set() selected_faces = sorted(selected_faces) # select the start vertex, select it if it only appears once (the start or end), # else select the lowest index cnt = {} for u, v, _ in selected_faces: cnt[u] = cnt.get(u, 0) + 1 cnt[v] = cnt.get(v, 0) + 1 starts = [] for vertex, num in cnt.items(): if num == 1: starts.append(vertex) start_idx = min(starts) if len(starts) else selected_faces[0][0] res = [start_idx] while len(res) <= len(selected_faces): vertex = res[-1] for u_i, v_i, face_idx_i in selected_faces: if face_idx_i not in face_patch and vertex in (u_i, v_i): u_i, v_i = (u_i, v_i) if vertex == u_i else (v_i, u_i) res.append(v_i) face_patch.add(face_idx_i) break if res[-1] == vertex: break if fix_orient and len(res) >= 2 and not face_orient['{}-{}-{}'.format(max_deg_vertex, res[0], res[1])]: res = res[::-1] # reduce the degree of related vertices and mark the visited faces degrees[max_deg_vertex] = len(selected_faces) - len(res) + 1 for pos_idx, vertex in enumerate(res): if pos_idx in [0, len(res) - 1]: degrees[vertex] -= 1 else: degrees[vertex] -= 2 for face_idx in face_patch: unvisited[face_idx] = False sequence.extend( [mesh.vertices[max_deg_vertex]] + [mesh.vertices[vertex_idx] for vertex_idx in res] + [[special_token] * 3] ) assert sum(degrees) == 0, 'All degrees should be zero' return np.array(sequence) def get_block_representation( sequence, block_size=8, offset_size=16, block_compressed=True, special_token=-2, use_special_block=True ): ''' convert coordinates from Cartesian system to block indexes. ''' special_block_base = block_size**3 + offset_size**3 # prepare coordinates sp_mask = sequence != special_token sp_mask = np.all(sp_mask, axis=1) coords = sequence[sp_mask].reshape(-1, 3) coords = discretize(coords) # convert [x, y, z] to [block_id, offset_id] block_id = coords // offset_size block_id = block_id[:, 0] * block_size**2 + block_id[:, 1] * block_size + block_id[:, 2] offset_id = coords % offset_size offset_id = offset_id[:, 0] * offset_size**2 + offset_id[:, 1] * offset_size + offset_id[:, 2] offset_id += block_size**3 block_coords = np.concatenate([block_id[..., None], offset_id[..., None]], axis=-1).astype(np.int64) sequence[:, :2][sp_mask] = block_coords sequence = sequence[:, :2] # convert to codes codes = [] cur_block_id = sequence[0, 0] codes.append(cur_block_id) for i in range(len(sequence)): if sequence[i, 0] == special_token: if not use_special_block: codes.append(special_token) cur_block_id = special_token elif sequence[i, 0] == cur_block_id: if block_compressed: codes.append(sequence[i, 1]) else: codes.extend([sequence[i, 0], sequence[i, 1]]) else: if use_special_block and cur_block_id == special_token: block_id = sequence[i, 0] + special_block_base else: block_id = sequence[i, 0] codes.extend([block_id, sequence[i, 1]]) cur_block_id = block_id codes = np.array(codes).astype(np.int64) sequence = codes return sequence.flatten() def BPT_serialize(mesh: trimesh.Trimesh): # serialize mesh with BPT # 1. patchify faces into patches sequence = patchified_mesh(mesh, special_token=-2) # 2. convert coordinates to block-wise indexes codes = get_block_representation( sequence, block_size=8, offset_size=16, block_compressed=True, special_token=-2, use_special_block=True ) return codes def decode_block(sequence, compressed=True, block_size=8, offset_size=16): # decode from compressed representation if compressed: res = [] res_block = 0 for token_id in range(len(sequence)): if block_size**3 + offset_size**3 > sequence[token_id] >= block_size**3: res.append([res_block, sequence[token_id]]) elif block_size**3 > sequence[token_id] >= 0: res_block = sequence[token_id] else: print('[Warning] too large offset idx!', token_id, sequence[token_id]) sequence = np.array(res) block_id, offset_id = np.array_split(sequence, 2, axis=-1) # from hash representation to xyz coords = [] offset_id -= block_size**3 for i in [2, 1, 0]: axis = (block_id // block_size**i) * offset_size + (offset_id // offset_size**i) block_id %= block_size**i offset_id %= offset_size**i coords.append(axis) coords = np.concatenate(coords, axis=-1) # (nf 3) # back to continuous space coords = undiscretize(coords) return coords def BPT_deserialize(sequence, block_size=8, offset_size=16, compressed=True, special_token=-2, use_special_block=True): # decode codes back to coordinates special_block_base = block_size**3 + offset_size**3 start_idx = 0 vertices = [] for i in range(len(sequence)): sub_seq = [] if not use_special_block and (sequence[i] == special_token or i == len(sequence) - 1): sub_seq = sequence[start_idx:i] sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size) start_idx = i + 1 elif use_special_block and \ (special_block_base <= sequence[i] < special_block_base + block_size**3 or i == len(sequence)-1): if i != 0: sub_seq = sequence[start_idx:i] if i != len(sequence) - 1 else sequence[start_idx: i+1] if special_block_base <= sub_seq[0] < special_block_base + block_size**3: sub_seq[0] -= special_block_base sub_seq = decode_block(sub_seq, compressed=compressed, block_size=block_size, offset_size=offset_size) start_idx = i if len(sub_seq): center, sub_seq = sub_seq[0], sub_seq[1:] for j in range(len(sub_seq) - 1): vertices.extend([center.reshape(1, 3), sub_seq[j].reshape(1, 3), sub_seq[j+1].reshape(1, 3)]) # (nf, 3) return np.concatenate(vertices, axis=0) if __name__ == '__main__': # a simple demo for serialize and deserialize mesh with bpt from data_utils import load_process_mesh, to_mesh import torch mesh = load_process_mesh('/path/to/your/mesh', quantization_bits=7) mesh['faces'] = np.array(mesh['faces']) mesh = to_mesh(mesh['vertices'], mesh['faces'], transpose=True) mesh.export('gt.obj') codes = BPT_serialize(mesh) coordinates = BPT_deserialize(codes) faces = torch.arange(1, len(coordinates) + 1).view(-1, 3) mesh = to_mesh(coordinates, faces, transpose=False, post_process=False) mesh.export('reconstructed.obj')