import gradio as gr import os from PIL import Image import subprocess from gradio_model4dgs import Model4DGS import numpy import hashlib os.system('pip install -e ./simple-knn') os.system('pip install -e ./diff-gaussian-rasterization') from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors") js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.href = url.href; } } """ # check if there is a picture uploaded or selected def check_img_input(control_image): if control_image is None: raise gr.Error("Please select or upload an input image") # check if there is a picture uploaded or selected def check_video_input(image_block: Image.Image): img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')): raise gr.Error("Please generate a video first") def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int): if not os.path.exists('tmp_data'): os.makedirs('tmp_data') img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() if preprocess_chk: # save image to a designated path image_block.save(os.path.join('tmp_data', f'{img_hash}.png')) # preprocess image print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}') subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True) else: image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png')) # stage 1 subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True) # return [os.path.join('logs', 'tmp_rgba_model.ply')] return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4') def optimize_stage_2(image_block: Image.Image, seed_slider: int): img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True) # stage 2 subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True) # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames')) image_dir = os.path.join('logs', f'{img_hash}_rgba_frames') # return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')] return [image_dir+f'/{t:03d}.ply' for t in range(28)] if __name__ == "__main__": _TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting''' _DESCRIPTION = '''
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting. ''' _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), select a random seed, and click **Generate Video**. After having the video generated, please click **Generate 4D**." # load images in 'data' folder as examples example_folder = os.path.join(os.path.dirname(__file__), 'data') example_fns = os.listdir(example_folder) example_fns.sort() examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')] # Compose demo layout & data flow with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ' + _TITLE) gr.Markdown(_DESCRIPTION) # Image-to-3D with gr.Row(variant='panel'): with gr.Column(scale=4): image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image') # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle') seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed') gr.Markdown( "random seed for video generation.") preprocess_chk = gr.Checkbox(True, label='Preprocess image automatically (remove background and recenter object)') gr.Examples( examples=examples_full, # NOTE: elements must match inputs list! inputs=[image_block], outputs=[image_block], cache_examples=False, label='Examples (click one of the images below to start)', examples_per_page=40 ) img_run_btn = gr.Button("Generate Video") fourd_run_btn = gr.Button("Generate 4D") img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True) with gr.Column(scale=5): obj3d = gr.Video(label="video",height=290) obj4d = Model4DGS(label="4D Model", height=500, fps=14) img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1, inputs=[image_block, preprocess_chk, seed_slider], outputs=[ obj3d]) fourd_run_btn.click(check_video_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[obj4d]) # demo.queue().launch(share=True) demo.queue(max_size=10) # <-- Sets up a queue with default parameters demo.launch(share=True)