import yaml import torch import os import argparse import trimesh import numpy as np from model.serializaiton import BPT_deserialize from model.model import MeshTransformer from utils import joint_filter, Dataset from model.data_utils import to_mesh # prepare arguments parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='config/BPT-pc-open-8k-8-16.yaml') parser.add_argument('--model_path', type=str) parser.add_argument('--input_dir', default=None, type=str) parser.add_argument('--input_path', default=None, type=str) parser.add_argument('--out_dir', default="output", type=str) parser.add_argument('--input_type', choices=['mesh','pc_normal'], default='mesh') parser.add_argument('--output_path', type=str, default='output') parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--temperature', type=float, default=0.5) # key sampling parameter parser.add_argument('--condition', type=str, default='pc') args = parser.parse_args() if __name__ == '__main__': with open(args.config, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) # prepare model with fp16 precision model = MeshTransformer( dim = config['dim'], attn_depth = config['depth'], max_seq_len = config['max_seq_len'], dropout = config['dropout'], mode = config['mode'], num_discrete_coors= 2**int(config['quant_bit']), block_size = config['block_size'], offset_size = config['offset_size'], conditioned_on_pc = config['conditioned_on_pc'], use_special_block = config['use_special_block'], encoder_name = config['encoder_name'], encoder_freeze = config['encoder_freeze'], ) model.load(args.model_path) model = model.eval() model = model.half() model = model.cuda() num_params = sum([param.nelement() for param in model.decoder.parameters()]) print('Number of parameters: %.2f M' % (num_params / 1e6)) print(f'Block Size: {model.block_size} | Offset Size: {model.offset_size}') # prepare data if args.input_dir is not None: input_list = sorted(os.listdir(args.input_dir)) if args.input_type == 'pc_normal': # npy file with shape (n, 6): # point_cloud (n, 3) + normal (n, 3) input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.npy')] else: # mesh file (e.g., obj, ply, glb) input_list = [os.path.join(args.input_dir, x) for x in input_list] dataset = Dataset(args.input_type, input_list) elif args.input_path is not None: dataset = Dataset(args.input_type, [args.input_path]) else: raise ValueError("input_dir or input_path must be provided.") dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, drop_last = False, shuffle = False, ) os.makedirs(args.output_path, exist_ok=True) with torch.no_grad(): for it, data in enumerate(dataloader): if args.condition == 'pc': # generate codes with model codes = model.generate( batch_size = args.batch_size, temperature = args.temperature, pc = data['pc_normal'].cuda().half(), filter_logits_fn = joint_filter, filter_kwargs = dict(k=50, p=0.95), return_codes=True, ) coords = [] try: # decoding codes to coordinates for i in range(len(codes)): code = codes[i] code = code[code != model.pad_id].cpu().numpy() vertices = BPT_deserialize( code, block_size = model.block_size, offset_size = model.offset_size, use_special_block = model.use_special_block, ) coords.append(vertices) except: coords.append(np.zeros(3, 3)) # convert coordinates to mesh for i in range(args.batch_size): uid = data['uid'][i] vertices = coords[i] faces = torch.arange(1, len(vertices) + 1).view(-1, 3) mesh = to_mesh(vertices, faces, transpose=False, post_process=True) num_faces = len(mesh.faces) # set the color for mesh face_color = np.array([120, 154, 192, 255], dtype=np.uint8) face_colors = np.tile(face_color, (num_faces, 1)) mesh.visual.face_colors = face_colors mesh.export(f'{args.output_path}/{uid}_mesh.obj') # save pc if args.condition == 'pc': pcd = data['pc_normal'][i].cpu().numpy() point_cloud = trimesh.points.PointCloud(pcd[..., 0:3]) point_cloud.export(f'{args.output_path}/{uid}_pc.ply', "ply")