Spaces:
Running
on
Zero
Running
on
Zero
from model.data_utils import to_mesh | |
from model.serializaiton import BPT_deserialize | |
import spaces | |
import os | |
import torch | |
import trimesh | |
from accelerate.utils import set_seed | |
import numpy as np | |
import gradio as gr | |
import time | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.mplot3d.art3d import Poly3DCollection | |
from matplotlib.animation import FuncAnimation | |
import yaml | |
from huggingface_hub import snapshot_download | |
from model.model import MeshTransformer | |
from utils import apply_normalize, joint_filter, sample_pc | |
CONFIG_PATH = 'config/BPT-open-8k-8-16.yaml' | |
with open(CONFIG_PATH, "r") as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
def download_models(): | |
os.makedirs("weights", exist_ok=True) | |
try: | |
snapshot_download( | |
repo_id="whaohan/bpt", | |
local_dir="./weights", | |
resume_download=True | |
) | |
print("Successfully downloaded Hunyuan3D-1 model") | |
except Exception as e: | |
print(f"Error downloading Hunyuan3D-1: {e}") | |
model_path = 'weights/bpt-8-16-500m.pt' | |
return model_path | |
MODEL_PATH = download_models() | |
# prepare model with fp16 precision | |
model = MeshTransformer( | |
dim = config['dim'], | |
attn_depth = config['depth'], | |
max_seq_len = config['max_seq_len'], | |
dropout = config['dropout'], | |
mode = config['mode'], | |
num_discrete_coors= 2**int(config['quant_bit']), | |
block_size = config['block_size'], | |
offset_size = config['offset_size'], | |
conditioned_on_pc = config['conditioned_on_pc'], | |
use_special_block = config['use_special_block'], | |
encoder_name = config['encoder_name'], | |
encoder_freeze = config['encoder_freeze'], | |
) | |
model.load(MODEL_PATH) | |
model = model.eval() | |
model = model.half() | |
model = model.cuda() | |
device = torch.device('cuda') | |
print('Model loaded') | |
def create_animation(mesh): | |
mesh.vertices = mesh.vertices[:, [2, 0, 1]] | |
bounding_box = mesh.bounds | |
center = mesh.centroid | |
scale = np.ptp(bounding_box, axis=0).max() | |
fig = plt.figure(figsize=(10, 10)) | |
ax = fig.add_subplot(111, projection='3d') | |
ax.set_axis_off() | |
# Extract vertices and faces for plotting | |
vertices = mesh.vertices | |
faces = mesh.faces | |
# Plot faces | |
ax.add_collection3d(Poly3DCollection( | |
vertices[faces] * 1.4, | |
facecolors=[120/255, 154/255, 192/255, 255/255], | |
edgecolors='k', | |
linewidths=0.5, | |
)) | |
# Set limits and center the view on the object | |
ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2) | |
ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2) | |
ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2) | |
# Function to update the view angle | |
def update_view(num, ax): | |
ax.view_init(elev=20, azim=num) | |
return ax, | |
# Create the animation | |
ani = FuncAnimation(fig, update_view, frames=np.arange(0, 360, 10), interval=100, fargs=(ax,), blit=False) | |
# Save the animation as a GIF | |
output_path = f'model_{int(time.time())}.gif' | |
ani.save(output_path, writer='pillow', fps=10) | |
# Close the figure | |
plt.close(fig) | |
return output_path | |
def do_inference(input_3d, sample_seed=0, temperature=0.5, top_k_value=50, top_p_value=0.9): | |
print('Start Inference') | |
set_seed(sample_seed) | |
print("Seed value:", sample_seed) | |
mesh = trimesh.load(input_3d, force='mesh') | |
mesh = apply_normalize(mesh) | |
pc_normal = sample_pc(mesh, pc_num=4096, with_normal=True) | |
vertices = mesh.vertices | |
pc_coor = pc_normal[:, :3] | |
normals = pc_normal[:, 3:] | |
assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong" | |
normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) | |
input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] | |
print("Data loaded") | |
with torch.no_grad(): | |
code = model.generate( | |
batch_size = 1, | |
temperature = temperature, | |
pc = input, | |
filter_logits_fn = joint_filter, | |
filter_kwargs = dict(k=top_k_value, p=top_p_value), | |
return_codes=True, | |
)[0] | |
print("Model inference done") | |
# convert to mesh | |
code = code[code != model.pad_id].cpu().numpy() | |
vertices = BPT_deserialize( | |
code, | |
block_size = model.block_size, | |
offset_size = model.offset_size, | |
use_special_block = model.use_special_block, | |
) | |
faces = torch.arange(1, len(vertices) + 1).view(-1, 3) | |
artist_mesh = to_mesh(vertices, faces, transpose=False, post_process=True) | |
# add color for visualization | |
num_faces = len(artist_mesh.faces) | |
face_color = np.array([120, 154, 192, 255], dtype=np.uint8) | |
face_colors = np.tile(face_color, (num_faces, 1)) | |
artist_mesh.visual.face_colors = face_colors | |
# add time stamp to avoid cache | |
save_name = f"output_{int(time.time())}.obj" | |
artist_mesh.export(save_name) | |
output_render = create_animation(artist_mesh) | |
return save_name, output_render | |
_HEADER_ = ''' | |
<h2><b>Official π€ Gradio Demo for Paper</b> <a href='https://github.com/whaohan/bpt' target='_blank'><b>Scaling Mesh Generation with Compressive Tokenization</b></a></h2> | |
''' | |
_CITE_ = r""" | |
If you found our model is helpful, please help to β the <a href='https://github.com/whaohan/bpt' target='_blank'>Github Repo</a>. Code: <a href='https://github.com/whaohan/bpt' target='_blank'>GitHub</a>. Arxiv Paper: <a href='https://arxiv.org/abs/2411.07025' target='_blank'>ArXiv</a>. | |
π§ **Contact** | |
If you have any questions, feel free to contact <a href='https://whaohan.github.io' target='_blank'>Haohan Weng</a>. | |
""" | |
output_model_obj = gr.Model3D( | |
label="Generated Mesh (OBJ Format)", | |
display_mode="wireframe", | |
scale = 2, | |
) | |
output_image_render = gr.Image( | |
label="Wireframe Render of Generated Mesh", | |
scale = 1, | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown(_HEADER_) | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
input_3d = gr.Model3D( | |
label="Input Mesh", | |
) | |
# with gr.Row(): | |
# # with gr.Group(): | |
with gr.Row(): | |
sample_seed = gr.Number(value=0, label="Seed Value", precision=0) | |
temperature = gr.Number(value=0.5, label="Temperature For Sampling", precision=None) | |
with gr.Row(): | |
top_k_value = gr.Number(value=50, label="TopK For Sampling", precision=0) | |
top_p_value = gr.Number(value=0.9, label="TopP For Sampling", precision=None) | |
with gr.Row(): | |
submit = gr.Button("Generate", elem_id="generate", variant="primary") | |
with gr.Row(variant="panel"): | |
mesh_examples = gr.Examples( | |
examples=[ | |
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples")) | |
], | |
inputs=input_3d, | |
outputs=[output_model_obj, output_image_render], | |
fn=do_inference, | |
cache_examples = False, | |
examples_per_page=10 | |
) | |
with gr.Row(): | |
gr.Markdown('''Try different <b>Seed Value</b> or <b>Temperature</b> if the result is unsatisfying''') | |
with gr.Column(scale=2): | |
with gr.Row(equal_height=True): | |
output_model_obj.render() | |
output_image_render.render() | |
gr.Markdown(_CITE_) | |
mv_images = gr.State() | |
submit.click( | |
fn=do_inference, | |
inputs=[input_3d, sample_seed, temperature, top_k_value, top_p_value], | |
outputs = [output_model_obj, output_image_render], | |
) | |
demo.launch(share=True) | |