Spaces:
Sleeping
Sleeping
""" | |
app.py | |
An interactive demo of text-guided shape generation. | |
""" | |
from pathlib import Path | |
from typing import Literal | |
import gradio as gr | |
import plotly.graph_objects as go | |
from salad.utils.spaghetti_util import ( | |
get_mesh_from_spaghetti, | |
generate_zc_from_sj_gaus, | |
load_mesher, | |
load_spaghetti, | |
) | |
import hydra | |
from omegaconf import OmegaConf | |
import torch | |
from pytorch_lightning import seed_everything | |
def load_model( | |
model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"], | |
device, | |
): | |
checkpoint_dir = Path(__file__).parent / "checkpoints" | |
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml") | |
model = hydra.utils.instantiate(c) | |
ckpt = torch.load( | |
checkpoint_dir / f"{model_class}/state_only.ckpt", | |
map_location=device, | |
) | |
model.load_state_dict(ckpt) | |
model.eval() | |
for p in model.parameters(): p.requires_grad_(False) | |
model = model.to(device) | |
return model | |
def run_inference(prompt: str): | |
"""The entry point of the demo.""" | |
device: torch.device = torch.device("cuda") | |
"""Device to run the demo on.""" | |
seed: int = 63 | |
"""Random seed for reproducibility.""" | |
# set random seed | |
seed_everything(seed) | |
# load SPAGHETTI and mesher | |
spaghetti = load_spaghetti(device) | |
mesher = load_mesher(device) | |
# load SALAD | |
lang_phase1_model = load_model("lang_phase1", device) | |
lang_phase2_model = load_model("phase2", device) | |
lang_phase1_model._build_dataset("val") | |
# run phase 1 | |
extrinsics = lang_phase1_model.sampling_gaussians([prompt]) | |
# run phase 2 | |
intrinsics = lang_phase2_model.sample(extrinsics) | |
# generate mesh | |
zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics) | |
vertices, faces = get_mesh_from_spaghetti( | |
spaghetti, | |
mesher, | |
zcs[0], | |
res=256, | |
) | |
# plot | |
figure = go.Figure( | |
data=[ | |
go.Mesh3d( | |
x=vertices[:, 0], # flip front-back | |
y=-vertices[:, 2], | |
z=vertices[:, 1], | |
i=faces[:, 0], | |
j=faces[:, 1], | |
k=faces[:, 2], | |
color="gray", | |
) | |
], | |
layout=dict( | |
scene=dict( | |
xaxis=dict(visible=False), | |
yaxis=dict(visible=False), | |
zaxis=dict(visible=False), | |
) | |
), | |
) | |
return figure | |
if __name__ == "__main__": | |
# create UI | |
demo = gr.Interface( | |
fn=run_inference, | |
inputs="text", | |
outputs=gr.Plot(), | |
title="SALAD: Text-Guided Shape Generation", | |
description="Describe a chair", | |
examples=[ | |
"an office chair", | |
"a chair with armrests", | |
"a chair without armrests", | |
] | |
) | |
# initiate | |
demo.queue(max_size=30) | |
demo.launch() | |