import os import shutil import subprocess import textwrap from pathlib import Path import gradio as gr import torch from huggingface_hub import hf_hub_download REPO_ID = "kbrodt/sketch2pose" API_TOKEN = os.environ["sketch2pose"] ASSET_DIR = Path("./assets") SAVE_DIR = "output" CMD = textwrap.dedent(""" python src/pose.py --save-path {} --img-path {} """) TITLE = "Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch" DESCRIPTION = '''
''' def prepare(): filename = "models_smplx_v1_1.zip" smpl_path = hf_hub_download( repo_id=REPO_ID, repo_type="model", filename=filename, use_auth_token=API_TOKEN, cache_dir=ASSET_DIR, ) if not (ASSET_DIR / filename).is_file(): shutil.copy(smpl_path, ASSET_DIR) subprocess.run("bash ./scripts/download.sh".split()) subprocess.run("bash ./scripts/prepare.sh".split()) def main(): prepare() save_dir = Path(SAVE_DIR) save_dir.mkdir(parents=True, exist_ok=True) def pose(img_path, use_cos=True, use_angle_transf=True, use_contacts=False, use_natural=True): if use_cos == False: use_angle_transf = False cmd = CMD.format(save_dir, img_path) if use_cos: cmd = cmd + " --use-cos" if use_angle_transf: cmd = cmd + " --use-angle-transf" if use_contacts: cmd = cmd + " --use-contacts" if use_natural: cmd = cmd + " --use-natural" out_dir = (save_dir / Path(img_path).name).with_suffix("") mesh_path = out_dir / "us.glb" if not mesh_path.is_file(): subprocess.call(cmd.split()) return str(mesh_path) examples = [] for img_path in Path("./data/images").glob("*"): examples.append([str(img_path), True, True, False, True]) break demo = gr.Interface( fn=pose, inputs=[ gr.Image(type="filepath", label="Image"), gr.Checkbox(value=True, label="Bone lenghts"), gr.Checkbox(value=True, label="Foreshortening"), gr.Checkbox(value=False, label="Self-contacts (available with cuda)", interactive=torch.cuda.is_available()), gr.Checkbox(value=True, label="Pose naturalness"), ], outputs=gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="SMPL 3D pose"), examples=examples, title=TITLE, description=DESCRIPTION, ) demo.launch() if __name__ == "__main__": main()