oxkitsune's picture
continue, don't break
40c7c34
raw
history blame
3.27 kB
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import subprocess
import torch
import cv2
import numpy as np
import os
from pathlib import Path
import gradio as gr
from gradio_rerun import Rerun
import spaces
# Run the script to get pretrained models
if not os.path.exists("checkpoints/depth_pro.pt"):
print("downloading pretrained model")
subprocess.run(["bash", "get_pretrained_models.sh"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model and preprocessing transform
print("loading model...")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()
@spaces.GPU(duration=20)
def predict(frame):
image = transform(frame)
image = image.to(device)
prediction = model.infer(image)
depth = prediction["depth"].squeeze().detach().cpu().numpy()
return depth, prediction["focallength_px"].item()
@rr.thread_local_stream("rerun_example_ml_depth_pro")
def run_rerun(path_to_video):
stream = rr.binary_stream()
blueprint = rrb.Blueprint(
rrb.Vertical(
rrb.Spatial3DView(origin="/"),
rrb.Horizontal(
rrb.Spatial2DView(
origin="/world/camera/depth",
),
rrb.Spatial2DView(origin="/world/camera/image"),
),
),
collapse_panels=True,
)
rr.send_blueprint(blueprint)
yield stream.read()
print("Loading video from", path_to_video)
video = cv2.VideoCapture(path_to_video)
frame_idx = -1
while True:
read, frame = video.read()
if not read:
break
frame_idx += 1
if frame_idx % 10 != 0:
continue
frame = cv2.resize(frame, (320, 240))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
rr.set_time_sequence("frame", frame_idx)
rr.log("world/camera/image", rr.Image(frame))
yield stream.read()
depth, focal_length = estimate_depth(frame)
rr.log(
"world/camera",
rr.Pinhole(
width=frame.shape[1],
height=frame.shape[0],
focal_length=focal_length,
principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
image_plane_distance=depth.max(),
),
)
rr.log(
"world/camera/depth",
# need 0.19 stable for this
# rr.DepthImage(depth, meter=1, depth_range=(depth.min(), depth.max())),
rr.DepthImage(depth, meter=1),
)
yield stream.read()
@spaces.GPU(duration=20)
def estimate_depth(frame):
image = transform(frame)
image = image.to(device)
prediction = model.infer(image)
depth = prediction["depth"].squeeze().detach().cpu().numpy()
focal_length = prediction["focallength_px"].item()
return depth, focal_length
with gr.Blocks() as demo:
video = gr.Video(interactive=True, include_audio=False, label="Video")
visualize = gr.Button("Visualize ML Depth Pro")
with gr.Row():
viewer = Rerun(
streaming=True,
)
visualize.click(run_rerun, inputs=[video], outputs=[viewer])
if __name__ == "__main__":
demo.launch()