File size: 6,843 Bytes
eea4530
 
 
b2ee272
eea4530
 
95377ef
eea4530
 
95377ef
3507688
 
ac1916a
eea4530
b2ee272
cb5d809
95377ef
 
b2ee272
 
 
 
95377ef
b2ee272
 
 
 
eea4530
e9ee731
ac1916a
3507688
 
 
 
 
 
 
 
 
 
 
 
 
 
4be800e
 
ad0f5c7
4be800e
3507688
 
ad0f5c7
 
 
 
3507688
 
 
 
 
 
 
ad0f5c7
3507688
ad0f5c7
3507688
 
 
 
 
 
 
 
 
4d29a77
ac1916a
7b059da
0dd7180
eea4530
 
afc0455
 
 
 
d5dadbf
afc0455
d5dadbf
95377ef
0c06861
eea4530
 
 
 
4d29a77
3507688
7b059da
ac1916a
 
 
 
 
 
 
7b059da
0dd7180
 
 
 
ad0f5c7
 
 
 
 
906a232
0dd7180
4be800e
 
0dd7180
 
 
4be800e
ad0f5c7
4be800e
 
 
 
 
ac1916a
4be800e
ac1916a
4be800e
 
 
 
 
ad0f5c7
 
 
4be800e
 
 
 
ad0f5c7
4be800e
 
 
 
 
 
 
 
 
ad0f5c7
4be800e
 
e9ee731
4be800e
 
 
 
 
 
 
 
 
 
 
 
 
 
ac1916a
4be800e
ac1916a
0dd7180
ac1916a
4be800e
 
 
 
eea4530
3507688
eec6e0c
 
3507688
 
 
 
 
7b059da
0dd7180
e9ee731
3507688
 
eec6e0c
e8629d6
e9ee731
0dd7180
e9ee731
e8629d6
 
 
 
 
ac1916a
eea4530
 
 
3507688
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import subprocess

import torch
import os
import gradio as gr
from gradio_rerun import Rerun
import spaces
from PIL import Image
import tempfile
import cv2

# Run the script to get pretrained models
if not os.path.exists("./checkpoints/depth_pro.pt"):
    print("downloading pretrained model")
    subprocess.run(["bash", "get_pretrained_models.sh"])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model and preprocessing transform
print("loading model...")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()


def resize_image(image_buffer, max_size=256):
    with Image.fromarray(image_buffer) as img:
        # Calculate the new size while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])

        # Resize the image
        img = img.resize(new_size, Image.LANCZOS)

        # Create a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name


@spaces.GPU(duration=20)
def predict_depth(input_images):
    results = [depth_pro.load_rgb(image) for image in input_images]
    images = torch.stack([transform(result[0]) for result in results])
    images = images.to(device)

    # Run inference
    with torch.no_grad():
        prediction = model.infer(images)
        depth = prediction["depth"]  # Depth in [m]
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

    # Convert depth to numpy array if it's a torch tensor
    if isinstance(depth, torch.Tensor):
        depth = depth.cpu().numpy()

    # Convert focal length to a float if it's a torch tensor
    if isinstance(focallength_px, torch.Tensor):
        focallength_px = [focal_length.item() for focal_length in focallength_px]

    # Ensure depth is a BxHxW tensor
    if depth.ndim != 2:
        depth = depth.squeeze()

    # Clip depth values to 0m - 10m
    depth = depth.clip(0, 10)

    return depth, focallength_px


@rr.thread_local_stream("rerun_example_ml_depth_pro")
def run_rerun(path_to_video):
    print("video path:", path_to_video)
    stream = rr.binary_stream()

    blueprint = rrb.Blueprint(
        rrb.Vertical(
            rrb.Spatial3DView(origin="/"),
            rrb.Horizontal(
                rrb.Spatial2DView(
                    origin="/world/camera/depth",
                ),
                rrb.Spatial2DView(origin="/world/camera/frame"),
            ),
        ),
        collapse_panels=True,
    )

    rr.send_blueprint(blueprint)

    yield stream.read()

    video_asset = rr.AssetVideo(path=path_to_video)
    rr.log("world/video", video_asset, static=True)

    # Send automatically determined video frame timestamps.
    frame_timestamps_ns = video_asset.read_frame_timestamps_ns()

    cap = cv2.VideoCapture(path_to_video)
    num_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    fps_video = cap.get(cv2.CAP_PROP_FPS)

    # limit the number of frames to 10 seconds of video
    max_frames = min(10 * fps_video, num_frames)

    free_vram, _ = torch.cuda.mem_get_info(device)
    free_vram = free_vram / 1024 / 1024 / 1024

    # batch size is determined by the amount of free vram
    batch_size = int(min(min(4, free_vram // 4), max_frames))

    # go through all the frames in the video, using the batch size
    for i in range(0, int(max_frames), batch_size):
        if i >= max_frames:
            raise gr.Error("Reached the maximum number of frames to process")

        frames = []
        frame_indices = list(range(i, min(i + batch_size, int(max_frames))))
        for _ in range(batch_size):
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)

        temp_files = []
        try:
            # Resize the images to make the inference faster
            temp_files = [resize_image(frame, max_size=256) for frame in frames]

            depths, focal_lengths = predict_depth(temp_files)

            for depth, focal_length, frame_idx in zip(
                depths, focal_lengths, frame_indices
            ):
                # find x and y scale factors, which can be applied to image
                x_scale = depth.shape[1] / frames[0].shape[1]
                y_scale = depth.shape[0] / frames[0].shape[0]

                rr.set_time_nanos("video_time", frame_timestamps_ns[frame_idx])
                rr.log(
                    "world/camera/depth",
                    rr.DepthImage(depth, meter=1),
                )

                rr.log(
                    "world/camera/frame",
                    rr.VideoFrameReference(
                        timestamp=rr.components.VideoTimestamp(
                            nanoseconds=frame_timestamps_ns[frame_idx]
                        ),
                        video_reference="world/video",
                    ),
                    rr.Transform3D(scale=(x_scale, y_scale, 1)),
                )

                rr.log(
                    "world/camera",
                    rr.Pinhole(
                        focal_length=focal_length,
                        width=depth.shape[1],
                        height=depth.shape[0],
                        principal_point=(depth.shape[1] / 2, depth.shape[0] / 2),
                        camera_xyz=rr.ViewCoordinates.FLU,
                        image_plane_distance=depth.max(),
                    ),
                )

                yield stream.read()
        except Exception as e:
            raise gr.Error(f"An error has occurred: {e}")
        finally:
            # Clean up the temporary files
            for temp_file in temp_files:
                if temp_file and os.path.exists(temp_file):
                    os.remove(temp_file)

    yield stream.read()


with gr.Blocks() as interface:
    gr.Markdown(
        """
        # DepthPro Rerun Demo

        [DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload a video to visualize the depth predictions in real-time.

        High resolution videos will be automatically resized to 256x256 pixels, to speed up the inference and visualize multiple frames.
        """
    )
    with gr.Row():
        with gr.Column(variant="compact"):
            video = gr.Video(
                format="mp4", interactive=True, label="Video", include_audio=False
            )
            visualize = gr.Button("Visualize ML Depth Pro")
        with gr.Column():
            viewer = Rerun(
                streaming=True,
            )
        visualize.click(run_rerun, inputs=[video], outputs=[viewer])


if __name__ == "__main__":
    interface.launch()