File size: 12,205 Bytes
ada4b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
import torch
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F
from einops import rearrange, repeat, pack
from pytorch_custom_utils import save_load
from beartype import beartype
from beartype.typing import Union, Tuple, Callable, Optional, Any
from einops import rearrange, repeat, pack
from x_transformers import Decoder
from x_transformers.x_transformers import LayerIntermediates
from x_transformers.autoregressive_wrapper import (
    eval_decorator,
    top_k,
)
from .miche_conditioner import PointConditioner
from functools import partial
from tqdm import tqdm
from .data_utils import discretize

# helper functions

def exists(v):
    return v is not None

def default(v, d):
    return v if exists(v) else d

def first(it):
    return it[0]

def divisible_by(num, den):
    return (num % den) == 0

def pad_at_dim(t, padding, dim = -1, value = 0):
    ndim = t.ndim
    right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
    zeros = (0, 0) * right_dims
    return F.pad(t, (*zeros, *padding), value = value)


# main class of auto-regressive Transformer 
@save_load()
class MeshTransformer(Module):
    @beartype
    def __init__(
        self,
        *,
        dim: Union[int, Tuple[int, int]] = 512,  # hidden size of Transformer
        max_seq_len = 9600,                      # max sequence length
        flash_attn = True,                       # wether to use flash attention
        attn_depth = 12,                         # number of layers
        attn_dim_head = 64,                      # dim for each head
        attn_heads = 16,                         # number of heads
        attn_kwargs: dict = dict(
            ff_glu = True,
            num_mem_kv = 4,
            attn_qk_norm = True,
        ),
        dropout = 0.,
        pad_id = -1,
        coor_continuous_range = (-1., 1.),
        num_discrete_coors = 128,
        block_size = 8,
        offset_size = 16,
        mode = 'vertices',
        special_token = -2,
        use_special_block = False,
        conditioned_on_pc = False,
        encoder_name = 'miche-256-feature',
        encoder_freeze = True,
    ):
        super().__init__()

        if use_special_block:
            # block_ids, offset_ids, special_block_ids
            vocab_size = block_size**3 + offset_size**3 + block_size**3
            self.sp_block_embed = nn.Parameter(torch.randn(1, dim))
        else:
            # block_ids, offset_ids, special_token
            vocab_size = block_size**3 + offset_size**3 + 1
            self.special_token = special_token
            self.special_token_cb = block_size**3 + offset_size**3
            
        self.use_special_block = use_special_block
        
        self.sos_token = nn.Parameter(torch.randn(dim))
        self.eos_token_id = vocab_size
        self.mode = mode
        self.token_embed = nn.Embedding(vocab_size + 1, dim)
        self.num_discrete_coors = num_discrete_coors
        self.coor_continuous_range = coor_continuous_range
        self.block_size = block_size
        self.offset_size = offset_size
        self.abs_pos_emb = nn.Embedding(max_seq_len, dim)
        self.max_seq_len = max_seq_len
        self.conditioner = None
        self.conditioned_on_pc = conditioned_on_pc
        cross_attn_dim_context = None
        
        self.block_embed = nn.Parameter(torch.randn(1, dim))
        self.offset_embed = nn.Parameter(torch.randn(1, dim))
        
        assert self.block_size * self.offset_size == self.num_discrete_coors

        # load point_cloud encoder
        if conditioned_on_pc:
            print(f'Point cloud encoder: {encoder_name} | freeze: {encoder_freeze}')
            self.conditioner = PointConditioner(model_name=encoder_name, freeze=encoder_freeze)
            cross_attn_dim_context = self.conditioner.dim_latent
        else:
            raise NotImplementedError
        
        # main autoregressive attention network
        self.decoder = Decoder(
            dim = dim,
            depth = attn_depth,
            dim_head = attn_dim_head,
            heads = attn_heads,
            attn_flash = flash_attn,
            attn_dropout = dropout,
            ff_dropout = dropout,
            cross_attend = conditioned_on_pc,
            cross_attn_dim_context = cross_attn_dim_context,
            cross_attn_num_mem_kv = 4,  # needed for preventing nan when dropping out text condition
            **attn_kwargs
        )

        self.to_logits = nn.Linear(dim, vocab_size + 1)
        self.pad_id = pad_id
        self.discretize_face_coords = partial(
            discretize, 
            num_discrete = num_discrete_coors, 
            continuous_range = coor_continuous_range
        )

    @property
    def device(self):
        return next(self.parameters()).device


    @eval_decorator
    @torch.no_grad()
    @beartype
    def generate(
        self,
        prompt: Optional[Tensor] = None,
        pc: Optional[Tensor] = None,
        cond_embeds: Optional[Tensor] = None,
        batch_size: Optional[int] = None,
        filter_logits_fn: Callable = top_k,
        filter_kwargs: dict = dict(),
        temperature = 1.,
        return_codes = False,
        cache_kv = True,
        max_seq_len = None,
        face_coords_to_file: Optional[Callable[[Tensor], Any]] = None,
        tqdm_position = 0,
    ):
        max_seq_len = default(max_seq_len, self.max_seq_len)

        if exists(prompt):
            assert not exists(batch_size)

            prompt = rearrange(prompt, 'b ... -> b (...)')
            assert prompt.shape[-1] <= self.max_seq_len

            batch_size = prompt.shape[0]

        # encode point cloud
        if cond_embeds is None:
            if self.conditioned_on_pc:
                cond_embeds = self.conditioner(pc = pc)

        batch_size = default(batch_size, 1)

        codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device))

        curr_length = codes.shape[-1]

        cache = None

        # predict tokens auto-regressively
        for i in tqdm(range(curr_length, max_seq_len), position=tqdm_position, 
                      desc=f'Process: {tqdm_position}', dynamic_ncols=True, leave=False):

            output = self.forward_on_codes(
                codes,
                return_loss = False,
                return_cache = cache_kv,
                append_eos = False,
                cond_embeds = cond_embeds,
                cache = cache
            )

            if cache_kv:
                logits, cache = output

            else:
                logits = output

            # sample code from logits
            logits = logits[:, -1]
            filtered_logits = filter_logits_fn(logits, **filter_kwargs)
            probs = F.softmax(filtered_logits / temperature, dim = -1)
            sample = torch.multinomial(probs, 1)
            codes, _ = pack([codes, sample], 'b *')

            # check for all rows to have [eos] to terminate

            is_eos_codes = (codes == self.eos_token_id)

            if is_eos_codes.any(dim = -1).all():
                break

        # mask out to padding anything after the first eos

        mask = is_eos_codes.float().cumsum(dim = -1) >= 1
        codes = codes.masked_fill(mask, self.pad_id)
        
        # early return of raw residual quantizer codes

        if return_codes:
            # codes = rearrange(codes, 'b (n q) -> b n q', q = 2)
            if not self.use_special_block:
                codes[codes == self.special_token_cb] = self.special_token
            return codes

        face_coords, face_mask = self.decode_codes(codes)

        if not exists(face_coords_to_file):
            return face_coords, face_mask

        files = [face_coords_to_file(coords[mask]) for coords, mask in zip(face_coords, face_mask)]
        return files


    def forward(
        self,
        *,
        codes:          Optional[Tensor] = None,
        cache:          Optional[LayerIntermediates] = None,
        **kwargs
    ):
        # convert special tokens
        if not self.use_special_block:
            codes[codes == self.special_token] = self.special_token_cb
            
        return self.forward_on_codes(codes, cache = cache, **kwargs)


    def forward_on_codes(
        self,
        codes = None,
        return_loss = True,
        return_cache = False,
        append_eos = True,
        cache = None,
        pc = None,
        cond_embeds = None,
    ):
        # handle conditions

        attn_context_kwargs = dict()
        
        if self.conditioned_on_pc:
            assert exists(pc) ^ exists(cond_embeds), 'point cloud should be given'
            
            # preprocess faces and vertices
            if not exists(cond_embeds):
                cond_embeds = self.conditioner(
                    pc = pc,
                    pc_embeds = cond_embeds,
                )
            
            attn_context_kwargs = dict(
                context = cond_embeds,
                context_mask = None,
            )

        # take care of codes that may be flattened

        if codes.ndim > 2:
            codes = rearrange(codes, 'b ... -> b (...)')

        # prepare mask for position embedding of block and offset tokens
        block_mask = (0 <= codes) & (codes < self.block_size**3)
        offset_mask = (self.block_size**3 <= codes) & (codes < self.block_size**3 + self.offset_size**3)
        if self.use_special_block:
            sp_block_mask = (
                self.block_size**3 + self.offset_size**3 <= codes
            ) & (
                codes < self.block_size**3 + self.offset_size**3 + self.block_size**3
            )
        

        # get some variable

        batch, seq_len, device = *codes.shape, codes.device

        assert seq_len <= self.max_seq_len, \
            f'received codes of length {seq_len} but needs to be less than {self.max_seq_len}'

        # auto append eos token

        if append_eos:
            assert exists(codes)

            code_lens = ((codes == self.pad_id).cumsum(dim = -1) == 0).sum(dim = -1)

            codes = F.pad(codes, (0, 1), value = 0)  # value=-1

            batch_arange = torch.arange(batch, device = device)

            batch_arange = rearrange(batch_arange, '... -> ... 1')
            code_lens = rearrange(code_lens, '... -> ... 1')

            codes[batch_arange, code_lens] = self.eos_token_id


        # if returning loss, save the labels for cross entropy

        if return_loss:
            assert seq_len > 0
            codes, labels = codes[:, :-1], codes

        # token embed

        codes = codes.masked_fill(codes == self.pad_id, 0)
        codes = self.token_embed(codes)

        # codebook embed + absolute positions

        seq_arange = torch.arange(codes.shape[-2], device = device)
        codes = codes + self.abs_pos_emb(seq_arange)
        
        # add positional embedding for block and offset token
        block_embed = repeat(self.block_embed, '1 d -> b n d', n = seq_len, b = batch)
        offset_embed = repeat(self.offset_embed, '1 d -> b n d', n = seq_len, b = batch)
        codes[block_mask] += block_embed[block_mask]
        codes[offset_mask] += offset_embed[offset_mask]
        
        if self.use_special_block:
            sp_block_embed = repeat(self.sp_block_embed, '1 d -> b n d', n = seq_len, b = batch)
            codes[sp_block_mask] += sp_block_embed[sp_block_mask]

        # auto prepend sos token

        sos = repeat(self.sos_token, 'd -> b d', b = batch)
        codes, _ = pack([sos, codes], 'b * d')

        # attention

        attended, intermediates_with_cache = self.decoder(
            codes,
            cache = cache,
            return_hiddens = True,
            **attn_context_kwargs
        )

        # logits

        logits = self.to_logits(attended)

        if not return_loss:
            if not return_cache:
                return logits

            return logits, intermediates_with_cache

        # loss

        ce_loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels,
            ignore_index = self.pad_id
        )

        return ce_loss