File size: 4,286 Bytes
14788de
5fff857
14788de
 
20fff88
 
9ec56ae
20fff88
9ec56ae
20fff88
 
 
 
9ec56ae
20fff88
9ec56ae
 
 
14788de
89d8b05
1bc0980
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14788de
5fff857
 
 
 
 
 
9ec56ae
 
 
 
5fff857
 
 
 
20fff88
9ec56ae
20fff88
 
 
 
1bc0980
 
 
 
9ec56ae
 
 
 
20fff88
9ec56ae
 
 
 
 
20fff88
9ec56ae
89d8b05
 
 
 
 
 
5fff857
 
89d8b05
5fff857
 
 
89d8b05
 
 
 
 
 
 
 
 
 
 
 
5fff857
 
 
 
 
 
 
 
 
14788de
9ec56ae
14788de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr

import spaces
import torch
from gradio_rerun import Rerun
import rerun as rr
import rerun.blueprint as rrb
from pathlib import Path
import uuid

from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
from mini_dust3r.model import AsymmetricCroCo3DStereo

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = AsymmetricCroCo3DStereo.from_pretrained(
    "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
).to(DEVICE)


def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
    # dont show 2d views if there are more than 4 images as to not clutter the view
    if len(image_name_list) > 4:
        blueprint = rrb.Blueprint(
            rrb.Horizontal(
                rrb.Spatial3DView(origin=f"{log_path}"),
            ),
            collapse_panels=True,
        )
    else:
        blueprint = rrb.Blueprint(
            rrb.Horizontal(
                contents=[
                    rrb.Spatial3DView(origin=f"{log_path}"),
                    rrb.Vertical(
                        contents=[
                            rrb.Spatial2DView(
                                origin=f"{log_path}/camera_{i}/pinhole/",
                                contents=[
                                    "+ $origin/**",
                                ],
                            )
                            for i in range(len(image_name_list))
                        ]
                    ),
                ],
                column_shares=[3, 1],
            ),
            collapse_panels=True,
        )
    return blueprint


@spaces.GPU
def predict(image_name_list: list[str] | str):
    # check if is list or string and if not raise error
    if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
        raise gr.Error(
            f"Input must be a list of strings or a string, got: {type(image_name_list)}"
        )
    uuid_str = str(uuid.uuid4())
    filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
    rr.init(f"{uuid_str}")
    log_path = Path("world")

    if isinstance(image_name_list, str):
        image_name_list = [image_name_list]

    optimized_results: OptimizedResult = inferece_dust3r(
        image_dir_or_list=image_name_list,
        model=model,
        device=DEVICE,
        batch_size=1,
    )

    blueprint: rrb.Blueprint = create_blueprint(image_name_list, log_path)
    rr.send_blueprint(blueprint)

    rr.set_time_sequence("sequence", 0)
    log_optimized_result(optimized_results, log_path)
    rr.save(filename.as_posix())
    return filename.as_posix()


with gr.Blocks(
    css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
    title="Mini-DUSt3R Demo",
) as demo:
    # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
    gr.HTML('<h2 style="text-align: center;">Mini-DUSt3R Demo</h2>')
    gr.HTML(
        '<p style="text-align: center;">Unofficial DUSt3R demo using the mini-dust3r pip package</p>'
    )
    gr.HTML(
        '<p style="text-align: center;">More info <a href="https://github.com/pablovela5620/mini-dust3r">here</a></p>'
    )
    with gr.Tab(label="Single Image"):
        with gr.Column():
            single_image = gr.Image(type="filepath", height=300)
            run_btn_single = gr.Button("Run")
            rerun_viewer_single = Rerun(height=900)
            run_btn_single.click(
                fn=predict, inputs=[single_image], outputs=[rerun_viewer_single]
            )

            example_single_dir = Path("examples/single_image")
            example_single_files = sorted(example_single_dir.glob("*.png"))

            examples_single = gr.Examples(
                examples=example_single_files,
                inputs=[single_image],
                outputs=[rerun_viewer_single],
                fn=predict,
                cache_examples="lazy",
            )
    with gr.Tab(label="Multi Image"):
        with gr.Column():
            multi_files = gr.File(file_count="multiple")
            run_btn_multi = gr.Button("Run")
            rerun_viewer_multi = Rerun(height=900)
            run_btn_multi.click(
                fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi]
            )


demo.launch()