from model.data_utils import to_mesh from model.serializaiton import BPT_deserialize import spaces import os import torch import trimesh from accelerate.utils import set_seed import numpy as np import gradio as gr import time import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Poly3DCollection from matplotlib.animation import FuncAnimation import yaml from huggingface_hub import snapshot_download from model.model import MeshTransformer from utils import apply_normalize, joint_filter, sample_pc CONFIG_PATH = 'config/BPT-open-8k-8-16.yaml' with open(CONFIG_PATH, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) def download_models(): os.makedirs("weights", exist_ok=True) try: snapshot_download( repo_id="whaohan/bpt", local_dir="./weights", resume_download=True ) print("Successfully downloaded Hunyuan3D-1 model") except Exception as e: print(f"Error downloading Hunyuan3D-1: {e}") model_path = 'weights/bpt-8-16-500m.pt' return model_path MODEL_PATH = download_models() # prepare model with fp16 precision model = MeshTransformer( dim = config['dim'], attn_depth = config['depth'], max_seq_len = config['max_seq_len'], dropout = config['dropout'], mode = config['mode'], num_discrete_coors= 2**int(config['quant_bit']), block_size = config['block_size'], offset_size = config['offset_size'], conditioned_on_pc = config['conditioned_on_pc'], use_special_block = config['use_special_block'], encoder_name = config['encoder_name'], encoder_freeze = config['encoder_freeze'], ) model.load(MODEL_PATH) model = model.eval() model = model.half() model = model.cuda() device = torch.device('cuda') print('Model loaded') def create_animation(mesh): mesh.vertices = mesh.vertices[:, [2, 0, 1]] bounding_box = mesh.bounds center = mesh.centroid scale = np.ptp(bounding_box, axis=0).max() fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111, projection='3d') ax.set_axis_off() # Extract vertices and faces for plotting vertices = mesh.vertices faces = mesh.faces # Plot faces ax.add_collection3d(Poly3DCollection( vertices[faces] * 1.4, facecolors=[120/255, 154/255, 192/255, 255/255], edgecolors='k', linewidths=0.5, )) # Set limits and center the view on the object ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2) ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2) ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2) # Function to update the view angle def update_view(num, ax): ax.view_init(elev=20, azim=num) return ax, # Create the animation ani = FuncAnimation(fig, update_view, frames=np.arange(0, 360, 10), interval=100, fargs=(ax,), blit=False) # Save the animation as a GIF output_path = f'model_{int(time.time())}.gif' ani.save(output_path, writer='pillow', fps=10) # Close the figure plt.close(fig) return output_path @spaces.GPU(duration=480) def do_inference(input_3d, sample_seed=0, temperature=0.5, top_k_value=50, top_p_value=0.9): print('Start Inference') set_seed(sample_seed) print("Seed value:", sample_seed) mesh = trimesh.load(input_3d, force='mesh') mesh = apply_normalize(mesh) pc_normal = sample_pc(mesh, pc_num=4096, with_normal=True) vertices = mesh.vertices pc_coor = pc_normal[:, :3] normals = pc_normal[:, 3:] assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong" normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] print("Data loaded") with torch.no_grad(): code = model.generate( batch_size = 1, temperature = temperature, pc = input, filter_logits_fn = joint_filter, filter_kwargs = dict(k=top_k_value, p=top_p_value), return_codes=True, )[0] print("Model inference done") # convert to mesh code = code[code != model.pad_id].cpu().numpy() vertices = BPT_deserialize( code, block_size = model.block_size, offset_size = model.offset_size, use_special_block = model.use_special_block, ) faces = torch.arange(1, len(vertices) + 1).view(-1, 3) artist_mesh = to_mesh(vertices, faces, transpose=False, post_process=True) # add color for visualization num_faces = len(artist_mesh.faces) face_color = np.array([120, 154, 192, 255], dtype=np.uint8) face_colors = np.tile(face_color, (num_faces, 1)) artist_mesh.visual.face_colors = face_colors # add time stamp to avoid cache save_name = f"output_{int(time.time())}.obj" artist_mesh.export(save_name) output_render = create_animation(artist_mesh) return save_name, output_render _HEADER_ = '''

Official 🤗 Gradio Demo for Paper Scaling Mesh Generation with Compressive Tokenization

''' _CITE_ = r""" If you found our model is helpful, please help to ⭐ the Github Repo. Code: GitHub. Arxiv Paper: ArXiv. 📧 **Contact** If you have any questions, feel free to contact Haohan Weng. """ output_model_obj = gr.Model3D( label="Generated Mesh (OBJ Format)", display_mode="wireframe", scale = 2, ) output_image_render = gr.Image( label="Wireframe Render of Generated Mesh", scale = 1, ) with gr.Blocks() as demo: gr.Markdown(_HEADER_) with gr.Row(variant="panel"): with gr.Column(scale=1): with gr.Row(): input_3d = gr.Model3D( label="Input Mesh", ) # with gr.Row(): # # with gr.Group(): with gr.Row(): sample_seed = gr.Number(value=0, label="Seed Value", precision=0) temperature = gr.Number(value=0.5, label="Temperature For Sampling", precision=None) with gr.Row(): top_k_value = gr.Number(value=50, label="TopK For Sampling", precision=0) top_p_value = gr.Number(value=0.9, label="TopP For Sampling", precision=None) with gr.Row(): submit = gr.Button("Generate", elem_id="generate", variant="primary") with gr.Row(variant="panel"): mesh_examples = gr.Examples( examples=[ os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples")) ], inputs=input_3d, outputs=[output_model_obj, output_image_render], fn=do_inference, cache_examples = False, examples_per_page=10 ) with gr.Row(): gr.Markdown('''Try different Seed Value or Temperature if the result is unsatisfying''') with gr.Column(scale=2): with gr.Row(equal_height=True): output_model_obj.render() output_image_render.render() gr.Markdown(_CITE_) mv_images = gr.State() submit.click( fn=do_inference, inputs=[input_3d, sample_seed, temperature, top_k_value, top_p_value], outputs = [output_model_obj, output_image_render], ) demo.launch(share=True)