File size: 4,286 Bytes
8cf8c7b
bb54316
d80ec21
 
 
 
 
 
 
8cf8c7b
49f0812
 
 
d80ec21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cf8c7b
d80ec21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cf8c7b
49f0812
d80ec21
49f0812
d80ec21
49f0812
 
00c4703
d80ec21
 
00c4703
d80ec21
 
 
 
00c4703
 
d80ec21
 
 
 
 
 
 
 
 
 
 
 
 
00c4703
d80ec21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cf8c7b
d80ec21
 
 
 
 
 
00c4703
 
d80ec21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr

import spaces
import torch
from gradio_rerun import Rerun
import rerun as rr
import rerun.blueprint as rrb
from pathlib import Path
import uuid

from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
from mini_dust3r.model import AsymmetricCroCo3DStereo

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = AsymmetricCroCo3DStereo.from_pretrained(
    "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
).to(DEVICE)


def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
    # dont show 2d views if there are more than 4 images as to not clutter the view
    if len(image_name_list) > 4:
        blueprint = rrb.Blueprint(
            rrb.Horizontal(
                rrb.Spatial3DView(origin=f"{log_path}"),
            ),
            collapse_panels=True,
        )
    else:
        blueprint = rrb.Blueprint(
            rrb.Horizontal(
                contents=[
                    rrb.Spatial3DView(origin=f"{log_path}"),
                    rrb.Vertical(
                        contents=[
                            rrb.Spatial2DView(
                                origin=f"{log_path}/camera_{i}/pinhole/",
                                contents=[
                                    "+ $origin/**",
                                ],
                            )
                            for i in range(len(image_name_list))
                        ]
                    ),
                ],
                column_shares=[3, 1],
            ),
            collapse_panels=True,
        )
    return blueprint


@spaces.GPU
def predict(image_name_list: list[str] | str):
    # check if is list or string and if not raise error
    if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
        raise gr.Error(
            f"Input must be a list of strings or a string, got: {type(image_name_list)}"
        )
    uuid_str = str(uuid.uuid4())
    filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
    rr.init(f"{uuid_str}")
    log_path = Path("world")

    if isinstance(image_name_list, str):
        image_name_list = [image_name_list]

    optimized_results: OptimizedResult = inferece_dust3r(
        image_dir_or_list=image_name_list,
        model=model,
        device=DEVICE,
        batch_size=1,
    )

    blueprint: rrb.Blueprint = create_blueprint(image_name_list, log_path)
    rr.send_blueprint(blueprint)

    rr.set_time_sequence("sequence", 0)
    log_optimized_result(optimized_results, log_path)
    rr.save(filename.as_posix())
    return filename.as_posix()


with gr.Blocks(
    css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
    title="Mini-DUSt3R Demo",
) as demo:
    # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
    gr.HTML('<h2 style="text-align: center;">Mini-DUSt3R Demo</h2>')
    gr.HTML(
        '<p style="text-align: center;">Unofficial DUSt3R demo using the mini-dust3r pip package</p>'
    )
    gr.HTML(
        '<p style="text-align: center;">More info <a href="https://github.com/pablovela5620/mini-dust3r">here</a></p>'
    )
    with gr.Tab(label="Single Image"):
        with gr.Column():
            single_image = gr.Image(type="filepath", height=300)
            run_btn_single = gr.Button("Run")
            rerun_viewer_single = Rerun(height=900)
            run_btn_single.click(
                fn=predict, inputs=[single_image], outputs=[rerun_viewer_single]
            )

            example_single_dir = Path("examples/single_image")
            example_single_files = sorted(example_single_dir.glob("*.png"))

            examples_single = gr.Examples(
                examples=example_single_files,
                inputs=[single_image],
                outputs=[rerun_viewer_single],
                fn=predict,
                cache_examples="lazy",
            )
    with gr.Tab(label="Multi Image"):
        with gr.Column():
            multi_files = gr.File(file_count="multiple")
            run_btn_multi = gr.Button("Run")
            rerun_viewer_multi = Rerun(height=900)
            run_btn_multi.click(
                fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi]
            )


demo.launch()