Michelangelo / app.py
Maikou's picture
update
d41cb1d
raw
history blame
15 kB
# -*- coding: utf-8 -*-
import os
import time
from collections import OrderedDict
from PIL import Image
import torch
import trimesh
from typing import Optional, List
from einops import repeat, rearrange
import numpy as np
from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
from michelangelo.utils.visualizers import html_util
import gradio as gr
from omegaconf import OmegaConf
from huggingface_hub import snapshot_download
gradio_cached_dir = "./gradio_cached_dir"
os.makedirs(gradio_cached_dir, exist_ok=True)
save_mesh = False
state = ""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
box_v = 1.1
viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
image_model_config_dict = OrderedDict({
"ASLDM-256-obj": {
# "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml",
# "ckpt_path": "./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt",
"config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml",
"ckpt_path": "checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt",
},
})
text_model_config_dict = OrderedDict({
"ASLDM-256": {
# "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml",
# "ckpt_path": "./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt",
"config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml",
"ckpt_path": "checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt",
},
})
model_path = snapshot_download(repo_id="Maikou/Michelangelo")
class InferenceModel(object):
model = None
name = ""
text2mesh_model = InferenceModel()
image2mesh_model = InferenceModel()
def set_state(s):
global state
state = s
print(s)
def output_to_html_frame(mesh_outputs: List[Latent2MeshOutput], bbox_size: float,
image: Optional[np.ndarray] = None,
html_frame: bool = False):
global viewer
for i in range(len(mesh_outputs)):
mesh = mesh_outputs[i]
if mesh is None:
continue
mesh_v = mesh.mesh_v.copy()
mesh_v[:, 0] += i * np.max(bbox_size)
mesh_v[:, 2] += np.max(bbox_size)
viewer.add_mesh(mesh_v, mesh.mesh_f)
mesh_tag = viewer.to_html(html_frame=False)
if image is not None:
image_tag = html_util.to_image_embed_tag(image)
frame = f"""
<table border = "1">
<tr>
<td>{image_tag}</td>
<td>{mesh_tag}</td>
</tr>
</table>
"""
else:
frame = mesh_tag
if html_frame:
frame = html_util.to_html_frame(frame)
viewer.reset()
return frame
def load_model(model_name: str, model_config_dict: dict, inference_model: InferenceModel):
global device
if inference_model.name == model_name:
model = inference_model.model
else:
assert model_name in model_config_dict
if inference_model.model is not None:
del inference_model.model
config_ckpt_path = model_config_dict[model_name]
# raw_config_file = config_ckpt_path["config"]
# raw_config = OmegaConf.load(raw_config_file)
# raw_clip_ckpt_path = raw_config['model']['params']['first_stage_config']['params']['aligned_module_cfg']['params']['clip_model_version']
# clip_ckpt_path = os.path.join(model_path, raw_clip_ckpt_path)
# raw_config['model']['params']['first_stage_config']['params']['aligned_module_cfg']['params']['clip_model_version'] = clip_ckpt_path
# raw_config['model']['params']['cond_stage_config']['params']['version'] = clip_ckpt_path
# OmegaConf.save(raw_config, 'current_config.yaml')
# model_config = get_config_from_file('current_config.yaml')
model_config = get_config_from_file(config_ckpt_path["config"])
if hasattr(model_config, "model"):
model_config = model_config.model
ckpt_path = os.path.join(model_path, config_ckpt_path["ckpt_path"])
model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
model = model.to(device)
model = model.eval()
inference_model.model = model
inference_model.name = model_name
return model
def prepare_img(image: np.ndarray):
image_pt = torch.tensor(image).float()
image_pt = image_pt / 255 * 2 - 1
image_pt = rearrange(image_pt, "h w c -> c h w")
return image_pt
def prepare_model_viewer(fp):
content = f"""
<head>
<script
type="module" src="https://ajax.googleapis.com/ajax/libs/model-viewer/3.1.1/model-viewer.min.js">
</script>
</head>
<body>
<model-viewer
style="height: 150px; width: 150px;"
rotation-per-second="10deg"
id="t1"
src="file/gradio_cached_dir/{fp}"
environment-image="neutral"
camera-target="0m 0m 0m"
orientation="0deg 90deg 170deg"
shadow-intensity="1"
ar:true
auto-rotate
camera-controls>
</model-viewer>
</body>
"""
return content
def prepare_html_frame(content):
frame = f"""
<html>
<body>
{content}
</body>
</html>
"""
return frame
def prepare_html_body(content):
frame = f"""
<body>
{content}
</body>
"""
return frame
def post_process_mesh_outputs(mesh_outputs):
# html_frame = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=True)
html_content = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=False)
html_frame = prepare_html_frame(html_content)
# filename = f"{time.time()}.html"
filename = f"four-in-one-{time.time()}.html"
html_filepath = os.path.join(gradio_cached_dir, filename)
with open(html_filepath, "w") as writer:
writer.write(html_frame)
'''
Bug: The iframe tag does not work in Gradio.
The chrome returns "No resource with given URL found"
Solutions:
https://github.com/gradio-app/gradio/issues/884
Due to the security bitches, the server can only find files parallel to the gradio_app.py.
The path has format "file/TARGET_FILE_PATH"
'''
iframe_tag = f'<iframe src="file/gradio_cached_dir/{filename}" width="600%" height="400" frameborder="0"></iframe>'
filelist = []
filenames = []
for i, mesh in enumerate(mesh_outputs):
mesh.mesh_f = mesh.mesh_f[:, ::-1]
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
name = str(i) + "_out_mesh.obj"
filepath = gradio_cached_dir + "/" + name
mesh_output.export(filepath, include_normals=True)
filelist.append(filepath)
filenames.append(name)
filelist.append(html_filepath)
return iframe_tag, filelist
def image2mesh(image: np.ndarray,
model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03",
num_samples: int = 4,
guidance_scale: int = 7.5,
octree_depth: int = 7):
global device, gradio_cached_dir, image_model_config_dict, box_v
# load model
model = load_model(model_name, image_model_config_dict, image2mesh_model)
# prepare image inputs
image_pt = prepare_img(image)
image_pt = repeat(image_pt, "c h w -> b c h w", b=num_samples)
sample_inputs = {
"image": image_pt
}
mesh_outputs = model.sample(
sample_inputs,
sample_times=1,
guidance_scale=guidance_scale,
return_intermediates=False,
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
octree_depth=octree_depth,
)[0]
iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs)
return iframe_tag, gr.update(value=filelist, visible=True)
def text2mesh(text: str,
model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03",
num_samples: int = 4,
guidance_scale: int = 7.5,
octree_depth: int = 7):
global device, gradio_cached_dir, text_model_config_dict, text2mesh_model, box_v
# load model
model = load_model(model_name, text_model_config_dict, text2mesh_model)
# prepare text inputs
sample_inputs = {
"text": [text] * num_samples
}
mesh_outputs = model.sample(
sample_inputs,
sample_times=1,
guidance_scale=guidance_scale,
return_intermediates=False,
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
octree_depth=octree_depth,
)[0]
iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs)
return iframe_tag, gr.update(value=filelist, visible=True)
example_dir = './gradio_cached_dir/example/img_example'
first_page_items = [
'alita.jpg',
'burger.jpg'
'loopy.jpg'
'building.jpg',
'mario.jpg',
'car.jpg',
'airplane.jpg',
'bag.jpg',
'bench.jpg',
'ship.jpg'
]
raw_example_items = [
# (os.path.join(example_dir, x), x)
os.path.join(example_dir, x)
for x in os.listdir(example_dir)
if x.endswith(('.jpg', '.png'))
]
example_items = [x for x in raw_example_items if os.path.basename(x) in first_page_items] + [x for x in raw_example_items if os.path.basename(x) not in first_page_items]
example_text = [
["A 3D model of a car; Audi A6."],
["A 3D model of police car; Highway Patrol Charger"]
],
def set_cache(data: gr.SelectData):
img_name = os.path.basename(example_items[data.index])
return os.path.join(example_dir, img_name), os.path.join(img_name)
def disable_cache():
return ""
with gr.Blocks() as app:
gr.Markdown("# Michelangelo")
gr.Markdown("## [Github](https://github.com/NeuralCarver/Michelangelo) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Project Page](https://neuralcarver.github.io/michelangelo/)")
gr.Markdown("Michelangelo is a conditional 3D shape generation system that trains based on the shape-image-text aligned latent representation.")
gr.Markdown("### Hint:")
gr.Markdown("1. We provide two APIs: Image-conditioned generation and Text-conditioned generation")
gr.Markdown("2. Note that the Image-conditioned model is trained on multiple 3D datasets like ShapeNet and Objaverse")
gr.Markdown("3. We provide some examples for you to try. You can also upload images or text as input.")
gr.Markdown("4. To make it convenient to take favor results home, we provide download buttons for each OBJ file and a combined HTML file.")
gr.Markdown("5. Welcome to share suggestions or amazing results with us, and thanks for your interest in our work!")
gr.Markdown("6. Please note that the model may require some time to download the weights and set up during the first launch; we are working to fix this issue.")
with gr.Row():
with gr.Column():
with gr.Tab("Image to 3D"):
img = gr.Image(label="Image")
gr.Markdown("For the best results, we suggest that the images uploaded meet the following three criteria: 1. The object is positioned at the center of the image, 2. The image size is square, and 3. The background is relatively clean.")
btn_generate_img2obj = gr.Button(value="Generate")
with gr.Accordion("Advanced settings", open=False):
image_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256-obj",choices=list(image_model_config_dict.keys()))
num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1)
guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1)
octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1)
cache_dir = gr.Textbox(value="", visible=False)
examples = gr.Gallery(label='Examples', value=example_items, elem_id="gallery", allow_preview=False, columns=[4], object_fit="contain")
with gr.Tab("Text to 3D"):
prompt = gr.Textbox(label="Prompt", placeholder="A 3D model of motorcar; Porche Cayenne Turbo.")
gr.Markdown("For the best results, we suggest that the prompt follows 'A 3D model of CATEGORY; DESCRIPTION'. For example, A 3D model of motorcar; Porche Cayenne Turbo.")
btn_generate_txt2obj = gr.Button(value="Generate")
with gr.Accordion("Advanced settings", open=False):
text_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256",choices=list(text_model_config_dict.keys()))
num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1)
guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1)
octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1)
gr.Markdown("#### Examples:")
gr.Markdown("1. A 3D model of an airplane; Airbus.")
gr.Markdown("2. A 3D model of a fighter aircraft; Attack Fighter.")
gr.Markdown("3. A 3D model of a chair; Simple Wooden Chair.")
gr.Markdown("4. A 3D model of a laptop computer; Dell Laptop.")
gr.Markdown("5. A 3D model of a coupe; Audi A6.")
gr.Markdown("6. A 3D model of a motorcar; Hummer H2 SUT.")
gr.Markdown("7. A 3D model of a lamp; Light Post.")
gr.Markdown("8. A 3D model of a rifle; AK47.")
gr.Markdown("9. A 3D model of a knife; Sword.")
gr.Markdown("10. A 3D model of a vase; Plant in pot.")
with gr.Column():
model_3d = gr.HTML()
file_out = gr.File(label="Files", visible=False)
outputs = [model_3d, file_out]
img.upload(disable_cache, outputs=cache_dir)
examples.select(set_cache, outputs=[img, cache_dir])
print(os.path.abspath(os.path.dirname(__file__)), flush=True)
print(model_path, flush=True)
fps = os.listdir(model_path)
print(fps)
print(f'line:404: {cache_dir}', flush=True)
btn_generate_img2obj.click(image2mesh, inputs=[img, image_dropdown_models, num_samples,
guidance_scale, octree_depth],
outputs=outputs, api_name="generate_img2obj")
btn_generate_txt2obj.click(text2mesh, inputs=[prompt, text_dropdown_models, num_samples,
guidance_scale, octree_depth],
outputs=outputs, api_name="generate_txt2obj")
app.launch()