Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn, Tensor | |
from torch.nn import Module | |
import torch.nn.functional as F | |
from einops import rearrange, repeat, pack | |
from pytorch_custom_utils import save_load | |
from beartype import beartype | |
from beartype.typing import Union, Tuple, Callable, Optional, Any | |
from einops import rearrange, repeat, pack | |
from x_transformers import Decoder | |
from x_transformers.x_transformers import LayerIntermediates | |
from x_transformers.autoregressive_wrapper import ( | |
eval_decorator, | |
top_k, | |
) | |
from .miche_conditioner import PointConditioner | |
from functools import partial | |
from tqdm import tqdm | |
from .data_utils import discretize | |
# helper functions | |
def exists(v): | |
return v is not None | |
def default(v, d): | |
return v if exists(v) else d | |
def first(it): | |
return it[0] | |
def divisible_by(num, den): | |
return (num % den) == 0 | |
def pad_at_dim(t, padding, dim = -1, value = 0): | |
ndim = t.ndim | |
right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1) | |
zeros = (0, 0) * right_dims | |
return F.pad(t, (*zeros, *padding), value = value) | |
# main class of auto-regressive Transformer | |
class MeshTransformer(Module): | |
def __init__( | |
self, | |
*, | |
dim: Union[int, Tuple[int, int]] = 512, # hidden size of Transformer | |
max_seq_len = 9600, # max sequence length | |
flash_attn = True, # wether to use flash attention | |
attn_depth = 12, # number of layers | |
attn_dim_head = 64, # dim for each head | |
attn_heads = 16, # number of heads | |
attn_kwargs: dict = dict( | |
ff_glu = True, | |
num_mem_kv = 4, | |
attn_qk_norm = True, | |
), | |
dropout = 0., | |
pad_id = -1, | |
coor_continuous_range = (-1., 1.), | |
num_discrete_coors = 128, | |
block_size = 8, | |
offset_size = 16, | |
mode = 'vertices', | |
special_token = -2, | |
use_special_block = False, | |
conditioned_on_pc = False, | |
encoder_name = 'miche-256-feature', | |
encoder_freeze = True, | |
): | |
super().__init__() | |
if use_special_block: | |
# block_ids, offset_ids, special_block_ids | |
vocab_size = block_size**3 + offset_size**3 + block_size**3 | |
self.sp_block_embed = nn.Parameter(torch.randn(1, dim)) | |
else: | |
# block_ids, offset_ids, special_token | |
vocab_size = block_size**3 + offset_size**3 + 1 | |
self.special_token = special_token | |
self.special_token_cb = block_size**3 + offset_size**3 | |
self.use_special_block = use_special_block | |
self.sos_token = nn.Parameter(torch.randn(dim)) | |
self.eos_token_id = vocab_size | |
self.mode = mode | |
self.token_embed = nn.Embedding(vocab_size + 1, dim) | |
self.num_discrete_coors = num_discrete_coors | |
self.coor_continuous_range = coor_continuous_range | |
self.block_size = block_size | |
self.offset_size = offset_size | |
self.abs_pos_emb = nn.Embedding(max_seq_len, dim) | |
self.max_seq_len = max_seq_len | |
self.conditioner = None | |
self.conditioned_on_pc = conditioned_on_pc | |
cross_attn_dim_context = None | |
self.block_embed = nn.Parameter(torch.randn(1, dim)) | |
self.offset_embed = nn.Parameter(torch.randn(1, dim)) | |
assert self.block_size * self.offset_size == self.num_discrete_coors | |
# load point_cloud encoder | |
if conditioned_on_pc: | |
print(f'Point cloud encoder: {encoder_name} | freeze: {encoder_freeze}') | |
self.conditioner = PointConditioner(model_name=encoder_name, freeze=encoder_freeze) | |
cross_attn_dim_context = self.conditioner.dim_latent | |
else: | |
raise NotImplementedError | |
# main autoregressive attention network | |
self.decoder = Decoder( | |
dim = dim, | |
depth = attn_depth, | |
dim_head = attn_dim_head, | |
heads = attn_heads, | |
attn_flash = flash_attn, | |
attn_dropout = dropout, | |
ff_dropout = dropout, | |
cross_attend = conditioned_on_pc, | |
cross_attn_dim_context = cross_attn_dim_context, | |
cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out text condition | |
**attn_kwargs | |
) | |
self.to_logits = nn.Linear(dim, vocab_size + 1) | |
self.pad_id = pad_id | |
self.discretize_face_coords = partial( | |
discretize, | |
num_discrete = num_discrete_coors, | |
continuous_range = coor_continuous_range | |
) | |
def device(self): | |
return next(self.parameters()).device | |
def generate( | |
self, | |
prompt: Optional[Tensor] = None, | |
pc: Optional[Tensor] = None, | |
cond_embeds: Optional[Tensor] = None, | |
batch_size: Optional[int] = None, | |
filter_logits_fn: Callable = top_k, | |
filter_kwargs: dict = dict(), | |
temperature = 1., | |
return_codes = False, | |
cache_kv = True, | |
max_seq_len = None, | |
face_coords_to_file: Optional[Callable[[Tensor], Any]] = None, | |
tqdm_position = 0, | |
): | |
max_seq_len = default(max_seq_len, self.max_seq_len) | |
if exists(prompt): | |
assert not exists(batch_size) | |
prompt = rearrange(prompt, 'b ... -> b (...)') | |
assert prompt.shape[-1] <= self.max_seq_len | |
batch_size = prompt.shape[0] | |
# encode point cloud | |
if cond_embeds is None: | |
if self.conditioned_on_pc: | |
cond_embeds = self.conditioner(pc = pc) | |
batch_size = default(batch_size, 1) | |
codes = default(prompt, torch.empty((batch_size, 0), dtype = torch.long, device = self.device)) | |
curr_length = codes.shape[-1] | |
cache = None | |
# predict tokens auto-regressively | |
for i in tqdm(range(curr_length, max_seq_len), position=tqdm_position, | |
desc=f'Process: {tqdm_position}', dynamic_ncols=True, leave=False): | |
output = self.forward_on_codes( | |
codes, | |
return_loss = False, | |
return_cache = cache_kv, | |
append_eos = False, | |
cond_embeds = cond_embeds, | |
cache = cache | |
) | |
if cache_kv: | |
logits, cache = output | |
else: | |
logits = output | |
# sample code from logits | |
logits = logits[:, -1] | |
filtered_logits = filter_logits_fn(logits, **filter_kwargs) | |
probs = F.softmax(filtered_logits / temperature, dim = -1) | |
sample = torch.multinomial(probs, 1) | |
codes, _ = pack([codes, sample], 'b *') | |
# check for all rows to have [eos] to terminate | |
is_eos_codes = (codes == self.eos_token_id) | |
if is_eos_codes.any(dim = -1).all(): | |
break | |
# mask out to padding anything after the first eos | |
mask = is_eos_codes.float().cumsum(dim = -1) >= 1 | |
codes = codes.masked_fill(mask, self.pad_id) | |
# early return of raw residual quantizer codes | |
if return_codes: | |
# codes = rearrange(codes, 'b (n q) -> b n q', q = 2) | |
if not self.use_special_block: | |
codes[codes == self.special_token_cb] = self.special_token | |
return codes | |
face_coords, face_mask = self.decode_codes(codes) | |
if not exists(face_coords_to_file): | |
return face_coords, face_mask | |
files = [face_coords_to_file(coords[mask]) for coords, mask in zip(face_coords, face_mask)] | |
return files | |
def forward( | |
self, | |
*, | |
codes: Optional[Tensor] = None, | |
cache: Optional[LayerIntermediates] = None, | |
**kwargs | |
): | |
# convert special tokens | |
if not self.use_special_block: | |
codes[codes == self.special_token] = self.special_token_cb | |
return self.forward_on_codes(codes, cache = cache, **kwargs) | |
def forward_on_codes( | |
self, | |
codes = None, | |
return_loss = True, | |
return_cache = False, | |
append_eos = True, | |
cache = None, | |
pc = None, | |
cond_embeds = None, | |
): | |
# handle conditions | |
attn_context_kwargs = dict() | |
if self.conditioned_on_pc: | |
assert exists(pc) ^ exists(cond_embeds), 'point cloud should be given' | |
# preprocess faces and vertices | |
if not exists(cond_embeds): | |
cond_embeds = self.conditioner( | |
pc = pc, | |
pc_embeds = cond_embeds, | |
) | |
attn_context_kwargs = dict( | |
context = cond_embeds, | |
context_mask = None, | |
) | |
# take care of codes that may be flattened | |
if codes.ndim > 2: | |
codes = rearrange(codes, 'b ... -> b (...)') | |
# prepare mask for position embedding of block and offset tokens | |
block_mask = (0 <= codes) & (codes < self.block_size**3) | |
offset_mask = (self.block_size**3 <= codes) & (codes < self.block_size**3 + self.offset_size**3) | |
if self.use_special_block: | |
sp_block_mask = ( | |
self.block_size**3 + self.offset_size**3 <= codes | |
) & ( | |
codes < self.block_size**3 + self.offset_size**3 + self.block_size**3 | |
) | |
# get some variable | |
batch, seq_len, device = *codes.shape, codes.device | |
assert seq_len <= self.max_seq_len, \ | |
f'received codes of length {seq_len} but needs to be less than {self.max_seq_len}' | |
# auto append eos token | |
if append_eos: | |
assert exists(codes) | |
code_lens = ((codes == self.pad_id).cumsum(dim = -1) == 0).sum(dim = -1) | |
codes = F.pad(codes, (0, 1), value = 0) # value=-1 | |
batch_arange = torch.arange(batch, device = device) | |
batch_arange = rearrange(batch_arange, '... -> ... 1') | |
code_lens = rearrange(code_lens, '... -> ... 1') | |
codes[batch_arange, code_lens] = self.eos_token_id | |
# if returning loss, save the labels for cross entropy | |
if return_loss: | |
assert seq_len > 0 | |
codes, labels = codes[:, :-1], codes | |
# token embed | |
codes = codes.masked_fill(codes == self.pad_id, 0) | |
codes = self.token_embed(codes) | |
# codebook embed + absolute positions | |
seq_arange = torch.arange(codes.shape[-2], device = device) | |
codes = codes + self.abs_pos_emb(seq_arange) | |
# add positional embedding for block and offset token | |
block_embed = repeat(self.block_embed, '1 d -> b n d', n = seq_len, b = batch) | |
offset_embed = repeat(self.offset_embed, '1 d -> b n d', n = seq_len, b = batch) | |
codes[block_mask] += block_embed[block_mask] | |
codes[offset_mask] += offset_embed[offset_mask] | |
if self.use_special_block: | |
sp_block_embed = repeat(self.sp_block_embed, '1 d -> b n d', n = seq_len, b = batch) | |
codes[sp_block_mask] += sp_block_embed[sp_block_mask] | |
# auto prepend sos token | |
sos = repeat(self.sos_token, 'd -> b d', b = batch) | |
codes, _ = pack([sos, codes], 'b * d') | |
# attention | |
attended, intermediates_with_cache = self.decoder( | |
codes, | |
cache = cache, | |
return_hiddens = True, | |
**attn_context_kwargs | |
) | |
# logits | |
logits = self.to_logits(attended) | |
if not return_loss: | |
if not return_cache: | |
return logits | |
return logits, intermediates_with_cache | |
# loss | |
ce_loss = F.cross_entropy( | |
rearrange(logits, 'b n c -> b c n'), | |
labels, | |
ignore_index = self.pad_id | |
) | |
return ce_loss | |