File size: 5,095 Bytes
ada4b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import yaml
import torch
import os
import argparse
import trimesh
import numpy as np
from model.serializaiton import BPT_deserialize
from model.model import MeshTransformer
from utils import joint_filter, Dataset
from model.data_utils import to_mesh

# prepare arguments
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/BPT-pc-open-8k-8-16.yaml')
parser.add_argument('--model_path', type=str)
parser.add_argument('--input_dir', default=None, type=str)
parser.add_argument('--input_path', default=None, type=str)
parser.add_argument('--out_dir', default="output", type=str)
parser.add_argument('--input_type', choices=['mesh','pc_normal'], default='mesh')
parser.add_argument('--output_path', type=str, default='output')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--temperature', type=float, default=0.5)  # key sampling parameter
parser.add_argument('--condition', type=str, default='pc')
args = parser.parse_args()


if __name__ == '__main__':
    with open(args.config, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    # prepare model with fp16 precision
    model = MeshTransformer(
        dim = config['dim'],
        attn_depth = config['depth'],
        max_seq_len = config['max_seq_len'],
        dropout = config['dropout'],
        mode = config['mode'],
        num_discrete_coors= 2**int(config['quant_bit']),
        block_size = config['block_size'],
        offset_size = config['offset_size'],
        conditioned_on_pc = config['conditioned_on_pc'],
        use_special_block = config['use_special_block'],
        encoder_name = config['encoder_name'],
        encoder_freeze = config['encoder_freeze'],
    )
    model.load(args.model_path)
    model = model.eval()
    model = model.half()
    model = model.cuda()
    num_params = sum([param.nelement() for param in model.decoder.parameters()])
    print('Number of parameters: %.2f M' % (num_params / 1e6))
    print(f'Block Size: {model.block_size} | Offset Size: {model.offset_size}')

    # prepare data
    if args.input_dir is not None:
        input_list = sorted(os.listdir(args.input_dir))
        if args.input_type == 'pc_normal':
            # npy file with shape (n, 6):
            # point_cloud (n, 3) + normal (n, 3)
            input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.npy')]
        else:
            # mesh file (e.g., obj, ply, glb)
            input_list = [os.path.join(args.input_dir, x) for x in input_list]
        dataset = Dataset(args.input_type, input_list)

    elif args.input_path is not None:
        dataset = Dataset(args.input_type, [args.input_path])

    else:
        raise ValueError("input_dir or input_path must be provided.")

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        drop_last = False,
        shuffle = False,
    )

    os.makedirs(args.output_path, exist_ok=True)
    with torch.no_grad():
        for it, data in enumerate(dataloader):
            if args.condition == 'pc':
                # generate codes with model
                codes = model.generate(
                    batch_size = args.batch_size,
                    temperature = args.temperature,
                    pc = data['pc_normal'].cuda().half(),
                    filter_logits_fn = joint_filter,
                    filter_kwargs = dict(k=50, p=0.95),
                    return_codes=True,
                )

            coords = []
            try:
                # decoding codes to coordinates
                for i in range(len(codes)):
                    code = codes[i]
                    code = code[code != model.pad_id].cpu().numpy()
                    vertices = BPT_deserialize(
                        code, 
                        block_size = model.block_size, 
                        offset_size = model.offset_size,
                        use_special_block = model.use_special_block,
                    )
                    coords.append(vertices)
            except:
                coords.append(np.zeros(3, 3))

            # convert coordinates to mesh
            for i in range(args.batch_size):
                uid = data['uid'][i]
                vertices = coords[i]
                faces = torch.arange(1, len(vertices) + 1).view(-1, 3)
                mesh = to_mesh(vertices, faces, transpose=False, post_process=True)
                num_faces = len(mesh.faces)
                # set the color for mesh
                face_color = np.array([120, 154, 192, 255], dtype=np.uint8)
                face_colors = np.tile(face_color, (num_faces, 1))
                mesh.visual.face_colors = face_colors
                mesh.export(f'{args.output_path}/{uid}_mesh.obj')

                # save pc
                if args.condition == 'pc':
                    pcd = data['pc_normal'][i].cpu().numpy()
                    point_cloud = trimesh.points.PointCloud(pcd[..., 0:3])
                    point_cloud.export(f'{args.output_path}/{uid}_pc.ply', "ply")