Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,843 Bytes
eea4530 b2ee272 eea4530 95377ef eea4530 95377ef 3507688 ac1916a eea4530 b2ee272 cb5d809 95377ef b2ee272 95377ef b2ee272 eea4530 e9ee731 ac1916a 3507688 4be800e ad0f5c7 4be800e 3507688 ad0f5c7 3507688 ad0f5c7 3507688 ad0f5c7 3507688 4d29a77 ac1916a 7b059da 0dd7180 eea4530 afc0455 d5dadbf afc0455 d5dadbf 95377ef 0c06861 eea4530 4d29a77 3507688 7b059da ac1916a 7b059da 0dd7180 ad0f5c7 906a232 0dd7180 4be800e 0dd7180 4be800e ad0f5c7 4be800e ac1916a 4be800e ac1916a 4be800e ad0f5c7 4be800e ad0f5c7 4be800e ad0f5c7 4be800e e9ee731 4be800e ac1916a 4be800e ac1916a 0dd7180 ac1916a 4be800e eea4530 3507688 eec6e0c 3507688 7b059da 0dd7180 e9ee731 3507688 eec6e0c e8629d6 e9ee731 0dd7180 e9ee731 e8629d6 ac1916a eea4530 3507688 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import subprocess
import torch
import os
import gradio as gr
from gradio_rerun import Rerun
import spaces
from PIL import Image
import tempfile
import cv2
# Run the script to get pretrained models
if not os.path.exists("./checkpoints/depth_pro.pt"):
print("downloading pretrained model")
subprocess.run(["bash", "get_pretrained_models.sh"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model and preprocessing transform
print("loading model...")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()
def resize_image(image_buffer, max_size=256):
with Image.fromarray(image_buffer) as img:
# Calculate the new size while maintaining aspect ratio
ratio = max_size / max(img.size)
new_size = tuple([int(x * ratio) for x in img.size])
# Resize the image
img = img.resize(new_size, Image.LANCZOS)
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
img.save(temp_file, format="PNG")
return temp_file.name
@spaces.GPU(duration=20)
def predict_depth(input_images):
results = [depth_pro.load_rgb(image) for image in input_images]
images = torch.stack([transform(result[0]) for result in results])
images = images.to(device)
# Run inference
with torch.no_grad():
prediction = model.infer(images)
depth = prediction["depth"] # Depth in [m]
focallength_px = prediction["focallength_px"] # Focal length in pixels
# Convert depth to numpy array if it's a torch tensor
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Convert focal length to a float if it's a torch tensor
if isinstance(focallength_px, torch.Tensor):
focallength_px = [focal_length.item() for focal_length in focallength_px]
# Ensure depth is a BxHxW tensor
if depth.ndim != 2:
depth = depth.squeeze()
# Clip depth values to 0m - 10m
depth = depth.clip(0, 10)
return depth, focallength_px
@rr.thread_local_stream("rerun_example_ml_depth_pro")
def run_rerun(path_to_video):
print("video path:", path_to_video)
stream = rr.binary_stream()
blueprint = rrb.Blueprint(
rrb.Vertical(
rrb.Spatial3DView(origin="/"),
rrb.Horizontal(
rrb.Spatial2DView(
origin="/world/camera/depth",
),
rrb.Spatial2DView(origin="/world/camera/frame"),
),
),
collapse_panels=True,
)
rr.send_blueprint(blueprint)
yield stream.read()
video_asset = rr.AssetVideo(path=path_to_video)
rr.log("world/video", video_asset, static=True)
# Send automatically determined video frame timestamps.
frame_timestamps_ns = video_asset.read_frame_timestamps_ns()
cap = cv2.VideoCapture(path_to_video)
num_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fps_video = cap.get(cv2.CAP_PROP_FPS)
# limit the number of frames to 10 seconds of video
max_frames = min(10 * fps_video, num_frames)
free_vram, _ = torch.cuda.mem_get_info(device)
free_vram = free_vram / 1024 / 1024 / 1024
# batch size is determined by the amount of free vram
batch_size = int(min(min(4, free_vram // 4), max_frames))
# go through all the frames in the video, using the batch size
for i in range(0, int(max_frames), batch_size):
if i >= max_frames:
raise gr.Error("Reached the maximum number of frames to process")
frames = []
frame_indices = list(range(i, min(i + batch_size, int(max_frames))))
for _ in range(batch_size):
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
temp_files = []
try:
# Resize the images to make the inference faster
temp_files = [resize_image(frame, max_size=256) for frame in frames]
depths, focal_lengths = predict_depth(temp_files)
for depth, focal_length, frame_idx in zip(
depths, focal_lengths, frame_indices
):
# find x and y scale factors, which can be applied to image
x_scale = depth.shape[1] / frames[0].shape[1]
y_scale = depth.shape[0] / frames[0].shape[0]
rr.set_time_nanos("video_time", frame_timestamps_ns[frame_idx])
rr.log(
"world/camera/depth",
rr.DepthImage(depth, meter=1),
)
rr.log(
"world/camera/frame",
rr.VideoFrameReference(
timestamp=rr.components.VideoTimestamp(
nanoseconds=frame_timestamps_ns[frame_idx]
),
video_reference="world/video",
),
rr.Transform3D(scale=(x_scale, y_scale, 1)),
)
rr.log(
"world/camera",
rr.Pinhole(
focal_length=focal_length,
width=depth.shape[1],
height=depth.shape[0],
principal_point=(depth.shape[1] / 2, depth.shape[0] / 2),
camera_xyz=rr.ViewCoordinates.FLU,
image_plane_distance=depth.max(),
),
)
yield stream.read()
except Exception as e:
raise gr.Error(f"An error has occurred: {e}")
finally:
# Clean up the temporary files
for temp_file in temp_files:
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
yield stream.read()
with gr.Blocks() as interface:
gr.Markdown(
"""
# DepthPro Rerun Demo
[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload a video to visualize the depth predictions in real-time.
High resolution videos will be automatically resized to 256x256 pixels, to speed up the inference and visualize multiple frames.
"""
)
with gr.Row():
with gr.Column(variant="compact"):
video = gr.Video(
format="mp4", interactive=True, label="Video", include_audio=False
)
visualize = gr.Button("Visualize ML Depth Pro")
with gr.Column():
viewer = Rerun(
streaming=True,
)
visualize.click(run_rerun, inputs=[video], outputs=[viewer])
if __name__ == "__main__":
interface.launch()
|