import spaces import subprocess # Install flash attention, skipping CUDA build if necessary subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import os import torch import trimesh from accelerate.utils import set_seed from accelerate import Accelerator import numpy as np import gradio as gr from main import get_args, load_model from mesh_to_pc import process_mesh_to_pc import time import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Poly3DCollection from PIL import Image import io args = get_args() model = load_model(args) device = torch.device('cuda') accelerator = Accelerator( mixed_precision="fp16", ) model = accelerator.prepare(model) model.eval() print("Model loaded to device") def wireframe_render(mesh): views = [ (90, 20), (270, 20) ] mesh.vertices = mesh.vertices[:, [0, 2, 1]] bounding_box = mesh.bounds center = mesh.centroid scale = np.ptp(bounding_box, axis=0).max() fig = plt.figure(figsize=(10, 10)) # Function to render and return each view as an image def render_view(mesh, azimuth, elevation): ax = fig.add_subplot(111, projection='3d') ax.set_axis_off() # Extract vertices and faces for plotting vertices = mesh.vertices faces = mesh.faces # Plot faces ax.add_collection3d(Poly3DCollection( vertices[faces], facecolors=(0.8, 0.5, 0.2, 1.0), # Brownish yellow edgecolors='k', linewidths=0.5, )) # Set limits and center the view on the object ax.set_xlim(center[0] - scale / 2, center[0] + scale / 2) ax.set_ylim(center[1] - scale / 2, center[1] + scale / 2) ax.set_zlim(center[2] - scale / 2, center[2] + scale / 2) # Set view angle ax.view_init(elev=elevation, azim=azimuth) # Save the figure to a buffer buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=300) plt.clf() buf.seek(0) return Image.open(buf) # Render each view and store in a list images = [render_view(mesh, az, el) for az, el in views] # Combine images horizontally widths, heights = zip(*(i.size for i in images)) total_width = sum(widths) max_height = max(heights) combined_image = Image.new('RGBA', (total_width, max_height)) x_offset = 0 for img in images: combined_image.paste(img, (x_offset, 0)) x_offset += img.width # Save the combined image save_path = f"combined_mesh_view_{int(time.time())}.png" combined_image.save(save_path) plt.close(fig) return save_path @spaces.GPU(duration=300) def do_inference(input_3d, sample_seed=0, do_sampling=False, do_marching_cubes=False): set_seed(sample_seed) print("Seed value:", sample_seed) input_mesh = trimesh.load(input_3d) pc_list, mesh_list = process_mesh_to_pc([input_mesh], marching_cubes = do_marching_cubes) pc_normal = pc_list[0] # 4096, 6 mesh = mesh_list[0] vertices = mesh.vertices pc_coor = pc_normal[:, :3] normals = pc_normal[:, 3:] bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)]) # scale mesh and pc vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2 vertices = vertices / (bounds[1] - bounds[0]).max() mesh.vertices = vertices pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2 pc_coor = pc_coor / (bounds[1] - bounds[0]).max() mesh.merge_vertices() mesh.update_faces(mesh.unique_faces()) mesh.fix_normals() if mesh.visual.vertex_colors is not None: orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1)) else: orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) mesh.visual.vertex_colors = np.tile(orange_color, (mesh.vertices.shape[0], 1)) input_save_name = f"processed_input_{int(time.time())}.obj" mesh.export(input_save_name) pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995 # input should be from -1 to 1 assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong" normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16) input = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None] print("Data loaded") # with accelerator.autocast(): with accelerator.autocast(): outputs = model(input, do_sampling) print("Model inference done") recon_mesh = outputs[0] recon_mesh = recon_mesh[~torch.isnan(recon_mesh[:, 0, 0])] # nvalid_face x 3 x 3 vertices = recon_mesh.reshape(-1, 3).cpu() vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face triangles = vertices_index.reshape(-1, 3) artist_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh", merge_primitives=True) artist_mesh.merge_vertices() artist_mesh.update_faces(artist_mesh.unique_faces()) artist_mesh.fix_normals() if artist_mesh.visual.vertex_colors is not None: orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1)) else: orange_color = np.array([255, 165, 0, 255], dtype=np.uint8) artist_mesh.visual.vertex_colors = np.tile(orange_color, (artist_mesh.vertices.shape[0], 1)) num_faces = len(artist_mesh.faces) brown_color = np.array([165, 42, 42, 255], dtype=np.uint8) face_colors = np.tile(brown_color, (num_faces, 1)) artist_mesh.visual.face_colors = face_colors # add time stamp to avoid cache save_name = f"output_{int(time.time())}.obj" artist_mesh.export(save_name) mesh.vertices = mesh.vertices[:, [0, 2, 1]] artist_mesh.vertices = artist_mesh.vertices[:, [0, 2, 1]] input_wireframe_save_name = f"input_wireframe_{int(time.time())}.obj" output_wireframe_save_name = f"output_wireframe_{int(time.time())}.obj" mesh.export(input_wireframe_save_name) artist_mesh.export(output_wireframe_save_name) return input_save_name, input_wireframe_save_name, save_name, output_wireframe_save_name _HEADER_ = '''