bpt / app.py
whaohan's picture
init commit
ada4b81 verified
raw
history blame
7.89 kB
from model.data_utils import to_mesh
from model.serializaiton import BPT_deserialize
import spaces
import os
import torch
import trimesh
from accelerate.utils import set_seed
import numpy as np
import gradio as gr
import time
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.animation import FuncAnimation
import yaml
from huggingface_hub import snapshot_download
from model.model import MeshTransformer
from utils import apply_normalize, joint_filter, sample_pc
CONFIG_PATH = 'config/BPT-open-8k-8-16.yaml'
with open(CONFIG_PATH, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
def download_models():
os.makedirs("weights", exist_ok=True)
try:
snapshot_download(
repo_id="whaohan/bpt",
local_dir="./weights",
resume_download=True
)
print("Successfully downloaded Hunyuan3D-1 model")
except Exception as e:
print(f"Error downloading Hunyuan3D-1: {e}")
model_path = 'weights/bpt-8-16-500m.pt'
return model_path
MODEL_PATH = download_models()
# prepare model with fp16 precision
model = MeshTransformer(
dim = config['dim'],
attn_depth = config['depth'],
max_seq_len = config['max_seq_len'],
dropout = config['dropout'],
mode = config['mode'],
num_discrete_coors= 2**int(config['quant_bit']),
block_size = config['block_size'],
offset_size = config['offset_size'],
conditioned_on_pc = config['conditioned_on_pc'],
use_special_block = config['use_special_block'],
encoder_name = config['encoder_name'],
encoder_freeze = config['encoder_freeze'],
)
model.load(MODEL_PATH)
model = model.eval()
model = model.half()
model = model.cuda()
device = torch.device('cuda')
print('Model loaded')
def create_animation(mesh):
mesh.vertices = mesh.vertices[:, [2, 0, 1]]
bounding_box = mesh.bounds
center = mesh.centroid
scale = np.ptp(bounding_box, axis=0).max()
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.set_axis_off()
# Extract vertices and faces for plotting
vertices = mesh.vertices
faces = mesh.faces
# Plot faces
ax.add_collection3d(Poly3DCollection(
vertices[faces] * 1.4,
facecolors=[120/255, 154/255, 192/255, 255/255],
edgecolors='k',
linewidths=0.5,
))
# Set limits and center the view on the object
ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2)
ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2)
ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2)
# Function to update the view angle
def update_view(num, ax):
ax.view_init(elev=20, azim=num)
return ax,
# Create the animation
ani = FuncAnimation(fig, update_view, frames=np.arange(0, 360, 10), interval=100, fargs=(ax,), blit=False)
# Save the animation as a GIF
output_path = f'model_{int(time.time())}.gif'
ani.save(output_path, writer='pillow', fps=10)
# Close the figure
plt.close(fig)
return output_path
@spaces.GPU(duration=480)
def do_inference(input_3d, sample_seed=0, temperature=0.5, top_k_value=50, top_p_value=0.9):
print('Start Inference')
set_seed(sample_seed)
print("Seed value:", sample_seed)
mesh = trimesh.load(input_3d, force='mesh')
mesh = apply_normalize(mesh)
pc_normal = sample_pc(mesh, pc_num=4096, with_normal=True)
vertices = mesh.vertices
pc_coor = pc_normal[:, :3]
normals = pc_normal[:, 3:]
assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
print("Data loaded")
with torch.no_grad():
code = model.generate(
batch_size = 1,
temperature = temperature,
pc = input,
filter_logits_fn = joint_filter,
filter_kwargs = dict(k=top_k_value, p=top_p_value),
return_codes=True,
)[0]
print("Model inference done")
# convert to mesh
code = code[code != model.pad_id].cpu().numpy()
vertices = BPT_deserialize(
code,
block_size = model.block_size,
offset_size = model.offset_size,
use_special_block = model.use_special_block,
)
faces = torch.arange(1, len(vertices) + 1).view(-1, 3)
artist_mesh = to_mesh(vertices, faces, transpose=False, post_process=True)
# add color for visualization
num_faces = len(artist_mesh.faces)
face_color = np.array([120, 154, 192, 255], dtype=np.uint8)
face_colors = np.tile(face_color, (num_faces, 1))
artist_mesh.visual.face_colors = face_colors
# add time stamp to avoid cache
save_name = f"output_{int(time.time())}.obj"
artist_mesh.export(save_name)
output_render = create_animation(artist_mesh)
return save_name, output_render
_HEADER_ = '''
<h2><b>Official πŸ€— Gradio Demo for Paper</b> <a href='https://github.com/whaohan/bpt' target='_blank'><b>Scaling Mesh Generation with Compressive Tokenization</b></a></h2>
'''
_CITE_ = r"""
If you found our model is helpful, please help to ⭐ the <a href='https://github.com/whaohan/bpt' target='_blank'>Github Repo</a>. Code: <a href='https://github.com/whaohan/bpt' target='_blank'>GitHub</a>. Arxiv Paper: <a href='https://arxiv.org/abs/2411.07025' target='_blank'>ArXiv</a>.
πŸ“§ **Contact**
If you have any questions, feel free to contact <a href='https://whaohan.github.io' target='_blank'>Haohan Weng</a>.
"""
output_model_obj = gr.Model3D(
label="Generated Mesh (OBJ Format)",
display_mode="wireframe",
scale = 2,
)
output_image_render = gr.Image(
label="Wireframe Render of Generated Mesh",
scale = 1,
)
with gr.Blocks() as demo:
gr.Markdown(_HEADER_)
with gr.Row(variant="panel"):
with gr.Column(scale=1):
with gr.Row():
input_3d = gr.Model3D(
label="Input Mesh",
)
# with gr.Row():
# # with gr.Group():
with gr.Row():
sample_seed = gr.Number(value=0, label="Seed Value", precision=0)
temperature = gr.Number(value=0.5, label="Temperature For Sampling", precision=None)
with gr.Row():
top_k_value = gr.Number(value=50, label="TopK For Sampling", precision=0)
top_p_value = gr.Number(value=0.9, label="TopP For Sampling", precision=None)
with gr.Row():
submit = gr.Button("Generate", elem_id="generate", variant="primary")
with gr.Row(variant="panel"):
mesh_examples = gr.Examples(
examples=[
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
],
inputs=input_3d,
outputs=[output_model_obj, output_image_render],
fn=do_inference,
cache_examples = False,
examples_per_page=10
)
with gr.Row():
gr.Markdown('''Try different <b>Seed Value</b> or <b>Temperature</b> if the result is unsatisfying''')
with gr.Column(scale=2):
with gr.Row(equal_height=True):
output_model_obj.render()
output_image_render.render()
gr.Markdown(_CITE_)
mv_images = gr.State()
submit.click(
fn=do_inference,
inputs=[input_3d, sample_seed, temperature, top_k_value, top_p_value],
outputs = [output_model_obj, output_image_render],
)
demo.launch(share=True)