jhshao commited on
Commit
02f6d94
1 Parent(s): 8bd250a
.gitattributes copy ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ files/sora_1764106507569053773.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ files/sora_e2.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,25 @@
1
  ---
2
  title: ChronoDepth
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: ChronoDepth
3
+ emoji: 🔥
4
+ colorFrom: pink
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
+ license: cc-by-4.0
11
  ---
12
 
13
+
14
+ This is a demo of the monocular video depth estimation pipeline, described in the paper titled ["Learning Temporally Consistent Video Depth from Video Diffusion Priors"](https://arxiv.org/abs/2406.01493).
15
+
16
+ ```bibtex
17
+ @misc{shao2024learning,
18
+ title={Learning Temporally Consistent Video Depth from Video Diffusion Priors},
19
+ author={Jiahao Shao and Yuanbo Yang and Hongyu Zhou and Youmin Zhang and Yujun Shen and Matteo Poggi and Yiyi Liao},
20
+ year={2024},
21
+ eprint={2406.01493},
22
+ archivePrefix={arXiv},
23
+ primaryClass={cs.CV}
24
+ }
25
+ ```
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2024 Jiahao Shao
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ import functools
24
+ import os
25
+ import zipfile
26
+ import tempfile
27
+ from io import BytesIO
28
+
29
+ import spaces
30
+ import gradio as gr
31
+ import numpy as np
32
+ import torch as torch
33
+ from PIL import Image
34
+ from tqdm import tqdm
35
+ import mediapy as media
36
+
37
+ from huggingface_hub import login
38
+
39
+ from chronodepth_pipeline import ChronoDepthPipeline
40
+ from gradio_patches.examples import Examples
41
+
42
+ default_seed = 2024
43
+
44
+ default_num_inference_steps = 5
45
+ default_num_frames = 10
46
+ default_window_size = 9
47
+ default_video_processing_resolution = 768
48
+ default_video_out_max_frames = 10
49
+ default_decode_chunk_size = 10
50
+
51
+ def process_video(
52
+ pipe,
53
+ path_input,
54
+ num_inference_steps=default_num_inference_steps,
55
+ num_frames=default_num_frames,
56
+ window_size=default_window_size,
57
+ out_max_frames=default_video_out_max_frames,
58
+ progress=gr.Progress(),
59
+ ):
60
+ if path_input is None:
61
+ raise gr.Error(
62
+ "Missing video in the first pane: upload a file or use one from the gallery below."
63
+ )
64
+
65
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
66
+ print(f"Processing video {name_base}{name_ext}")
67
+
68
+ path_output_dir = tempfile.mkdtemp()
69
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4")
70
+ path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip")
71
+
72
+ generator = torch.Generator(device=pipe.device).manual_seed(default_seed)
73
+
74
+ import time
75
+ start_time = time.time()
76
+ zipf = None
77
+ try:
78
+ if window_size is None or window_size == num_frames:
79
+ inpaint_inference = False
80
+ else:
81
+ inpaint_inference = True
82
+ data_ls = []
83
+ video_data = media.read_video(path_input)
84
+ video_length = len(video_data)
85
+ fps = video_data.metadata.fps
86
+
87
+ duration_sec = video_length / fps
88
+
89
+ out_duration_sec = out_max_frames / fps
90
+ if duration_sec > out_duration_sec:
91
+ gr.Warning(
92
+ f"Only the first ~{int(out_duration_sec)} seconds will be processed; "
93
+ f"use alternative setups such as ChronoDepth on github for full processing"
94
+ )
95
+ video_length = out_max_frames
96
+
97
+ for i in tqdm(range(video_length-num_frames+1)):
98
+ is_first_clip = i == 0
99
+ is_last_clip = i == video_length - num_frames
100
+ is_new_clip = (
101
+ (inpaint_inference and i % window_size == 0)
102
+ or (inpaint_inference == False and i % num_frames == 0)
103
+ )
104
+ if is_first_clip or is_last_clip or is_new_clip:
105
+ data_ls.append(np.array(video_data[i: i+num_frames])) # [t, H, W, 3]
106
+
107
+ zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED)
108
+
109
+ depth_colored_pred = []
110
+ depth_pred = []
111
+ # -------------------- Inference and saving --------------------
112
+ with torch.no_grad():
113
+ for iter, batch in enumerate(tqdm(data_ls)):
114
+ rgb_int = batch
115
+ input_images = [Image.fromarray(rgb_int[i]) for i in range(num_frames)]
116
+
117
+ # Predict depth
118
+ if iter == 0: # First clip
119
+ pipe_out = pipe(
120
+ input_images,
121
+ num_frames=len(input_images),
122
+ num_inference_steps=num_inference_steps,
123
+ decode_chunk_size=default_decode_chunk_size,
124
+ motion_bucket_id=127,
125
+ fps=7,
126
+ noise_aug_strength=0.0,
127
+ generator=generator,
128
+ )
129
+ elif inpaint_inference and (iter == len(data_ls) - 1): # temporal inpaint inference for last clip
130
+ last_window_size = window_size if video_length%window_size == 0 else video_length%window_size
131
+ pipe_out = pipe(
132
+ input_images,
133
+ num_frames=num_frames,
134
+ num_inference_steps=num_inference_steps,
135
+ decode_chunk_size=default_decode_chunk_size,
136
+ motion_bucket_id=127,
137
+ fps=7,
138
+ noise_aug_strength=0.0,
139
+ generator=generator,
140
+ depth_pred_last=depth_frames_pred_ts[last_window_size:],
141
+ )
142
+ elif inpaint_inference and iter > 0: # temporal inpaint inference
143
+ pipe_out = pipe(
144
+ input_images,
145
+ num_frames=num_frames,
146
+ num_inference_steps=num_inference_steps,
147
+ decode_chunk_size=default_decode_chunk_size,
148
+ motion_bucket_id=127,
149
+ fps=7,
150
+ noise_aug_strength=0.0,
151
+ generator=generator,
152
+ depth_pred_last=depth_frames_pred_ts[window_size:],
153
+ )
154
+ else: # separate inference
155
+ pipe_out = pipe(
156
+ input_images,
157
+ num_frames=num_frames,
158
+ num_inference_steps=num_inference_steps,
159
+ decode_chunk_size=default_decode_chunk_size,
160
+ motion_bucket_id=127,
161
+ fps=7,
162
+ noise_aug_strength=0.0,
163
+ generator=generator,
164
+ )
165
+
166
+ depth_frames_pred = [pipe_out.depth_np[i] for i in range(num_frames)]
167
+
168
+ depth_frames_colored_pred = []
169
+ for i in range(num_frames):
170
+ depth_frame_colored_pred = np.array(pipe_out.depth_colored[i])
171
+ depth_frames_colored_pred.append(depth_frame_colored_pred)
172
+ depth_frames_colored_pred = np.stack(depth_frames_colored_pred, axis=0)
173
+
174
+ depth_frames_pred = np.stack(depth_frames_pred, axis=0)
175
+ depth_frames_pred_ts = torch.from_numpy(depth_frames_pred).to(pipe.device)
176
+ depth_frames_pred_ts = depth_frames_pred_ts * 2 - 1
177
+
178
+ if inpaint_inference == False:
179
+ if iter == len(data_ls) - 1:
180
+ last_window_size = num_frames if video_length%num_frames == 0 else video_length%num_frames
181
+ depth_colored_pred.append(depth_frames_colored_pred[-last_window_size:])
182
+ depth_pred.append(depth_frames_pred[-last_window_size:])
183
+ else:
184
+ depth_colored_pred.append(depth_frames_colored_pred)
185
+ depth_pred.append(depth_frames_pred)
186
+ else:
187
+ if iter == 0:
188
+ depth_colored_pred.append(depth_frames_colored_pred)
189
+ depth_pred.append(depth_frames_pred)
190
+ elif iter == len(data_ls) - 1:
191
+ depth_colored_pred.append(depth_frames_colored_pred[-last_window_size:])
192
+ depth_pred.append(depth_frames_pred[-last_window_size:])
193
+ else:
194
+ depth_colored_pred.append(depth_frames_colored_pred[-window_size:])
195
+ depth_pred.append(depth_frames_pred[-window_size:])
196
+
197
+ depth_colored_pred = np.concatenate(depth_colored_pred, axis=0)
198
+ depth_pred = np.concatenate(depth_pred, axis=0)
199
+
200
+ # -------------------- Save results --------------------
201
+ # Save images
202
+ for i in tqdm(range(len(depth_pred))):
203
+ archive_path = os.path.join(
204
+ f"{name_base}_depth_16bit", f"{i:05d}.png"
205
+ )
206
+ img_byte_arr = BytesIO()
207
+ depth_16bit = Image.fromarray((depth_pred[i] * 65535.0).astype(np.uint16))
208
+ depth_16bit.save(img_byte_arr, format="png")
209
+ img_byte_arr.seek(0)
210
+ zipf.writestr(archive_path, img_byte_arr.read())
211
+
212
+ # Export to video
213
+ media.write_video(path_out_vis, depth_colored_pred, fps=fps)
214
+ finally:
215
+ if zipf is not None:
216
+ zipf.close()
217
+
218
+ end_time = time.time()
219
+ print(f"Processing time: {end_time - start_time} seconds")
220
+ return (
221
+ path_out_vis,
222
+ [path_out_vis, path_out_16bit],
223
+ )
224
+
225
+
226
+ def run_demo_server(pipe):
227
+ process_pipe_video = spaces.GPU(
228
+ functools.partial(process_video, pipe), duration=210
229
+ )
230
+ os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
231
+
232
+ with gr.Blocks(
233
+ analytics_enabled=False,
234
+ title="ChronoDepth Video Depth Estimation",
235
+ css="""
236
+ #download {
237
+ height: 118px;
238
+ }
239
+ .slider .inner {
240
+ width: 5px;
241
+ background: #FFF;
242
+ }
243
+ .viewport {
244
+ aspect-ratio: 4/3;
245
+ }
246
+ h1 {
247
+ text-align: center;
248
+ display: block;
249
+ }
250
+ h2 {
251
+ text-align: center;
252
+ display: block;
253
+ }
254
+ h3 {
255
+ text-align: center;
256
+ display: block;
257
+ }
258
+ """,
259
+ ) as demo:
260
+ gr.Markdown(
261
+ """
262
+ # ChronoDepth Video Depth Estimation
263
+
264
+ <p align="center">
265
+ <a title="Website" href="https://jhaoshao.github.io/ChronoDepth/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
266
+ <img src="https://img.shields.io/website?url=https%3A%2F%2Fjhaoshao.github.io%2FChronoDepth%2F&up_message=ChronoDepth&up_color=blue&style=flat&logo=timescale&logoColor=%23FFDC0F">
267
+ </a>
268
+ <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
269
+ <img src="https://img.shields.io/badge/arXiv-PDF-b31b1b">
270
+ </a>
271
+ <a title="Github" href="https://github.com/jhaoshao/ChronoDepth" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
272
+ <img src="https://img.shields.io/github/stars/jhaoshao/ChronoDepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
273
+ </a>
274
+ </p>
275
+
276
+ ChronoDepth is the state-of-the-art video depth estimator for videos in the wild.
277
+ Upload your video and have a try!<br>
278
+ We set denoising steps to 5, number of frames for each video clip to 10, and overlap between clips to 1.
279
+
280
+ """
281
+ )
282
+
283
+ with gr.Row():
284
+ with gr.Column():
285
+ video_input = gr.Video(
286
+ label="Input Video",
287
+ sources=["upload"],
288
+ )
289
+ with gr.Row():
290
+ video_submit_btn = gr.Button(
291
+ value="Compute Depth", variant="primary"
292
+ )
293
+ video_reset_btn = gr.Button(value="Reset")
294
+ with gr.Column():
295
+ video_output_video = gr.Video(
296
+ label="Output video depth (red-near, blue-far)",
297
+ interactive=False,
298
+ )
299
+ video_output_files = gr.Files(
300
+ label="Depth outputs",
301
+ elem_id="download",
302
+ interactive=False,
303
+ )
304
+ Examples(
305
+ fn=process_pipe_video,
306
+ examples=[
307
+ os.path.join("files", name)
308
+ for name in [
309
+ "sora_e2.mp4",
310
+ "sora_1758192960116785459.mp4",
311
+ ]
312
+ ],
313
+ inputs=[video_input],
314
+ outputs=[video_output_video, video_output_files],
315
+ cache_examples=True,
316
+ directory_name="examples_video",
317
+ )
318
+
319
+ video_submit_btn.click(
320
+ fn=process_pipe_video,
321
+ inputs=[video_input],
322
+ outputs=[video_output_video, video_output_files],
323
+ concurrency_limit=1,
324
+ )
325
+
326
+ video_reset_btn.click(
327
+ fn=lambda: (None, None, None),
328
+ inputs=[],
329
+ outputs=[video_input, video_output_video],
330
+ concurrency_limit=1,
331
+ )
332
+
333
+ demo.queue(
334
+ api_open=False,
335
+ ).launch(
336
+ server_name="0.0.0.0",
337
+ server_port=7860,
338
+ )
339
+
340
+
341
+ def main():
342
+ CHECKPOINT = "jhshao/ChronoDepth"
343
+
344
+ if "HF_TOKEN_LOGIN" in os.environ:
345
+ login(token=os.environ["HF_TOKEN_LOGIN"])
346
+
347
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
348
+ print(f"Running on device: {device}")
349
+ pipe = ChronoDepthPipeline.from_pretrained(CHECKPOINT)
350
+ try:
351
+ import xformers
352
+
353
+ pipe.enable_xformers_memory_efficient_attention()
354
+ except:
355
+ pass # run without xformers
356
+
357
+ pipe = pipe.to(device)
358
+ run_demo_server(pipe)
359
+
360
+
361
+ if __name__ == "__main__":
362
+ main()
chronodepth_pipeline.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from Marigold: https://github.com/prs-eth/Marigold and diffusers
2
+
3
+ import inspect
4
+ from typing import Union, Optional, List
5
+
6
+ import torch
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from tqdm.auto import tqdm
10
+ import PIL
11
+ from PIL import Image
12
+ from diffusers import (
13
+ DiffusionPipeline,
14
+ EulerDiscreteScheduler,
15
+ UNetSpatioTemporalConditionModel,
16
+ AutoencoderKLTemporalDecoder,
17
+ )
18
+ from diffusers.image_processor import VaeImageProcessor
19
+ from diffusers.utils import BaseOutput
20
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
21
+ from transformers import (
22
+ CLIPVisionModelWithProjection,
23
+ CLIPImageProcessor,
24
+ )
25
+ from einops import rearrange, repeat
26
+
27
+
28
+ class ChronoDepthOutput(BaseOutput):
29
+ r"""
30
+ Output class for zero-shot text-to-video pipeline.
31
+
32
+ Args:
33
+ frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
34
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
35
+ num_channels)`.
36
+ """
37
+ depth_np: np.ndarray
38
+ depth_colored: Union[List[PIL.Image.Image], np.ndarray]
39
+
40
+
41
+ class ChronoDepthPipeline(DiffusionPipeline):
42
+ model_cpu_offload_seq = "image_encoder->unet->vae"
43
+ _callback_tensor_inputs = ["latents"]
44
+ rgb_latent_scale_factor = 0.18215
45
+ depth_latent_scale_factor = 0.18215
46
+
47
+ def __init__(
48
+ self,
49
+ vae: AutoencoderKLTemporalDecoder,
50
+ image_encoder: CLIPVisionModelWithProjection,
51
+ unet: UNetSpatioTemporalConditionModel,
52
+ scheduler: EulerDiscreteScheduler,
53
+ feature_extractor: CLIPImageProcessor,
54
+ ):
55
+ super().__init__()
56
+
57
+ self.register_modules(
58
+ vae=vae,
59
+ image_encoder=image_encoder,
60
+ unet=unet,
61
+ scheduler=scheduler,
62
+ feature_extractor=feature_extractor,
63
+ )
64
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
65
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
66
+ if not hasattr(self, "dtype"):
67
+ self.dtype = self.unet.dtype
68
+
69
+ def encode_RGB(self,
70
+ image: torch.Tensor,
71
+ ):
72
+ video_length = image.shape[1]
73
+ image = rearrange(image, "b f c h w -> (b f) c h w")
74
+ latents = self.vae.encode(image).latent_dist.sample()
75
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
76
+ latents = latents * self.vae.config.scaling_factor
77
+
78
+ return latents
79
+
80
+ def _encode_image(self, image, device, discard=True):
81
+ '''
82
+ set image to zero tensor discards the image embeddings if discard is True
83
+ '''
84
+ dtype = next(self.image_encoder.parameters()).dtype
85
+
86
+ if not isinstance(image, torch.Tensor):
87
+ image = self.image_processor.pil_to_numpy(image)
88
+ if discard:
89
+ image = np.zeros_like(image)
90
+ image = self.image_processor.numpy_to_pt(image)
91
+
92
+ # We normalize the image before resizing to match with the original implementation.
93
+ # Then we unnormalize it after resizing.
94
+ image = image * 2.0 - 1.0
95
+ image = _resize_with_antialiasing(image, (224, 224))
96
+ image = (image + 1.0) / 2.0
97
+
98
+ # Normalize the image with for CLIP input
99
+ image = self.feature_extractor(
100
+ images=image,
101
+ do_normalize=True,
102
+ do_center_crop=False,
103
+ do_resize=False,
104
+ do_rescale=False,
105
+ return_tensors="pt",
106
+ ).pixel_values
107
+
108
+ image = image.to(device=device, dtype=dtype)
109
+ image_embeddings = self.image_encoder(image).image_embeds
110
+ image_embeddings = image_embeddings.unsqueeze(1)
111
+
112
+ return image_embeddings
113
+
114
+ def decode_depth(self, depth_latent: torch.Tensor, decode_chunk_size=5) -> torch.Tensor:
115
+ num_frames = depth_latent.shape[1]
116
+ depth_latent = rearrange(depth_latent, "b f c h w -> (b f) c h w")
117
+
118
+ depth_latent = depth_latent / self.vae.config.scaling_factor
119
+
120
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
121
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
122
+
123
+ depth_frames = []
124
+ for i in range(0, depth_latent.shape[0], decode_chunk_size):
125
+ num_frames_in = depth_latent[i : i + decode_chunk_size].shape[0]
126
+ decode_kwargs = {}
127
+ if accepts_num_frames:
128
+ # we only pass num_frames_in if it's expected
129
+ decode_kwargs["num_frames"] = num_frames_in
130
+
131
+ depth_frame = self.vae.decode(depth_latent[i : i + decode_chunk_size], **decode_kwargs).sample
132
+ depth_frames.append(depth_frame)
133
+
134
+ depth_frames = torch.cat(depth_frames, dim=0)
135
+ depth_frames = depth_frames.reshape(-1, num_frames, *depth_frames.shape[1:])
136
+ depth_mean = depth_frames.mean(dim=2, keepdim=True)
137
+
138
+ return depth_mean
139
+
140
+ def _get_add_time_ids(self,
141
+ fps,
142
+ motion_bucket_id,
143
+ noise_aug_strength,
144
+ dtype,
145
+ batch_size,
146
+ ):
147
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
148
+
149
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * \
150
+ len(add_time_ids)
151
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
152
+
153
+ if expected_add_embed_dim != passed_add_embed_dim:
154
+ raise ValueError(
155
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
156
+ )
157
+
158
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
159
+ add_time_ids = add_time_ids.repeat(batch_size, 1)
160
+ return add_time_ids
161
+
162
+ def decode_latents(self, latents, num_frames, decode_chunk_size=14):
163
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
164
+ latents = latents.flatten(0, 1)
165
+
166
+ latents = 1 / self.vae.config.scaling_factor * latents
167
+
168
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
169
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
170
+
171
+ # decode decode_chunk_size frames at a time to avoid OOM
172
+ frames = []
173
+ for i in range(0, latents.shape[0], decode_chunk_size):
174
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
175
+ decode_kwargs = {}
176
+ if accepts_num_frames:
177
+ # we only pass num_frames_in if it's expected
178
+ decode_kwargs["num_frames"] = num_frames_in
179
+
180
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
181
+ frames.append(frame)
182
+ frames = torch.cat(frames, dim=0)
183
+
184
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
185
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
186
+
187
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
188
+ frames = frames.float()
189
+ return frames
190
+
191
+ def check_inputs(self, image, height, width):
192
+ if (
193
+ not isinstance(image, torch.Tensor)
194
+ and not isinstance(image, PIL.Image.Image)
195
+ and not isinstance(image, list)
196
+ ):
197
+ raise ValueError(
198
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
199
+ f" {type(image)}"
200
+ )
201
+
202
+ if height % 64 != 0 or width % 64 != 0:
203
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
204
+
205
+ def prepare_latents(
206
+ self,
207
+ shape,
208
+ dtype,
209
+ device,
210
+ generator,
211
+ latent=None,
212
+ ):
213
+ if isinstance(generator, list) and len(generator) != shape[0]:
214
+ raise ValueError(
215
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
216
+ f" size of {shape[0]}. Make sure the batch size matches the length of the generators."
217
+ )
218
+
219
+ if latent is None:
220
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
221
+ else:
222
+ latents = latents.to(device)
223
+
224
+ # scale the initial noise by the standard deviation required by the scheduler
225
+ latents = latents * self.scheduler.init_noise_sigma
226
+ return latents
227
+
228
+ @property
229
+ def num_timesteps(self):
230
+ return self._num_timesteps
231
+
232
+ @torch.no_grad()
233
+ def __call__(
234
+ self,
235
+ input_image: Union[List[PIL.Image.Image], torch.FloatTensor],
236
+ height: int = 576,
237
+ width: int = 768,
238
+ num_frames: Optional[int] = None,
239
+ num_inference_steps: int = 10,
240
+ fps: int = 7,
241
+ motion_bucket_id: int = 127,
242
+ noise_aug_strength: float = 0.02,
243
+ decode_chunk_size: Optional[int] = None,
244
+ color_map: str="Spectral",
245
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
246
+ show_progress_bar: bool = True,
247
+ match_input_res: bool = True,
248
+ depth_pred_last: Optional[torch.FloatTensor] = None,
249
+ ):
250
+ assert height >= 0 and width >=0
251
+ assert num_inference_steps >=1
252
+
253
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
254
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
255
+
256
+ # 1. Check inputs. Raise error if not correct
257
+ self.check_inputs(input_image, height, width)
258
+
259
+ # 2. Define call parameters
260
+ if isinstance(input_image, list):
261
+ batch_size = 1
262
+ input_size = input_image[0].size
263
+ elif isinstance(input_image, torch.Tensor):
264
+ batch_size = input_image.shape[0]
265
+ input_size = input_image.shape[:-3:-1]
266
+ assert batch_size == 1, "Batch size must be 1 for now"
267
+ device = self._execution_device
268
+
269
+ # 3. Encode input image
270
+ image_embeddings = self._encode_image(input_image[0], device)
271
+ image_embeddings = image_embeddings.repeat((batch_size, 1, 1))
272
+
273
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
274
+ # is why it is reduced here.
275
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
276
+ fps = fps - 1
277
+
278
+ # 4. Encode input image using VAE
279
+ input_image = self.image_processor.preprocess(input_image, height=height, width=width).to(device)
280
+ assert input_image.min() >= -1.0 and input_image.max() <= 1.0
281
+ noise = randn_tensor(input_image.shape, generator=generator, device=device, dtype=input_image.dtype)
282
+ input_image = input_image + noise_aug_strength * noise
283
+ if depth_pred_last is not None:
284
+ depth_pred_last = depth_pred_last.to(device)
285
+ # resize depth
286
+ from torchvision.transforms import InterpolationMode
287
+ from torchvision.transforms.functional import resize
288
+ depth_pred_last = resize(depth_pred_last.unsqueeze(1), (height, width), InterpolationMode.NEAREST_EXACT, antialias=True)
289
+ depth_pred_last = repeat(depth_pred_last, 'f c h w ->b f c h w', b=batch_size)
290
+
291
+ rgb_batch = repeat(input_image, 'f c h w ->b f c h w', b=batch_size)
292
+
293
+ added_time_ids = self._get_add_time_ids(
294
+ fps,
295
+ motion_bucket_id,
296
+ noise_aug_strength,
297
+ image_embeddings.dtype,
298
+ batch_size,
299
+ )
300
+ added_time_ids = added_time_ids.to(device)
301
+
302
+ depth_pred_raw = self.single_infer(rgb_batch,
303
+ image_embeddings,
304
+ added_time_ids,
305
+ num_inference_steps,
306
+ show_progress_bar,
307
+ generator,
308
+ depth_pred_last=depth_pred_last,
309
+ decode_chunk_size=decode_chunk_size)
310
+
311
+ depth_colored_img_list = []
312
+ depth_frames = []
313
+ for i in range(num_frames):
314
+ depth_frame = depth_pred_raw[:, i].squeeze()
315
+
316
+ # Convert to numpy
317
+ depth_frame = depth_frame.cpu().numpy().astype(np.float32)
318
+
319
+ if match_input_res:
320
+ pred_img = Image.fromarray(depth_frame)
321
+ pred_img = pred_img.resize(input_size, resample=Image.NEAREST)
322
+ depth_frame = np.asarray(pred_img)
323
+
324
+ # Clip output range: current size is the original size
325
+ depth_frame = depth_frame.clip(0, 1)
326
+
327
+ # Colorize
328
+ depth_colored = plt.get_cmap(color_map)(depth_frame, bytes=True)[..., :3]
329
+ depth_colored_img = Image.fromarray(depth_colored)
330
+
331
+ depth_colored_img_list.append(depth_colored_img)
332
+ depth_frames.append(depth_frame)
333
+
334
+ depth_frame = np.stack(depth_frames)
335
+
336
+ self.maybe_free_model_hooks()
337
+
338
+ return ChronoDepthOutput(
339
+ depth_np = depth_frames,
340
+ depth_colored = depth_colored_img_list,
341
+ )
342
+
343
+ @torch.no_grad()
344
+ def single_infer(self,
345
+ input_rgb: torch.Tensor,
346
+ image_embeddings: torch.Tensor,
347
+ added_time_ids: torch.Tensor,
348
+ num_inference_steps: int,
349
+ show_pbar: bool,
350
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
351
+ depth_pred_last: Optional[torch.Tensor] = None,
352
+ decode_chunk_size=1,
353
+ ):
354
+ device = input_rgb.device
355
+
356
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
357
+ if needs_upcasting:
358
+ self.vae.to(dtype=torch.float32)
359
+
360
+ rgb_latent = self.encode_RGB(input_rgb)
361
+ rgb_latent = rgb_latent.to(image_embeddings.dtype)
362
+ if depth_pred_last is not None:
363
+ depth_pred_last = depth_pred_last.repeat(1, 1, 3, 1, 1)
364
+ depth_pred_last_latent = self.encode_RGB(depth_pred_last)
365
+ depth_pred_last_latent = depth_pred_last_latent.to(image_embeddings.dtype)
366
+ else:
367
+ depth_pred_last_latent = None
368
+
369
+ # cast back to fp16 if needed
370
+ if needs_upcasting:
371
+ self.vae.to(dtype=torch.float16)
372
+
373
+ # Prepare timesteps
374
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
375
+ timesteps = self.scheduler.timesteps
376
+
377
+ depth_latent = self.prepare_latents(
378
+ rgb_latent.shape,
379
+ image_embeddings.dtype,
380
+ device,
381
+ generator
382
+ )
383
+
384
+ if show_pbar:
385
+ iterable = tqdm(
386
+ enumerate(timesteps),
387
+ total=len(timesteps),
388
+ leave=False,
389
+ desc=" " * 4 + "Diffusion denoising",
390
+ )
391
+ else:
392
+ iterable = enumerate(timesteps)
393
+
394
+ for i, t in iterable:
395
+ if depth_pred_last_latent is not None:
396
+ known_frames_num = depth_pred_last_latent.shape[1]
397
+ epsilon = randn_tensor(
398
+ depth_pred_last_latent.shape,
399
+ generator=generator,
400
+ device=device,
401
+ dtype=image_embeddings.dtype
402
+ )
403
+ depth_latent[:, :known_frames_num] = depth_pred_last_latent + epsilon * self.scheduler.sigmas[i]
404
+ depth_latent = self.scheduler.scale_model_input(depth_latent, t)
405
+ unet_input = torch.cat([rgb_latent, depth_latent], dim=2)
406
+
407
+ noise_pred = self.unet(
408
+ unet_input, t, image_embeddings, added_time_ids=added_time_ids
409
+ )[0]
410
+
411
+ # compute the previous noisy sample x_t -> x_t-1
412
+ depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
413
+
414
+ torch.cuda.empty_cache()
415
+ if needs_upcasting:
416
+ self.vae.to(dtype=torch.float16)
417
+ depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size)
418
+ # clip prediction
419
+ depth = torch.clip(depth, -1.0, 1.0)
420
+ # shift to [0, 1]
421
+ depth = (depth + 1.0) / 2.0
422
+
423
+ return depth
424
+
425
+ # resizing utils
426
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
427
+ h, w = input.shape[-2:]
428
+ factors = (h / size[0], w / size[1])
429
+
430
+ # First, we have to determine sigma
431
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
432
+ sigmas = (
433
+ max((factors[0] - 1.0) / 2.0, 0.001),
434
+ max((factors[1] - 1.0) / 2.0, 0.001),
435
+ )
436
+
437
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
438
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
439
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
440
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
441
+
442
+ # Make sure it is odd
443
+ if (ks[0] % 2) == 0:
444
+ ks = ks[0] + 1, ks[1]
445
+
446
+ if (ks[1] % 2) == 0:
447
+ ks = ks[0], ks[1] + 1
448
+
449
+ input = _gaussian_blur2d(input, ks, sigmas)
450
+
451
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
452
+ return output
453
+
454
+
455
+ def _compute_padding(kernel_size):
456
+ """Compute padding tuple."""
457
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
458
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
459
+ if len(kernel_size) < 2:
460
+ raise AssertionError(kernel_size)
461
+ computed = [k - 1 for k in kernel_size]
462
+
463
+ # for even kernels we need to do asymmetric padding :(
464
+ out_padding = 2 * len(kernel_size) * [0]
465
+
466
+ for i in range(len(kernel_size)):
467
+ computed_tmp = computed[-(i + 1)]
468
+
469
+ pad_front = computed_tmp // 2
470
+ pad_rear = computed_tmp - pad_front
471
+
472
+ out_padding[2 * i + 0] = pad_front
473
+ out_padding[2 * i + 1] = pad_rear
474
+
475
+ return out_padding
476
+
477
+
478
+ def _filter2d(input, kernel):
479
+ # prepare kernel
480
+ b, c, h, w = input.shape
481
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
482
+
483
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
484
+
485
+ height, width = tmp_kernel.shape[-2:]
486
+
487
+ padding_shape: list[int] = _compute_padding([height, width])
488
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
489
+
490
+ # kernel and input tensor reshape to align element-wise or batch-wise params
491
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
492
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
493
+
494
+ # convolve the tensor with the kernel.
495
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
496
+
497
+ out = output.view(b, c, h, w)
498
+ return out
499
+
500
+
501
+ def _gaussian(window_size: int, sigma):
502
+ if isinstance(sigma, float):
503
+ sigma = torch.tensor([[sigma]])
504
+
505
+ batch_size = sigma.shape[0]
506
+
507
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
508
+
509
+ if window_size % 2 == 0:
510
+ x = x + 0.5
511
+
512
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
513
+
514
+ return gauss / gauss.sum(-1, keepdim=True)
515
+
516
+
517
+ def _gaussian_blur2d(input, kernel_size, sigma):
518
+ if isinstance(sigma, tuple):
519
+ sigma = torch.tensor([sigma], dtype=input.dtype)
520
+ else:
521
+ sigma = sigma.to(dtype=input.dtype)
522
+
523
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
524
+ bs = sigma.shape[0]
525
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
526
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
527
+ out_x = _filter2d(input, kernel_x[..., None, :])
528
+ out = _filter2d(out_x, kernel_y[..., None])
529
+
530
+ return out
gradio_patches/examples.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio
4
+ from gradio.utils import get_cache_folder
5
+
6
+
7
+ class Examples(gradio.helpers.Examples):
8
+ def __init__(self, *args, directory_name=None, **kwargs):
9
+ super().__init__(*args, **kwargs, _initiated_directly=False)
10
+ if directory_name is not None:
11
+ self.cached_folder = get_cache_folder() / directory_name
12
+ self.cached_file = Path(self.cached_folder) / "log.csv"
13
+ self.create()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ gradio>=4.32.1
3
+ diffusers==0.26.0
4
+ easydict==1.13
5
+ einops==0.8.0
6
+ matplotlib==3.8.4
7
+ mediapy==1.2.2
8
+ numpy==1.26.4
9
+ Pillow==10.3.0
10
+ torch==2.0.1
11
+ torchvision==0.15.2
12
+ tqdm==4.66.2
13
+ accelerate==0.28.0
14
+ transformers==4.36.2