Spaces:
Running
on
Zero
Running
on
Zero
wbhu-tc
commited on
Commit
•
7c1a14b
0
Parent(s):
update
Browse files- .gitattributes +37 -0
- .gitignore +169 -0
- LICENSE +32 -0
- README.md +105 -0
- app.py +141 -0
- depthcrafter/__init__.py +0 -0
- depthcrafter/depth_crafter_ppl.py +366 -0
- depthcrafter/unet.py +142 -0
- depthcrafter/utils.py +92 -0
- examples/example_01.mp4 +3 -0
- requirements.txt +5 -0
- run.py +210 -0
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Python template
|
2 |
+
# Byte-compiled / optimized / DLL files
|
3 |
+
__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/#use-with-ide
|
111 |
+
.pdm.toml
|
112 |
+
|
113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
114 |
+
__pypackages__/
|
115 |
+
|
116 |
+
# Celery stuff
|
117 |
+
celerybeat-schedule
|
118 |
+
celerybeat.pid
|
119 |
+
|
120 |
+
# SageMath parsed files
|
121 |
+
*.sage.py
|
122 |
+
|
123 |
+
# Environments
|
124 |
+
.env
|
125 |
+
.venv
|
126 |
+
env/
|
127 |
+
venv/
|
128 |
+
ENV/
|
129 |
+
env.bak/
|
130 |
+
venv.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
.idea/
|
162 |
+
|
163 |
+
/logs
|
164 |
+
/gin-config
|
165 |
+
*.json
|
166 |
+
/eval/*csv
|
167 |
+
*__pycache__
|
168 |
+
scripts/
|
169 |
+
eval/
|
LICENSE
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications").
|
2 |
+
|
3 |
+
License Terms of the inference code of DepthCrafter:
|
4 |
+
--------------------------------------------------------------------
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
- You agree to use the DepthCrafter only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
|
9 |
+
|
10 |
+
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
11 |
+
|
12 |
+
For avoidance of doubts, “Software” means the DepthCrafter model inference code and weights made available under this license excluding any pre-trained data and other AI components.
|
13 |
+
|
14 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
15 |
+
|
16 |
+
|
17 |
+
Other dependencies and licenses:
|
18 |
+
|
19 |
+
Open Source Software Licensed under the MIT License:
|
20 |
+
--------------------------------------------------------------------
|
21 |
+
1. Stability AI - Code
|
22 |
+
Copyright (c) 2023 Stability AI
|
23 |
+
|
24 |
+
Terms of the MIT License:
|
25 |
+
--------------------------------------------------------------------
|
26 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
27 |
+
|
28 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
29 |
+
|
30 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
31 |
+
|
32 |
+
**You may find the code license of Stability AI at the following links: https://github.com/Stability-AI/generative-models/blob/main/LICENSE-CODE
|
README.md
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## ___***DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos***___
|
2 |
+
<div align="center">
|
3 |
+
<img src='https://depthcrafter.github.io/img/logo.png' style="height:140px"></img>
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
<a href='https://arxiv.org/abs/2409.02095'><img src='https://img.shields.io/badge/arXiv-2409.02095-b31b1b.svg'></a>
|
8 |
+
<a href='https://depthcrafter.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
9 |
+
|
10 |
+
|
11 |
+
_**[Wenbo Hu<sup>1* †</sup>](https://wbhu.github.io),
|
12 |
+
[Xiangjun Gao<sup>2*</sup>](https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en),
|
13 |
+
[Xiaoyu Li<sup>1* †</sup>](https://xiaoyu258.github.io),
|
14 |
+
[Sijie Zhao<sup>1</sup>](https://scholar.google.com/citations?user=tZ3dS3MAAAAJ&hl=en),
|
15 |
+
[Xiaodong Cun<sup>1</sup>](https://vinthony.github.io/academic), <br>
|
16 |
+
[Yong Zhang<sup>1</sup>](https://yzhang2016.github.io),
|
17 |
+
[Long Quan<sup>2</sup>](https://home.cse.ust.hk/~quan),
|
18 |
+
[Ying Shan<sup>3, 1</sup>](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en)**_
|
19 |
+
<br><br>
|
20 |
+
<sup>1</sup>Tencent AI Lab
|
21 |
+
<sup>2</sup>The Hong Kong University of Science and Technology
|
22 |
+
<sup>3</sup>ARC Lab, Tencent PCG
|
23 |
+
|
24 |
+
arXiv preprint, 2024
|
25 |
+
|
26 |
+
</div>
|
27 |
+
|
28 |
+
## 🔆 Introduction
|
29 |
+
🤗 DepthCrafter can generate temporally consistent long depth sequences with fine-grained details for open-world videos,
|
30 |
+
without requiring additional information such as camera poses or optical flow.
|
31 |
+
|
32 |
+
## 🎥 Visualization
|
33 |
+
We provide some demos of unprojected point cloud sequences, with reference RGB and estimated depth videos.
|
34 |
+
Please refer to our [project page](https://depthcrafter.github.io) for more details.
|
35 |
+
|
36 |
+
|
37 |
+
https://github.com/user-attachments/assets/62141cc8-04d0-458f-9558-fe50bc04cc21
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
## 🚀 Quick Start
|
43 |
+
|
44 |
+
### 🛠️ Installation
|
45 |
+
1. Clone this repo:
|
46 |
+
```bash
|
47 |
+
git clone https://github.com/Tencent/DepthCrafter.git
|
48 |
+
```
|
49 |
+
2. Install dependencies (please refer to [requirements.txt](requirements.txt)):
|
50 |
+
```bash
|
51 |
+
pip install -r requirements.txt
|
52 |
+
```
|
53 |
+
|
54 |
+
## 🤗 Model Zoo
|
55 |
+
[DepthCrafter](https://huggingface.co/tencent/DepthCrafter) is available in the Hugging Face Model Hub.
|
56 |
+
|
57 |
+
### 🏃♂️ Inference
|
58 |
+
#### 1. High-resolution inference, requires a GPU with ~26GB memory for 1024x576 resolution:
|
59 |
+
- Full inference (~0.6 fps on A100, recommended for high-quality results):
|
60 |
+
|
61 |
+
```bash
|
62 |
+
python run.py --video-path examples/example_01.mp4
|
63 |
+
```
|
64 |
+
|
65 |
+
|
66 |
+
- Fast inference through 4-step denoising and without classifier-free guidance (~2.3 fps on A100):
|
67 |
+
|
68 |
+
```bash
|
69 |
+
python run.py --video-path examples/example_01.mp4 --num-inference-steps 4 --guidance-scale 1.0
|
70 |
+
```
|
71 |
+
|
72 |
+
|
73 |
+
#### 2. Low-resolution inference, requires a GPU with ~9GB memory for 512x256 resolution:
|
74 |
+
|
75 |
+
- Full inference (~2.3 fps on A100):
|
76 |
+
|
77 |
+
```bash
|
78 |
+
python run.py --video-path examples/example_01.mp4 --max-res 512
|
79 |
+
```
|
80 |
+
|
81 |
+
- Fast inference through 4-step denoising and without classifier-free guidance (~9.4 fps on A100):
|
82 |
+
```bash
|
83 |
+
python run.py --video-path examples/example_01.mp4 --max-res 512 --num-inference-steps 4 --guidance-scale 1.0
|
84 |
+
```
|
85 |
+
|
86 |
+
## 🤖 Gradio Demo
|
87 |
+
We provide a local Gradio demo for DepthCrafter, which can be launched by running:
|
88 |
+
```bash
|
89 |
+
gradio app.py
|
90 |
+
```
|
91 |
+
|
92 |
+
## 🤝 Contributing
|
93 |
+
- Welcome to open issues and pull requests.
|
94 |
+
- Welcome to optimize the inference speed and memory usage, e.g., through model quantization, distillation, or other acceleration techniques.
|
95 |
+
|
96 |
+
## 📜 Citation
|
97 |
+
If you find this work helpful, please consider citing:
|
98 |
+
```bibtex
|
99 |
+
@article{hu2024-DepthCrafter,
|
100 |
+
author = {Hu, Wenbo and Gao, Xiangjun and Li, Xiaoyu and Zhao, Sijie and Cun, Xiaodong and Zhang, Yong and Quan, Long and Shan, Ying},
|
101 |
+
title = {DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos},
|
102 |
+
journal = {arXiv preprint arXiv:2409.02095},
|
103 |
+
year = {2024}
|
104 |
+
}
|
105 |
+
```
|
app.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.training_utils import set_seed
|
9 |
+
|
10 |
+
from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
|
11 |
+
from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
|
12 |
+
from depthcrafter.utils import read_video_frames, vis_sequence_depth, save_video
|
13 |
+
from run import DepthCrafterDemo
|
14 |
+
|
15 |
+
examples = [
|
16 |
+
["examples/example_01.mp4", 25, 1.2, 1024, 195],
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def construct_demo():
|
21 |
+
with gr.Blocks(analytics_enabled=False) as depthcrafter_iface:
|
22 |
+
gr.Markdown(
|
23 |
+
"""
|
24 |
+
<div align='center'> <h1> DepthCrafter: Generating Consistent Long Depth Sequences for Open-world Videos </span> </h1> \
|
25 |
+
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
26 |
+
<a href='https://wbhu.github.io'>Wenbo Hu</a>, \
|
27 |
+
<a href='https://scholar.google.com/citations?user=qgdesEcAAAAJ&hl=en'>Xiangjun Gao</a>, \
|
28 |
+
<a href='https://xiaoyu258.github.io/'>Xiaoyu Li</a>, \
|
29 |
+
<a href='https://scholar.google.com/citations?user=tZ3dS3MAAAAJ&hl=en'>Sijie Zhao</a>, \
|
30 |
+
<a href='https://vinthony.github.io/academic'> Xiaodong Cun</a>, \
|
31 |
+
<a href='https://yzhang2016.github.io'>Yong Zhang</a>, \
|
32 |
+
<a href='https://home.cse.ust.hk/~quan'>Long Quan</a>, \
|
33 |
+
<a href='https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en'>Ying Shan</a>\
|
34 |
+
</h2> \
|
35 |
+
<a style='font-size:18px;color: #000000'>If you find DepthCrafter useful, please help star the </a>\
|
36 |
+
<a style='font-size:18px;color: #FF5DB0' href='https://github.com/wbhu/DepthCrafter'>[Github Repo]</a>\
|
37 |
+
<a style='font-size:18px;color: #000000'>, which is important to Open-Source projects. Thanks!</a>\
|
38 |
+
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02095'> [ArXiv] </a>\
|
39 |
+
<a style='font-size:18px;color: #000000' href='https://depthcrafter.github.io/'> [Project Page] </a> </div>
|
40 |
+
"""
|
41 |
+
)
|
42 |
+
# demo
|
43 |
+
depthcrafter_demo = DepthCrafterDemo(
|
44 |
+
unet_path="tencent/DepthCrafter",
|
45 |
+
pre_train_path="stabilityai/stable-video-diffusion-img2vid-xt",
|
46 |
+
)
|
47 |
+
|
48 |
+
with gr.Row(equal_height=True):
|
49 |
+
with gr.Column(scale=1):
|
50 |
+
input_video = gr.Video(label="Input Video")
|
51 |
+
|
52 |
+
# with gr.Tab(label="Output"):
|
53 |
+
with gr.Column(scale=2):
|
54 |
+
with gr.Row(equal_height=True):
|
55 |
+
output_video_1 = gr.Video(
|
56 |
+
label="Preprocessed video",
|
57 |
+
interactive=False,
|
58 |
+
autoplay=True,
|
59 |
+
loop=True,
|
60 |
+
show_share_button=True,
|
61 |
+
scale=5,
|
62 |
+
)
|
63 |
+
output_video_2 = gr.Video(
|
64 |
+
label="Generated Depth Video",
|
65 |
+
interactive=False,
|
66 |
+
autoplay=True,
|
67 |
+
loop=True,
|
68 |
+
show_share_button=True,
|
69 |
+
scale=5,
|
70 |
+
)
|
71 |
+
|
72 |
+
with gr.Row(equal_height=True):
|
73 |
+
with gr.Column(scale=1):
|
74 |
+
with gr.Row(equal_height=False):
|
75 |
+
with gr.Accordion("Advanced Settings", open=False):
|
76 |
+
num_denoising_steps = gr.Slider(
|
77 |
+
label="num denoising steps",
|
78 |
+
minimum=1,
|
79 |
+
maximum=25,
|
80 |
+
value=25,
|
81 |
+
step=1,
|
82 |
+
)
|
83 |
+
guidance_scale = gr.Slider(
|
84 |
+
label="cfg scale",
|
85 |
+
minimum=1.0,
|
86 |
+
maximum=1.2,
|
87 |
+
value=1.2,
|
88 |
+
step=0.1,
|
89 |
+
)
|
90 |
+
max_res = gr.Slider(
|
91 |
+
label="max resolution",
|
92 |
+
minimum=512,
|
93 |
+
maximum=2048,
|
94 |
+
value=1024,
|
95 |
+
step=64,
|
96 |
+
)
|
97 |
+
process_length = gr.Slider(
|
98 |
+
label="process length",
|
99 |
+
minimum=1,
|
100 |
+
maximum=280,
|
101 |
+
value=195,
|
102 |
+
step=1,
|
103 |
+
)
|
104 |
+
generate_btn = gr.Button("Generate")
|
105 |
+
with gr.Column(scale=2):
|
106 |
+
pass
|
107 |
+
|
108 |
+
gr.Examples(
|
109 |
+
examples=examples,
|
110 |
+
inputs=[
|
111 |
+
input_video,
|
112 |
+
num_denoising_steps,
|
113 |
+
guidance_scale,
|
114 |
+
max_res,
|
115 |
+
process_length,
|
116 |
+
],
|
117 |
+
outputs=[output_video_1, output_video_2],
|
118 |
+
fn=depthcrafter_demo.run,
|
119 |
+
cache_examples=False,
|
120 |
+
)
|
121 |
+
|
122 |
+
generate_btn.click(
|
123 |
+
fn=depthcrafter_demo.run,
|
124 |
+
inputs=[
|
125 |
+
input_video,
|
126 |
+
num_denoising_steps,
|
127 |
+
guidance_scale,
|
128 |
+
max_res,
|
129 |
+
process_length,
|
130 |
+
],
|
131 |
+
outputs=[output_video_1, output_video_2],
|
132 |
+
)
|
133 |
+
|
134 |
+
return depthcrafter_iface
|
135 |
+
|
136 |
+
|
137 |
+
demo = construct_demo()
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
demo.queue()
|
141 |
+
demo.launch(server_name="0.0.0.0", server_port=80, debug=True)
|
depthcrafter/__init__.py
ADDED
File without changes
|
depthcrafter/depth_crafter_ppl.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
7 |
+
_resize_with_antialiasing,
|
8 |
+
StableVideoDiffusionPipelineOutput,
|
9 |
+
StableVideoDiffusionPipeline,
|
10 |
+
retrieve_timesteps,
|
11 |
+
)
|
12 |
+
from diffusers.utils import logging
|
13 |
+
from diffusers.utils.torch_utils import randn_tensor
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
16 |
+
|
17 |
+
|
18 |
+
class DepthCrafterPipeline(StableVideoDiffusionPipeline):
|
19 |
+
|
20 |
+
@torch.inference_mode()
|
21 |
+
def encode_video(
|
22 |
+
self,
|
23 |
+
video: torch.Tensor,
|
24 |
+
chunk_size: int = 14,
|
25 |
+
) -> torch.Tensor:
|
26 |
+
"""
|
27 |
+
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
28 |
+
:param chunk_size: the chunk size to encode video
|
29 |
+
:return: image_embeddings in shape of [b, 1024]
|
30 |
+
"""
|
31 |
+
|
32 |
+
video_224 = _resize_with_antialiasing(video.float(), (224, 224))
|
33 |
+
video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
34 |
+
|
35 |
+
embeddings = []
|
36 |
+
for i in range(0, video_224.shape[0], chunk_size):
|
37 |
+
tmp = self.feature_extractor(
|
38 |
+
images=video_224[i : i + chunk_size],
|
39 |
+
do_normalize=True,
|
40 |
+
do_center_crop=False,
|
41 |
+
do_resize=False,
|
42 |
+
do_rescale=False,
|
43 |
+
return_tensors="pt",
|
44 |
+
).pixel_values.to(video.device, dtype=video.dtype)
|
45 |
+
embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024]
|
46 |
+
|
47 |
+
embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
|
48 |
+
return embeddings
|
49 |
+
|
50 |
+
@torch.inference_mode()
|
51 |
+
def encode_vae_video(
|
52 |
+
self,
|
53 |
+
video: torch.Tensor,
|
54 |
+
chunk_size: int = 14,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
58 |
+
:param chunk_size: the chunk size to encode video
|
59 |
+
:return: vae latents in shape of [b, c, h, w]
|
60 |
+
"""
|
61 |
+
video_latents = []
|
62 |
+
for i in range(0, video.shape[0], chunk_size):
|
63 |
+
video_latents.append(
|
64 |
+
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
|
65 |
+
)
|
66 |
+
video_latents = torch.cat(video_latents, dim=0)
|
67 |
+
return video_latents
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def check_inputs(video, height, width):
|
71 |
+
"""
|
72 |
+
:param video:
|
73 |
+
:param height:
|
74 |
+
:param width:
|
75 |
+
:return:
|
76 |
+
"""
|
77 |
+
if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray):
|
78 |
+
raise ValueError(
|
79 |
+
f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}"
|
80 |
+
)
|
81 |
+
|
82 |
+
if height % 8 != 0 or width % 8 != 0:
|
83 |
+
raise ValueError(
|
84 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
85 |
+
)
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def __call__(
|
89 |
+
self,
|
90 |
+
video: Union[np.ndarray, torch.Tensor],
|
91 |
+
height: int = 576,
|
92 |
+
width: int = 1024,
|
93 |
+
num_inference_steps: int = 25,
|
94 |
+
guidance_scale: float = 1.0,
|
95 |
+
window_size: Optional[int] = 110,
|
96 |
+
noise_aug_strength: float = 0.02,
|
97 |
+
decode_chunk_size: Optional[int] = None,
|
98 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
99 |
+
latents: Optional[torch.FloatTensor] = None,
|
100 |
+
output_type: Optional[str] = "pil",
|
101 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
102 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
103 |
+
return_dict: bool = True,
|
104 |
+
overlap: int = 25,
|
105 |
+
track_time: bool = False,
|
106 |
+
):
|
107 |
+
"""
|
108 |
+
:param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
|
109 |
+
:param height:
|
110 |
+
:param width:
|
111 |
+
:param num_inference_steps:
|
112 |
+
:param guidance_scale:
|
113 |
+
:param window_size: sliding window processing size
|
114 |
+
:param fps:
|
115 |
+
:param motion_bucket_id:
|
116 |
+
:param noise_aug_strength:
|
117 |
+
:param decode_chunk_size:
|
118 |
+
:param generator:
|
119 |
+
:param latents:
|
120 |
+
:param output_type:
|
121 |
+
:param callback_on_step_end:
|
122 |
+
:param callback_on_step_end_tensor_inputs:
|
123 |
+
:param return_dict:
|
124 |
+
:return:
|
125 |
+
"""
|
126 |
+
# 0. Default height and width to unet
|
127 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
128 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
129 |
+
num_frames = video.shape[0]
|
130 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
|
131 |
+
if num_frames <= window_size:
|
132 |
+
window_size = num_frames
|
133 |
+
overlap = 0
|
134 |
+
stride = window_size - overlap
|
135 |
+
|
136 |
+
# 1. Check inputs. Raise error if not correct
|
137 |
+
self.check_inputs(video, height, width)
|
138 |
+
|
139 |
+
# 2. Define call parameters
|
140 |
+
batch_size = 1
|
141 |
+
device = self._execution_device
|
142 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
143 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
144 |
+
# corresponds to doing no classifier free guidance.
|
145 |
+
self._guidance_scale = guidance_scale
|
146 |
+
|
147 |
+
# 3. Encode input video
|
148 |
+
if isinstance(video, np.ndarray):
|
149 |
+
video = torch.from_numpy(video.transpose(0, 3, 1, 2))
|
150 |
+
else:
|
151 |
+
assert isinstance(video, torch.Tensor)
|
152 |
+
video = video.to(device=device, dtype=self.dtype)
|
153 |
+
video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
|
154 |
+
|
155 |
+
if track_time:
|
156 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
157 |
+
encode_event = torch.cuda.Event(enable_timing=True)
|
158 |
+
denoise_event = torch.cuda.Event(enable_timing=True)
|
159 |
+
decode_event = torch.cuda.Event(enable_timing=True)
|
160 |
+
start_event.record()
|
161 |
+
|
162 |
+
video_embeddings = self.encode_video(
|
163 |
+
video, chunk_size=decode_chunk_size
|
164 |
+
).unsqueeze(
|
165 |
+
0
|
166 |
+
) # [1, t, 1024]
|
167 |
+
torch.cuda.empty_cache()
|
168 |
+
# 4. Encode input image using VAE
|
169 |
+
noise = randn_tensor(
|
170 |
+
video.shape, generator=generator, device=device, dtype=video.dtype
|
171 |
+
)
|
172 |
+
video = video + noise_aug_strength * noise # in [t, c, h, w]
|
173 |
+
|
174 |
+
# pdb.set_trace()
|
175 |
+
needs_upcasting = (
|
176 |
+
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
177 |
+
)
|
178 |
+
if needs_upcasting:
|
179 |
+
self.vae.to(dtype=torch.float32)
|
180 |
+
|
181 |
+
video_latents = self.encode_vae_video(
|
182 |
+
video.to(self.vae.dtype),
|
183 |
+
chunk_size=decode_chunk_size,
|
184 |
+
).unsqueeze(
|
185 |
+
0
|
186 |
+
) # [1, t, c, h, w]
|
187 |
+
|
188 |
+
if track_time:
|
189 |
+
encode_event.record()
|
190 |
+
torch.cuda.synchronize()
|
191 |
+
elapsed_time_ms = start_event.elapsed_time(encode_event)
|
192 |
+
print(f"Elapsed time for encoding video: {elapsed_time_ms} ms")
|
193 |
+
|
194 |
+
torch.cuda.empty_cache()
|
195 |
+
|
196 |
+
# cast back to fp16 if needed
|
197 |
+
if needs_upcasting:
|
198 |
+
self.vae.to(dtype=torch.float16)
|
199 |
+
|
200 |
+
# 5. Get Added Time IDs
|
201 |
+
added_time_ids = self._get_add_time_ids(
|
202 |
+
7,
|
203 |
+
127,
|
204 |
+
noise_aug_strength,
|
205 |
+
video_embeddings.dtype,
|
206 |
+
batch_size,
|
207 |
+
1,
|
208 |
+
False,
|
209 |
+
) # [1 or 2, 3]
|
210 |
+
added_time_ids = added_time_ids.to(device)
|
211 |
+
|
212 |
+
# 6. Prepare timesteps
|
213 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
214 |
+
self.scheduler, num_inference_steps, device, None, None
|
215 |
+
)
|
216 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
217 |
+
self._num_timesteps = len(timesteps)
|
218 |
+
|
219 |
+
# 7. Prepare latent variables
|
220 |
+
num_channels_latents = self.unet.config.in_channels
|
221 |
+
latents_init = self.prepare_latents(
|
222 |
+
batch_size,
|
223 |
+
window_size,
|
224 |
+
num_channels_latents,
|
225 |
+
height,
|
226 |
+
width,
|
227 |
+
video_embeddings.dtype,
|
228 |
+
device,
|
229 |
+
generator,
|
230 |
+
latents,
|
231 |
+
) # [1, t, c, h, w]
|
232 |
+
latents_all = None
|
233 |
+
|
234 |
+
idx_start = 0
|
235 |
+
if overlap > 0:
|
236 |
+
weights = torch.linspace(0, 1, overlap, device=device)
|
237 |
+
weights = weights.view(1, overlap, 1, 1, 1)
|
238 |
+
else:
|
239 |
+
weights = None
|
240 |
+
|
241 |
+
torch.cuda.empty_cache()
|
242 |
+
|
243 |
+
# inference strategy for long videos
|
244 |
+
# two main strategies: 1. noise init from previous frame, 2. segments stitching
|
245 |
+
while idx_start < num_frames - overlap:
|
246 |
+
idx_end = min(idx_start + window_size, num_frames)
|
247 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
248 |
+
|
249 |
+
# 9. Denoising loop
|
250 |
+
latents = latents_init[:, : idx_end - idx_start].clone()
|
251 |
+
latents_init = torch.cat(
|
252 |
+
[latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
|
253 |
+
)
|
254 |
+
|
255 |
+
video_latents_current = video_latents[:, idx_start:idx_end]
|
256 |
+
video_embeddings_current = video_embeddings[:, idx_start:idx_end]
|
257 |
+
|
258 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
259 |
+
for i, t in enumerate(timesteps):
|
260 |
+
if latents_all is not None and i == 0:
|
261 |
+
latents[:, :overlap] = (
|
262 |
+
latents_all[:, -overlap:]
|
263 |
+
+ latents[:, :overlap]
|
264 |
+
/ self.scheduler.init_noise_sigma
|
265 |
+
* self.scheduler.sigmas[i]
|
266 |
+
)
|
267 |
+
|
268 |
+
latent_model_input = latents # [1, t, c, h, w]
|
269 |
+
latent_model_input = self.scheduler.scale_model_input(
|
270 |
+
latent_model_input, t
|
271 |
+
) # [1, t, c, h, w]
|
272 |
+
latent_model_input = torch.cat(
|
273 |
+
[latent_model_input, video_latents_current], dim=2
|
274 |
+
)
|
275 |
+
noise_pred = self.unet(
|
276 |
+
latent_model_input,
|
277 |
+
t,
|
278 |
+
encoder_hidden_states=video_embeddings_current,
|
279 |
+
added_time_ids=added_time_ids,
|
280 |
+
return_dict=False,
|
281 |
+
)[0]
|
282 |
+
# perform guidance
|
283 |
+
if self.do_classifier_free_guidance:
|
284 |
+
latent_model_input = latents
|
285 |
+
latent_model_input = self.scheduler.scale_model_input(
|
286 |
+
latent_model_input, t
|
287 |
+
)
|
288 |
+
latent_model_input = torch.cat(
|
289 |
+
[latent_model_input, torch.zeros_like(latent_model_input)],
|
290 |
+
dim=2,
|
291 |
+
)
|
292 |
+
noise_pred_uncond = self.unet(
|
293 |
+
latent_model_input,
|
294 |
+
t,
|
295 |
+
encoder_hidden_states=torch.zeros_like(
|
296 |
+
video_embeddings_current
|
297 |
+
),
|
298 |
+
added_time_ids=added_time_ids,
|
299 |
+
return_dict=False,
|
300 |
+
)[0]
|
301 |
+
|
302 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
303 |
+
noise_pred - noise_pred_uncond
|
304 |
+
)
|
305 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
306 |
+
|
307 |
+
if callback_on_step_end is not None:
|
308 |
+
callback_kwargs = {}
|
309 |
+
for k in callback_on_step_end_tensor_inputs:
|
310 |
+
callback_kwargs[k] = locals()[k]
|
311 |
+
callback_outputs = callback_on_step_end(
|
312 |
+
self, i, t, callback_kwargs
|
313 |
+
)
|
314 |
+
|
315 |
+
latents = callback_outputs.pop("latents", latents)
|
316 |
+
|
317 |
+
if i == len(timesteps) - 1 or (
|
318 |
+
(i + 1) > num_warmup_steps
|
319 |
+
and (i + 1) % self.scheduler.order == 0
|
320 |
+
):
|
321 |
+
progress_bar.update()
|
322 |
+
|
323 |
+
if latents_all is None:
|
324 |
+
latents_all = latents.clone()
|
325 |
+
else:
|
326 |
+
assert weights is not None
|
327 |
+
# latents_all[:, -overlap:] = (
|
328 |
+
# latents[:, :overlap] + latents_all[:, -overlap:]
|
329 |
+
# ) / 2.0
|
330 |
+
latents_all[:, -overlap:] = latents[
|
331 |
+
:, :overlap
|
332 |
+
] * weights + latents_all[:, -overlap:] * (1 - weights)
|
333 |
+
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
|
334 |
+
|
335 |
+
idx_start += stride
|
336 |
+
|
337 |
+
if track_time:
|
338 |
+
denoise_event.record()
|
339 |
+
torch.cuda.synchronize()
|
340 |
+
elapsed_time_ms = encode_event.elapsed_time(denoise_event)
|
341 |
+
print(f"Elapsed time for denoising video: {elapsed_time_ms} ms")
|
342 |
+
|
343 |
+
if not output_type == "latent":
|
344 |
+
# cast back to fp16 if needed
|
345 |
+
if needs_upcasting:
|
346 |
+
self.vae.to(dtype=torch.float16)
|
347 |
+
frames = self.decode_latents(latents_all, num_frames, decode_chunk_size)
|
348 |
+
|
349 |
+
if track_time:
|
350 |
+
decode_event.record()
|
351 |
+
torch.cuda.synchronize()
|
352 |
+
elapsed_time_ms = denoise_event.elapsed_time(decode_event)
|
353 |
+
print(f"Elapsed time for decoding video: {elapsed_time_ms} ms")
|
354 |
+
|
355 |
+
frames = self.video_processor.postprocess_video(
|
356 |
+
video=frames, output_type=output_type
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
frames = latents_all
|
360 |
+
|
361 |
+
self.maybe_free_model_hooks()
|
362 |
+
|
363 |
+
if not return_dict:
|
364 |
+
return frames
|
365 |
+
|
366 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
depthcrafter/unet.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import UNetSpatioTemporalConditionModel
|
5 |
+
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
|
6 |
+
|
7 |
+
|
8 |
+
class DiffusersUNetSpatioTemporalConditionModelDepthCrafter(
|
9 |
+
UNetSpatioTemporalConditionModel
|
10 |
+
):
|
11 |
+
|
12 |
+
def forward(
|
13 |
+
self,
|
14 |
+
sample: torch.Tensor,
|
15 |
+
timestep: Union[torch.Tensor, float, int],
|
16 |
+
encoder_hidden_states: torch.Tensor,
|
17 |
+
added_time_ids: torch.Tensor,
|
18 |
+
return_dict: bool = True,
|
19 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
20 |
+
|
21 |
+
# 1. time
|
22 |
+
timesteps = timestep
|
23 |
+
if not torch.is_tensor(timesteps):
|
24 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
25 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
26 |
+
is_mps = sample.device.type == "mps"
|
27 |
+
if isinstance(timestep, float):
|
28 |
+
dtype = torch.float32 if is_mps else torch.float64
|
29 |
+
else:
|
30 |
+
dtype = torch.int32 if is_mps else torch.int64
|
31 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
32 |
+
elif len(timesteps.shape) == 0:
|
33 |
+
timesteps = timesteps[None].to(sample.device)
|
34 |
+
|
35 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
36 |
+
batch_size, num_frames = sample.shape[:2]
|
37 |
+
timesteps = timesteps.expand(batch_size)
|
38 |
+
|
39 |
+
t_emb = self.time_proj(timesteps)
|
40 |
+
|
41 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
42 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
43 |
+
# there might be better ways to encapsulate this.
|
44 |
+
t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
|
45 |
+
|
46 |
+
emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
|
47 |
+
|
48 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
49 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
50 |
+
time_embeds = time_embeds.to(emb.dtype)
|
51 |
+
aug_emb = self.add_embedding(time_embeds)
|
52 |
+
emb = emb + aug_emb
|
53 |
+
|
54 |
+
# Flatten the batch and frames dimensions
|
55 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
56 |
+
sample = sample.flatten(0, 1)
|
57 |
+
# Repeat the embeddings num_video_frames times
|
58 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
59 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
60 |
+
# encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
|
61 |
+
encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
|
62 |
+
|
63 |
+
# 2. pre-process
|
64 |
+
sample = sample.to(dtype=self.conv_in.weight.dtype)
|
65 |
+
assert sample.dtype == self.conv_in.weight.dtype, (
|
66 |
+
f"sample.dtype: {sample.dtype}, "
|
67 |
+
f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
|
68 |
+
)
|
69 |
+
sample = self.conv_in(sample)
|
70 |
+
|
71 |
+
image_only_indicator = torch.zeros(
|
72 |
+
batch_size, num_frames, dtype=sample.dtype, device=sample.device
|
73 |
+
)
|
74 |
+
|
75 |
+
down_block_res_samples = (sample,)
|
76 |
+
for downsample_block in self.down_blocks:
|
77 |
+
if (
|
78 |
+
hasattr(downsample_block, "has_cross_attention")
|
79 |
+
and downsample_block.has_cross_attention
|
80 |
+
):
|
81 |
+
sample, res_samples = downsample_block(
|
82 |
+
hidden_states=sample,
|
83 |
+
temb=emb,
|
84 |
+
encoder_hidden_states=encoder_hidden_states,
|
85 |
+
image_only_indicator=image_only_indicator,
|
86 |
+
)
|
87 |
+
|
88 |
+
else:
|
89 |
+
sample, res_samples = downsample_block(
|
90 |
+
hidden_states=sample,
|
91 |
+
temb=emb,
|
92 |
+
image_only_indicator=image_only_indicator,
|
93 |
+
)
|
94 |
+
|
95 |
+
down_block_res_samples += res_samples
|
96 |
+
|
97 |
+
# 4. mid
|
98 |
+
sample = self.mid_block(
|
99 |
+
hidden_states=sample,
|
100 |
+
temb=emb,
|
101 |
+
encoder_hidden_states=encoder_hidden_states,
|
102 |
+
image_only_indicator=image_only_indicator,
|
103 |
+
)
|
104 |
+
|
105 |
+
# 5. up
|
106 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
107 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
108 |
+
down_block_res_samples = down_block_res_samples[
|
109 |
+
: -len(upsample_block.resnets)
|
110 |
+
]
|
111 |
+
|
112 |
+
if (
|
113 |
+
hasattr(upsample_block, "has_cross_attention")
|
114 |
+
and upsample_block.has_cross_attention
|
115 |
+
):
|
116 |
+
sample = upsample_block(
|
117 |
+
hidden_states=sample,
|
118 |
+
res_hidden_states_tuple=res_samples,
|
119 |
+
temb=emb,
|
120 |
+
encoder_hidden_states=encoder_hidden_states,
|
121 |
+
image_only_indicator=image_only_indicator,
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
sample = upsample_block(
|
125 |
+
hidden_states=sample,
|
126 |
+
res_hidden_states_tuple=res_samples,
|
127 |
+
temb=emb,
|
128 |
+
image_only_indicator=image_only_indicator,
|
129 |
+
)
|
130 |
+
|
131 |
+
# 6. post-process
|
132 |
+
sample = self.conv_norm_out(sample)
|
133 |
+
sample = self.conv_act(sample)
|
134 |
+
sample = self.conv_out(sample)
|
135 |
+
|
136 |
+
# 7. Reshape back to original shape
|
137 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
138 |
+
|
139 |
+
if not return_dict:
|
140 |
+
return (sample,)
|
141 |
+
|
142 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
depthcrafter/utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import matplotlib.cm as cm
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def read_video_frames(video_path, process_length, target_fps, max_res):
|
8 |
+
# a simple function to read video frames
|
9 |
+
cap = cv2.VideoCapture(video_path)
|
10 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
11 |
+
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
12 |
+
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
13 |
+
# round the height and width to the nearest multiple of 64
|
14 |
+
height = round(original_height / 64) * 64
|
15 |
+
width = round(original_width / 64) * 64
|
16 |
+
|
17 |
+
# resize the video if the height or width is larger than max_res
|
18 |
+
if max(height, width) > max_res:
|
19 |
+
scale = max_res / max(original_height, original_width)
|
20 |
+
height = round(original_height * scale / 64) * 64
|
21 |
+
width = round(original_width * scale / 64) * 64
|
22 |
+
|
23 |
+
if target_fps < 0:
|
24 |
+
target_fps = original_fps
|
25 |
+
|
26 |
+
stride = max(round(original_fps / target_fps), 1)
|
27 |
+
|
28 |
+
frames = []
|
29 |
+
frame_count = 0
|
30 |
+
while cap.isOpened():
|
31 |
+
ret, frame = cap.read()
|
32 |
+
if not ret or (process_length > 0 and frame_count >= process_length):
|
33 |
+
break
|
34 |
+
if frame_count % stride == 0:
|
35 |
+
frame = cv2.resize(frame, (width, height))
|
36 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
|
37 |
+
frames.append(frame.astype("float32") / 255.0)
|
38 |
+
frame_count += 1
|
39 |
+
cap.release()
|
40 |
+
|
41 |
+
frames = np.array(frames)
|
42 |
+
return frames, target_fps
|
43 |
+
|
44 |
+
|
45 |
+
def save_video(
|
46 |
+
video_frames,
|
47 |
+
output_video_path,
|
48 |
+
fps: int = 15,
|
49 |
+
) -> str:
|
50 |
+
# a simple function to save video frames
|
51 |
+
height, width = video_frames[0].shape[:2]
|
52 |
+
is_color = video_frames[0].ndim == 3
|
53 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
54 |
+
video_writer = cv2.VideoWriter(
|
55 |
+
output_video_path, fourcc, fps, (width, height), isColor=is_color
|
56 |
+
)
|
57 |
+
|
58 |
+
for frame in video_frames:
|
59 |
+
frame = (frame * 255).astype(np.uint8)
|
60 |
+
if is_color:
|
61 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
62 |
+
video_writer.write(frame)
|
63 |
+
|
64 |
+
video_writer.release()
|
65 |
+
return output_video_path
|
66 |
+
|
67 |
+
|
68 |
+
class ColorMapper:
|
69 |
+
# a color mapper to map depth values to a certain colormap
|
70 |
+
def __init__(self, colormap: str = "inferno"):
|
71 |
+
self.colormap = torch.tensor(cm.get_cmap(colormap).colors)
|
72 |
+
|
73 |
+
def apply(self, image: torch.Tensor, v_min=None, v_max=None):
|
74 |
+
# assert len(image.shape) == 2
|
75 |
+
if v_min is None:
|
76 |
+
v_min = image.min()
|
77 |
+
if v_max is None:
|
78 |
+
v_max = image.max()
|
79 |
+
image = (image - v_min) / (v_max - v_min)
|
80 |
+
image = (image * 255).long()
|
81 |
+
image = self.colormap[image]
|
82 |
+
return image
|
83 |
+
|
84 |
+
|
85 |
+
def vis_sequence_depth(depths: np.ndarray, v_min=None, v_max=None):
|
86 |
+
visualizer = ColorMapper()
|
87 |
+
if v_min is None:
|
88 |
+
v_min = depths.min()
|
89 |
+
if v_max is None:
|
90 |
+
v_max = depths.max()
|
91 |
+
res = visualizer.apply(torch.tensor(depths), v_min=v_min, v_max=v_max).numpy()
|
92 |
+
return res
|
examples/example_01.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:afb78decc210225793b20d5bca5b13da07c97233e6fabea44bf02eba8a52bdaf
|
3 |
+
size 14393250
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.0+cu117
|
2 |
+
diffusers==0.29.1
|
3 |
+
numpy==1.26.4
|
4 |
+
matplotlib==3.8.4
|
5 |
+
opencv-python==4.8.1.78
|
run.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
from diffusers.training_utils import set_seed
|
7 |
+
|
8 |
+
from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
|
9 |
+
from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
|
10 |
+
from depthcrafter.utils import vis_sequence_depth, save_video, read_video_frames
|
11 |
+
|
12 |
+
|
13 |
+
class DepthCrafterDemo:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
unet_path: str,
|
17 |
+
pre_train_path: str,
|
18 |
+
cpu_offload: str = "model",
|
19 |
+
):
|
20 |
+
unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained(
|
21 |
+
unet_path,
|
22 |
+
subfolder="unet",
|
23 |
+
low_cpu_mem_usage=True,
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
)
|
26 |
+
# load weights of other components from the provided checkpoint
|
27 |
+
self.pipe = DepthCrafterPipeline.from_pretrained(
|
28 |
+
pre_train_path,
|
29 |
+
unet=unet,
|
30 |
+
torch_dtype=torch.float16,
|
31 |
+
variant="fp16",
|
32 |
+
)
|
33 |
+
|
34 |
+
# for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
|
35 |
+
if cpu_offload is not None:
|
36 |
+
if cpu_offload == "sequential":
|
37 |
+
# This will slow, but save more memory
|
38 |
+
self.pipe.enable_sequential_cpu_offload()
|
39 |
+
elif cpu_offload == "model":
|
40 |
+
self.pipe.enable_model_cpu_offload()
|
41 |
+
else:
|
42 |
+
raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
|
43 |
+
else:
|
44 |
+
self.pipe.to("cuda")
|
45 |
+
# enable attention slicing and xformers memory efficient attention
|
46 |
+
try:
|
47 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
48 |
+
except Exception as e:
|
49 |
+
print(e)
|
50 |
+
print("Xformers is not enabled")
|
51 |
+
self.pipe.enable_attention_slicing()
|
52 |
+
|
53 |
+
def infer(
|
54 |
+
self,
|
55 |
+
video: str,
|
56 |
+
num_denoising_steps: int,
|
57 |
+
guidance_scale: float,
|
58 |
+
save_folder: str = "./demo_output",
|
59 |
+
window_size: int = 110,
|
60 |
+
process_length: int = 195,
|
61 |
+
overlap: int = 25,
|
62 |
+
max_res: int = 1024,
|
63 |
+
target_fps: int = 15,
|
64 |
+
seed: int = 42,
|
65 |
+
track_time: bool = True,
|
66 |
+
save_npz: bool = False,
|
67 |
+
):
|
68 |
+
set_seed(seed)
|
69 |
+
|
70 |
+
frames, target_fps = read_video_frames(
|
71 |
+
video, process_length, target_fps, max_res
|
72 |
+
)
|
73 |
+
print(f"==> video name: {video}, frames shape: {frames.shape}")
|
74 |
+
|
75 |
+
# inference the depth map using the DepthCrafter pipeline
|
76 |
+
with torch.inference_mode():
|
77 |
+
res = self.pipe(
|
78 |
+
frames,
|
79 |
+
height=frames.shape[1],
|
80 |
+
width=frames.shape[2],
|
81 |
+
output_type="np",
|
82 |
+
guidance_scale=guidance_scale,
|
83 |
+
num_inference_steps=num_denoising_steps,
|
84 |
+
window_size=window_size,
|
85 |
+
overlap=overlap,
|
86 |
+
track_time=track_time,
|
87 |
+
).frames[0]
|
88 |
+
# convert the three-channel output to a single channel depth map
|
89 |
+
res = res.sum(-1) / res.shape[-1]
|
90 |
+
# normalize the depth map to [0, 1] across the whole video
|
91 |
+
res = (res - res.min()) / (res.max() - res.min())
|
92 |
+
# visualize the depth map and save the results
|
93 |
+
vis = vis_sequence_depth(res)
|
94 |
+
# save the depth map and visualization with the target FPS
|
95 |
+
save_path = os.path.join(
|
96 |
+
save_folder, os.path.splitext(os.path.basename(video))[0]
|
97 |
+
)
|
98 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
99 |
+
if save_npz:
|
100 |
+
np.savez_compressed(save_path + ".npz", depth=res)
|
101 |
+
save_video(res, save_path + "_depth.mp4", fps=target_fps)
|
102 |
+
save_video(vis, save_path + "_vis.mp4", fps=target_fps)
|
103 |
+
save_video(frames, save_path + "_input.mp4", fps=target_fps)
|
104 |
+
return [
|
105 |
+
save_path + "_input.mp4",
|
106 |
+
save_path + "_vis.mp4",
|
107 |
+
save_path + "_depth.mp4",
|
108 |
+
]
|
109 |
+
|
110 |
+
def run(
|
111 |
+
self,
|
112 |
+
input_video,
|
113 |
+
num_denoising_steps,
|
114 |
+
guidance_scale,
|
115 |
+
max_res=1024,
|
116 |
+
process_length=195,
|
117 |
+
):
|
118 |
+
res_path = self.infer(
|
119 |
+
input_video,
|
120 |
+
num_denoising_steps,
|
121 |
+
guidance_scale,
|
122 |
+
max_res=max_res,
|
123 |
+
process_length=process_length,
|
124 |
+
)
|
125 |
+
# clear the cache for the next video
|
126 |
+
gc.collect()
|
127 |
+
torch.cuda.empty_cache()
|
128 |
+
return res_path[:2]
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
# running configs
|
133 |
+
# the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size`
|
134 |
+
# the most important arguments for trade-off between quality and speed are
|
135 |
+
# `num_inference_steps`, `guidance_scale`, and `max_res`
|
136 |
+
parser = argparse.ArgumentParser(description="DepthCrafter")
|
137 |
+
parser.add_argument(
|
138 |
+
"--video-path", type=str, required=True, help="Path to the input video file(s)"
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--save-folder",
|
142 |
+
type=str,
|
143 |
+
default="./demo_output",
|
144 |
+
help="Folder to save the output",
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--unet-path",
|
148 |
+
type=str,
|
149 |
+
default="tencent/DepthCrafter",
|
150 |
+
help="Path to the UNet model",
|
151 |
+
)
|
152 |
+
parser.add_argument(
|
153 |
+
"--pre-train-path",
|
154 |
+
type=str,
|
155 |
+
default="stabilityai/stable-video-diffusion-img2vid-xt",
|
156 |
+
help="Path to the pre-trained model",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--process-length", type=int, default=195, help="Number of frames to process"
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--cpu-offload",
|
163 |
+
type=str,
|
164 |
+
default="model",
|
165 |
+
choices=["model", "sequential", None],
|
166 |
+
help="CPU offload option",
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--target-fps", type=int, default=15, help="Target FPS for the output video"
|
170 |
+
) # -1 for original fps
|
171 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
172 |
+
parser.add_argument(
|
173 |
+
"--num-inference-steps", type=int, default=25, help="Number of inference steps"
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--guidance-scale", type=float, default=1.2, help="Guidance scale"
|
177 |
+
)
|
178 |
+
parser.add_argument("--window-size", type=int, default=110, help="Window size")
|
179 |
+
parser.add_argument("--overlap", type=int, default=25, help="Overlap size")
|
180 |
+
parser.add_argument("--max-res", type=int, default=1024, help="Maximum resolution")
|
181 |
+
parser.add_argument("--save_npz", type=bool, default=True, help="Save npz file")
|
182 |
+
parser.add_argument("--track_time", type=bool, default=False, help="Track time")
|
183 |
+
|
184 |
+
args = parser.parse_args()
|
185 |
+
|
186 |
+
depthcrafter_demo = DepthCrafterDemo(
|
187 |
+
unet_path=args.unet_path,
|
188 |
+
pre_train_path=args.pre_train_path,
|
189 |
+
cpu_offload=args.cpu_offload,
|
190 |
+
)
|
191 |
+
# process the videos, the video paths are separated by comma
|
192 |
+
video_paths = args.video_path.split(",")
|
193 |
+
for video in video_paths:
|
194 |
+
depthcrafter_demo.infer(
|
195 |
+
video,
|
196 |
+
args.num_inference_steps,
|
197 |
+
args.guidance_scale,
|
198 |
+
save_folder=args.save_folder,
|
199 |
+
window_size=args.window_size,
|
200 |
+
process_length=args.process_length,
|
201 |
+
overlap=args.overlap,
|
202 |
+
max_res=args.max_res,
|
203 |
+
target_fps=args.target_fps,
|
204 |
+
seed=args.seed,
|
205 |
+
track_time=args.track_time,
|
206 |
+
save_npz=args.save_npz,
|
207 |
+
)
|
208 |
+
# clear the cache for the next video
|
209 |
+
gc.collect()
|
210 |
+
torch.cuda.empty_cache()
|