import os import shutil import gradio as gr desc = """
Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.
""" def download_code(): os.system('git clone https://github.com/prs-eth/Marigold.git') def find_first_png(directory): for file in os.listdir(directory): if file.lower().endswith(".png"): return os.path.join(directory, file) return None def marigold_process(path_input, path_out_vis=None, path_out_pred=None): if path_out_vis is not None and path_out_pred is not None: return path_out_vis, path_out_pred path_input_dir = path_input + ".input" path_output_dir = path_input + ".output" os.makedirs(path_input_dir, exist_ok=True) os.makedirs(path_output_dir, exist_ok=True) shutil.copy(path_input, path_input_dir) os.system( f"cd Marigold && python3 run.py " f"--input_rgb_dir \"{path_input_dir}\" " f"--output_dir \"{path_output_dir}\" " f"--n_infer 10 " f"--denoise_steps 10 " ) path_out_vis = find_first_png(path_output_dir + "/depth_colored") assert path_out_vis is not None, "Processing failed" path_out_pred = find_first_png(path_output_dir + "/depth_bw") assert path_out_pred is not None, "Processing failed" return path_out_vis, path_out_pred iface = gr.Interface( title="Marigold Depth Estimation", description=desc, thumbnail="marigold_logo_square.jpg", fn=marigold_process, inputs=[ gr.Image( label="Input Image", type="filepath", ), gr.Image( label="Predicted depth (red-near, blue-far)", type="filepath", visible=False, ), gr.Image( label="Predicted depth", type="filepath", visible=False, ), ], outputs=[ gr.Image( label="Predicted depth (red-near, blue-far)", type="pil", ), gr.Image( label="Predicted depth", type="pil", elem_classes="imgdownload", ), ], allow_flagging="never", examples=[ [ os.path.join(os.path.dirname(__file__), "files/bee.jpg"), os.path.join(os.path.dirname(__file__), "files/bee_vis.jpg"), os.path.join(os.path.dirname(__file__), "files/bee_pred.jpg"), ], [ os.path.join(os.path.dirname(__file__), "files/cat.jpg"), os.path.join(os.path.dirname(__file__), "files/cat_vis.jpg"), os.path.join(os.path.dirname(__file__), "files/cat_pred.jpg"), ], [ os.path.join(os.path.dirname(__file__), "files/swings.jpg"), os.path.join(os.path.dirname(__file__), "files/swings_vis.jpg"), os.path.join(os.path.dirname(__file__), "files/swings_pred.jpg"), ], ], css=""" .viewport { aspect-ratio: 4/3; } .imgdownload { height: 64px; } """, cache_examples=True, ) if __name__ == "__main__": download_code() iface.queue().launch(server_name="0.0.0.0", server_port=7860)