|
import torch
|
|
from torch import nn, Tensor
|
|
from transformers import AutoModelForCausalLM, AutoConfig, AutoModel
|
|
from MeshAnything.miche.encode import load_model
|
|
from MeshAnything.models.shape_opt import ShapeOPTConfig
|
|
from einops.layers.torch import Rearrange
|
|
from einops import rearrange, repeat, reduce, pack, unpack
|
|
import torch.nn.functional as F
|
|
|
|
class NoiseResistantDecoder(nn.Module):
|
|
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.args = args
|
|
self.pad_id = -1
|
|
self.num_quantizers = 3
|
|
|
|
self.discrete_num = 128
|
|
self.codebook_size = args.codebook_size
|
|
self.codebook_dim = args.codebook_dim
|
|
|
|
config = AutoConfig.from_pretrained("bert-base-uncased")
|
|
config.num_hidden_layers = 6
|
|
self.decoder= AutoModel.from_config(config=config).to_bettertransformer().encoder
|
|
self.n_embd = self.decoder.config.hidden_size
|
|
|
|
self.pos_embedding = nn.Embedding(18000, self.n_embd)
|
|
self.layernorm = nn.LayerNorm(self.n_embd)
|
|
self.point_layernorm = nn.LayerNorm(self.n_embd)
|
|
|
|
self.cond_length = 257
|
|
self.cond_dim = 768
|
|
self.point_pe = nn.Embedding(self.cond_length, self.n_embd)
|
|
self.cond_proj = nn.Linear(self.cond_dim, self.n_embd)
|
|
self.cond_head_proj = nn.Linear(self.cond_dim, self.n_embd)
|
|
|
|
self.project_down_codebook = nn.Linear(self.codebook_dim * 3, self.n_embd)
|
|
self.to_coor_logits = nn.Sequential(
|
|
nn.Linear(self.n_embd, self.discrete_num * 9),
|
|
Rearrange('... (v c) -> ... v c', v = 9)
|
|
)
|
|
def process_point_feature(self, encode_feature):
|
|
point_feature = torch.zeros(encode_feature.shape[0], self.cond_length, self.n_embd, device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype)
|
|
point_feature[:, 0] = self.cond_head_proj(encode_feature[:, 0])
|
|
point_feature[:, 1:] = self.cond_proj(encode_feature[:, 1:])
|
|
|
|
point_feature = self.point_layernorm(point_feature + self.point_pe.weight[None, :point_feature.shape[1]])
|
|
return point_feature
|
|
|
|
def forward(self, input_ids, input_embeds, point_feature = None):
|
|
input_ids = input_ids.reshape(input_ids.shape[0], -1)
|
|
point_feature = self.process_point_feature(point_feature)
|
|
|
|
face_embeds = rearrange(input_embeds, 'b (nf nv) d -> b nf (nv d)', nv = 3)
|
|
face_embeds = self.project_down_codebook(face_embeds)
|
|
|
|
face_mask = reduce(input_ids != self.pad_id, 'b (nf nv q) -> b nf', 'all', nv = 3, q = self.num_quantizers)
|
|
face_embeds[~face_mask] = 0
|
|
|
|
face_embeds = self.layernorm(face_embeds + self.pos_embedding.weight[None, :face_embeds.shape[1]])
|
|
|
|
outputs = self.decoder(
|
|
hidden_states=torch.concatenate([point_feature, face_embeds], dim=1),
|
|
)
|
|
decoded = outputs.last_hidden_state[:, self.cond_length:]
|
|
decoded = decoded.masked_fill(~face_mask.unsqueeze(-1), 0.)
|
|
|
|
|
|
pred_face_logits = self.to_coor_logits(decoded)
|
|
pred_face_coords = rearrange(pred_face_logits.argmax(dim = -1), '... (v c) -> ... v c', v = 3)
|
|
|
|
continuous_coors = undiscretize(
|
|
pred_face_coords,
|
|
num_discrete = self.discrete_num,
|
|
low = -0.5,
|
|
high = 0.5
|
|
)
|
|
continuous_coors = continuous_coors.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1 1'), float('nan'))
|
|
|
|
return continuous_coors
|
|
|
|
class MeshAnything(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.args = args
|
|
self.point_encoder = load_model(ckpt_path=None)
|
|
self.tokenizer = NoiseResistantDecoder(args)
|
|
|
|
self.num_quantizers = 3
|
|
self.face_per_token = self.num_quantizers * 3
|
|
self.cond_length = 257
|
|
self.cond_dim = 768
|
|
self.max_length = args.n_max_triangles * self.face_per_token + 2 + self.cond_length
|
|
|
|
self.config = ShapeOPTConfig.from_pretrained(
|
|
args.llm,
|
|
n_positions=18259,
|
|
max_position_embeddings=18259,
|
|
vocab_size=self.tokenizer.codebook_size + 3,
|
|
_attn_implementation="flash_attention_2"
|
|
)
|
|
self.bos_token_id = 0
|
|
self.eos_token_id = 1
|
|
self.pad_token_id = 2
|
|
self.config.bos_token_id = self.bos_token_id
|
|
self.config.eos_token_id = self.eos_token_id
|
|
self.config.pad_token_id = self.pad_token_id
|
|
self.config.quantize_codebook_dim = self.tokenizer.codebook_dim
|
|
self.config.face_per_token = self.face_per_token
|
|
self.config._attn_implementation="flash_attention_2"
|
|
self.config.cond_length = self.cond_length
|
|
if self.config.word_embed_proj_dim != self.config.hidden_size:
|
|
self.config.word_embed_proj_dim = self.config.hidden_size
|
|
self.transformer = AutoModelForCausalLM.from_config(
|
|
config=self.config, use_flash_attention_2 = True
|
|
)
|
|
self.transformer.to_bettertransformer()
|
|
self.transformer.model.decoder.quantize_codebooks = nn.Parameter(torch.zeros(1, self.tokenizer.codebook_size, self.tokenizer.codebook_dim))
|
|
|
|
self.cond_head_proj = nn.Linear(self.cond_dim, self.config.word_embed_proj_dim)
|
|
self.cond_proj = nn.Linear(self.cond_dim * 2, self.config.word_embed_proj_dim)
|
|
|
|
self.eval()
|
|
|
|
def process_point_feature(self, point_feature):
|
|
encode_feature = torch.zeros(point_feature.shape[0], self.cond_length, self.config.word_embed_proj_dim,
|
|
device=self.cond_head_proj.weight.device, dtype=self.cond_head_proj.weight.dtype)
|
|
encode_feature[:, 0] = self.cond_head_proj(point_feature[:, 0])
|
|
shape_latents = self.point_encoder.to_shape_latents(point_feature[:, 1:])
|
|
encode_feature[:, 1:] = self.cond_proj(torch.cat([point_feature[:, 1:], shape_latents], dim=-1))
|
|
|
|
return encode_feature
|
|
|
|
@torch.no_grad()
|
|
def forward(self, pc_normal, sampling=False) -> dict:
|
|
batch_size = pc_normal.shape[0]
|
|
point_feature = self.point_encoder.encode_latents(pc_normal)
|
|
processed_point_feature = self.process_point_feature(point_feature)
|
|
|
|
generate_length = self.max_length - self.cond_length
|
|
net_device = next(self.parameters()).device
|
|
outputs = torch.ones(batch_size, generate_length).long().to(net_device) * self.eos_token_id
|
|
if not sampling:
|
|
results = self.transformer.generate(
|
|
inputs_embeds=processed_point_feature,
|
|
max_new_tokens=generate_length,
|
|
num_beams=1,
|
|
bos_token_id=self.bos_token_id,
|
|
eos_token_id=self.eos_token_id,
|
|
pad_token_id=self.pad_token_id,
|
|
)
|
|
else:
|
|
results = self.transformer.generate(
|
|
inputs_embeds = processed_point_feature,
|
|
max_new_tokens=generate_length,
|
|
do_sample=True,
|
|
top_k=50,
|
|
top_p=0.95,
|
|
bos_token_id = self.bos_token_id,
|
|
eos_token_id = self.eos_token_id,
|
|
pad_token_id = self.pad_token_id,
|
|
)
|
|
assert results.shape[1] <= generate_length
|
|
outputs[:, :results.shape[1]] = results
|
|
|
|
outputs = outputs[:, 1: -1]
|
|
|
|
outputs[outputs == self.bos_token_id] = self.tokenizer.pad_id
|
|
outputs[outputs == self.eos_token_id] = self.tokenizer.pad_id
|
|
outputs[outputs == self.pad_token_id] = self.tokenizer.pad_id
|
|
|
|
outputs[outputs != self.tokenizer.pad_id] -= 3
|
|
code_embed = self.get_codes(outputs)
|
|
decoder_output = self.tokenizer(outputs, code_embed, point_feature=point_feature)
|
|
|
|
return decoder_output
|
|
|
|
def get_codes(self, indices):
|
|
indices = indices.reshape(indices.shape[0], -1)
|
|
|
|
indices = rearrange(indices, 'b (n q) -> b n q', q=self.num_quantizers)
|
|
|
|
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
|
|
|
|
|
indices, ps = pack([indices], 'b * q')
|
|
|
|
|
|
|
|
|
|
if quantize_dim < self.num_quantizers:
|
|
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
|
|
|
|
|
|
|
|
mask = indices == -1.
|
|
indices = indices.masked_fill(mask, 0)
|
|
|
|
|
|
all_codes = self.transformer.model.decoder.quantize_codebooks[0][indices]
|
|
all_codes = all_codes.permute(2, 0, 1, 3)
|
|
|
|
|
|
|
|
all_codes = all_codes.masked_fill(rearrange(mask, 'b n q -> q b n 1'), 0.)
|
|
|
|
|
|
|
|
codes, = unpack(all_codes, ps, 'q b * d')
|
|
|
|
codes_summed = reduce(codes, 'q ... -> ...', 'sum')
|
|
return codes_summed
|
|
|
|
def undiscretize(
|
|
t,
|
|
low,
|
|
high,
|
|
num_discrete
|
|
) -> Tensor:
|
|
t = t.float()
|
|
|
|
t /= num_discrete
|
|
return t * (high - low) + low
|
|
|