Spaces:
Running
on
A10G
Running
on
A10G
nick_93
commited on
Commit
•
bcec54e
1
Parent(s):
2fb7fdb
init
Browse files- .gitignore +166 -0
- LICENSE +21 -0
- app.py +103 -0
- depth/README.md +63 -0
- depth/configs/base_options.py +56 -0
- depth/configs/test_options.py +27 -0
- depth/configs/train_options.py +50 -0
- depth/inference.py +53 -0
- depth/models_depth/attractor.py +208 -0
- depth/models_depth/checkpoint.py +608 -0
- depth/models_depth/dist_layers.py +121 -0
- depth/models_depth/layers.py +36 -0
- depth/models_depth/localbins_layers.py +169 -0
- depth/models_depth/miniViT.py +45 -0
- depth/models_depth/model.py +666 -0
- depth/models_depth/model_vpd.py +252 -0
- depth/models_depth/optimizer.py +154 -0
- depth/requirements.txt +8 -0
- depth/test_img.jpg +0 -0
- depth/utils.py +525 -0
- depth/utils_depth/criterion.py +22 -0
- depth/utils_depth/logging.py +161 -0
- depth/utils_depth/metrics.py +79 -0
- depth/utils_depth/misc.py +73 -0
- depth/v1-inference.yaml +70 -0
- evp/__init__.py +1 -0
- evp/models.py +349 -0
- refer/README.md +78 -0
- refer/args.py +42 -0
- refer/inference.py +60 -0
- refer/models_refer/__init__.py +1 -0
- refer/models_refer/model.py +301 -0
- refer/requirements.txt +12 -0
- refer/transforms.py +126 -0
- refer/utils.py +222 -0
- refer/v1-inference.yaml +70 -0
.gitignore
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
*.ckpt
|
7 |
+
*.pth
|
8 |
+
refer/refer/data/
|
9 |
+
depth/kitti_dataset/
|
10 |
+
depth/nyu_depth_v2/
|
11 |
+
|
12 |
+
# C extensions
|
13 |
+
.so
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
share/python-wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
MANIFEST
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
cover/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
.pybuilder/
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# IPython
|
88 |
+
profile_default/
|
89 |
+
ipython_config.py
|
90 |
+
|
91 |
+
# pyenv
|
92 |
+
# For a library or package, you might want to ignore these files since the code is
|
93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
94 |
+
# .python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/#use-with-ide
|
116 |
+
.pdm.toml
|
117 |
+
|
118 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
119 |
+
__pypackages__/
|
120 |
+
|
121 |
+
# Celery stuff
|
122 |
+
celerybeat-schedule
|
123 |
+
celerybeat.pid
|
124 |
+
|
125 |
+
# SageMath parsed files
|
126 |
+
*.sage.py
|
127 |
+
|
128 |
+
# Environments
|
129 |
+
.env
|
130 |
+
.venv
|
131 |
+
env/
|
132 |
+
venv/
|
133 |
+
ENV/
|
134 |
+
env.bak/
|
135 |
+
venv.bak/
|
136 |
+
|
137 |
+
# Spyder project settings
|
138 |
+
.spyderproject
|
139 |
+
.spyproject
|
140 |
+
|
141 |
+
# Rope project settings
|
142 |
+
.ropeproject
|
143 |
+
|
144 |
+
# mkdocs documentation
|
145 |
+
/site
|
146 |
+
|
147 |
+
# mypy
|
148 |
+
.mypy_cache/
|
149 |
+
.dmypy.json
|
150 |
+
dmypy.json
|
151 |
+
|
152 |
+
# Pyre type checker
|
153 |
+
.pyre/
|
154 |
+
|
155 |
+
# pytype static type analyzer
|
156 |
+
.pytype/
|
157 |
+
|
158 |
+
# Cython debug symbols
|
159 |
+
cython_debug/
|
160 |
+
|
161 |
+
# PyCharm
|
162 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
163 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
164 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
165 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
166 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Mykola Lavreniuk
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
depth_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth'))
|
5 |
+
sys.path.append(depth_directory)
|
6 |
+
os.chdir(depth_directory)
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
from depth.models_depth.model import EVPDepth
|
13 |
+
from depth.configs.train_options import TrainOptions
|
14 |
+
from depth.configs.test_options import TestOptions
|
15 |
+
import glob
|
16 |
+
import utils
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
from utils_depth.misc import colorize
|
19 |
+
from PIL import Image
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import gradio as gr
|
22 |
+
import tempfile
|
23 |
+
|
24 |
+
|
25 |
+
css = """
|
26 |
+
#img-display-container {
|
27 |
+
max-height: 50vh;
|
28 |
+
}
|
29 |
+
#img-display-input {
|
30 |
+
max-height: 40vh;
|
31 |
+
}
|
32 |
+
#img-display-output {
|
33 |
+
max-height: 40vh;
|
34 |
+
}
|
35 |
+
|
36 |
+
"""
|
37 |
+
|
38 |
+
def create_demo(model, device):
|
39 |
+
gr.Markdown("### Depth Prediction demo")
|
40 |
+
with gr.Row():
|
41 |
+
input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
|
42 |
+
depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
|
43 |
+
raw_file = gr.File(label="16-bit raw depth, multiplier:256")
|
44 |
+
submit = gr.Button("Submit")
|
45 |
+
|
46 |
+
def on_submit(image):
|
47 |
+
transform = transforms.ToTensor()
|
48 |
+
image = transform(image).unsqueeze(0).to(device)
|
49 |
+
shape = image.shape
|
50 |
+
image = torch.nn.functional.interpolate(image, (440,480), mode='bilinear', align_corners=True)
|
51 |
+
image = F.pad(image, (0, 0, 40, 0))
|
52 |
+
with torch.no_grad():
|
53 |
+
pred = model(image)['pred_d']
|
54 |
+
|
55 |
+
pred = pred[:,:,40:,:]
|
56 |
+
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
|
57 |
+
pred_d_numpy = pred.squeeze().cpu().numpy()
|
58 |
+
colored_depth, _, _ = colorize(pred_d_numpy, cmap='gray_r')
|
59 |
+
|
60 |
+
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
61 |
+
raw_depth = Image.fromarray((pred_d_numpy*256).astype('uint16'))
|
62 |
+
raw_depth.save(tmp.name)
|
63 |
+
return [colored_depth, tmp.name]
|
64 |
+
|
65 |
+
submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
|
66 |
+
examples = gr.Examples(examples=["test_img.jpg"],
|
67 |
+
inputs=[input_image])
|
68 |
+
|
69 |
+
|
70 |
+
def main():
|
71 |
+
opt = TestOptions().initialize()
|
72 |
+
opt.add_argument('--img_path', type=str)
|
73 |
+
args = opt.parse_args()
|
74 |
+
|
75 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
76 |
+
model = EVPDepth(args=args, caption_aggregation=True)
|
77 |
+
cudnn.benchmark = True
|
78 |
+
model.to(device)
|
79 |
+
model_weight = torch.load(args.ckpt_dir)['model']
|
80 |
+
if 'module' in next(iter(model_weight.items()))[0]:
|
81 |
+
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
|
82 |
+
model.load_state_dict(model_weight, strict=False)
|
83 |
+
model.eval()
|
84 |
+
|
85 |
+
title = "# EVP"
|
86 |
+
description = """Official demo for **EVP: Enhanced Visual Perception using Inverse Multi-Attentive Feature
|
87 |
+
Refinement and Regularized Image-Text Alignment**.
|
88 |
+
EVP is a deep learning model for metric depth estimation from a single image.
|
89 |
+
Please refer to our [paper](https://arxiv.org/abs/2312.08548) or [github](https://github.com/Lavreniuk/EVP) for more details."""
|
90 |
+
|
91 |
+
with gr.Blocks() as demo:
|
92 |
+
gr.Markdown(title)
|
93 |
+
gr.Markdown(description)
|
94 |
+
with gr.Tab("Depth Prediction"):
|
95 |
+
create_demo(model, device)
|
96 |
+
gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/shariqfarooq/ZoeDepth?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br>
|
97 |
+
<p><img src="https://visitor-badge.glitch.me/badge?page_id=shariqfarooq.zoedepth_demo_hf" alt="visitors"></p></center>''')
|
98 |
+
|
99 |
+
demo.queue().launch(share=True)
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == '__main__':
|
103 |
+
main()
|
depth/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Depth Estimation
|
2 |
+
## Getting Started
|
3 |
+
|
4 |
+
1. Install the [mmcv-full](https://github.com/open-mmlab/mmcv) library and some required packages.
|
5 |
+
|
6 |
+
```bash
|
7 |
+
pip install openmim
|
8 |
+
mim install mmcv-full
|
9 |
+
pip install -r requirements.txt
|
10 |
+
```
|
11 |
+
|
12 |
+
2. Prepare NYUDepthV2 datasets following [GLPDepth](https://github.com/vinvino02/GLPDepth) and [BTS](https://github.com/cleinc/bts/tree/master).
|
13 |
+
|
14 |
+
```
|
15 |
+
mkdir nyu_depth_v2
|
16 |
+
wget http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat
|
17 |
+
python extract_official_train_test_set_from_mat.py nyu_depth_v2_labeled.mat splits.mat ./nyu_depth_v2/official_splits/
|
18 |
+
```
|
19 |
+
|
20 |
+
Download sync.zip provided by the authors of BTS from this [url](https://drive.google.com/file/d/1AysroWpfISmm-yRFGBgFTrLy6FjQwvwP/view) and unzip in `./nyu_depth_v2` folder.
|
21 |
+
|
22 |
+
Your dataset directory should be:
|
23 |
+
|
24 |
+
```
|
25 |
+
│nyu_depth_v2/
|
26 |
+
├──official_splits/
|
27 |
+
│ ├── test
|
28 |
+
│ ├── train
|
29 |
+
├──sync/
|
30 |
+
```
|
31 |
+
|
32 |
+
## Results and Fine-tuned Models
|
33 |
+
|
34 |
+
EVP obtains 0.224 RMSE on NYUv2 depth estimation benchmark, establishing the new state-of-the-art.
|
35 |
+
|
36 |
+
| | RMSE | d1 | d2 | d3 | REL | log_10 |
|
37 |
+
|---------|-------|-------|--------|------|-------|-------|
|
38 |
+
| **EVP** | 0.224 | 0.976 | 0.997 | 0.999 | 0.061 | 0.027 |
|
39 |
+
|
40 |
+
EVP obtains 0.048 REL and 0.136 SqREL on KITTI depth estimation benchmark, establishing the new state-of-the-art.
|
41 |
+
|
42 |
+
| | REL | SqREL | RMSE | RMSE log | d1 | d2 | d3 |
|
43 |
+
|---------|-------|-------|--------|------|-------|-------|-------|
|
44 |
+
| **EVP** | 0.048 | 0.136 | 2.015 | 0.073 | 0.980 | 0.998 | 1.000 |
|
45 |
+
|
46 |
+
## Training
|
47 |
+
|
48 |
+
Run the following instuction to train the EVP-Depth model.
|
49 |
+
|
50 |
+
```
|
51 |
+
bash train.sh <LOG_DIR>
|
52 |
+
```
|
53 |
+
|
54 |
+
## Evaluation
|
55 |
+
Command format:
|
56 |
+
```
|
57 |
+
bash test.sh <CHECKPOINT_PATH>
|
58 |
+
```
|
59 |
+
|
60 |
+
## Custom inference
|
61 |
+
```
|
62 |
+
PYTHONPATH="../":$PYTHONPATH python inference.py --img_path test_img.jpg --ckpt_dir nyu.ckpt
|
63 |
+
```
|
depth/configs/base_options.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
4 |
+
# Modified by Zigang Geng ([email protected]).
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
|
10 |
+
def str2bool(v):
|
11 |
+
if isinstance(v, bool):
|
12 |
+
return v
|
13 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
14 |
+
return True
|
15 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
16 |
+
return False
|
17 |
+
else:
|
18 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
19 |
+
|
20 |
+
|
21 |
+
class BaseOptions():
|
22 |
+
def __init__(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def initialize(self):
|
26 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
27 |
+
# base configs
|
28 |
+
parser.add_argument('--resume_from', type=str, default='')
|
29 |
+
parser.add_argument('--exp_name', type=str, default='')
|
30 |
+
parser.add_argument('--gpu_or_cpu', type=str, default='gpu')
|
31 |
+
parser.add_argument('--data_path', type=str, default='/data/ssd1/')
|
32 |
+
parser.add_argument('--dataset', type=str, default='nyudepthv2',
|
33 |
+
choices=['nyudepthv2', 'kitti', 'imagepath'])
|
34 |
+
parser.add_argument('--batch_size', type=int, default=8)
|
35 |
+
parser.add_argument('--workers', type=int, default=8)
|
36 |
+
|
37 |
+
# depth configs
|
38 |
+
parser.add_argument('--max_depth', type=float, default=10.0)
|
39 |
+
parser.add_argument('--max_depth_eval', type=float, default=10.0)
|
40 |
+
parser.add_argument('--min_depth_eval', type=float, default=1e-3)
|
41 |
+
parser.add_argument('--do_kb_crop', type=int, default=1)
|
42 |
+
parser.add_argument('--kitti_crop', type=str, default=None,
|
43 |
+
choices=['garg_crop', 'eigen_crop'])
|
44 |
+
|
45 |
+
parser.add_argument('--pretrained', type=str, default='')
|
46 |
+
parser.add_argument('--drop_path_rate', type=float, default=0.3)
|
47 |
+
parser.add_argument('--use_checkpoint', type=str2bool, default='False')
|
48 |
+
parser.add_argument('--num_deconv', type=int, default=3)
|
49 |
+
parser.add_argument('--num_filters', nargs='+', type=int, default=[32,32,32])
|
50 |
+
parser.add_argument('--deconv_kernels', nargs='+', type=int, default=[2,2,2])
|
51 |
+
|
52 |
+
parser.add_argument('--shift_window_test', action='store_true')
|
53 |
+
parser.add_argument('--shift_size', type=int, default=2)
|
54 |
+
parser.add_argument('--flip_test', action='store_true')
|
55 |
+
|
56 |
+
return parser
|
depth/configs/test_options.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
from configs.base_options import BaseOptions
|
7 |
+
|
8 |
+
class TestOptions(BaseOptions):
|
9 |
+
def initialize(self):
|
10 |
+
parser = BaseOptions.initialize(self)
|
11 |
+
|
12 |
+
# experiment configs
|
13 |
+
parser.add_argument('--ckpt_dir', type=str,
|
14 |
+
default='./ckpt/best_model_nyu.ckpt',
|
15 |
+
help='load ckpt path')
|
16 |
+
parser.add_argument('--result_dir', type=str, default='./results',
|
17 |
+
help='save result images into result_dir/exp_name')
|
18 |
+
parser.add_argument('--crop_h', type=int, default=448)
|
19 |
+
parser.add_argument('--crop_w', type=int, default=576)
|
20 |
+
|
21 |
+
parser.add_argument('--save_eval_pngs', action='store_true',
|
22 |
+
help='save result image into evaluation form')
|
23 |
+
parser.add_argument('--save_visualize', action='store_true',
|
24 |
+
help='save result image into visulized form')
|
25 |
+
return parser
|
26 |
+
|
27 |
+
|
depth/configs/train_options.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
4 |
+
# Modified by Zigang Geng ([email protected]).
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from configs.base_options import BaseOptions
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
|
11 |
+
def str2bool(v):
|
12 |
+
if isinstance(v, bool):
|
13 |
+
return v
|
14 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
15 |
+
return True
|
16 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
17 |
+
return False
|
18 |
+
else:
|
19 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
20 |
+
|
21 |
+
|
22 |
+
class TrainOptions(BaseOptions):
|
23 |
+
def initialize(self):
|
24 |
+
parser = BaseOptions.initialize(self)
|
25 |
+
|
26 |
+
# experiment configs
|
27 |
+
parser.add_argument('--epochs', type=int, default=25)
|
28 |
+
parser.add_argument('--max_lr', type=float, default=5e-4)
|
29 |
+
parser.add_argument('--min_lr', type=float, default=3e-5)
|
30 |
+
parser.add_argument('--weight_decay', type=float, default=5e-2)
|
31 |
+
parser.add_argument('--layer_decay', type=float, default=0.9)
|
32 |
+
|
33 |
+
parser.add_argument('--crop_h', type=int, default=448)
|
34 |
+
parser.add_argument('--crop_w', type=int, default=576)
|
35 |
+
parser.add_argument('--log_dir', type=str, default='./logs')
|
36 |
+
|
37 |
+
# logging options
|
38 |
+
parser.add_argument('--val_freq', type=int, default=1)
|
39 |
+
parser.add_argument('--pro_bar', type=str2bool, default='False')
|
40 |
+
parser.add_argument('--save_freq', type=int, default=1)
|
41 |
+
parser.add_argument('--print_freq', type=int, default=100)
|
42 |
+
parser.add_argument('--save_model', action='store_true')
|
43 |
+
parser.add_argument(
|
44 |
+
'--resume-from', help='the checkpoint file to resume from')
|
45 |
+
parser.add_argument('--auto_resume', action='store_true')
|
46 |
+
parser.add_argument('--save_result', action='store_true')
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
return parser
|
depth/inference.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
from models_depth.model import EVPDepth
|
7 |
+
from configs.train_options import TrainOptions
|
8 |
+
from configs.test_options import TestOptions
|
9 |
+
import glob
|
10 |
+
import utils
|
11 |
+
import torchvision.transforms as transforms
|
12 |
+
from utils_depth.misc import colorize
|
13 |
+
from PIL import Image
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
opt = TestOptions().initialize()
|
19 |
+
opt.add_argument('--img_path', type=str)
|
20 |
+
args = opt.parse_args()
|
21 |
+
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
model = EVPDepth(args=args, caption_aggregation=True)
|
24 |
+
cudnn.benchmark = True
|
25 |
+
model.to(device)
|
26 |
+
model_weight = torch.load(args.ckpt_dir)['model']
|
27 |
+
if 'module' in next(iter(model_weight.items()))[0]:
|
28 |
+
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
|
29 |
+
model.load_state_dict(model_weight, strict=False)
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
img_path = args.img_path
|
33 |
+
image = cv2.imread(img_path)
|
34 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
35 |
+
transform = transforms.ToTensor()
|
36 |
+
image = transform(image).unsqueeze(0).to(device)
|
37 |
+
shape = image.shape
|
38 |
+
image = torch.nn.functional.interpolate(image, (440,480), mode='bilinear', align_corners=True)
|
39 |
+
image = F.pad(image, (0, 0, 40, 0))
|
40 |
+
|
41 |
+
with torch.no_grad():
|
42 |
+
pred = model(image)['pred_d']
|
43 |
+
|
44 |
+
pred = pred[:,:,40:,:]
|
45 |
+
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
|
46 |
+
pred_d_numpy = pred.squeeze().cpu().numpy()
|
47 |
+
pred_d_color, _, _ = colorize(pred_d_numpy, cmap='gray_r')
|
48 |
+
Image.fromarray(pred_d_color).save('res.png')
|
49 |
+
|
50 |
+
return 0
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
main()
|
depth/models_depth/attractor.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
|
28 |
+
|
29 |
+
@torch.jit.script
|
30 |
+
def exp_attractor(dx, alpha: float = 300, gamma: int = 2):
|
31 |
+
"""Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor
|
32 |
+
|
33 |
+
Args:
|
34 |
+
dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
|
35 |
+
alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
|
36 |
+
gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc
|
40 |
+
"""
|
41 |
+
return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx)
|
42 |
+
|
43 |
+
|
44 |
+
@torch.jit.script
|
45 |
+
def inv_attractor(dx, alpha: float = 300, gamma: int = 2):
|
46 |
+
"""Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center
|
47 |
+
This is the default one according to the accompanying paper.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center.
|
51 |
+
alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300.
|
52 |
+
gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc
|
56 |
+
"""
|
57 |
+
return dx.div(1+alpha*dx.pow(gamma))
|
58 |
+
|
59 |
+
|
60 |
+
class AttractorLayer(nn.Module):
|
61 |
+
def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
|
62 |
+
alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
|
63 |
+
"""
|
64 |
+
Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth)
|
65 |
+
"""
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
self.n_attractors = n_attractors
|
69 |
+
self.n_bins = n_bins
|
70 |
+
self.min_depth = min_depth
|
71 |
+
self.max_depth = max_depth
|
72 |
+
self.alpha = alpha
|
73 |
+
self.gamma = gamma
|
74 |
+
self.kind = kind
|
75 |
+
self.attractor_type = attractor_type
|
76 |
+
self.memory_efficient = memory_efficient
|
77 |
+
|
78 |
+
self._net = nn.Sequential(
|
79 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
80 |
+
nn.ReLU(inplace=True),
|
81 |
+
nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm
|
82 |
+
nn.ReLU(inplace=True)
|
83 |
+
)
|
84 |
+
|
85 |
+
def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
|
86 |
+
"""
|
87 |
+
Args:
|
88 |
+
x (torch.Tensor) : feature block; shape - n, c, h, w
|
89 |
+
b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w
|
93 |
+
"""
|
94 |
+
if prev_b_embedding is not None:
|
95 |
+
if interpolate:
|
96 |
+
prev_b_embedding = nn.functional.interpolate(
|
97 |
+
prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
|
98 |
+
x = x + prev_b_embedding
|
99 |
+
|
100 |
+
A = self._net(x)
|
101 |
+
eps = 1e-3
|
102 |
+
A = A + eps
|
103 |
+
n, c, h, w = A.shape
|
104 |
+
A = A.view(n, self.n_attractors, 2, h, w)
|
105 |
+
A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w
|
106 |
+
A_normed = A[:, :, 0, ...] # n, na, h, w
|
107 |
+
|
108 |
+
b_prev = nn.functional.interpolate(
|
109 |
+
b_prev, (h, w), mode='bilinear', align_corners=True)
|
110 |
+
b_centers = b_prev
|
111 |
+
|
112 |
+
if self.attractor_type == 'exp':
|
113 |
+
dist = exp_attractor
|
114 |
+
else:
|
115 |
+
dist = inv_attractor
|
116 |
+
|
117 |
+
if not self.memory_efficient:
|
118 |
+
func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
|
119 |
+
# .shape N, nbins, h, w
|
120 |
+
delta_c = func(dist(A_normed.unsqueeze(
|
121 |
+
2) - b_centers.unsqueeze(1)), dim=1)
|
122 |
+
else:
|
123 |
+
delta_c = torch.zeros_like(b_centers, device=b_centers.device)
|
124 |
+
for i in range(self.n_attractors):
|
125 |
+
# .shape N, nbins, h, w
|
126 |
+
delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers)
|
127 |
+
|
128 |
+
if self.kind == 'mean':
|
129 |
+
delta_c = delta_c / self.n_attractors
|
130 |
+
|
131 |
+
b_new_centers = b_centers + delta_c
|
132 |
+
B_centers = (self.max_depth - self.min_depth) * \
|
133 |
+
b_new_centers + self.min_depth
|
134 |
+
B_centers, _ = torch.sort(B_centers, dim=1)
|
135 |
+
B_centers = torch.clip(B_centers, self.min_depth, self.max_depth)
|
136 |
+
return b_new_centers, B_centers
|
137 |
+
|
138 |
+
|
139 |
+
class AttractorLayerUnnormed(nn.Module):
|
140 |
+
def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10,
|
141 |
+
alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False):
|
142 |
+
"""
|
143 |
+
Attractor layer for bin centers. Bin centers are unbounded
|
144 |
+
"""
|
145 |
+
super().__init__()
|
146 |
+
|
147 |
+
self.n_attractors = n_attractors
|
148 |
+
self.n_bins = n_bins
|
149 |
+
self.min_depth = min_depth
|
150 |
+
self.max_depth = max_depth
|
151 |
+
self.alpha = alpha
|
152 |
+
self.gamma = gamma
|
153 |
+
self.kind = kind
|
154 |
+
self.attractor_type = attractor_type
|
155 |
+
self.memory_efficient = memory_efficient
|
156 |
+
|
157 |
+
self._net = nn.Sequential(
|
158 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
159 |
+
nn.ReLU(inplace=True),
|
160 |
+
nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0),
|
161 |
+
nn.Softplus()
|
162 |
+
)
|
163 |
+
|
164 |
+
def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
|
165 |
+
"""
|
166 |
+
Args:
|
167 |
+
x (torch.Tensor) : feature block; shape - n, c, h, w
|
168 |
+
b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version
|
172 |
+
"""
|
173 |
+
if prev_b_embedding is not None:
|
174 |
+
if interpolate:
|
175 |
+
prev_b_embedding = nn.functional.interpolate(
|
176 |
+
prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
|
177 |
+
x = x + prev_b_embedding
|
178 |
+
|
179 |
+
A = self._net(x)
|
180 |
+
n, c, h, w = A.shape
|
181 |
+
|
182 |
+
b_prev = nn.functional.interpolate(
|
183 |
+
b_prev, (h, w), mode='bilinear', align_corners=True)
|
184 |
+
b_centers = b_prev
|
185 |
+
|
186 |
+
if self.attractor_type == 'exp':
|
187 |
+
dist = exp_attractor
|
188 |
+
else:
|
189 |
+
dist = inv_attractor
|
190 |
+
|
191 |
+
if not self.memory_efficient:
|
192 |
+
func = {'mean': torch.mean, 'sum': torch.sum}[self.kind]
|
193 |
+
# .shape N, nbins, h, w
|
194 |
+
delta_c = func(
|
195 |
+
dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1)
|
196 |
+
else:
|
197 |
+
delta_c = torch.zeros_like(b_centers, device=b_centers.device)
|
198 |
+
for i in range(self.n_attractors):
|
199 |
+
delta_c += dist(A[:, i, ...].unsqueeze(1) -
|
200 |
+
b_centers) # .shape N, nbins, h, w
|
201 |
+
|
202 |
+
if self.kind == 'mean':
|
203 |
+
delta_c = delta_c / self.n_attractors
|
204 |
+
|
205 |
+
b_new_centers = b_centers + delta_c
|
206 |
+
B_centers = b_new_centers
|
207 |
+
|
208 |
+
return b_new_centers, B_centers
|
depth/models_depth/checkpoint.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft
|
3 |
+
# Licensed under the MIT License.
|
4 |
+
# The code is from Swin Transformer.
|
5 |
+
# (https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmcv_custom/checkpoint.py)
|
6 |
+
# ------------------------------------------------------------------------------
|
7 |
+
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import os.path as osp
|
11 |
+
import pkgutil
|
12 |
+
import time
|
13 |
+
import warnings
|
14 |
+
import numpy as np
|
15 |
+
from collections import OrderedDict
|
16 |
+
from importlib import import_module
|
17 |
+
from tempfile import TemporaryDirectory
|
18 |
+
from scipy import interpolate
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torchvision
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.optim import Optimizer
|
24 |
+
from torch.utils import model_zoo
|
25 |
+
from torch.nn import functional as F
|
26 |
+
|
27 |
+
import mmcv
|
28 |
+
from mmcv.fileio import FileClient
|
29 |
+
from mmcv.fileio import load as load_file
|
30 |
+
from mmcv.parallel import is_module_wrapper
|
31 |
+
from mmcv.utils import mkdir_or_exist
|
32 |
+
from mmcv.runner import get_dist_info
|
33 |
+
from mmcv.utils import get_logger
|
34 |
+
|
35 |
+
import logging
|
36 |
+
|
37 |
+
|
38 |
+
def get_root_logger(log_file=None, log_level=logging.INFO):
|
39 |
+
"""Get the root logger.
|
40 |
+
|
41 |
+
The logger will be initialized if it has not been initialized. By default a
|
42 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
43 |
+
also be added. The name of the root logger is the top-level package name,
|
44 |
+
e.g., "mmseg".
|
45 |
+
|
46 |
+
Args:
|
47 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
48 |
+
will be added to the root logger.
|
49 |
+
log_level (int): The root logger level. Note that only the process of
|
50 |
+
rank 0 is affected, while other processes will set the level to
|
51 |
+
"Error" and be silent most of the time.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
logging.Logger: The root logger.
|
55 |
+
"""
|
56 |
+
|
57 |
+
logger = get_logger(name='mmpose', log_file=log_file, log_level=log_level)
|
58 |
+
|
59 |
+
return logger
|
60 |
+
|
61 |
+
|
62 |
+
def _get_mmcv_home():
|
63 |
+
mmcv_home = os.path.expanduser(
|
64 |
+
os.getenv(
|
65 |
+
ENV_MMCV_HOME,
|
66 |
+
os.path.join(
|
67 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
68 |
+
|
69 |
+
mkdir_or_exist(mmcv_home)
|
70 |
+
return mmcv_home
|
71 |
+
|
72 |
+
|
73 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
74 |
+
"""Load state_dict to a module.
|
75 |
+
|
76 |
+
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
77 |
+
Default value for ``strict`` is set to ``False`` and the message for
|
78 |
+
param mismatch will be shown even if strict is False.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
module (Module): Module that receives the state_dict.
|
82 |
+
state_dict (OrderedDict): Weights.
|
83 |
+
strict (bool): whether to strictly enforce that the keys
|
84 |
+
in :attr:`state_dict` match the keys returned by this module's
|
85 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
86 |
+
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
87 |
+
message. If not specified, print function will be used.
|
88 |
+
"""
|
89 |
+
unexpected_keys = []
|
90 |
+
all_missing_keys = []
|
91 |
+
err_msg = []
|
92 |
+
|
93 |
+
metadata = getattr(state_dict, '_metadata', None)
|
94 |
+
state_dict = state_dict.copy()
|
95 |
+
if metadata is not None:
|
96 |
+
state_dict._metadata = metadata
|
97 |
+
|
98 |
+
# use _load_from_state_dict to enable checkpoint version control
|
99 |
+
def load(module, prefix=''):
|
100 |
+
# recursively check parallel module in case that the model has a
|
101 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
102 |
+
if is_module_wrapper(module):
|
103 |
+
module = module.module
|
104 |
+
local_metadata = {} if metadata is None else metadata.get(
|
105 |
+
prefix[:-1], {})
|
106 |
+
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
107 |
+
all_missing_keys, unexpected_keys,
|
108 |
+
err_msg)
|
109 |
+
for name, child in module._modules.items():
|
110 |
+
if child is not None:
|
111 |
+
load(child, prefix + name + '.')
|
112 |
+
|
113 |
+
load(module)
|
114 |
+
load = None # break load->load reference cycle
|
115 |
+
|
116 |
+
# ignore "num_batches_tracked" of BN layers
|
117 |
+
missing_keys = [
|
118 |
+
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
119 |
+
]
|
120 |
+
|
121 |
+
if unexpected_keys:
|
122 |
+
err_msg.append('unexpected key in source '
|
123 |
+
f'state_dict: {", ".join(unexpected_keys)}\n')
|
124 |
+
if missing_keys:
|
125 |
+
err_msg.append(
|
126 |
+
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
127 |
+
|
128 |
+
rank, _ = get_dist_info()
|
129 |
+
if len(err_msg) > 0 and rank == 0:
|
130 |
+
err_msg.insert(
|
131 |
+
0, 'The model and loaded state dict do not match exactly\n')
|
132 |
+
err_msg = '\n'.join(err_msg)
|
133 |
+
if strict:
|
134 |
+
raise RuntimeError(err_msg)
|
135 |
+
elif logger is not None:
|
136 |
+
logger.warning(err_msg)
|
137 |
+
else:
|
138 |
+
print(err_msg)
|
139 |
+
|
140 |
+
|
141 |
+
def load_url_dist(url, model_dir=None):
|
142 |
+
"""In distributed setting, this function only download checkpoint at local
|
143 |
+
rank 0."""
|
144 |
+
rank, world_size = get_dist_info()
|
145 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
146 |
+
if rank == 0:
|
147 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
148 |
+
if world_size > 1:
|
149 |
+
torch.distributed.barrier()
|
150 |
+
if rank > 0:
|
151 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
152 |
+
return checkpoint
|
153 |
+
|
154 |
+
|
155 |
+
def load_pavimodel_dist(model_path, map_location=None):
|
156 |
+
"""In distributed setting, this function only download checkpoint at local
|
157 |
+
rank 0."""
|
158 |
+
try:
|
159 |
+
from pavi import modelcloud
|
160 |
+
except ImportError:
|
161 |
+
raise ImportError(
|
162 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
163 |
+
rank, world_size = get_dist_info()
|
164 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
165 |
+
if rank == 0:
|
166 |
+
model = modelcloud.get(model_path)
|
167 |
+
with TemporaryDirectory() as tmp_dir:
|
168 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
169 |
+
model.download(downloaded_file)
|
170 |
+
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
171 |
+
if world_size > 1:
|
172 |
+
torch.distributed.barrier()
|
173 |
+
if rank > 0:
|
174 |
+
model = modelcloud.get(model_path)
|
175 |
+
with TemporaryDirectory() as tmp_dir:
|
176 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
177 |
+
model.download(downloaded_file)
|
178 |
+
checkpoint = torch.load(
|
179 |
+
downloaded_file, map_location=map_location)
|
180 |
+
return checkpoint
|
181 |
+
|
182 |
+
|
183 |
+
def load_fileclient_dist(filename, backend, map_location):
|
184 |
+
"""In distributed setting, this function only download checkpoint at local
|
185 |
+
rank 0."""
|
186 |
+
rank, world_size = get_dist_info()
|
187 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
188 |
+
allowed_backends = ['ceph']
|
189 |
+
if backend not in allowed_backends:
|
190 |
+
raise ValueError(f'Load from Backend {backend} is not supported.')
|
191 |
+
if rank == 0:
|
192 |
+
fileclient = FileClient(backend=backend)
|
193 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
194 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
195 |
+
if world_size > 1:
|
196 |
+
torch.distributed.barrier()
|
197 |
+
if rank > 0:
|
198 |
+
fileclient = FileClient(backend=backend)
|
199 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
200 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
201 |
+
return checkpoint
|
202 |
+
|
203 |
+
|
204 |
+
def get_torchvision_models():
|
205 |
+
model_urls = dict()
|
206 |
+
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
207 |
+
if ispkg:
|
208 |
+
continue
|
209 |
+
_zoo = import_module(f'torchvision.models.{name}')
|
210 |
+
if hasattr(_zoo, 'model_urls'):
|
211 |
+
_urls = getattr(_zoo, 'model_urls')
|
212 |
+
model_urls.update(_urls)
|
213 |
+
return model_urls
|
214 |
+
|
215 |
+
|
216 |
+
def get_external_models():
|
217 |
+
mmcv_home = _get_mmcv_home()
|
218 |
+
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
219 |
+
default_urls = load_file(default_json_path)
|
220 |
+
assert isinstance(default_urls, dict)
|
221 |
+
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
222 |
+
if osp.exists(external_json_path):
|
223 |
+
external_urls = load_file(external_json_path)
|
224 |
+
assert isinstance(external_urls, dict)
|
225 |
+
default_urls.update(external_urls)
|
226 |
+
|
227 |
+
return default_urls
|
228 |
+
|
229 |
+
|
230 |
+
def get_mmcls_models():
|
231 |
+
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
232 |
+
mmcls_urls = load_file(mmcls_json_path)
|
233 |
+
|
234 |
+
return mmcls_urls
|
235 |
+
|
236 |
+
|
237 |
+
def get_deprecated_model_names():
|
238 |
+
deprecate_json_path = osp.join(mmcv.__path__[0],
|
239 |
+
'model_zoo/deprecated.json')
|
240 |
+
deprecate_urls = load_file(deprecate_json_path)
|
241 |
+
assert isinstance(deprecate_urls, dict)
|
242 |
+
|
243 |
+
return deprecate_urls
|
244 |
+
|
245 |
+
|
246 |
+
def _process_mmcls_checkpoint(checkpoint):
|
247 |
+
state_dict = checkpoint['state_dict']
|
248 |
+
new_state_dict = OrderedDict()
|
249 |
+
for k, v in state_dict.items():
|
250 |
+
if k.startswith('backbone.'):
|
251 |
+
new_state_dict[k[9:]] = v
|
252 |
+
new_checkpoint = dict(state_dict=new_state_dict)
|
253 |
+
|
254 |
+
return new_checkpoint
|
255 |
+
|
256 |
+
|
257 |
+
def _load_checkpoint(filename, map_location=None):
|
258 |
+
"""Load checkpoint from somewhere (modelzoo, file, url).
|
259 |
+
|
260 |
+
Args:
|
261 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
262 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
263 |
+
details.
|
264 |
+
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
dict | OrderedDict: The loaded checkpoint. It can be either an
|
268 |
+
OrderedDict storing model weights or a dict containing other
|
269 |
+
information, which depends on the checkpoint.
|
270 |
+
"""
|
271 |
+
if filename.startswith('modelzoo://'):
|
272 |
+
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
273 |
+
'use "torchvision://" instead')
|
274 |
+
model_urls = get_torchvision_models()
|
275 |
+
model_name = filename[11:]
|
276 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
277 |
+
elif filename.startswith('torchvision://'):
|
278 |
+
model_urls = get_torchvision_models()
|
279 |
+
model_name = filename[14:]
|
280 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
281 |
+
elif filename.startswith('open-mmlab://'):
|
282 |
+
model_urls = get_external_models()
|
283 |
+
model_name = filename[13:]
|
284 |
+
deprecated_urls = get_deprecated_model_names()
|
285 |
+
if model_name in deprecated_urls:
|
286 |
+
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
|
287 |
+
f'of open-mmlab://{deprecated_urls[model_name]}')
|
288 |
+
model_name = deprecated_urls[model_name]
|
289 |
+
model_url = model_urls[model_name]
|
290 |
+
# check if is url
|
291 |
+
if model_url.startswith(('http://', 'https://')):
|
292 |
+
checkpoint = load_url_dist(model_url)
|
293 |
+
else:
|
294 |
+
filename = osp.join(_get_mmcv_home(), model_url)
|
295 |
+
if not osp.isfile(filename):
|
296 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
297 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
298 |
+
elif filename.startswith('mmcls://'):
|
299 |
+
model_urls = get_mmcls_models()
|
300 |
+
model_name = filename[8:]
|
301 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
302 |
+
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
303 |
+
elif filename.startswith(('http://', 'https://')):
|
304 |
+
checkpoint = load_url_dist(filename)
|
305 |
+
elif filename.startswith('pavi://'):
|
306 |
+
model_path = filename[7:]
|
307 |
+
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
|
308 |
+
elif filename.startswith('s3://'):
|
309 |
+
checkpoint = load_fileclient_dist(
|
310 |
+
filename, backend='ceph', map_location=map_location)
|
311 |
+
else:
|
312 |
+
if not osp.isfile(filename):
|
313 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
314 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
315 |
+
return checkpoint
|
316 |
+
|
317 |
+
|
318 |
+
def load_checkpoint_swin(model,
|
319 |
+
filename,
|
320 |
+
map_location='cpu',
|
321 |
+
strict=False,
|
322 |
+
rpe_interpolation='outer_mask',
|
323 |
+
logger=None):
|
324 |
+
"""Load checkpoint from a file or URI.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
model (Module): Module to load checkpoint.
|
328 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
329 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
330 |
+
details.
|
331 |
+
map_location (str): Same as :func:`torch.load`.
|
332 |
+
strict (bool): Whether to allow different params for the model and
|
333 |
+
checkpoint.
|
334 |
+
logger (:mod:`logging.Logger` or None): The logger for error message.
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
dict or OrderedDict: The loaded checkpoint.
|
338 |
+
"""
|
339 |
+
checkpoint = _load_checkpoint(filename, map_location)
|
340 |
+
# OrderedDict is a subclass of dict
|
341 |
+
if not isinstance(checkpoint, dict):
|
342 |
+
raise RuntimeError(
|
343 |
+
f'No state_dict found in checkpoint file {filename}')
|
344 |
+
# get state_dict from checkpoint
|
345 |
+
if 'state_dict' in checkpoint:
|
346 |
+
state_dict = checkpoint['state_dict']
|
347 |
+
elif 'model' in checkpoint:
|
348 |
+
state_dict = checkpoint['model']
|
349 |
+
elif 'module' in checkpoint:
|
350 |
+
state_dict = checkpoint['module']
|
351 |
+
else:
|
352 |
+
state_dict = checkpoint
|
353 |
+
# strip prefix of state_dict
|
354 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
355 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
356 |
+
|
357 |
+
# for MoBY, load model of online branch
|
358 |
+
if sorted(list(state_dict.keys()))[2].startswith('encoder'):
|
359 |
+
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
360 |
+
|
361 |
+
# reshape absolute position embedding for Swin
|
362 |
+
if state_dict.get('absolute_pos_embed') is not None:
|
363 |
+
absolute_pos_embed = state_dict['absolute_pos_embed']
|
364 |
+
N1, L, C1 = absolute_pos_embed.size()
|
365 |
+
N2, C2, H, W = model.absolute_pos_embed.size()
|
366 |
+
if N1 != N2 or C1 != C2 or L != H * W:
|
367 |
+
logger.warning("Error in loading absolute_pos_embed, pass")
|
368 |
+
else:
|
369 |
+
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
|
370 |
+
|
371 |
+
# interpolate position bias table if needed
|
372 |
+
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
|
373 |
+
for k in relative_position_bias_table_keys:
|
374 |
+
table_pretrained = state_dict[k]
|
375 |
+
table_current = model.state_dict()[k]
|
376 |
+
L1, nH1 = table_pretrained.size()
|
377 |
+
L2, nH2 = table_current.size()
|
378 |
+
if nH1 != nH2:
|
379 |
+
logger.warning(f"Error in loading {k}, pass")
|
380 |
+
else:
|
381 |
+
if L1 != L2:
|
382 |
+
if rpe_interpolation in ['bicubic', 'bilinear', 'nearest']:
|
383 |
+
logger.info(f"Interpolate relative_position_bias_table using {rpe_interpolation}")
|
384 |
+
S1 = int(L1 ** 0.5)
|
385 |
+
S2 = int(L2 ** 0.5)
|
386 |
+
table_pretrained_resized = F.interpolate(
|
387 |
+
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
388 |
+
size=(S2, S2), mode=rpe_interpolation)
|
389 |
+
state_dict[k] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
390 |
+
elif rpe_interpolation == 'geo':
|
391 |
+
logger.info("Interpolate relative_position_bias_table using geo.")
|
392 |
+
src_size = int(L1 ** 0.5)
|
393 |
+
dst_size = int(L2 ** 0.5)
|
394 |
+
|
395 |
+
def geometric_progression(a, r, n):
|
396 |
+
return a * (1.0 - r ** n) / (1.0 - r)
|
397 |
+
|
398 |
+
left, right = 1.01, 1.5
|
399 |
+
while right - left > 1e-6:
|
400 |
+
q = (left + right) / 2.0
|
401 |
+
gp = geometric_progression(1, q, src_size // 2)
|
402 |
+
if gp > dst_size // 2:
|
403 |
+
right = q
|
404 |
+
else:
|
405 |
+
left = q
|
406 |
+
|
407 |
+
# if q > 1.13492:
|
408 |
+
# q = 1.13492
|
409 |
+
|
410 |
+
dis = []
|
411 |
+
cur = 1
|
412 |
+
for i in range(src_size // 2):
|
413 |
+
dis.append(cur)
|
414 |
+
cur += q ** (i + 1)
|
415 |
+
|
416 |
+
r_ids = [-_ for _ in reversed(dis)]
|
417 |
+
|
418 |
+
x = r_ids + [0] + dis
|
419 |
+
y = r_ids + [0] + dis
|
420 |
+
|
421 |
+
t = dst_size // 2.0
|
422 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
423 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
424 |
+
|
425 |
+
logger.info("Original positions = %s" % str(x))
|
426 |
+
logger.info("Target positions = %s" % str(dx))
|
427 |
+
|
428 |
+
all_rel_pos_bias = []
|
429 |
+
|
430 |
+
for i in range(nH1):
|
431 |
+
z = table_pretrained[:, i].view(src_size, src_size).float().numpy()
|
432 |
+
f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
|
433 |
+
all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to(
|
434 |
+
table_pretrained.device))
|
435 |
+
|
436 |
+
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
437 |
+
state_dict[k] = new_rel_pos_bias
|
438 |
+
|
439 |
+
if 'pos_embed' in state_dict:
|
440 |
+
pos_embed_checkpoint = state_dict['pos_embed']
|
441 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
442 |
+
num_patches = model.patch_embed.num_patches
|
443 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
444 |
+
# height (== width) for the checkpoint position embedding
|
445 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
446 |
+
# height (== width) for the new position embedding
|
447 |
+
new_size = int(num_patches ** 0.5)
|
448 |
+
# class_token and dist_token are kept unchanged
|
449 |
+
if orig_size != new_size:
|
450 |
+
if dist.get_rank() == 0:
|
451 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
452 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
453 |
+
# only the position tokens are interpolated
|
454 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
455 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
456 |
+
pos_tokens = torch.nn.functional.interpolate(
|
457 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
458 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
459 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
460 |
+
state_dict['pos_embed'] = new_pos_embed
|
461 |
+
|
462 |
+
# load state_dict
|
463 |
+
load_state_dict(model, state_dict, strict, logger)
|
464 |
+
return checkpoint
|
465 |
+
|
466 |
+
|
467 |
+
def weights_to_cpu(state_dict):
|
468 |
+
"""Copy a model state_dict to cpu.
|
469 |
+
|
470 |
+
Args:
|
471 |
+
state_dict (OrderedDict): Model weights on GPU.
|
472 |
+
|
473 |
+
Returns:
|
474 |
+
OrderedDict: Model weights on GPU.
|
475 |
+
"""
|
476 |
+
state_dict_cpu = OrderedDict()
|
477 |
+
for key, val in state_dict.items():
|
478 |
+
state_dict_cpu[key] = val.cpu()
|
479 |
+
return state_dict_cpu
|
480 |
+
|
481 |
+
|
482 |
+
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
483 |
+
"""Saves module state to `destination` dictionary.
|
484 |
+
|
485 |
+
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
module (nn.Module): The module to generate state_dict.
|
489 |
+
destination (dict): A dict where state will be stored.
|
490 |
+
prefix (str): The prefix for parameters and buffers used in this
|
491 |
+
module.
|
492 |
+
"""
|
493 |
+
for name, param in module._parameters.items():
|
494 |
+
if param is not None:
|
495 |
+
destination[prefix + name] = param if keep_vars else param.detach()
|
496 |
+
for name, buf in module._buffers.items():
|
497 |
+
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
498 |
+
if buf is not None:
|
499 |
+
destination[prefix + name] = buf if keep_vars else buf.detach()
|
500 |
+
|
501 |
+
|
502 |
+
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
503 |
+
"""Returns a dictionary containing a whole state of the module.
|
504 |
+
|
505 |
+
Both parameters and persistent buffers (e.g. running averages) are
|
506 |
+
included. Keys are corresponding parameter and buffer names.
|
507 |
+
|
508 |
+
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
509 |
+
recursively check parallel module in case that the model has a complicated
|
510 |
+
structure, e.g., nn.Module(nn.Module(DDP)).
|
511 |
+
|
512 |
+
Args:
|
513 |
+
module (nn.Module): The module to generate state_dict.
|
514 |
+
destination (OrderedDict): Returned dict for the state of the
|
515 |
+
module.
|
516 |
+
prefix (str): Prefix of the key.
|
517 |
+
keep_vars (bool): Whether to keep the variable property of the
|
518 |
+
parameters. Default: False.
|
519 |
+
|
520 |
+
Returns:
|
521 |
+
dict: A dictionary containing a whole state of the module.
|
522 |
+
"""
|
523 |
+
# recursively check parallel module in case that the model has a
|
524 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
525 |
+
if is_module_wrapper(module):
|
526 |
+
module = module.module
|
527 |
+
|
528 |
+
# below is the same as torch.nn.Module.state_dict()
|
529 |
+
if destination is None:
|
530 |
+
destination = OrderedDict()
|
531 |
+
destination._metadata = OrderedDict()
|
532 |
+
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
533 |
+
version=module._version)
|
534 |
+
_save_to_state_dict(module, destination, prefix, keep_vars)
|
535 |
+
for name, child in module._modules.items():
|
536 |
+
if child is not None:
|
537 |
+
get_state_dict(
|
538 |
+
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
539 |
+
for hook in module._state_dict_hooks.values():
|
540 |
+
hook_result = hook(module, destination, prefix, local_metadata)
|
541 |
+
if hook_result is not None:
|
542 |
+
destination = hook_result
|
543 |
+
return destination
|
544 |
+
|
545 |
+
|
546 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
547 |
+
"""Save checkpoint to file.
|
548 |
+
|
549 |
+
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
550 |
+
``optimizer``. By default ``meta`` will contain version and time info.
|
551 |
+
|
552 |
+
Args:
|
553 |
+
model (Module): Module whose params are to be saved.
|
554 |
+
filename (str): Checkpoint filename.
|
555 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
556 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
557 |
+
"""
|
558 |
+
if meta is None:
|
559 |
+
meta = {}
|
560 |
+
elif not isinstance(meta, dict):
|
561 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
562 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
563 |
+
|
564 |
+
if is_module_wrapper(model):
|
565 |
+
model = model.module
|
566 |
+
|
567 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
568 |
+
# save class name to the meta
|
569 |
+
meta.update(CLASSES=model.CLASSES)
|
570 |
+
|
571 |
+
checkpoint = {
|
572 |
+
'meta': meta,
|
573 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
574 |
+
}
|
575 |
+
# save optimizer state dict in the checkpoint
|
576 |
+
if isinstance(optimizer, Optimizer):
|
577 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
578 |
+
elif isinstance(optimizer, dict):
|
579 |
+
checkpoint['optimizer'] = {}
|
580 |
+
for name, optim in optimizer.items():
|
581 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
582 |
+
|
583 |
+
if filename.startswith('pavi://'):
|
584 |
+
try:
|
585 |
+
from pavi import modelcloud
|
586 |
+
from pavi.exception import NodeNotFoundError
|
587 |
+
except ImportError:
|
588 |
+
raise ImportError(
|
589 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
590 |
+
model_path = filename[7:]
|
591 |
+
root = modelcloud.Folder()
|
592 |
+
model_dir, model_name = osp.split(model_path)
|
593 |
+
try:
|
594 |
+
model = modelcloud.get(model_dir)
|
595 |
+
except NodeNotFoundError:
|
596 |
+
model = root.create_training_model(model_dir)
|
597 |
+
with TemporaryDirectory() as tmp_dir:
|
598 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
599 |
+
with open(checkpoint_file, 'wb') as f:
|
600 |
+
torch.save(checkpoint, f)
|
601 |
+
f.flush()
|
602 |
+
model.create_file(checkpoint_file, name=model_name)
|
603 |
+
else:
|
604 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
605 |
+
# immediately flush buffer
|
606 |
+
with open(filename, 'wb') as f:
|
607 |
+
torch.save(checkpoint, f)
|
608 |
+
f.flush()
|
depth/models_depth/dist_layers.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
|
28 |
+
|
29 |
+
def log_binom(n, k, eps=1e-7):
|
30 |
+
""" log(nCk) using stirling approximation """
|
31 |
+
n = n + eps
|
32 |
+
k = k + eps
|
33 |
+
return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps)
|
34 |
+
|
35 |
+
|
36 |
+
class LogBinomial(nn.Module):
|
37 |
+
def __init__(self, n_classes=256, act=torch.softmax):
|
38 |
+
"""Compute log binomial distribution for n_classes
|
39 |
+
|
40 |
+
Args:
|
41 |
+
n_classes (int, optional): number of output classes. Defaults to 256.
|
42 |
+
"""
|
43 |
+
super().__init__()
|
44 |
+
self.K = n_classes
|
45 |
+
self.act = act
|
46 |
+
self.register_buffer('k_idx', torch.arange(
|
47 |
+
0, n_classes).view(1, -1, 1, 1))
|
48 |
+
self.register_buffer('K_minus_1', torch.Tensor(
|
49 |
+
[self.K-1]).view(1, -1, 1, 1))
|
50 |
+
|
51 |
+
def forward(self, x, t=1., eps=1e-4):
|
52 |
+
"""Compute log binomial distribution for x
|
53 |
+
|
54 |
+
Args:
|
55 |
+
x (torch.Tensor - NCHW): probabilities
|
56 |
+
t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1..
|
57 |
+
eps (float, optional): Small number for numerical stability. Defaults to 1e-4.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
torch.Tensor -NCHW: log binomial distribution logbinomial(p;t)
|
61 |
+
"""
|
62 |
+
if x.ndim == 3:
|
63 |
+
x = x.unsqueeze(1) # make it nchw
|
64 |
+
|
65 |
+
one_minus_x = torch.clamp(1 - x, eps, 1)
|
66 |
+
x = torch.clamp(x, eps, 1)
|
67 |
+
y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \
|
68 |
+
torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x)
|
69 |
+
return self.act(y/t, dim=1)
|
70 |
+
|
71 |
+
|
72 |
+
class ConditionalLogBinomial(nn.Module):
|
73 |
+
def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax):
|
74 |
+
"""Conditional Log Binomial distribution
|
75 |
+
|
76 |
+
Args:
|
77 |
+
in_features (int): number of input channels in main feature
|
78 |
+
condition_dim (int): number of input channels in condition feature
|
79 |
+
n_classes (int, optional): Number of classes. Defaults to 256.
|
80 |
+
bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2.
|
81 |
+
p_eps (float, optional): small eps value. Defaults to 1e-4.
|
82 |
+
max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50.
|
83 |
+
min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7.
|
84 |
+
"""
|
85 |
+
super().__init__()
|
86 |
+
self.p_eps = p_eps
|
87 |
+
self.max_temp = max_temp
|
88 |
+
self.min_temp = min_temp
|
89 |
+
self.log_binomial_transform = LogBinomial(n_classes, act=act)
|
90 |
+
bottleneck = (in_features + condition_dim) // bottleneck_factor
|
91 |
+
self.mlp = nn.Sequential(
|
92 |
+
nn.Conv2d(in_features + condition_dim, bottleneck,
|
93 |
+
kernel_size=1, stride=1, padding=0),
|
94 |
+
nn.GELU(),
|
95 |
+
# 2 for p linear norm, 2 for t linear norm
|
96 |
+
nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0),
|
97 |
+
nn.Softplus()
|
98 |
+
)
|
99 |
+
|
100 |
+
def forward(self, x, cond):
|
101 |
+
"""Forward pass
|
102 |
+
|
103 |
+
Args:
|
104 |
+
x (torch.Tensor - NCHW): Main feature
|
105 |
+
cond (torch.Tensor - NCHW): condition feature
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
torch.Tensor: Output log binomial distribution
|
109 |
+
"""
|
110 |
+
pt = self.mlp(torch.concat((x, cond), dim=1))
|
111 |
+
p, t = pt[:, :2, ...], pt[:, 2:, ...]
|
112 |
+
|
113 |
+
p = p + self.p_eps
|
114 |
+
p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...])
|
115 |
+
|
116 |
+
t = t + self.p_eps
|
117 |
+
t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...])
|
118 |
+
t = t.unsqueeze(1)
|
119 |
+
t = (self.max_temp - self.min_temp) * t + self.min_temp
|
120 |
+
|
121 |
+
return self.log_binomial_transform(p, t)
|
depth/models_depth/layers.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class PatchTransformerEncoder(nn.Module):
|
6 |
+
def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4):
|
7 |
+
super(PatchTransformerEncoder, self).__init__()
|
8 |
+
encoder_layers = nn.TransformerEncoderLayer(embedding_dim, num_heads, dim_feedforward=1024)
|
9 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=4) # takes shape S,N,E
|
10 |
+
|
11 |
+
self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim,
|
12 |
+
kernel_size=patch_size, stride=patch_size, padding=0)
|
13 |
+
|
14 |
+
self.positional_encodings = nn.Parameter(torch.rand(900, embedding_dim), requires_grad=True)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
embeddings = self.embedding_convPxP(x).flatten(2) # .shape = n,c,s = n, embedding_dim, s
|
18 |
+
# embeddings = nn.functional.pad(embeddings, (1,0)) # extra special token at start ?
|
19 |
+
embeddings = embeddings + self.positional_encodings[:embeddings.shape[2], :].T.unsqueeze(0)
|
20 |
+
|
21 |
+
# change to S,N,E format required by transformer
|
22 |
+
embeddings = embeddings.permute(2, 0, 1)
|
23 |
+
x = self.transformer_encoder(embeddings) # .shape = S, N, E
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class PixelWiseDotProduct(nn.Module):
|
28 |
+
def __init__(self):
|
29 |
+
super(PixelWiseDotProduct, self).__init__()
|
30 |
+
|
31 |
+
def forward(self, x, K):
|
32 |
+
n, c, h, w = x.size()
|
33 |
+
_, cout, ck = K.size()
|
34 |
+
assert c == ck, "Number of channels in x and Embedding dimension (at dim 2) of K matrix must match"
|
35 |
+
y = torch.matmul(x.view(n, c, h * w).permute(0, 2, 1), K.permute(0, 2, 1)) # .shape = n, hw, cout
|
36 |
+
return y.permute(0, 2, 1).view(n, cout, h, w)
|
depth/models_depth/localbins_layers.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
|
23 |
+
# File author: Shariq Farooq Bhat
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
|
28 |
+
|
29 |
+
class SeedBinRegressor(nn.Module):
|
30 |
+
def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
|
31 |
+
"""Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
in_features (int): input channels
|
35 |
+
n_bins (int, optional): Number of bin centers. Defaults to 16.
|
36 |
+
mlp_dim (int, optional): Hidden dimension. Defaults to 256.
|
37 |
+
min_depth (float, optional): Min depth value. Defaults to 1e-3.
|
38 |
+
max_depth (float, optional): Max depth value. Defaults to 10.
|
39 |
+
"""
|
40 |
+
super().__init__()
|
41 |
+
self.version = "1_1"
|
42 |
+
self.min_depth = min_depth
|
43 |
+
self.max_depth = max_depth
|
44 |
+
|
45 |
+
self._net = nn.Sequential(
|
46 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
|
49 |
+
nn.ReLU(inplace=True)
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
"""
|
54 |
+
Returns tensor of bin_width vectors (centers). One vector b for every pixel
|
55 |
+
"""
|
56 |
+
B = self._net(x)
|
57 |
+
eps = 1e-3
|
58 |
+
B = B + eps
|
59 |
+
B_widths_normed = B / B.sum(dim=1, keepdim=True)
|
60 |
+
B_widths = (self.max_depth - self.min_depth) * \
|
61 |
+
B_widths_normed # .shape NCHW
|
62 |
+
# pad has the form (left, right, top, bottom, front, back)
|
63 |
+
B_widths = nn.functional.pad(
|
64 |
+
B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth)
|
65 |
+
B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
|
66 |
+
|
67 |
+
B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...])
|
68 |
+
return B_widths_normed, B_centers
|
69 |
+
|
70 |
+
|
71 |
+
class SeedBinRegressorUnnormed(nn.Module):
|
72 |
+
def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10):
|
73 |
+
"""Bin center regressor network. Bin centers are unbounded
|
74 |
+
|
75 |
+
Args:
|
76 |
+
in_features (int): input channels
|
77 |
+
n_bins (int, optional): Number of bin centers. Defaults to 16.
|
78 |
+
mlp_dim (int, optional): Hidden dimension. Defaults to 256.
|
79 |
+
min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
|
80 |
+
max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor)
|
81 |
+
"""
|
82 |
+
super().__init__()
|
83 |
+
self.version = "1_1"
|
84 |
+
self._net = nn.Sequential(
|
85 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
86 |
+
nn.ReLU(inplace=True),
|
87 |
+
nn.Conv2d(mlp_dim, n_bins, 1, 1, 0),
|
88 |
+
nn.Softplus()
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
"""
|
93 |
+
Returns tensor of bin_width vectors (centers). One vector b for every pixel
|
94 |
+
"""
|
95 |
+
B_centers = self._net(x)
|
96 |
+
return B_centers, B_centers
|
97 |
+
|
98 |
+
|
99 |
+
class Projector(nn.Module):
|
100 |
+
def __init__(self, in_features, out_features, mlp_dim=128):
|
101 |
+
"""Projector MLP
|
102 |
+
|
103 |
+
Args:
|
104 |
+
in_features (int): input channels
|
105 |
+
out_features (int): output channels
|
106 |
+
mlp_dim (int, optional): hidden dimension. Defaults to 128.
|
107 |
+
"""
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self._net = nn.Sequential(
|
111 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
112 |
+
nn.ReLU(inplace=True),
|
113 |
+
nn.Conv2d(mlp_dim, out_features, 1, 1, 0),
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
return self._net(x)
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
class LinearSplitter(nn.Module):
|
122 |
+
def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10):
|
123 |
+
super().__init__()
|
124 |
+
|
125 |
+
self.prev_nbins = prev_nbins
|
126 |
+
self.split_factor = split_factor
|
127 |
+
self.min_depth = min_depth
|
128 |
+
self.max_depth = max_depth
|
129 |
+
|
130 |
+
self._net = nn.Sequential(
|
131 |
+
nn.Conv2d(in_features, mlp_dim, 1, 1, 0),
|
132 |
+
nn.GELU(),
|
133 |
+
nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0),
|
134 |
+
nn.ReLU()
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False):
|
138 |
+
"""
|
139 |
+
x : feature block; shape - n, c, h, w
|
140 |
+
b_prev : previous bin widths normed; shape - n, prev_nbins, h, w
|
141 |
+
"""
|
142 |
+
if prev_b_embedding is not None:
|
143 |
+
if interpolate:
|
144 |
+
prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True)
|
145 |
+
x = x + prev_b_embedding
|
146 |
+
S = self._net(x)
|
147 |
+
eps = 1e-3
|
148 |
+
S = S + eps
|
149 |
+
n, c, h, w = S.shape
|
150 |
+
S = S.view(n, self.prev_nbins, self.split_factor, h, w)
|
151 |
+
S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits
|
152 |
+
|
153 |
+
b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True)
|
154 |
+
|
155 |
+
|
156 |
+
b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees
|
157 |
+
# print(b_prev.shape, S_normed.shape)
|
158 |
+
# if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat?
|
159 |
+
b = b_prev.unsqueeze(2) * S_normed
|
160 |
+
b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w
|
161 |
+
|
162 |
+
# calculate bin centers for loss calculation
|
163 |
+
B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W
|
164 |
+
# pad has the form (left, right, top, bottom, front, back)
|
165 |
+
B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth)
|
166 |
+
B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW
|
167 |
+
|
168 |
+
B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...])
|
169 |
+
return b, B_centers
|
depth/models_depth/miniViT.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .layers import PatchTransformerEncoder, PixelWiseDotProduct
|
5 |
+
|
6 |
+
|
7 |
+
class mViT(nn.Module):
|
8 |
+
def __init__(self, in_channels, n_query_channels=128, patch_size=16, dim_out=256,
|
9 |
+
embedding_dim=128, num_heads=4, norm='linear'):
|
10 |
+
super(mViT, self).__init__()
|
11 |
+
self.norm = norm
|
12 |
+
self.n_query_channels = n_query_channels
|
13 |
+
self.patch_transformer = PatchTransformerEncoder(in_channels, patch_size, embedding_dim, num_heads)
|
14 |
+
self.dot_product_layer = PixelWiseDotProduct()
|
15 |
+
|
16 |
+
self.conv3x3 = nn.Conv2d(in_channels, embedding_dim, kernel_size=3, stride=1, padding=1)
|
17 |
+
self.regressor = nn.Sequential(nn.Linear(embedding_dim, 256),
|
18 |
+
nn.LeakyReLU(),
|
19 |
+
nn.Linear(256, 256),
|
20 |
+
nn.LeakyReLU(),
|
21 |
+
nn.Linear(256, dim_out))
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
# n, c, h, w = x.size()
|
25 |
+
tgt = self.patch_transformer(x.clone()) # .shape = S, N, E
|
26 |
+
|
27 |
+
x = self.conv3x3(x)
|
28 |
+
|
29 |
+
regression_head, queries = tgt[0, ...], tgt[1:self.n_query_channels + 1, ...]
|
30 |
+
|
31 |
+
# Change from S, N, E to N, S, E
|
32 |
+
queries = queries.permute(1, 0, 2)
|
33 |
+
range_attention_maps = self.dot_product_layer(x, queries) # .shape = n, n_query_channels, h, w
|
34 |
+
|
35 |
+
y = self.regressor(regression_head) # .shape = N, dim_out
|
36 |
+
if self.norm == 'linear':
|
37 |
+
y = torch.relu(y)
|
38 |
+
eps = 0.1
|
39 |
+
y = y + eps
|
40 |
+
elif self.norm == 'softmax':
|
41 |
+
return torch.softmax(y, dim=1), range_attention_maps
|
42 |
+
else:
|
43 |
+
y = torch.sigmoid(y)
|
44 |
+
y = y / y.sum(dim=1, keepdim=True)
|
45 |
+
return y, range_attention_maps
|
depth/models_depth/model.py
ADDED
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft
|
3 |
+
# Licensed under the MIT License.
|
4 |
+
# The deconvolution code is based on Simple Baseline.
|
5 |
+
# (https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py)
|
6 |
+
# Modified by Zigang Geng ([email protected]).
|
7 |
+
# ------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
12 |
+
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
|
13 |
+
constant_init, normal_init)
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from ldm.util import instantiate_from_config
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from evp.models import UNetWrapper, TextAdapterRefer, FrozenCLIPEmbedder
|
19 |
+
from .miniViT import mViT
|
20 |
+
from .attractor import AttractorLayer, AttractorLayerUnnormed
|
21 |
+
from .dist_layers import ConditionalLogBinomial
|
22 |
+
from .localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed)
|
23 |
+
import os
|
24 |
+
|
25 |
+
|
26 |
+
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
|
27 |
+
"""
|
28 |
+
Checkerboard artifact free sub-pixel convolution
|
29 |
+
https://arxiv.org/abs/1707.02937
|
30 |
+
"""
|
31 |
+
ni,nf,h,w = x.shape
|
32 |
+
ni2 = int(ni/(scale**2))
|
33 |
+
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
|
34 |
+
k = k.contiguous().view(ni2, nf, -1)
|
35 |
+
k = k.repeat(1, 1, scale**2)
|
36 |
+
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
|
37 |
+
x.data.copy_(k)
|
38 |
+
|
39 |
+
|
40 |
+
class PixelShuffle(nn.Module):
|
41 |
+
"""
|
42 |
+
Real-Time Single Image and Video Super-Resolution
|
43 |
+
https://arxiv.org/abs/1609.05158
|
44 |
+
"""
|
45 |
+
def __init__(self, n_channels, scale):
|
46 |
+
super(PixelShuffle, self).__init__()
|
47 |
+
self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
|
48 |
+
icnr(self.conv.weight)
|
49 |
+
self.shuf = nn.PixelShuffle(scale)
|
50 |
+
self.relu = nn.ReLU()
|
51 |
+
|
52 |
+
def forward(self,x):
|
53 |
+
x = self.shuf(self.relu(self.conv(x)))
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class AttentionModule(nn.Module):
|
58 |
+
def __init__(self, in_channels, out_channels):
|
59 |
+
super(AttentionModule, self).__init__()
|
60 |
+
|
61 |
+
# Convolutional Layers
|
62 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
63 |
+
|
64 |
+
# Group Normalization
|
65 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
66 |
+
|
67 |
+
# ReLU Activation
|
68 |
+
self.relu = nn.ReLU()
|
69 |
+
|
70 |
+
# Spatial Attention
|
71 |
+
self.spatial_attention = nn.Sequential(
|
72 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
73 |
+
nn.Sigmoid()
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
# Apply spatial attention
|
78 |
+
spatial_attention = self.spatial_attention(x)
|
79 |
+
x = x * spatial_attention
|
80 |
+
|
81 |
+
# Apply convolutional layer
|
82 |
+
x = self.conv1(x)
|
83 |
+
x = self.group_norm(x)
|
84 |
+
x = self.relu(x)
|
85 |
+
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class AttentionDownsamplingModule(nn.Module):
|
90 |
+
def __init__(self, in_channels, out_channels, scale_factor=2):
|
91 |
+
super(AttentionDownsamplingModule, self).__init__()
|
92 |
+
|
93 |
+
# Spatial Attention
|
94 |
+
self.spatial_attention = nn.Sequential(
|
95 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
96 |
+
nn.Sigmoid()
|
97 |
+
)
|
98 |
+
|
99 |
+
# Channel Attention
|
100 |
+
self.channel_attention = nn.Sequential(
|
101 |
+
nn.AdaptiveAvgPool2d(1),
|
102 |
+
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
|
103 |
+
nn.ReLU(inplace=True),
|
104 |
+
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
|
105 |
+
nn.Sigmoid()
|
106 |
+
)
|
107 |
+
|
108 |
+
# Convolutional Layers
|
109 |
+
if scale_factor == 2:
|
110 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
111 |
+
elif scale_factor == 4:
|
112 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
113 |
+
|
114 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
115 |
+
|
116 |
+
# Group Normalization
|
117 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
118 |
+
|
119 |
+
# ReLU Activation
|
120 |
+
self.relu = nn.ReLU(inplace=True)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
# Apply spatial attention
|
124 |
+
spatial_attention = self.spatial_attention(x)
|
125 |
+
x = x * spatial_attention
|
126 |
+
|
127 |
+
# Apply channel attention
|
128 |
+
channel_attention = self.channel_attention(x)
|
129 |
+
x = x * channel_attention
|
130 |
+
|
131 |
+
# Apply convolutional layers
|
132 |
+
x = self.conv1(x)
|
133 |
+
x = self.group_norm(x)
|
134 |
+
x = self.relu(x)
|
135 |
+
x = self.conv2(x)
|
136 |
+
x = self.group_norm(x)
|
137 |
+
x = self.relu(x)
|
138 |
+
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class AttentionUpsamplingModule(nn.Module):
|
143 |
+
def __init__(self, in_channels, out_channels):
|
144 |
+
super(AttentionUpsamplingModule, self).__init__()
|
145 |
+
|
146 |
+
# Spatial Attention for outs[2]
|
147 |
+
self.spatial_attention = nn.Sequential(
|
148 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
149 |
+
nn.Sigmoid()
|
150 |
+
)
|
151 |
+
|
152 |
+
# Channel Attention for outs[2]
|
153 |
+
self.channel_attention = nn.Sequential(
|
154 |
+
nn.AdaptiveAvgPool2d(1),
|
155 |
+
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
|
156 |
+
nn.ReLU(),
|
157 |
+
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
|
158 |
+
nn.Sigmoid()
|
159 |
+
)
|
160 |
+
|
161 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
162 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
163 |
+
|
164 |
+
# Group Normalization
|
165 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
166 |
+
|
167 |
+
# ReLU Activation
|
168 |
+
self.relu = nn.ReLU()
|
169 |
+
self.upscale = PixelShuffle(in_channels, 2)
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
# Apply spatial attention
|
173 |
+
spatial_attention = self.spatial_attention(x)
|
174 |
+
x = x * spatial_attention
|
175 |
+
|
176 |
+
# Apply channel attention
|
177 |
+
channel_attention = self.channel_attention(x)
|
178 |
+
x = x * channel_attention
|
179 |
+
|
180 |
+
# Apply convolutional layers
|
181 |
+
x = self.conv1(x)
|
182 |
+
x = self.group_norm(x)
|
183 |
+
x = self.relu(x)
|
184 |
+
x = self.conv2(x)
|
185 |
+
x = self.group_norm(x)
|
186 |
+
x = self.relu(x)
|
187 |
+
|
188 |
+
# Upsample
|
189 |
+
x = self.upscale(x)
|
190 |
+
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class ConvLayer(nn.Module):
|
195 |
+
def __init__(self, in_channels, out_channels):
|
196 |
+
super(ConvLayer, self).__init__()
|
197 |
+
|
198 |
+
self.conv1 = nn.Sequential(
|
199 |
+
nn.Conv2d(in_channels, out_channels, 1),
|
200 |
+
nn.GroupNorm(20, out_channels),
|
201 |
+
nn.ReLU(),
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
x = self.conv1(x)
|
206 |
+
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class InverseMultiAttentiveFeatureRefinement(nn.Module):
|
211 |
+
def __init__(self, in_channels_list):
|
212 |
+
super(InverseMultiAttentiveFeatureRefinement, self).__init__()
|
213 |
+
|
214 |
+
self.layer1 = AttentionModule(in_channels_list[0], in_channels_list[0])
|
215 |
+
self.layer2 = AttentionDownsamplingModule(in_channels_list[0], in_channels_list[0]//2, scale_factor = 2)
|
216 |
+
self.layer3 = ConvLayer(in_channels_list[0]//2 + in_channels_list[1], in_channels_list[1])
|
217 |
+
self.layer4 = AttentionDownsamplingModule(in_channels_list[1], in_channels_list[1]//2, scale_factor = 2)
|
218 |
+
self.layer5 = ConvLayer(in_channels_list[1]//2 + in_channels_list[2], in_channels_list[2])
|
219 |
+
self.layer6 = AttentionDownsamplingModule(in_channels_list[2], in_channels_list[2]//2, scale_factor = 2)
|
220 |
+
self.layer7 = ConvLayer(in_channels_list[2]//2 + in_channels_list[3], in_channels_list[3])
|
221 |
+
|
222 |
+
'''
|
223 |
+
self.layer8 = AttentionUpsamplingModule(in_channels_list[3], in_channels_list[3])
|
224 |
+
self.layer9 = ConvLayer(in_channels_list[2] + in_channels_list[3], in_channels_list[2])
|
225 |
+
self.layer10 = AttentionUpsamplingModule(in_channels_list[2], in_channels_list[2])
|
226 |
+
self.layer11 = ConvLayer(in_channels_list[1] + in_channels_list[2], in_channels_list[1])
|
227 |
+
self.layer12 = AttentionUpsamplingModule(in_channels_list[1], in_channels_list[1])
|
228 |
+
self.layer13 = ConvLayer(in_channels_list[0] + in_channels_list[1], in_channels_list[0])
|
229 |
+
'''
|
230 |
+
def forward(self, inputs):
|
231 |
+
x_c4, x_c3, x_c2, x_c1 = inputs
|
232 |
+
x_c4 = self.layer1(x_c4)
|
233 |
+
x_c4_3 = self.layer2(x_c4)
|
234 |
+
x_c3 = torch.cat([x_c4_3, x_c3], dim=1)
|
235 |
+
x_c3 = self.layer3(x_c3)
|
236 |
+
x_c3_2 = self.layer4(x_c3)
|
237 |
+
x_c2 = torch.cat([x_c3_2, x_c2], dim=1)
|
238 |
+
x_c2 = self.layer5(x_c2)
|
239 |
+
x_c2_1 = self.layer6(x_c2)
|
240 |
+
x_c1 = torch.cat([x_c2_1, x_c1], dim=1)
|
241 |
+
x_c1 = self.layer7(x_c1)
|
242 |
+
'''
|
243 |
+
x_c1_2 = self.layer8(x_c1)
|
244 |
+
x_c2 = torch.cat([x_c1_2, x_c2], dim=1)
|
245 |
+
x_c2 = self.layer9(x_c2)
|
246 |
+
x_c2_3 = self.layer10(x_c2)
|
247 |
+
x_c3 = torch.cat([x_c2_3, x_c3], dim=1)
|
248 |
+
x_c3 = self.layer11(x_c3)
|
249 |
+
x_c3_4 = self.layer12(x_c3)
|
250 |
+
x_c4 = torch.cat([x_c3_4, x_c4], dim=1)
|
251 |
+
x_c4 = self.layer13(x_c4)
|
252 |
+
'''
|
253 |
+
return [x_c4, x_c3, x_c2, x_c1]
|
254 |
+
|
255 |
+
|
256 |
+
class EVPDepthEncoder(nn.Module):
|
257 |
+
def __init__(self, out_dim=1024, ldm_prior=[320, 680, 1320+1280], sd_path=None, text_dim=768,
|
258 |
+
dataset='nyu', caption_aggregation=False
|
259 |
+
):
|
260 |
+
super().__init__()
|
261 |
+
|
262 |
+
|
263 |
+
self.layer1 = nn.Sequential(
|
264 |
+
nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
|
265 |
+
nn.GroupNorm(16, ldm_prior[0]),
|
266 |
+
nn.ReLU(),
|
267 |
+
nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
|
268 |
+
)
|
269 |
+
|
270 |
+
self.layer2 = nn.Sequential(
|
271 |
+
nn.Conv2d(ldm_prior[1], ldm_prior[1], 3, stride=2, padding=1),
|
272 |
+
)
|
273 |
+
|
274 |
+
self.out_layer = nn.Sequential(
|
275 |
+
nn.Conv2d(sum(ldm_prior), out_dim, 1),
|
276 |
+
nn.GroupNorm(16, out_dim),
|
277 |
+
nn.ReLU(),
|
278 |
+
)
|
279 |
+
|
280 |
+
self.aggregation = InverseMultiAttentiveFeatureRefinement([320, 680, 1320, 1280])
|
281 |
+
|
282 |
+
self.apply(self._init_weights)
|
283 |
+
|
284 |
+
### stable diffusion layers
|
285 |
+
|
286 |
+
config = OmegaConf.load('./v1-inference.yaml')
|
287 |
+
if sd_path is None:
|
288 |
+
if os.path.exists('../checkpoints/v1-5-pruned-emaonly.ckpt'):
|
289 |
+
config.model.params.ckpt_path = '../checkpoints/v1-5-pruned-emaonly.ckpt'
|
290 |
+
else:
|
291 |
+
config.model.params.ckpt_path = None
|
292 |
+
else:
|
293 |
+
config.model.params.ckpt_path = f'../{sd_path}'
|
294 |
+
|
295 |
+
sd_model = instantiate_from_config(config.model)
|
296 |
+
self.encoder_vq = sd_model.first_stage_model
|
297 |
+
|
298 |
+
self.unet = UNetWrapper(sd_model.model, use_attn=True)
|
299 |
+
if dataset == 'kitti':
|
300 |
+
self.unet = UNetWrapper(sd_model.model, use_attn=True, base_size=384)
|
301 |
+
|
302 |
+
del sd_model.cond_stage_model
|
303 |
+
del self.encoder_vq.decoder
|
304 |
+
del self.unet.unet.diffusion_model.out
|
305 |
+
del self.encoder_vq.post_quant_conv.weight
|
306 |
+
del self.encoder_vq.post_quant_conv.bias
|
307 |
+
|
308 |
+
for param in self.encoder_vq.parameters():
|
309 |
+
param.requires_grad = True
|
310 |
+
|
311 |
+
self.text_adapter = TextAdapterRefer(text_dim=text_dim)
|
312 |
+
self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4)
|
313 |
+
|
314 |
+
if caption_aggregation:
|
315 |
+
class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth')
|
316 |
+
#class_embeddings_list = [value['class_embeddings'] for key, value in class_embeddings.items()]
|
317 |
+
#stacked_embeddings = torch.stack(class_embeddings_list, dim=0)
|
318 |
+
#class_embeddings = torch.mean(stacked_embeddings, dim=0).unsqueeze(0)
|
319 |
+
|
320 |
+
if 'aggregated' in class_embeddings:
|
321 |
+
class_embeddings = class_embeddings['aggregated']
|
322 |
+
else:
|
323 |
+
clip_model = FrozenCLIPEmbedder(max_length=40,pool=False).cuda()
|
324 |
+
class_embeddings_new = [clip_model.encode(value['caption'][0]) for key, value in class_embeddings.items()]
|
325 |
+
class_embeddings_new = torch.mean(torch.stack(class_embeddings_new, dim=0), dim=0)
|
326 |
+
class_embeddings['aggregated'] = class_embeddings_new
|
327 |
+
torch.save(class_embeddings, f'{dataset}_class_embeddings_my_captions.pth')
|
328 |
+
class_embeddings = class_embeddings['aggregated']
|
329 |
+
self.register_buffer('class_embeddings', class_embeddings)
|
330 |
+
else:
|
331 |
+
self.class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth')
|
332 |
+
|
333 |
+
self.clip_model = FrozenCLIPEmbedder(max_length=40,pool=False)
|
334 |
+
for param in self.clip_model.parameters():
|
335 |
+
param.requires_grad = True
|
336 |
+
|
337 |
+
#if dataset == 'kitti':
|
338 |
+
# self.text_adapter_ = TextAdapterRefer(text_dim=text_dim)
|
339 |
+
# self.gamma_ = nn.Parameter(torch.ones(text_dim) * 1e-4)
|
340 |
+
|
341 |
+
self.caption_aggregation = caption_aggregation
|
342 |
+
self.dataset = dataset
|
343 |
+
|
344 |
+
def _init_weights(self, m):
|
345 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
346 |
+
trunc_normal_(m.weight, std=.02)
|
347 |
+
nn.init.constant_(m.bias, 0)
|
348 |
+
|
349 |
+
def forward_features(self, feats):
|
350 |
+
x = self.ldm_to_net[0](feats[0])
|
351 |
+
for i in range(3):
|
352 |
+
if i > 0:
|
353 |
+
x = x + self.ldm_to_net[i](feats[i])
|
354 |
+
x = self.layers[i](x)
|
355 |
+
x = self.upsample_layers[i](x)
|
356 |
+
return self.out_conv(x)
|
357 |
+
|
358 |
+
def forward(self, x, class_ids=None, img_paths=None):
|
359 |
+
latents = self.encoder_vq.encode(x).mode()
|
360 |
+
|
361 |
+
# add division by std
|
362 |
+
if self.dataset == 'nyu':
|
363 |
+
latents = latents / 5.07543
|
364 |
+
elif self.dataset == 'kitti':
|
365 |
+
latents = latents / 4.6211
|
366 |
+
else:
|
367 |
+
print('Please calculate the STD for the dataset!')
|
368 |
+
|
369 |
+
if class_ids is not None:
|
370 |
+
if self.caption_aggregation:
|
371 |
+
class_embeddings = self.class_embeddings[[0]*len(class_ids.tolist())]#[class_ids.tolist()]
|
372 |
+
else:
|
373 |
+
class_embeddings = []
|
374 |
+
|
375 |
+
for img_path in img_paths:
|
376 |
+
class_embeddings.extend([value['caption'][0] for key, value in self.class_embeddings.items() if key in img_path.replace('//', '/')])
|
377 |
+
|
378 |
+
class_embeddings = self.clip_model.encode(class_embeddings)
|
379 |
+
else:
|
380 |
+
class_embeddings = self.class_embeddings
|
381 |
+
|
382 |
+
c_crossattn = self.text_adapter(latents, class_embeddings, self.gamma)
|
383 |
+
t = torch.ones((x.shape[0],), device=x.device).long()
|
384 |
+
|
385 |
+
#if self.dataset == 'kitti':
|
386 |
+
# c_crossattn_last = self.text_adapter_(latents, class_embeddings, self.gamma_)
|
387 |
+
# outs = self.unet(latents, t, c_crossattn=[c_crossattn, c_crossattn_last])
|
388 |
+
#else:
|
389 |
+
outs = self.unet(latents, t, c_crossattn=[c_crossattn])
|
390 |
+
outs = self.aggregation(outs)
|
391 |
+
|
392 |
+
feats = [outs[0], outs[1], torch.cat([outs[2], F.interpolate(outs[3], scale_factor=2)], dim=1)]
|
393 |
+
x = torch.cat([self.layer1(feats[0]), self.layer2(feats[1]), feats[2]], dim=1)
|
394 |
+
return self.out_layer(x)
|
395 |
+
|
396 |
+
def get_latent(self, x):
|
397 |
+
return self.encoder_vq.encode(x).mode()
|
398 |
+
|
399 |
+
|
400 |
+
class EVPDepth(nn.Module):
|
401 |
+
def __init__(self, args=None, caption_aggregation=False):
|
402 |
+
super().__init__()
|
403 |
+
self.max_depth = args.max_depth
|
404 |
+
self.min_depth = args.min_depth_eval
|
405 |
+
|
406 |
+
embed_dim = 192
|
407 |
+
|
408 |
+
channels_in = embed_dim*8
|
409 |
+
channels_out = embed_dim
|
410 |
+
|
411 |
+
if args.dataset == 'nyudepthv2':
|
412 |
+
self.encoder = EVPDepthEncoder(out_dim=channels_in, dataset='nyu', caption_aggregation=caption_aggregation)
|
413 |
+
else:
|
414 |
+
self.encoder = EVPDepthEncoder(out_dim=channels_in, dataset='kitti', caption_aggregation=caption_aggregation)
|
415 |
+
|
416 |
+
self.decoder = Decoder(channels_in, channels_out, args)
|
417 |
+
self.decoder.init_weights()
|
418 |
+
self.mViT = False
|
419 |
+
self.custom = False
|
420 |
+
|
421 |
+
|
422 |
+
if not self.mViT and not self.custom:
|
423 |
+
n_bins = 64
|
424 |
+
bin_embedding_dim = 128
|
425 |
+
num_out_features = [32, 32, 32, 192]
|
426 |
+
min_temp = 0.0212
|
427 |
+
max_temp = 50
|
428 |
+
btlnck_features = 256
|
429 |
+
n_attractors = [16, 8, 4, 1]
|
430 |
+
attractor_alpha = 1000
|
431 |
+
attractor_gamma = 2
|
432 |
+
attractor_kind = "mean"
|
433 |
+
attractor_type = "inv"
|
434 |
+
self.bin_centers_type = "softplus"
|
435 |
+
|
436 |
+
self.bottle_neck = nn.Sequential(
|
437 |
+
nn.Conv2d(channels_in, btlnck_features, kernel_size=3, stride=1, padding=1),
|
438 |
+
nn.ReLU(inplace=False),
|
439 |
+
nn.Conv2d(btlnck_features, btlnck_features, kernel_size=3, stride=1, padding=1))
|
440 |
+
|
441 |
+
|
442 |
+
for m in self.bottle_neck.modules():
|
443 |
+
if isinstance(m, nn.Conv2d):
|
444 |
+
normal_init(m, std=0.001, bias=0)
|
445 |
+
|
446 |
+
|
447 |
+
SeedBinRegressorLayer = SeedBinRegressorUnnormed
|
448 |
+
Attractor = AttractorLayerUnnormed
|
449 |
+
self.seed_bin_regressor = SeedBinRegressorLayer(
|
450 |
+
btlnck_features, n_bins=n_bins, min_depth=self.min_depth, max_depth=self.max_depth)
|
451 |
+
self.seed_projector = Projector(btlnck_features, bin_embedding_dim)
|
452 |
+
self.projectors = nn.ModuleList([
|
453 |
+
Projector(num_out, bin_embedding_dim)
|
454 |
+
for num_out in num_out_features
|
455 |
+
])
|
456 |
+
self.attractors = nn.ModuleList([
|
457 |
+
Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=self.min_depth, max_depth=self.max_depth,
|
458 |
+
alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type)
|
459 |
+
for i in range(len(num_out_features))
|
460 |
+
])
|
461 |
+
|
462 |
+
last_in = 192 + 1
|
463 |
+
self.conditional_log_binomial = ConditionalLogBinomial(
|
464 |
+
last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp)
|
465 |
+
elif self.mViT and not self.custom:
|
466 |
+
n_bins = 256
|
467 |
+
self.adaptive_bins_layer = mViT(192, n_query_channels=192, patch_size=16,
|
468 |
+
dim_out=n_bins,
|
469 |
+
embedding_dim=192, norm='linear')
|
470 |
+
self.conv_out = nn.Sequential(nn.Conv2d(192, n_bins, kernel_size=1, stride=1, padding=0),
|
471 |
+
nn.Softmax(dim=1))
|
472 |
+
|
473 |
+
|
474 |
+
def forward(self, x, class_ids=None, img_paths=None):
|
475 |
+
b, c, h, w = x.shape
|
476 |
+
x = x*2.0 - 1.0 # normalize to [-1, 1]
|
477 |
+
if h == 480 and w == 480:
|
478 |
+
new_x = torch.zeros(b, c, 512, 512, device=x.device)
|
479 |
+
new_x[:, :, 0:480, 0:480] = x
|
480 |
+
x = new_x
|
481 |
+
elif h==352 and w==352:
|
482 |
+
new_x = torch.zeros(b, c, 384, 384, device=x.device)
|
483 |
+
new_x[:, :, 0:352, 0:352] = x
|
484 |
+
x = new_x
|
485 |
+
elif h == 512 and w == 512:
|
486 |
+
pass
|
487 |
+
else:
|
488 |
+
print(h,w)
|
489 |
+
raise NotImplementedError
|
490 |
+
conv_feats = self.encoder(x, class_ids, img_paths)
|
491 |
+
|
492 |
+
if h == 480 or h == 352:
|
493 |
+
conv_feats = conv_feats[:, :, :-1, :-1]
|
494 |
+
|
495 |
+
self.decoder.remove_hooks()
|
496 |
+
out_depth, out, x_blocks = self.decoder([conv_feats])
|
497 |
+
|
498 |
+
if not self.mViT and not self.custom:
|
499 |
+
x = self.bottle_neck(conv_feats)
|
500 |
+
_, seed_b_centers = self.seed_bin_regressor(x)
|
501 |
+
|
502 |
+
if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2':
|
503 |
+
b_prev = (seed_b_centers - self.min_depth) / \
|
504 |
+
(self.max_depth - self.min_depth)
|
505 |
+
else:
|
506 |
+
b_prev = seed_b_centers
|
507 |
+
|
508 |
+
prev_b_embedding = self.seed_projector(x)
|
509 |
+
|
510 |
+
for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks):
|
511 |
+
b_embedding = projector(x)
|
512 |
+
b, b_centers = attractor(
|
513 |
+
b_embedding, b_prev, prev_b_embedding, interpolate=True)
|
514 |
+
b_prev = b.clone()
|
515 |
+
prev_b_embedding = b_embedding.clone()
|
516 |
+
|
517 |
+
rel_cond = torch.sigmoid(out_depth) * self.max_depth
|
518 |
+
|
519 |
+
# concat rel depth with last. First interpolate rel depth to last size
|
520 |
+
rel_cond = nn.functional.interpolate(
|
521 |
+
rel_cond, size=out.shape[2:], mode='bilinear', align_corners=True)
|
522 |
+
last = torch.cat([out, rel_cond], dim=1)
|
523 |
+
|
524 |
+
b_embedding = nn.functional.interpolate(
|
525 |
+
b_embedding, last.shape[-2:], mode='bilinear', align_corners=True)
|
526 |
+
x = self.conditional_log_binomial(last, b_embedding)
|
527 |
+
|
528 |
+
# Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor
|
529 |
+
b_centers = nn.functional.interpolate(
|
530 |
+
b_centers, x.shape[-2:], mode='bilinear', align_corners=True)
|
531 |
+
out_depth = torch.sum(x * b_centers, dim=1, keepdim=True)
|
532 |
+
|
533 |
+
elif self.mViT and not self.custom:
|
534 |
+
bin_widths_normed, range_attention_maps = self.adaptive_bins_layer(out)
|
535 |
+
out = self.conv_out(range_attention_maps)
|
536 |
+
|
537 |
+
bin_widths = (self.max_depth - self.min_depth) * bin_widths_normed # .shape = N, dim_out
|
538 |
+
bin_widths = nn.functional.pad(bin_widths, (1, 0), mode='constant', value=self.min_depth)
|
539 |
+
bin_edges = torch.cumsum(bin_widths, dim=1)
|
540 |
+
|
541 |
+
centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
|
542 |
+
n, dout = centers.size()
|
543 |
+
centers = centers.view(n, dout, 1, 1)
|
544 |
+
|
545 |
+
out_depth = torch.sum(out * centers, dim=1, keepdim=True)
|
546 |
+
else:
|
547 |
+
out_depth = torch.sigmoid(out_depth) * self.max_depth
|
548 |
+
|
549 |
+
return {'pred_d': out_depth}
|
550 |
+
|
551 |
+
|
552 |
+
class Decoder(nn.Module):
|
553 |
+
def __init__(self, in_channels, out_channels, args):
|
554 |
+
super().__init__()
|
555 |
+
self.deconv = args.num_deconv
|
556 |
+
self.in_channels = in_channels
|
557 |
+
|
558 |
+
embed_dim = 192
|
559 |
+
|
560 |
+
channels_in = embed_dim*8
|
561 |
+
channels_out = embed_dim
|
562 |
+
|
563 |
+
self.deconv_layers, self.intermediate_results = self._make_deconv_layer(
|
564 |
+
args.num_deconv,
|
565 |
+
args.num_filters,
|
566 |
+
args.deconv_kernels,
|
567 |
+
)
|
568 |
+
self.last_layer_depth = nn.Sequential(
|
569 |
+
nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
|
570 |
+
nn.ReLU(inplace=False),
|
571 |
+
nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1))
|
572 |
+
|
573 |
+
for m in self.last_layer_depth.modules():
|
574 |
+
if isinstance(m, nn.Conv2d):
|
575 |
+
normal_init(m, std=0.001, bias=0)
|
576 |
+
|
577 |
+
conv_layers = []
|
578 |
+
conv_layers.append(
|
579 |
+
build_conv_layer(
|
580 |
+
dict(type='Conv2d'),
|
581 |
+
in_channels=args.num_filters[-1],
|
582 |
+
out_channels=out_channels,
|
583 |
+
kernel_size=3,
|
584 |
+
stride=1,
|
585 |
+
padding=1))
|
586 |
+
conv_layers.append(
|
587 |
+
build_norm_layer(dict(type='BN'), out_channels)[1])
|
588 |
+
conv_layers.append(nn.ReLU())
|
589 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
590 |
+
|
591 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
592 |
+
|
593 |
+
def forward(self, conv_feats):
|
594 |
+
out = self.deconv_layers(conv_feats[0])
|
595 |
+
out = self.conv_layers(out)
|
596 |
+
out = self.up(out)
|
597 |
+
self.intermediate_results.append(out)
|
598 |
+
out = self.up(out)
|
599 |
+
out_depth = self.last_layer_depth(out)
|
600 |
+
|
601 |
+
return out_depth, out, self.intermediate_results
|
602 |
+
|
603 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
604 |
+
"""Make deconv layers."""
|
605 |
+
|
606 |
+
layers = []
|
607 |
+
in_planes = self.in_channels
|
608 |
+
intermediate_results = [] # List to store intermediate feature maps
|
609 |
+
|
610 |
+
for i in range(num_layers):
|
611 |
+
kernel, padding, output_padding = \
|
612 |
+
self._get_deconv_cfg(num_kernels[i])
|
613 |
+
|
614 |
+
planes = num_filters[i]
|
615 |
+
layers.append(
|
616 |
+
build_upsample_layer(
|
617 |
+
dict(type='deconv'),
|
618 |
+
in_channels=in_planes,
|
619 |
+
out_channels=planes,
|
620 |
+
kernel_size=kernel,
|
621 |
+
stride=2,
|
622 |
+
padding=padding,
|
623 |
+
output_padding=output_padding,
|
624 |
+
bias=False))
|
625 |
+
layers.append(nn.BatchNorm2d(planes))
|
626 |
+
layers.append(nn.ReLU())
|
627 |
+
in_planes = planes
|
628 |
+
|
629 |
+
# Add a hook to store the intermediate result
|
630 |
+
layers[-1].register_forward_hook(self._hook_fn(intermediate_results))
|
631 |
+
|
632 |
+
return nn.Sequential(*layers), intermediate_results
|
633 |
+
|
634 |
+
def _hook_fn(self, intermediate_results):
|
635 |
+
def hook(module, input, output):
|
636 |
+
intermediate_results.append(output)
|
637 |
+
return hook
|
638 |
+
|
639 |
+
def remove_hooks(self):
|
640 |
+
self.intermediate_results.clear()
|
641 |
+
|
642 |
+
def _get_deconv_cfg(self, deconv_kernel):
|
643 |
+
"""Get configurations for deconv layers."""
|
644 |
+
if deconv_kernel == 4:
|
645 |
+
padding = 1
|
646 |
+
output_padding = 0
|
647 |
+
elif deconv_kernel == 3:
|
648 |
+
padding = 1
|
649 |
+
output_padding = 1
|
650 |
+
elif deconv_kernel == 2:
|
651 |
+
padding = 0
|
652 |
+
output_padding = 0
|
653 |
+
else:
|
654 |
+
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
|
655 |
+
|
656 |
+
return deconv_kernel, padding, output_padding
|
657 |
+
|
658 |
+
def init_weights(self):
|
659 |
+
"""Initialize model weights."""
|
660 |
+
for m in self.modules():
|
661 |
+
if isinstance(m, nn.Conv2d):
|
662 |
+
normal_init(m, std=0.001, bias=0)
|
663 |
+
elif isinstance(m, nn.BatchNorm2d):
|
664 |
+
constant_init(m, 1)
|
665 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
666 |
+
normal_init(m, std=0.001)
|
depth/models_depth/model_vpd.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft
|
3 |
+
# Licensed under the MIT License.
|
4 |
+
# The deconvolution code is based on Simple Baseline.
|
5 |
+
# (https://github.com/microsoft/human-pose-estimation.pytorch/blob/master/lib/models/pose_resnet.py)
|
6 |
+
# Modified by Zigang Geng ([email protected]).
|
7 |
+
# ------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
12 |
+
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
|
13 |
+
constant_init, normal_init)
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from ldm.util import instantiate_from_config
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from evp.models import UNetWrapper, TextAdapterDepth
|
19 |
+
|
20 |
+
class VPDDepthEncoder(nn.Module):
|
21 |
+
def __init__(self, out_dim=1024, ldm_prior=[320, 640, 1280+1280], sd_path=None, text_dim=768,
|
22 |
+
dataset='nyu'
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
|
27 |
+
self.layer1 = nn.Sequential(
|
28 |
+
nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
|
29 |
+
nn.GroupNorm(16, ldm_prior[0]),
|
30 |
+
nn.ReLU(),
|
31 |
+
nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
|
32 |
+
)
|
33 |
+
|
34 |
+
self.layer2 = nn.Sequential(
|
35 |
+
nn.Conv2d(ldm_prior[1], ldm_prior[1], 3, stride=2, padding=1),
|
36 |
+
)
|
37 |
+
|
38 |
+
self.out_layer = nn.Sequential(
|
39 |
+
nn.Conv2d(sum(ldm_prior), out_dim, 1),
|
40 |
+
nn.GroupNorm(16, out_dim),
|
41 |
+
nn.ReLU(),
|
42 |
+
)
|
43 |
+
|
44 |
+
self.apply(self._init_weights)
|
45 |
+
|
46 |
+
### stable diffusion layers
|
47 |
+
|
48 |
+
config = OmegaConf.load('./v1-inference.yaml')
|
49 |
+
if sd_path is None:
|
50 |
+
config.model.params.ckpt_path = '../checkpoints/v1-5-pruned-emaonly.ckpt'
|
51 |
+
else:
|
52 |
+
config.model.params.ckpt_path = f'../{sd_path}'
|
53 |
+
|
54 |
+
sd_model = instantiate_from_config(config.model)
|
55 |
+
self.encoder_vq = sd_model.first_stage_model
|
56 |
+
|
57 |
+
self.unet = UNetWrapper(sd_model.model, use_attn=False)
|
58 |
+
|
59 |
+
del sd_model.cond_stage_model
|
60 |
+
del self.encoder_vq.decoder
|
61 |
+
del self.unet.unet.diffusion_model.out
|
62 |
+
|
63 |
+
for param in self.encoder_vq.parameters():
|
64 |
+
param.requires_grad = False
|
65 |
+
|
66 |
+
if dataset == 'nyu':
|
67 |
+
self.text_adapter = TextAdapterDepth(text_dim=text_dim)
|
68 |
+
class_embeddings = torch.load('nyu_class_embeddings.pth')
|
69 |
+
else:
|
70 |
+
raise NotImplementedError
|
71 |
+
|
72 |
+
self.register_buffer('class_embeddings', class_embeddings)
|
73 |
+
self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4)
|
74 |
+
|
75 |
+
|
76 |
+
def _init_weights(self, m):
|
77 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
78 |
+
trunc_normal_(m.weight, std=.02)
|
79 |
+
nn.init.constant_(m.bias, 0)
|
80 |
+
|
81 |
+
def forward_features(self, feats):
|
82 |
+
x = self.ldm_to_net[0](feats[0])
|
83 |
+
for i in range(3):
|
84 |
+
if i > 0:
|
85 |
+
x = x + self.ldm_to_net[i](feats[i])
|
86 |
+
x = self.layers[i](x)
|
87 |
+
x = self.upsample_layers[i](x)
|
88 |
+
return self.out_conv(x)
|
89 |
+
|
90 |
+
def forward(self, x, class_ids=None,img_paths=None):
|
91 |
+
with torch.no_grad():
|
92 |
+
latents = self.encoder_vq.encode(x).mode().detach()
|
93 |
+
|
94 |
+
if class_ids is not None:
|
95 |
+
class_embeddings = self.class_embeddings[class_ids.tolist()]
|
96 |
+
else:
|
97 |
+
class_embeddings = self.class_embeddings
|
98 |
+
|
99 |
+
c_crossattn = self.text_adapter(latents, class_embeddings, self.gamma) # NOTE: here the c_crossattn should be expand_dim as latents
|
100 |
+
t = torch.ones((x.shape[0],), device=x.device).long()
|
101 |
+
# import pdb; pdb.set_trace()
|
102 |
+
outs = self.unet(latents, t, c_crossattn=[c_crossattn])
|
103 |
+
feats = [outs[0], outs[1], torch.cat([outs[2], F.interpolate(outs[3], scale_factor=2)], dim=1)]
|
104 |
+
x = torch.cat([self.layer1(feats[0]), self.layer2(feats[1]), feats[2]], dim=1)
|
105 |
+
return self.out_layer(x)
|
106 |
+
|
107 |
+
class VPDDepth(nn.Module):
|
108 |
+
def __init__(self, args=None):
|
109 |
+
super().__init__()
|
110 |
+
self.max_depth = args.max_depth
|
111 |
+
|
112 |
+
embed_dim = 192
|
113 |
+
|
114 |
+
channels_in = embed_dim*8
|
115 |
+
channels_out = embed_dim
|
116 |
+
|
117 |
+
if args.dataset == 'nyudepthv2':
|
118 |
+
self.encoder = VPDDepthEncoder(out_dim=channels_in, dataset='nyu')
|
119 |
+
else:
|
120 |
+
raise NotImplementedError
|
121 |
+
|
122 |
+
self.decoder = Decoder(channels_in, channels_out, args)
|
123 |
+
self.decoder.init_weights()
|
124 |
+
|
125 |
+
self.last_layer_depth = nn.Sequential(
|
126 |
+
nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
|
127 |
+
nn.ReLU(inplace=False),
|
128 |
+
nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1))
|
129 |
+
|
130 |
+
for m in self.last_layer_depth.modules():
|
131 |
+
if isinstance(m, nn.Conv2d):
|
132 |
+
normal_init(m, std=0.001, bias=0)
|
133 |
+
|
134 |
+
def forward(self, x, class_ids=None,img_paths=None):
|
135 |
+
# import pdb; pdb.set_trace()
|
136 |
+
b, c, h, w = x.shape
|
137 |
+
x = x*2.0 - 1.0 # normalize to [-1, 1]
|
138 |
+
if h == 480 and w == 480:
|
139 |
+
new_x = torch.zeros(b, c, 512, 512, device=x.device)
|
140 |
+
new_x[:, :, 0:480, 0:480] = x
|
141 |
+
x = new_x
|
142 |
+
elif h==352 and w==352:
|
143 |
+
new_x = torch.zeros(b, c, 384, 384, device=x.device)
|
144 |
+
new_x[:, :, 0:352, 0:352] = x
|
145 |
+
x = new_x
|
146 |
+
elif h == 512 and w == 512:
|
147 |
+
pass
|
148 |
+
else:
|
149 |
+
raise NotImplementedError
|
150 |
+
conv_feats = self.encoder(x, class_ids)
|
151 |
+
|
152 |
+
if h == 480 or h == 352:
|
153 |
+
conv_feats = conv_feats[:, :, :-1, :-1]
|
154 |
+
|
155 |
+
out = self.decoder([conv_feats])
|
156 |
+
out_depth = self.last_layer_depth(out)
|
157 |
+
out_depth = torch.sigmoid(out_depth) * self.max_depth
|
158 |
+
|
159 |
+
return {'pred_d': out_depth}
|
160 |
+
|
161 |
+
|
162 |
+
class Decoder(nn.Module):
|
163 |
+
def __init__(self, in_channels, out_channels, args):
|
164 |
+
super().__init__()
|
165 |
+
self.deconv = args.num_deconv
|
166 |
+
self.in_channels = in_channels
|
167 |
+
|
168 |
+
# import pdb; pdb.set_trace()
|
169 |
+
|
170 |
+
self.deconv_layers = self._make_deconv_layer(
|
171 |
+
args.num_deconv,
|
172 |
+
args.num_filters,
|
173 |
+
args.deconv_kernels,
|
174 |
+
)
|
175 |
+
|
176 |
+
conv_layers = []
|
177 |
+
conv_layers.append(
|
178 |
+
build_conv_layer(
|
179 |
+
dict(type='Conv2d'),
|
180 |
+
in_channels=args.num_filters[-1],
|
181 |
+
out_channels=out_channels,
|
182 |
+
kernel_size=3,
|
183 |
+
stride=1,
|
184 |
+
padding=1))
|
185 |
+
conv_layers.append(
|
186 |
+
build_norm_layer(dict(type='BN'), out_channels)[1])
|
187 |
+
conv_layers.append(nn.ReLU(inplace=True))
|
188 |
+
self.conv_layers = nn.Sequential(*conv_layers)
|
189 |
+
|
190 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
191 |
+
|
192 |
+
def forward(self, conv_feats):
|
193 |
+
# import pdb; pdb.set_trace()
|
194 |
+
out = self.deconv_layers(conv_feats[0])
|
195 |
+
out = self.conv_layers(out)
|
196 |
+
|
197 |
+
out = self.up(out)
|
198 |
+
out = self.up(out)
|
199 |
+
|
200 |
+
return out
|
201 |
+
|
202 |
+
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
|
203 |
+
"""Make deconv layers."""
|
204 |
+
|
205 |
+
layers = []
|
206 |
+
in_planes = self.in_channels
|
207 |
+
for i in range(num_layers):
|
208 |
+
kernel, padding, output_padding = \
|
209 |
+
self._get_deconv_cfg(num_kernels[i])
|
210 |
+
|
211 |
+
planes = num_filters[i]
|
212 |
+
layers.append(
|
213 |
+
build_upsample_layer(
|
214 |
+
dict(type='deconv'),
|
215 |
+
in_channels=in_planes,
|
216 |
+
out_channels=planes,
|
217 |
+
kernel_size=kernel,
|
218 |
+
stride=2,
|
219 |
+
padding=padding,
|
220 |
+
output_padding=output_padding,
|
221 |
+
bias=False))
|
222 |
+
layers.append(nn.BatchNorm2d(planes))
|
223 |
+
layers.append(nn.ReLU(inplace=True))
|
224 |
+
in_planes = planes
|
225 |
+
|
226 |
+
return nn.Sequential(*layers)
|
227 |
+
|
228 |
+
def _get_deconv_cfg(self, deconv_kernel):
|
229 |
+
"""Get configurations for deconv layers."""
|
230 |
+
if deconv_kernel == 4:
|
231 |
+
padding = 1
|
232 |
+
output_padding = 0
|
233 |
+
elif deconv_kernel == 3:
|
234 |
+
padding = 1
|
235 |
+
output_padding = 1
|
236 |
+
elif deconv_kernel == 2:
|
237 |
+
padding = 0
|
238 |
+
output_padding = 0
|
239 |
+
else:
|
240 |
+
raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')
|
241 |
+
|
242 |
+
return deconv_kernel, padding, output_padding
|
243 |
+
|
244 |
+
def init_weights(self):
|
245 |
+
"""Initialize model weights."""
|
246 |
+
for m in self.modules():
|
247 |
+
if isinstance(m, nn.Conv2d):
|
248 |
+
normal_init(m, std=0.001, bias=0)
|
249 |
+
elif isinstance(m, nn.BatchNorm2d):
|
250 |
+
constant_init(m, 1)
|
251 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
252 |
+
normal_init(m, std=0.001)
|
depth/models_depth/optimizer.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Copyright (c) Microsoft
|
3 |
+
# Licensed under the MIT License.
|
4 |
+
# The code is from SimMIM.
|
5 |
+
# (https://github.com/microsoft/SimMIM)
|
6 |
+
# ------------------------------------------------------------------------------
|
7 |
+
|
8 |
+
import json
|
9 |
+
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
|
10 |
+
from mmcv.runner import build_optimizer
|
11 |
+
from mmcv.runner import get_dist_info
|
12 |
+
|
13 |
+
|
14 |
+
def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
|
15 |
+
var_name = var_name.replace('encoder', 'backbone') if var_name.startswith('encoder') else var_name
|
16 |
+
if var_name in ("backbone.cls_token", "backbone.mask_token",
|
17 |
+
"backbone.pos_embed", "backbone.absolute_pos_embed"):
|
18 |
+
return 0
|
19 |
+
elif var_name.startswith("backbone.patch_embed"):
|
20 |
+
return 0
|
21 |
+
elif var_name.startswith("backbone.layers"):
|
22 |
+
if var_name.split('.')[3] == "blocks":
|
23 |
+
stage_id = int(var_name.split('.')[2])
|
24 |
+
layer_id = int(var_name.split('.')[4]) \
|
25 |
+
+ sum(layers_per_stage[:stage_id])
|
26 |
+
return layer_id + 1
|
27 |
+
elif var_name.split('.')[3] == "downsample":
|
28 |
+
stage_id = int(var_name.split('.')[2])
|
29 |
+
layer_id = sum(layers_per_stage[:stage_id + 1])
|
30 |
+
return layer_id
|
31 |
+
else:
|
32 |
+
return num_max_layer - 1
|
33 |
+
|
34 |
+
@OPTIMIZER_BUILDERS.register_module()
|
35 |
+
class LDMOptimizerConstructor(DefaultOptimizerConstructor):
|
36 |
+
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
37 |
+
"""Add all parameters of module to the params list.
|
38 |
+
The parameters of the given module will be added to the list of param
|
39 |
+
groups, with specific rules defined by paramwise_cfg.
|
40 |
+
Args:
|
41 |
+
params (list[dict]): A list of param groups, it will be modified
|
42 |
+
in place.
|
43 |
+
module (nn.Module): The module to be added.
|
44 |
+
prefix (str): The prefix of the module
|
45 |
+
is_dcn_module (int|float|None): If the current module is a
|
46 |
+
submodule of DCN, `is_dcn_module` will be passed to
|
47 |
+
control conv_offset layer's learning rate. Defaults to None.
|
48 |
+
"""
|
49 |
+
parameter_groups = {}
|
50 |
+
no_decay_names = self.paramwise_cfg.get('no_decay_names', [])
|
51 |
+
print("Build LDMOptimizerConstructor")
|
52 |
+
weight_decay = self.base_wd
|
53 |
+
|
54 |
+
for name, param in module.named_parameters():
|
55 |
+
if not param.requires_grad:
|
56 |
+
continue # frozen weights
|
57 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in ('absolute_pos_embed'):
|
58 |
+
group_name = "no_decay"
|
59 |
+
this_weight_decay = 0.
|
60 |
+
else:
|
61 |
+
group_name = "decay"
|
62 |
+
this_weight_decay = weight_decay
|
63 |
+
|
64 |
+
for nd_name in no_decay_names:
|
65 |
+
if nd_name in name:
|
66 |
+
group_name = "no_decay"
|
67 |
+
this_weight_decay = 0.
|
68 |
+
break
|
69 |
+
|
70 |
+
if 'unet' in name or 'cond_stage_model' in name or 'encoder_vq' in name or 'clip_model' in name:
|
71 |
+
layer_id = 0
|
72 |
+
else:
|
73 |
+
layer_id = 1
|
74 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
75 |
+
|
76 |
+
if group_name not in parameter_groups:
|
77 |
+
if layer_id == 0:
|
78 |
+
scale = 0.01
|
79 |
+
else:
|
80 |
+
scale = 1.0
|
81 |
+
|
82 |
+
parameter_groups[group_name] = {
|
83 |
+
"weight_decay": this_weight_decay,
|
84 |
+
"params": [],
|
85 |
+
"param_names": [],
|
86 |
+
"lr_scale": scale,
|
87 |
+
"group_name": group_name,
|
88 |
+
"lr": scale * self.base_lr,
|
89 |
+
}
|
90 |
+
|
91 |
+
parameter_groups[group_name]["params"].append(param)
|
92 |
+
parameter_groups[group_name]["param_names"].append(name)
|
93 |
+
rank, _ = get_dist_info()
|
94 |
+
if rank == 0:
|
95 |
+
to_display = {}
|
96 |
+
for key in parameter_groups:
|
97 |
+
to_display[key] = {
|
98 |
+
"param_names": parameter_groups[key]["param_names"],
|
99 |
+
"lr_scale": parameter_groups[key]["lr_scale"],
|
100 |
+
"lr": parameter_groups[key]["lr"],
|
101 |
+
"weight_decay": parameter_groups[key]["weight_decay"],
|
102 |
+
}
|
103 |
+
|
104 |
+
params.extend(parameter_groups.values())
|
105 |
+
|
106 |
+
def build_optimizers(model, cfgs):
|
107 |
+
"""Build multiple optimizers from configs.
|
108 |
+
|
109 |
+
If `cfgs` contains several dicts for optimizers, then a dict for each
|
110 |
+
constructed optimizers will be returned.
|
111 |
+
If `cfgs` only contains one optimizer config, the constructed optimizer
|
112 |
+
itself will be returned.
|
113 |
+
|
114 |
+
For example,
|
115 |
+
|
116 |
+
1) Multiple optimizer configs:
|
117 |
+
|
118 |
+
.. code-block:: python
|
119 |
+
|
120 |
+
optimizer_cfg = dict(
|
121 |
+
model1=dict(type='SGD', lr=lr),
|
122 |
+
model2=dict(type='SGD', lr=lr))
|
123 |
+
|
124 |
+
The return dict is
|
125 |
+
``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)``
|
126 |
+
|
127 |
+
2) Single optimizer config:
|
128 |
+
|
129 |
+
.. code-block:: python
|
130 |
+
|
131 |
+
optimizer_cfg = dict(type='SGD', lr=lr)
|
132 |
+
|
133 |
+
The return is ``torch.optim.Optimizer``.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
model (:obj:`nn.Module`): The model with parameters to be optimized.
|
137 |
+
cfgs (dict): The config dict of the optimizer.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`:
|
141 |
+
The initialized optimizers.
|
142 |
+
"""
|
143 |
+
optimizers = {}
|
144 |
+
if hasattr(model, 'module'):
|
145 |
+
model = model.module
|
146 |
+
# determine whether 'cfgs' has several dicts for optimizers
|
147 |
+
if all(isinstance(v, dict) for v in cfgs.values()):
|
148 |
+
for key, cfg in cfgs.items():
|
149 |
+
cfg_ = cfg.copy()
|
150 |
+
module = getattr(model, key)
|
151 |
+
optimizers[key] = build_optimizer(module, cfg_)
|
152 |
+
return optimizers
|
153 |
+
|
154 |
+
return build_optimizer(model, cfgs)
|
depth/requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.6.0
|
2 |
+
h5py>=3.6.0
|
3 |
+
scipy>=1.7.3
|
4 |
+
opencv-python>=4.5.5
|
5 |
+
timm>=0.5.4
|
6 |
+
albumentations>=1.1.0
|
7 |
+
tensorboardX>=2.4.1
|
8 |
+
gdown>=4.2.1
|
depth/test_img.jpg
ADDED
depth/utils.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
import os
|
10 |
+
import math
|
11 |
+
import time
|
12 |
+
from collections import defaultdict, deque
|
13 |
+
import datetime
|
14 |
+
import numpy as np
|
15 |
+
from timm.utils import get_state_dict
|
16 |
+
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.distributed as dist
|
21 |
+
from torch._six import inf
|
22 |
+
|
23 |
+
from tensorboardX import SummaryWriter
|
24 |
+
|
25 |
+
class SmoothedValue(object):
|
26 |
+
"""Track a series of values and provide access to smoothed values over a
|
27 |
+
window or the global series average.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, window_size=20, fmt=None):
|
31 |
+
if fmt is None:
|
32 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
33 |
+
self.deque = deque(maxlen=window_size)
|
34 |
+
self.total = 0.0
|
35 |
+
self.count = 0
|
36 |
+
self.fmt = fmt
|
37 |
+
|
38 |
+
def update(self, value, n=1):
|
39 |
+
self.deque.append(value)
|
40 |
+
self.count += n
|
41 |
+
self.total += value * n
|
42 |
+
|
43 |
+
def synchronize_between_processes(self):
|
44 |
+
"""
|
45 |
+
Warning: does not synchronize the deque!
|
46 |
+
"""
|
47 |
+
if not is_dist_avail_and_initialized():
|
48 |
+
return
|
49 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
50 |
+
dist.barrier()
|
51 |
+
dist.all_reduce(t)
|
52 |
+
t = t.tolist()
|
53 |
+
self.count = int(t[0])
|
54 |
+
self.total = t[1]
|
55 |
+
|
56 |
+
@property
|
57 |
+
def median(self):
|
58 |
+
d = torch.tensor(list(self.deque))
|
59 |
+
return d.median().item()
|
60 |
+
|
61 |
+
@property
|
62 |
+
def avg(self):
|
63 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
64 |
+
return d.mean().item()
|
65 |
+
|
66 |
+
@property
|
67 |
+
def global_avg(self):
|
68 |
+
return self.total / self.count
|
69 |
+
|
70 |
+
@property
|
71 |
+
def max(self):
|
72 |
+
return max(self.deque)
|
73 |
+
|
74 |
+
@property
|
75 |
+
def value(self):
|
76 |
+
return self.deque[-1]
|
77 |
+
|
78 |
+
def __str__(self):
|
79 |
+
return self.fmt.format(
|
80 |
+
median=self.median,
|
81 |
+
avg=self.avg,
|
82 |
+
global_avg=self.global_avg,
|
83 |
+
max=self.max,
|
84 |
+
value=self.value)
|
85 |
+
|
86 |
+
|
87 |
+
class MetricLogger(object):
|
88 |
+
def __init__(self, delimiter="\t"):
|
89 |
+
self.meters = defaultdict(SmoothedValue)
|
90 |
+
self.delimiter = delimiter
|
91 |
+
|
92 |
+
def update(self, **kwargs):
|
93 |
+
for k, v in kwargs.items():
|
94 |
+
if v is None:
|
95 |
+
continue
|
96 |
+
if isinstance(v, torch.Tensor):
|
97 |
+
v = v.item()
|
98 |
+
assert isinstance(v, (float, int))
|
99 |
+
self.meters[k].update(v)
|
100 |
+
|
101 |
+
def __getattr__(self, attr):
|
102 |
+
if attr in self.meters:
|
103 |
+
return self.meters[attr]
|
104 |
+
if attr in self.__dict__:
|
105 |
+
return self.__dict__[attr]
|
106 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
107 |
+
type(self).__name__, attr))
|
108 |
+
|
109 |
+
def __str__(self):
|
110 |
+
loss_str = []
|
111 |
+
for name, meter in self.meters.items():
|
112 |
+
loss_str.append(
|
113 |
+
"{}: {}".format(name, str(meter))
|
114 |
+
)
|
115 |
+
return self.delimiter.join(loss_str)
|
116 |
+
|
117 |
+
def synchronize_between_processes(self):
|
118 |
+
for meter in self.meters.values():
|
119 |
+
meter.synchronize_between_processes()
|
120 |
+
|
121 |
+
def add_meter(self, name, meter):
|
122 |
+
self.meters[name] = meter
|
123 |
+
|
124 |
+
def log_every(self, iterable, print_freq, header=None):
|
125 |
+
i = 0
|
126 |
+
if not header:
|
127 |
+
header = ''
|
128 |
+
start_time = time.time()
|
129 |
+
end = time.time()
|
130 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
131 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
132 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
133 |
+
log_msg = [
|
134 |
+
header,
|
135 |
+
'[{0' + space_fmt + '}/{1}]',
|
136 |
+
'eta: {eta}',
|
137 |
+
'{meters}',
|
138 |
+
'time: {time}',
|
139 |
+
'data: {data}'
|
140 |
+
]
|
141 |
+
if torch.cuda.is_available():
|
142 |
+
log_msg.append('max mem: {memory:.0f}')
|
143 |
+
log_msg = self.delimiter.join(log_msg)
|
144 |
+
MB = 1024.0 * 1024.0
|
145 |
+
for obj in iterable:
|
146 |
+
data_time.update(time.time() - end)
|
147 |
+
yield obj
|
148 |
+
iter_time.update(time.time() - end)
|
149 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
150 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
151 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
152 |
+
if torch.cuda.is_available():
|
153 |
+
print(log_msg.format(
|
154 |
+
i, len(iterable), eta=eta_string,
|
155 |
+
meters=str(self),
|
156 |
+
time=str(iter_time), data=str(data_time),
|
157 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
158 |
+
else:
|
159 |
+
print(log_msg.format(
|
160 |
+
i, len(iterable), eta=eta_string,
|
161 |
+
meters=str(self),
|
162 |
+
time=str(iter_time), data=str(data_time)))
|
163 |
+
i += 1
|
164 |
+
end = time.time()
|
165 |
+
total_time = time.time() - start_time
|
166 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
167 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
168 |
+
header, total_time_str, total_time / len(iterable)))
|
169 |
+
|
170 |
+
|
171 |
+
class TensorboardLogger(object):
|
172 |
+
def __init__(self, log_dir):
|
173 |
+
self.writer = SummaryWriter(logdir=log_dir)
|
174 |
+
self.step = 0
|
175 |
+
|
176 |
+
def set_step(self, step=None):
|
177 |
+
if step is not None:
|
178 |
+
self.step = step
|
179 |
+
else:
|
180 |
+
self.step += 1
|
181 |
+
|
182 |
+
def update(self, head='scalar', step=None, **kwargs):
|
183 |
+
for k, v in kwargs.items():
|
184 |
+
if v is None:
|
185 |
+
continue
|
186 |
+
if isinstance(v, torch.Tensor):
|
187 |
+
v = v.item()
|
188 |
+
assert isinstance(v, (float, int))
|
189 |
+
self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
|
190 |
+
|
191 |
+
def flush(self):
|
192 |
+
self.writer.flush()
|
193 |
+
|
194 |
+
|
195 |
+
class WandbLogger(object):
|
196 |
+
def __init__(self, args):
|
197 |
+
self.args = args
|
198 |
+
|
199 |
+
try:
|
200 |
+
import wandb
|
201 |
+
self._wandb = wandb
|
202 |
+
except ImportError:
|
203 |
+
raise ImportError(
|
204 |
+
"To use the Weights and Biases Logger please install wandb."
|
205 |
+
"Run `pip install wandb` to install it."
|
206 |
+
)
|
207 |
+
|
208 |
+
# Initialize a W&B run
|
209 |
+
if self._wandb.run is None:
|
210 |
+
self._wandb.init(
|
211 |
+
project=args.project,
|
212 |
+
config=args
|
213 |
+
)
|
214 |
+
|
215 |
+
def log_epoch_metrics(self, metrics, commit=True):
|
216 |
+
"""
|
217 |
+
Log train/test metrics onto W&B.
|
218 |
+
"""
|
219 |
+
# Log number of model parameters as W&B summary
|
220 |
+
self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
|
221 |
+
metrics.pop('n_parameters', None)
|
222 |
+
|
223 |
+
# Log current epoch
|
224 |
+
self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
|
225 |
+
metrics.pop('epoch')
|
226 |
+
|
227 |
+
for k, v in metrics.items():
|
228 |
+
if 'train' in k:
|
229 |
+
self._wandb.log({f'Global Train/{k}': v}, commit=False)
|
230 |
+
elif 'test' in k:
|
231 |
+
self._wandb.log({f'Global Test/{k}': v}, commit=False)
|
232 |
+
|
233 |
+
self._wandb.log({})
|
234 |
+
|
235 |
+
def log_checkpoints(self):
|
236 |
+
output_dir = self.args.output_dir
|
237 |
+
model_artifact = self._wandb.Artifact(
|
238 |
+
self._wandb.run.id + "_model", type="model"
|
239 |
+
)
|
240 |
+
|
241 |
+
model_artifact.add_dir(output_dir)
|
242 |
+
self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
|
243 |
+
|
244 |
+
def set_steps(self):
|
245 |
+
# Set global training step
|
246 |
+
self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
|
247 |
+
# Set epoch-wise step
|
248 |
+
self._wandb.define_metric('Global Train/*', step_metric='epoch')
|
249 |
+
self._wandb.define_metric('Global Test/*', step_metric='epoch')
|
250 |
+
|
251 |
+
|
252 |
+
def setup_for_distributed(is_master):
|
253 |
+
"""
|
254 |
+
This function disables printing when not in master process
|
255 |
+
"""
|
256 |
+
import builtins as __builtin__
|
257 |
+
builtin_print = __builtin__.print
|
258 |
+
|
259 |
+
def print(*args, **kwargs):
|
260 |
+
force = kwargs.pop('force', False)
|
261 |
+
if is_master or force:
|
262 |
+
builtin_print(*args, **kwargs)
|
263 |
+
|
264 |
+
__builtin__.print = print
|
265 |
+
|
266 |
+
|
267 |
+
def is_dist_avail_and_initialized():
|
268 |
+
if not dist.is_available():
|
269 |
+
return False
|
270 |
+
if not dist.is_initialized():
|
271 |
+
return False
|
272 |
+
return True
|
273 |
+
|
274 |
+
|
275 |
+
def get_world_size():
|
276 |
+
if not is_dist_avail_and_initialized():
|
277 |
+
return 1
|
278 |
+
return dist.get_world_size()
|
279 |
+
|
280 |
+
|
281 |
+
def get_rank():
|
282 |
+
if not is_dist_avail_and_initialized():
|
283 |
+
return 0
|
284 |
+
return dist.get_rank()
|
285 |
+
|
286 |
+
|
287 |
+
def is_main_process():
|
288 |
+
return get_rank() == 0
|
289 |
+
|
290 |
+
|
291 |
+
def save_on_master(*args, **kwargs):
|
292 |
+
if is_main_process():
|
293 |
+
torch.save(*args, **kwargs)
|
294 |
+
|
295 |
+
|
296 |
+
def init_distributed_mode(args):
|
297 |
+
|
298 |
+
if args.dist_on_itp:
|
299 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
300 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
301 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
302 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
303 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
304 |
+
os.environ['RANK'] = str(args.rank)
|
305 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
306 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
307 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
308 |
+
args.rank = int(os.environ["RANK"])
|
309 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
310 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
311 |
+
elif 'SLURM_PROCID' in os.environ:
|
312 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
313 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
314 |
+
|
315 |
+
os.environ['RANK'] = str(args.rank)
|
316 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
317 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
318 |
+
else:
|
319 |
+
print('Not using distributed mode')
|
320 |
+
args.distributed = False
|
321 |
+
return
|
322 |
+
|
323 |
+
args.distributed = True
|
324 |
+
|
325 |
+
torch.cuda.set_device(args.gpu)
|
326 |
+
args.dist_backend = 'nccl'
|
327 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
328 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
329 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
330 |
+
world_size=args.world_size, rank=args.rank)
|
331 |
+
torch.distributed.barrier()
|
332 |
+
setup_for_distributed(args.rank == 0)
|
333 |
+
|
334 |
+
|
335 |
+
def init_distributed_mode_simple(args):
|
336 |
+
|
337 |
+
args.rank = int(os.environ["RANK"])
|
338 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
339 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
340 |
+
args.dist_url = 'env://'
|
341 |
+
|
342 |
+
args.distributed = True
|
343 |
+
|
344 |
+
torch.cuda.set_device(args.gpu)
|
345 |
+
args.dist_backend = 'nccl'
|
346 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
347 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
348 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
349 |
+
world_size=args.world_size, rank=args.rank)
|
350 |
+
torch.distributed.barrier()
|
351 |
+
setup_for_distributed(args.rank == 0)
|
352 |
+
|
353 |
+
def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
|
354 |
+
missing_keys = []
|
355 |
+
unexpected_keys = []
|
356 |
+
error_msgs = []
|
357 |
+
# copy state_dict so _load_from_state_dict can modify it
|
358 |
+
metadata = getattr(state_dict, '_metadata', None)
|
359 |
+
state_dict = state_dict.copy()
|
360 |
+
if metadata is not None:
|
361 |
+
state_dict._metadata = metadata
|
362 |
+
|
363 |
+
def load(module, prefix=''):
|
364 |
+
local_metadata = {} if metadata is None else metadata.get(
|
365 |
+
prefix[:-1], {})
|
366 |
+
module._load_from_state_dict(
|
367 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
368 |
+
for name, child in module._modules.items():
|
369 |
+
if child is not None:
|
370 |
+
load(child, prefix + name + '.')
|
371 |
+
|
372 |
+
load(model, prefix=prefix)
|
373 |
+
|
374 |
+
warn_missing_keys = []
|
375 |
+
ignore_missing_keys = []
|
376 |
+
for key in missing_keys:
|
377 |
+
keep_flag = True
|
378 |
+
for ignore_key in ignore_missing.split('|'):
|
379 |
+
if ignore_key in key:
|
380 |
+
keep_flag = False
|
381 |
+
break
|
382 |
+
if keep_flag:
|
383 |
+
warn_missing_keys.append(key)
|
384 |
+
else:
|
385 |
+
ignore_missing_keys.append(key)
|
386 |
+
|
387 |
+
missing_keys = warn_missing_keys
|
388 |
+
|
389 |
+
if len(missing_keys) > 0:
|
390 |
+
print("Weights of {} not initialized from pretrained model: {}".format(
|
391 |
+
model.__class__.__name__, missing_keys))
|
392 |
+
if len(unexpected_keys) > 0:
|
393 |
+
print("Weights from pretrained model not used in {}: {}".format(
|
394 |
+
model.__class__.__name__, unexpected_keys))
|
395 |
+
if len(ignore_missing_keys) > 0:
|
396 |
+
print("Ignored weights of {} not initialized from pretrained model: {}".format(
|
397 |
+
model.__class__.__name__, ignore_missing_keys))
|
398 |
+
if len(error_msgs) > 0:
|
399 |
+
print('\n'.join(error_msgs))
|
400 |
+
|
401 |
+
|
402 |
+
class NativeScalerWithGradNormCount:
|
403 |
+
state_dict_key = "amp_scaler"
|
404 |
+
|
405 |
+
def __init__(self):
|
406 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
407 |
+
|
408 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
409 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
410 |
+
if update_grad:
|
411 |
+
if clip_grad is not None:
|
412 |
+
assert parameters is not None
|
413 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
414 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
415 |
+
else:
|
416 |
+
self._scaler.unscale_(optimizer)
|
417 |
+
norm = get_grad_norm_(parameters)
|
418 |
+
self._scaler.step(optimizer)
|
419 |
+
self._scaler.update()
|
420 |
+
else:
|
421 |
+
norm = None
|
422 |
+
return norm
|
423 |
+
|
424 |
+
def state_dict(self):
|
425 |
+
return self._scaler.state_dict()
|
426 |
+
|
427 |
+
def load_state_dict(self, state_dict):
|
428 |
+
self._scaler.load_state_dict(state_dict)
|
429 |
+
|
430 |
+
|
431 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
432 |
+
if isinstance(parameters, torch.Tensor):
|
433 |
+
parameters = [parameters]
|
434 |
+
parameters = [p for p in parameters if p.grad is not None]
|
435 |
+
norm_type = float(norm_type)
|
436 |
+
if len(parameters) == 0:
|
437 |
+
return torch.tensor(0.)
|
438 |
+
device = parameters[0].grad.device
|
439 |
+
if norm_type == inf:
|
440 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
441 |
+
else:
|
442 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
443 |
+
return total_norm
|
444 |
+
|
445 |
+
|
446 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
|
447 |
+
start_warmup_value=0, warmup_steps=-1):
|
448 |
+
warmup_schedule = np.array([])
|
449 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
450 |
+
if warmup_steps > 0:
|
451 |
+
warmup_iters = warmup_steps
|
452 |
+
print("Set warmup steps = %d" % warmup_iters)
|
453 |
+
if warmup_epochs > 0:
|
454 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
455 |
+
|
456 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
457 |
+
schedule = np.array(
|
458 |
+
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
|
459 |
+
|
460 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
461 |
+
|
462 |
+
assert len(schedule) == epochs * niter_per_ep
|
463 |
+
return schedule
|
464 |
+
|
465 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
|
466 |
+
output_dir = Path(args.output_dir)
|
467 |
+
epoch_name = str(epoch)
|
468 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
469 |
+
for checkpoint_path in checkpoint_paths:
|
470 |
+
to_save = {
|
471 |
+
'model': model_without_ddp.state_dict(),
|
472 |
+
'optimizer': optimizer.state_dict(),
|
473 |
+
'epoch': epoch,
|
474 |
+
'scaler': loss_scaler.state_dict(),
|
475 |
+
'args': args,
|
476 |
+
}
|
477 |
+
|
478 |
+
if model_ema is not None:
|
479 |
+
to_save['model_ema'] = get_state_dict(model_ema)
|
480 |
+
|
481 |
+
save_on_master(to_save, checkpoint_path)
|
482 |
+
|
483 |
+
if is_main_process() and isinstance(epoch, int):
|
484 |
+
to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
|
485 |
+
old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
|
486 |
+
if os.path.exists(old_ckpt):
|
487 |
+
os.remove(old_ckpt)
|
488 |
+
|
489 |
+
|
490 |
+
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
|
491 |
+
output_dir = Path(args.output_dir)
|
492 |
+
if args.auto_resume and len(args.resume) == 0:
|
493 |
+
import glob
|
494 |
+
all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
|
495 |
+
latest_ckpt = -1
|
496 |
+
for ckpt in all_checkpoints:
|
497 |
+
t = ckpt.split('-')[-1].split('.')[0]
|
498 |
+
if t.isdigit():
|
499 |
+
latest_ckpt = max(int(t), latest_ckpt)
|
500 |
+
if latest_ckpt >= 0:
|
501 |
+
args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
|
502 |
+
print("Auto resume checkpoint: %s" % args.resume)
|
503 |
+
|
504 |
+
if args.resume:
|
505 |
+
if args.resume.startswith('https'):
|
506 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
507 |
+
args.resume, map_location='cpu', check_hash=True)
|
508 |
+
else:
|
509 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
510 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
511 |
+
print("Resume checkpoint %s" % args.resume)
|
512 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
513 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
514 |
+
if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
|
515 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
516 |
+
else:
|
517 |
+
assert args.eval, 'Does not support resuming with checkpoint-best'
|
518 |
+
if hasattr(args, 'model_ema') and args.model_ema:
|
519 |
+
if 'model_ema' in checkpoint.keys():
|
520 |
+
model_ema.ema.load_state_dict(checkpoint['model_ema'])
|
521 |
+
else:
|
522 |
+
model_ema.ema.load_state_dict(checkpoint['model'])
|
523 |
+
if 'scaler' in checkpoint:
|
524 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
525 |
+
print("With optim & sched!")
|
depth/utils_depth/criterion.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class SiLogLoss(nn.Module):
|
11 |
+
def __init__(self, lambd=0.5):
|
12 |
+
super().__init__()
|
13 |
+
self.lambd = lambd
|
14 |
+
|
15 |
+
def forward(self, pred, target):
|
16 |
+
valid_mask = (target > 0).detach()
|
17 |
+
diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
|
18 |
+
loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
|
19 |
+
self.lambd * torch.pow(diff_log.mean(), 2))
|
20 |
+
|
21 |
+
return loss
|
22 |
+
|
depth/utils_depth/logging.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import os
|
7 |
+
import cv2
|
8 |
+
import sys
|
9 |
+
import time
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
TOTAL_BAR_LENGTH = 30.
|
16 |
+
last_time = time.time()
|
17 |
+
begin_time = last_time
|
18 |
+
|
19 |
+
|
20 |
+
def progress_bar(current, total, epochs, cur_epoch, msg=None):
|
21 |
+
_, term_width = os.popen('stty size', 'r').read().split()
|
22 |
+
term_width = int(term_width)
|
23 |
+
global last_time, begin_time
|
24 |
+
if current == 0:
|
25 |
+
begin_time = time.time() # Reset for new bar.
|
26 |
+
|
27 |
+
cur_len = int(TOTAL_BAR_LENGTH * current / total)
|
28 |
+
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
|
29 |
+
|
30 |
+
sys.stdout.write(' [')
|
31 |
+
for i in range(cur_len):
|
32 |
+
sys.stdout.write('=')
|
33 |
+
sys.stdout.write('>')
|
34 |
+
for i in range(rest_len):
|
35 |
+
sys.stdout.write('.')
|
36 |
+
sys.stdout.write(']')
|
37 |
+
|
38 |
+
cur_time = time.time()
|
39 |
+
step_time = cur_time - last_time
|
40 |
+
last_time = cur_time
|
41 |
+
tot_time = cur_time - begin_time
|
42 |
+
remain_time = step_time * (total - current) + \
|
43 |
+
(epochs - cur_epoch) * step_time * total
|
44 |
+
|
45 |
+
L = []
|
46 |
+
L.append(' Step: %s' % format_time(step_time))
|
47 |
+
L.append(' | Tot: %s' % format_time(tot_time))
|
48 |
+
L.append(' | Rem: %s' % format_time(remain_time))
|
49 |
+
if msg:
|
50 |
+
L.append(' | ' + msg)
|
51 |
+
|
52 |
+
msg = ''.join(L)
|
53 |
+
sys.stdout.write(msg)
|
54 |
+
for i in range(157 - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
|
55 |
+
sys.stdout.write(' ')
|
56 |
+
|
57 |
+
# Go back to the center of the bar.
|
58 |
+
for i in range(157 - int(TOTAL_BAR_LENGTH / 2) + 2):
|
59 |
+
sys.stdout.write('\b')
|
60 |
+
sys.stdout.write(' %d/%d ' % (current + 1, total))
|
61 |
+
|
62 |
+
if current < total - 1:
|
63 |
+
sys.stdout.write('\r')
|
64 |
+
else:
|
65 |
+
sys.stdout.write('\n')
|
66 |
+
sys.stdout.flush()
|
67 |
+
|
68 |
+
|
69 |
+
class AverageMeter():
|
70 |
+
"""Computes and stores the average and current value"""
|
71 |
+
|
72 |
+
def __init__(self):
|
73 |
+
self.reset()
|
74 |
+
|
75 |
+
def reset(self):
|
76 |
+
self.val = 0
|
77 |
+
self.avg = 0
|
78 |
+
self.sum = 0
|
79 |
+
self.count = 0
|
80 |
+
|
81 |
+
def update(self, val, n=1):
|
82 |
+
self.val = val
|
83 |
+
self.sum += val * n
|
84 |
+
self.count += n
|
85 |
+
self.avg = self.sum / self.count
|
86 |
+
|
87 |
+
|
88 |
+
def format_time(seconds):
|
89 |
+
days = int(seconds / 3600 / 24)
|
90 |
+
seconds = seconds - days * 3600 * 24
|
91 |
+
hours = int(seconds / 3600)
|
92 |
+
seconds = seconds - hours * 3600
|
93 |
+
minutes = int(seconds / 60)
|
94 |
+
seconds = seconds - minutes * 60
|
95 |
+
secondsf = int(seconds)
|
96 |
+
seconds = seconds - secondsf
|
97 |
+
millis = int(seconds * 1000)
|
98 |
+
|
99 |
+
f = ''
|
100 |
+
i = 1
|
101 |
+
if days > 0:
|
102 |
+
f += str(days) + 'D'
|
103 |
+
i += 1
|
104 |
+
if hours > 0 and i <= 2:
|
105 |
+
f += str(hours) + 'h'
|
106 |
+
i += 1
|
107 |
+
if minutes > 0 and i <= 2:
|
108 |
+
f += str(minutes).zfill(2) + 'm'
|
109 |
+
i += 1
|
110 |
+
if secondsf > 0 and i <= 2:
|
111 |
+
f += str(secondsf).zfill(2) + 's'
|
112 |
+
i += 1
|
113 |
+
if millis > 0 and i <= 2:
|
114 |
+
f += str(millis).zfill(3) + 'ms'
|
115 |
+
i += 1
|
116 |
+
if f == '':
|
117 |
+
f = '0ms'
|
118 |
+
return f
|
119 |
+
|
120 |
+
|
121 |
+
def display_result(result_dict):
|
122 |
+
line = "\n"
|
123 |
+
line += "=" * 100 + '\n'
|
124 |
+
for metric, value in result_dict.items():
|
125 |
+
line += "{:>10} ".format(metric)
|
126 |
+
line += "\n"
|
127 |
+
for metric, value in result_dict.items():
|
128 |
+
line += "{:10.4f} ".format(value)
|
129 |
+
line += "\n"
|
130 |
+
line += "=" * 100 + '\n'
|
131 |
+
|
132 |
+
return line
|
133 |
+
|
134 |
+
|
135 |
+
def save_images(pred, save_path):
|
136 |
+
if len(pred.shape) > 3:
|
137 |
+
pred = pred.squeeze()
|
138 |
+
|
139 |
+
if isinstance(pred, torch.Tensor):
|
140 |
+
pred = pred.cpu().numpy().astype(np.uint8)
|
141 |
+
|
142 |
+
if pred.shape[0] < 4:
|
143 |
+
pred = np.transpose(pred, (1, 2, 0))
|
144 |
+
cv2.imwrite(save_path, pred, [cv2.IMWRITE_PNG_COMPRESSION, 0])
|
145 |
+
|
146 |
+
|
147 |
+
def check_and_make_dirs(paths):
|
148 |
+
if not isinstance(paths, list):
|
149 |
+
paths = [paths]
|
150 |
+
for path in paths:
|
151 |
+
if not os.path.exists(path):
|
152 |
+
os.makedirs(path)
|
153 |
+
|
154 |
+
def log_args_to_txt(log_txt, args):
|
155 |
+
if not os.path.exists(log_txt):
|
156 |
+
with open(log_txt, 'w') as txtfile:
|
157 |
+
args_ = vars(args)
|
158 |
+
args_str = ''
|
159 |
+
for k, v in args_.items():
|
160 |
+
args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
|
161 |
+
txtfile.write(args_str + '\n')
|
depth/utils_depth/metrics.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# The code is from GLPDepth (https://github.com/vinvino02/GLPDepth).
|
3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def eval_depth(pred, target):
|
10 |
+
assert pred.shape == target.shape
|
11 |
+
|
12 |
+
thresh = torch.max((target / pred), (pred / target))
|
13 |
+
|
14 |
+
d1 = torch.sum(thresh < 1.25).float() / len(thresh)
|
15 |
+
d2 = torch.sum(thresh < 1.25 ** 2).float() / len(thresh)
|
16 |
+
d3 = torch.sum(thresh < 1.25 ** 3).float() / len(thresh)
|
17 |
+
|
18 |
+
diff = pred - target
|
19 |
+
diff_log = torch.log(pred) - torch.log(target)
|
20 |
+
|
21 |
+
abs_rel = torch.mean(torch.abs(diff) / target)
|
22 |
+
sq_rel = torch.mean(torch.pow(diff, 2) / target)
|
23 |
+
|
24 |
+
rmse = torch.sqrt(torch.mean(torch.pow(diff, 2)))
|
25 |
+
|
26 |
+
rmse_log = torch.sqrt(torch.mean(torch.pow(diff_log , 2)))
|
27 |
+
|
28 |
+
log10 = torch.mean(torch.abs(torch.log10(pred) - torch.log10(target)))
|
29 |
+
silog = torch.sqrt(torch.pow(diff_log, 2).mean() - 0.5 * torch.pow(diff_log.mean(), 2))
|
30 |
+
|
31 |
+
return {'d1': d1.item(), 'd2': d2.item(), 'd3': d3.item(), 'abs_rel': abs_rel.item(),
|
32 |
+
'sq_rel': sq_rel.item(), 'rmse': rmse.item(), 'rmse_log': rmse_log.item(),
|
33 |
+
'log10':log10.item(), 'silog':silog.item()}
|
34 |
+
|
35 |
+
|
36 |
+
def cropping_img(args, pred, gt_depth):
|
37 |
+
min_depth_eval = args.min_depth_eval
|
38 |
+
|
39 |
+
max_depth_eval = args.max_depth_eval
|
40 |
+
|
41 |
+
pred[torch.isinf(pred)] = max_depth_eval
|
42 |
+
pred[torch.isnan(pred)] = min_depth_eval
|
43 |
+
|
44 |
+
valid_mask = torch.logical_and(
|
45 |
+
gt_depth > min_depth_eval, gt_depth < max_depth_eval)
|
46 |
+
|
47 |
+
if args.dataset == 'kitti':
|
48 |
+
if args.do_kb_crop:
|
49 |
+
height, width = gt_depth.shape
|
50 |
+
top_margin = int(height - 352)
|
51 |
+
left_margin = int((width - 1216) / 2)
|
52 |
+
gt_depth = gt_depth[top_margin:top_margin +
|
53 |
+
352, left_margin:left_margin + 1216]
|
54 |
+
|
55 |
+
if args.kitti_crop:
|
56 |
+
gt_height, gt_width = gt_depth.shape
|
57 |
+
eval_mask = torch.zeros(valid_mask.shape).to(
|
58 |
+
device=valid_mask.device)
|
59 |
+
|
60 |
+
if args.kitti_crop == 'garg_crop':
|
61 |
+
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height),
|
62 |
+
int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
|
63 |
+
|
64 |
+
elif args.kitti_crop == 'eigen_crop':
|
65 |
+
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height),
|
66 |
+
int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
|
67 |
+
else:
|
68 |
+
eval_mask = valid_mask
|
69 |
+
|
70 |
+
elif args.dataset == 'nyudepthv2':
|
71 |
+
eval_mask = torch.zeros(valid_mask.shape).to(device=valid_mask.device)
|
72 |
+
eval_mask[45:471, 41:601] = 1
|
73 |
+
else:
|
74 |
+
eval_mask = valid_mask
|
75 |
+
|
76 |
+
valid_mask = torch.logical_and(valid_mask, eval_mask)
|
77 |
+
|
78 |
+
return pred[valid_mask], gt_depth[valid_mask]
|
79 |
+
|
depth/utils_depth/misc.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# The code is from ZoeDepth (https://github.com/isl-org/ZoeDepth).
|
3 |
+
# For non-commercial purpose only (research, evaluation etc).
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
from scipy import ndimage
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import matplotlib
|
10 |
+
import matplotlib.cm
|
11 |
+
import numpy as np
|
12 |
+
import requests
|
13 |
+
import torch
|
14 |
+
from PIL import Image
|
15 |
+
from torchvision.transforms import ToTensor
|
16 |
+
|
17 |
+
|
18 |
+
def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
|
19 |
+
"""Converts a depth map to a color image.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
|
23 |
+
vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
|
24 |
+
vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
|
25 |
+
cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
|
26 |
+
invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
|
27 |
+
invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
|
28 |
+
background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
|
29 |
+
gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
|
30 |
+
value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
|
34 |
+
"""
|
35 |
+
if isinstance(value, torch.Tensor):
|
36 |
+
value = value.detach().cpu().numpy()
|
37 |
+
|
38 |
+
value = value.squeeze()
|
39 |
+
if invalid_mask is None:
|
40 |
+
invalid_mask = value == invalid_val
|
41 |
+
mask = np.logical_not(invalid_mask)
|
42 |
+
|
43 |
+
# normalize
|
44 |
+
vmin = np.percentile(value[mask],2) if vmin is None else vmin
|
45 |
+
vmax = np.percentile(value[mask],85) if vmax is None else vmax
|
46 |
+
if vmin != vmax:
|
47 |
+
value = (value - vmin) / (vmax - vmin) # vmin..vmax
|
48 |
+
else:
|
49 |
+
# Avoid 0-division
|
50 |
+
value = value * 0.
|
51 |
+
|
52 |
+
# squeeze last dim if it exists
|
53 |
+
# grey out the invalid values
|
54 |
+
|
55 |
+
value[invalid_mask] = np.nan
|
56 |
+
cmapper = matplotlib.colormaps.get_cmap(cmap)
|
57 |
+
if value_transform:
|
58 |
+
value = value_transform(value)
|
59 |
+
# value = value / value.max()
|
60 |
+
value = cmapper(value, bytes=True) # (nxmx4)
|
61 |
+
|
62 |
+
# img = value[:, :, :]
|
63 |
+
img = value[...]
|
64 |
+
img[invalid_mask] = background_color
|
65 |
+
|
66 |
+
# return img.transpose((2, 0, 1))
|
67 |
+
if gamma_corrected:
|
68 |
+
# gamma correction
|
69 |
+
img = img / 255
|
70 |
+
img = np.power(img, 2.2)
|
71 |
+
img = img * 255
|
72 |
+
img = img.astype(np.uint8)
|
73 |
+
return img, vmin, vmax
|
depth/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
evp/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .models import UNetWrapper, TextAdapter
|
evp/models.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf import OmegaConf
|
2 |
+
|
3 |
+
import torch as th
|
4 |
+
import torch
|
5 |
+
import math
|
6 |
+
import abc
|
7 |
+
|
8 |
+
from torch import nn, einsum
|
9 |
+
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
from transformers import CLIPTokenizer
|
13 |
+
from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPTextTransformer#, _expand_mask
|
14 |
+
from inspect import isfunction
|
15 |
+
|
16 |
+
|
17 |
+
def exists(val):
|
18 |
+
return val is not None
|
19 |
+
|
20 |
+
|
21 |
+
def uniq(arr):
|
22 |
+
return{el: True for el in arr}.keys()
|
23 |
+
|
24 |
+
|
25 |
+
def default(val, d):
|
26 |
+
if exists(val):
|
27 |
+
return val
|
28 |
+
return d() if isfunction(d) else d
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def register_attention_control(model, controller):
|
33 |
+
def ca_forward(self, place_in_unet):
|
34 |
+
def forward(x, context=None, mask=None):
|
35 |
+
h = self.heads
|
36 |
+
|
37 |
+
q = self.to_q(x)
|
38 |
+
is_cross = context is not None
|
39 |
+
context = default(context, x)
|
40 |
+
k = self.to_k(context)
|
41 |
+
v = self.to_v(context)
|
42 |
+
|
43 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
44 |
+
|
45 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
46 |
+
|
47 |
+
if exists(mask):
|
48 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
49 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
50 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
51 |
+
sim.masked_fill_(~mask, max_neg_value)
|
52 |
+
|
53 |
+
# attention, what we cannot get enough of
|
54 |
+
attn = sim.softmax(dim=-1)
|
55 |
+
|
56 |
+
attn2 = rearrange(attn, '(b h) k c -> h b k c', h=h).mean(0)
|
57 |
+
controller(attn2, is_cross, place_in_unet)
|
58 |
+
|
59 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
60 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
61 |
+
return self.to_out(out)
|
62 |
+
|
63 |
+
return forward
|
64 |
+
|
65 |
+
class DummyController:
|
66 |
+
def __call__(self, *args):
|
67 |
+
return args[0]
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
self.num_att_layers = 0
|
71 |
+
|
72 |
+
if controller is None:
|
73 |
+
controller = DummyController()
|
74 |
+
|
75 |
+
def register_recr(net_, count, place_in_unet):
|
76 |
+
if net_.__class__.__name__ == 'CrossAttention':
|
77 |
+
net_.forward = ca_forward(net_, place_in_unet)
|
78 |
+
return count + 1
|
79 |
+
elif hasattr(net_, 'children'):
|
80 |
+
for net__ in net_.children():
|
81 |
+
count = register_recr(net__, count, place_in_unet)
|
82 |
+
return count
|
83 |
+
|
84 |
+
cross_att_count = 0
|
85 |
+
sub_nets = model.diffusion_model.named_children()
|
86 |
+
|
87 |
+
for net in sub_nets:
|
88 |
+
if "input_blocks" in net[0]:
|
89 |
+
cross_att_count += register_recr(net[1], 0, "down")
|
90 |
+
elif "output_blocks" in net[0]:
|
91 |
+
cross_att_count += register_recr(net[1], 0, "up")
|
92 |
+
elif "middle_block" in net[0]:
|
93 |
+
cross_att_count += register_recr(net[1], 0, "mid")
|
94 |
+
|
95 |
+
controller.num_att_layers = cross_att_count
|
96 |
+
|
97 |
+
|
98 |
+
class AttentionControl(abc.ABC):
|
99 |
+
|
100 |
+
def step_callback(self, x_t):
|
101 |
+
return x_t
|
102 |
+
|
103 |
+
def between_steps(self):
|
104 |
+
return
|
105 |
+
|
106 |
+
@property
|
107 |
+
def num_uncond_att_layers(self):
|
108 |
+
return 0
|
109 |
+
|
110 |
+
@abc.abstractmethod
|
111 |
+
def forward (self, attn, is_cross: bool, place_in_unet: str):
|
112 |
+
raise NotImplementedError
|
113 |
+
|
114 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
115 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
116 |
+
return attn
|
117 |
+
|
118 |
+
def reset(self):
|
119 |
+
self.cur_step = 0
|
120 |
+
self.cur_att_layer = 0
|
121 |
+
|
122 |
+
def __init__(self):
|
123 |
+
self.cur_step = 0
|
124 |
+
self.num_att_layers = -1
|
125 |
+
self.cur_att_layer = 0
|
126 |
+
|
127 |
+
|
128 |
+
class AttentionStore(AttentionControl):
|
129 |
+
@staticmethod
|
130 |
+
def get_empty_store():
|
131 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
132 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
133 |
+
|
134 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
135 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
136 |
+
if attn.shape[1] <= (self.max_size) ** 2: # avoid memory overhead
|
137 |
+
self.step_store[key].append(attn)
|
138 |
+
return attn
|
139 |
+
|
140 |
+
def between_steps(self):
|
141 |
+
if len(self.attention_store) == 0:
|
142 |
+
self.attention_store = self.step_store
|
143 |
+
else:
|
144 |
+
for key in self.attention_store:
|
145 |
+
for i in range(len(self.attention_store[key])):
|
146 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
147 |
+
self.step_store = self.get_empty_store()
|
148 |
+
|
149 |
+
def get_average_attention(self):
|
150 |
+
average_attention = {key: [item for item in self.step_store[key]] for key in self.step_store}
|
151 |
+
return average_attention
|
152 |
+
|
153 |
+
def reset(self):
|
154 |
+
super(AttentionStore, self).reset()
|
155 |
+
self.step_store = self.get_empty_store()
|
156 |
+
self.attention_store = {}
|
157 |
+
|
158 |
+
def __init__(self, base_size=64, max_size=None):
|
159 |
+
super(AttentionStore, self).__init__()
|
160 |
+
self.step_store = self.get_empty_store()
|
161 |
+
self.attention_store = {}
|
162 |
+
self.base_size = base_size
|
163 |
+
if max_size is None:
|
164 |
+
self.max_size = self.base_size // 2
|
165 |
+
else:
|
166 |
+
self.max_size = max_size
|
167 |
+
|
168 |
+
def register_hier_output(model):
|
169 |
+
self = model.diffusion_model
|
170 |
+
from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
|
171 |
+
def forward(x, timesteps=None, context=None, y=None,**kwargs):
|
172 |
+
"""
|
173 |
+
Apply the model to an input batch.
|
174 |
+
:param x: an [N x C x ...] Tensor of inputs.
|
175 |
+
:param timesteps: a 1-D batch of timesteps.
|
176 |
+
:param context: conditioning plugged in via crossattn
|
177 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
178 |
+
:return: an [N x C x ...] Tensor of outputs.
|
179 |
+
"""
|
180 |
+
assert (y is not None) == (
|
181 |
+
self.num_classes is not None
|
182 |
+
), "must specify y if and only if the model is class-conditional"
|
183 |
+
hs = []
|
184 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
185 |
+
emb = self.time_embed(t_emb)
|
186 |
+
|
187 |
+
if self.num_classes is not None:
|
188 |
+
assert y.shape == (x.shape[0],)
|
189 |
+
emb = emb + self.label_emb(y)
|
190 |
+
|
191 |
+
h = x.type(self.dtype)
|
192 |
+
for module in self.input_blocks:
|
193 |
+
# import pdb; pdb.set_trace()
|
194 |
+
if context.shape[1]==2:
|
195 |
+
h = module(h, emb, context[:,0,:].unsqueeze(1))
|
196 |
+
else:
|
197 |
+
h = module(h, emb, context)
|
198 |
+
hs.append(h)
|
199 |
+
if context.shape[1]==2:
|
200 |
+
h = self.middle_block(h, emb, context[:,0,:].unsqueeze(1))
|
201 |
+
else:
|
202 |
+
h = self.middle_block(h, emb, context)
|
203 |
+
out_list = []
|
204 |
+
|
205 |
+
for i_out, module in enumerate(self.output_blocks):
|
206 |
+
h = th.cat([h, hs.pop()], dim=1)
|
207 |
+
if context.shape[1]==2:
|
208 |
+
h = module(h, emb, context[:,1,:].unsqueeze(1))
|
209 |
+
else:
|
210 |
+
h = module(h, emb, context)
|
211 |
+
if i_out in [1, 4, 7]:
|
212 |
+
out_list.append(h)
|
213 |
+
h = h.type(x.dtype)
|
214 |
+
|
215 |
+
out_list.append(h)
|
216 |
+
return out_list
|
217 |
+
|
218 |
+
self.forward = forward
|
219 |
+
|
220 |
+
class UNetWrapper(nn.Module):
|
221 |
+
def __init__(self, unet, use_attn=True, base_size=512, max_attn_size=None, attn_selector='up_cross+down_cross') -> None:
|
222 |
+
super().__init__()
|
223 |
+
self.unet = unet
|
224 |
+
self.attention_store = AttentionStore(base_size=base_size // 8, max_size=max_attn_size)
|
225 |
+
self.size16 = base_size // 32
|
226 |
+
self.size32 = base_size // 16
|
227 |
+
self.size64 = base_size // 8
|
228 |
+
self.use_attn = use_attn
|
229 |
+
if self.use_attn:
|
230 |
+
register_attention_control(unet, self.attention_store)
|
231 |
+
register_hier_output(unet)
|
232 |
+
self.attn_selector = attn_selector.split('+')
|
233 |
+
|
234 |
+
def forward(self, *args, **kwargs):
|
235 |
+
if self.use_attn:
|
236 |
+
self.attention_store.reset()
|
237 |
+
out_list = self.unet(*args, **kwargs)
|
238 |
+
if self.use_attn:
|
239 |
+
avg_attn = self.attention_store.get_average_attention()
|
240 |
+
attn16, attn32, attn64 = self.process_attn(avg_attn)
|
241 |
+
out_list[1] = torch.cat([out_list[1], attn16], dim=1)
|
242 |
+
out_list[2] = torch.cat([out_list[2], attn32], dim=1)
|
243 |
+
if attn64 is not None:
|
244 |
+
out_list[3] = torch.cat([out_list[3], attn64], dim=1)
|
245 |
+
return out_list[::-1]
|
246 |
+
|
247 |
+
def process_attn(self, avg_attn):
|
248 |
+
attns = {self.size16: [], self.size32: [], self.size64: []}
|
249 |
+
for k in self.attn_selector:
|
250 |
+
for up_attn in avg_attn[k]:
|
251 |
+
size = int(math.sqrt(up_attn.shape[1]))
|
252 |
+
attns[size].append(rearrange(up_attn, 'b (h w) c -> b c h w', h=size))
|
253 |
+
attn16 = torch.stack(attns[self.size16]).mean(0)
|
254 |
+
attn32 = torch.stack(attns[self.size32]).mean(0)
|
255 |
+
if len(attns[self.size64]) > 0:
|
256 |
+
attn64 = torch.stack(attns[self.size64]).mean(0)
|
257 |
+
else:
|
258 |
+
attn64 = None
|
259 |
+
return attn16, attn32, attn64
|
260 |
+
|
261 |
+
class TextAdapter(nn.Module):
|
262 |
+
def __init__(self, text_dim=768, hidden_dim=None):
|
263 |
+
super().__init__()
|
264 |
+
if hidden_dim is None:
|
265 |
+
hidden_dim = text_dim
|
266 |
+
self.fc = nn.Sequential(
|
267 |
+
nn.Linear(text_dim, hidden_dim),
|
268 |
+
nn.GELU(),
|
269 |
+
nn.Linear(hidden_dim, text_dim)
|
270 |
+
)
|
271 |
+
|
272 |
+
def forward(self, latents, texts, gamma):
|
273 |
+
n_class, channel = texts.shape
|
274 |
+
bs = latents.shape[0]
|
275 |
+
|
276 |
+
texts_after = self.fc(texts)
|
277 |
+
texts = texts + gamma * texts_after
|
278 |
+
texts = repeat(texts, 'n c -> b n c', b=bs)
|
279 |
+
return texts
|
280 |
+
|
281 |
+
class TextAdapterRefer(nn.Module):
|
282 |
+
def __init__(self, text_dim=768):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
self.fc = nn.Sequential(
|
286 |
+
nn.Linear(text_dim, text_dim),
|
287 |
+
nn.GELU(),
|
288 |
+
nn.Linear(text_dim, text_dim)
|
289 |
+
)
|
290 |
+
|
291 |
+
def forward(self, latents, texts, gamma):
|
292 |
+
texts_after = self.fc(texts)
|
293 |
+
texts = texts + gamma * texts_after
|
294 |
+
return texts
|
295 |
+
|
296 |
+
|
297 |
+
class TextAdapterDepth(nn.Module):
|
298 |
+
def __init__(self, text_dim=768):
|
299 |
+
super().__init__()
|
300 |
+
|
301 |
+
self.fc = nn.Sequential(
|
302 |
+
nn.Linear(text_dim, text_dim),
|
303 |
+
nn.GELU(),
|
304 |
+
nn.Linear(text_dim, text_dim)
|
305 |
+
)
|
306 |
+
|
307 |
+
def forward(self, latents, texts, gamma):
|
308 |
+
# use the gamma to blend
|
309 |
+
n_sen, channel = texts.shape
|
310 |
+
bs = latents.shape[0]
|
311 |
+
|
312 |
+
texts_after = self.fc(texts)
|
313 |
+
texts = texts + gamma * texts_after
|
314 |
+
texts = repeat(texts, 'n c -> n b c', b=1)
|
315 |
+
return texts
|
316 |
+
|
317 |
+
|
318 |
+
class FrozenCLIPEmbedder(nn.Module):
|
319 |
+
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
320 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, pool=True):
|
321 |
+
super().__init__()
|
322 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
323 |
+
self.transformer = CLIPTextModel.from_pretrained(version)
|
324 |
+
self.device = device
|
325 |
+
self.max_length = max_length
|
326 |
+
self.freeze()
|
327 |
+
|
328 |
+
self.pool = pool
|
329 |
+
|
330 |
+
def freeze(self):
|
331 |
+
self.transformer = self.transformer.eval()
|
332 |
+
for param in self.parameters():
|
333 |
+
param.requires_grad = False
|
334 |
+
|
335 |
+
def forward(self, text):
|
336 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
337 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
338 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
339 |
+
outputs = self.transformer(input_ids=tokens)
|
340 |
+
|
341 |
+
if self.pool:
|
342 |
+
z = outputs.pooler_output
|
343 |
+
else:
|
344 |
+
z = outputs.last_hidden_state
|
345 |
+
return z
|
346 |
+
|
347 |
+
def encode(self, text):
|
348 |
+
return self(text)
|
349 |
+
|
refer/README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Referring Image Segmentation
|
2 |
+
## Getting Started
|
3 |
+
|
4 |
+
1. Install the required packages.
|
5 |
+
|
6 |
+
```
|
7 |
+
pip install -r requirements.txt
|
8 |
+
```
|
9 |
+
|
10 |
+
2. Prepare RefCOCO datasets following [LAVT](https://github.com/yz93/LAVT-RIS).
|
11 |
+
|
12 |
+
* Download COCO 2014 Train Images [83K/13GB] from [COCO](https://cocodataset.org/#download), and extract `train2014.zip` to `./refer/data/images/mscoco/images`
|
13 |
+
|
14 |
+
* Follow the instructions in `./refer` to download and extract `refclef.zip, refcoco.zip, refcoco+.zip, refcocog.zip` to `./refer/data`
|
15 |
+
|
16 |
+
Your dataset directory should be:
|
17 |
+
|
18 |
+
```
|
19 |
+
refer/
|
20 |
+
├──data/
|
21 |
+
│ ├── images/mscoco/images/
|
22 |
+
│ ├── refclef
|
23 |
+
│ ├── refcoco
|
24 |
+
│ ├── refcoco+
|
25 |
+
│ ├── refcocog
|
26 |
+
├──evaluation/
|
27 |
+
├──...
|
28 |
+
```
|
29 |
+
|
30 |
+
## Results and Fine-tuned Models of EVP
|
31 |
+
EVP achieves 76.35 overall IoU and 77.61 mean IoU on the validation set of RefCOCO.
|
32 |
+
|
33 |
+
## Training
|
34 |
+
|
35 |
+
We count the max length of referring sentences and set the token length of lenguage model accrodingly. The checkpoint of the best epoch would be saved at `./checkpoints/`.
|
36 |
+
|
37 |
+
* Train on RefCOCO
|
38 |
+
|
39 |
+
```
|
40 |
+
bash train.sh refcoco /path/to/logdir <NUM_GPUS> --token_length 40
|
41 |
+
```
|
42 |
+
|
43 |
+
* Train on RefCOCO+
|
44 |
+
|
45 |
+
```
|
46 |
+
bash train.sh refcoco+ /path/to/logdir <NUM_GPUS> --token_length 40
|
47 |
+
```
|
48 |
+
|
49 |
+
* Train on RefCOCOg
|
50 |
+
|
51 |
+
```
|
52 |
+
bash train.sh refcocog /path/to/logdir <NUM_GPUS> --token_length 77 --splitBy umd
|
53 |
+
```
|
54 |
+
|
55 |
+
## Evaluation
|
56 |
+
|
57 |
+
* Evaluate on RefCOCO
|
58 |
+
|
59 |
+
```
|
60 |
+
bash test.sh refcoco /path/to/evp_ris_refcoco.pth --token_length 40
|
61 |
+
```
|
62 |
+
|
63 |
+
* Evaluate on RefCOCO+
|
64 |
+
|
65 |
+
```
|
66 |
+
bash test.sh refcoco+ /path/to/evp_ris_refcoco+.pth --token_length 40
|
67 |
+
```
|
68 |
+
|
69 |
+
* Evaluate on RefCOCOg
|
70 |
+
|
71 |
+
```
|
72 |
+
bash test.sh refcocog /path/to/evp_ris_gref.pth --token_length 77 --splitBy umd
|
73 |
+
```
|
74 |
+
|
75 |
+
## Custom inference
|
76 |
+
```
|
77 |
+
PYTHONPATH="../":$PYTHONPATH python inference.py --img_path test_img.jpg --resume refcoco.pth --token_length 40 --prompt 'green plant'
|
78 |
+
```
|
refer/args.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def get_parser():
|
5 |
+
parser = argparse.ArgumentParser(description='EVP training and testing')
|
6 |
+
parser.add_argument('--amsgrad', action='store_true',
|
7 |
+
help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
|
8 |
+
parser.add_argument('-b', '--batch-size', default=8, type=int)
|
9 |
+
parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
|
10 |
+
parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
|
11 |
+
parser.add_argument('--ddp_trained_weights', action='store_true',
|
12 |
+
help='Only needs specified when testing,'
|
13 |
+
'whether the weights to be loaded are from a DDP-trained model')
|
14 |
+
parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
|
15 |
+
parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
|
16 |
+
parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
|
17 |
+
parser.add_argument('--img_size', default=480, type=int, help='input image size')
|
18 |
+
parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel')
|
19 |
+
parser.add_argument("--local-rank", type=int, default=0, help='local rank for DistributedDataParallel')
|
20 |
+
parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
|
21 |
+
parser.add_argument('--model_id', default='evp', help='name to identify the model')
|
22 |
+
parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
|
23 |
+
parser.add_argument('--pin_mem', action='store_true',
|
24 |
+
help='If true, pin memory when using the data loader.')
|
25 |
+
parser.add_argument('--pretrained_swin_weights', default='',
|
26 |
+
help='path to pre-trained Swin backbone weights')
|
27 |
+
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
|
28 |
+
parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
|
29 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
30 |
+
parser.add_argument('--split', default='val')
|
31 |
+
parser.add_argument('--splitBy', default='unc')
|
32 |
+
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
|
33 |
+
dest='weight_decay')
|
34 |
+
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
|
35 |
+
parser.add_argument('--token_length', default=77, type=int)
|
36 |
+
|
37 |
+
return parser
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
parser = get_parser()
|
42 |
+
args_dict = parser.parse_args()
|
refer/inference.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
from models_refer.model import EVPRefer
|
7 |
+
from args import get_parser
|
8 |
+
import glob
|
9 |
+
import utils
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
from PIL import Image
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from transformers import CLIPTokenizer
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
parser = get_parser()
|
18 |
+
parser.add_argument('--img_path', type=str)
|
19 |
+
parser.add_argument('--prompt', type=str)
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
24 |
+
model = EVPRefer(sd_path='../checkpoints/v1-5-pruned-emaonly.ckpt')
|
25 |
+
cudnn.benchmark = True
|
26 |
+
model.to(device)
|
27 |
+
model_weight = torch.load(args.resume)['model']
|
28 |
+
if 'module' in next(iter(model_weight.items()))[0]:
|
29 |
+
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
|
30 |
+
model.load_state_dict(model_weight, strict=False)
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
img_path = args.img_path
|
34 |
+
|
35 |
+
image = cv2.imread(img_path)
|
36 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
37 |
+
image_t = transforms.ToTensor()(image).unsqueeze(0).to(device)
|
38 |
+
image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_t)
|
39 |
+
shape = image_t.shape
|
40 |
+
image_t = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
|
41 |
+
input_ids = tokenizer(text=args.prompt, truncation=True, max_length=args.token_length, return_length=True,
|
42 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")['input_ids'].to(device)
|
43 |
+
|
44 |
+
with torch.no_grad():
|
45 |
+
pred = model(image_t, input_ids)
|
46 |
+
|
47 |
+
pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
|
48 |
+
output_mask = pred.cpu().argmax(1).data.numpy().squeeze()
|
49 |
+
|
50 |
+
alpha = 0.65
|
51 |
+
image[output_mask == 0] = (image[output_mask == 0]*alpha).astype(np.uint8)
|
52 |
+
contours, _ = cv2.findContours(output_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
53 |
+
cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
|
54 |
+
|
55 |
+
Image.fromarray(image.astype(np.uint8)).save('res.png')
|
56 |
+
|
57 |
+
return 0
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
main()
|
refer/models_refer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import EVPRefer
|
refer/models_refer/model.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import sys
|
6 |
+
from ldm.util import instantiate_from_config
|
7 |
+
from transformers.models.clip.modeling_clip import CLIPTextModel
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from lib.mask_predictor import SimpleDecoding
|
10 |
+
|
11 |
+
from evp.models import UNetWrapper, TextAdapterRefer
|
12 |
+
|
13 |
+
|
14 |
+
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
|
15 |
+
"""
|
16 |
+
Checkerboard artifact free sub-pixel convolution
|
17 |
+
https://arxiv.org/abs/1707.02937
|
18 |
+
"""
|
19 |
+
ni,nf,h,w = x.shape
|
20 |
+
ni2 = int(ni/(scale**2))
|
21 |
+
k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
|
22 |
+
k = k.contiguous().view(ni2, nf, -1)
|
23 |
+
k = k.repeat(1, 1, scale**2)
|
24 |
+
k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
|
25 |
+
x.data.copy_(k)
|
26 |
+
|
27 |
+
|
28 |
+
class PixelShuffle(nn.Module):
|
29 |
+
"""
|
30 |
+
Real-Time Single Image and Video Super-Resolution
|
31 |
+
https://arxiv.org/abs/1609.05158
|
32 |
+
"""
|
33 |
+
def __init__(self, n_channels, scale):
|
34 |
+
super(PixelShuffle, self).__init__()
|
35 |
+
self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
|
36 |
+
icnr(self.conv.weight)
|
37 |
+
self.shuf = nn.PixelShuffle(scale)
|
38 |
+
self.relu = nn.ReLU()
|
39 |
+
|
40 |
+
def forward(self,x):
|
41 |
+
x = self.shuf(self.relu(self.conv(x)))
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
class AttentionModule(nn.Module):
|
46 |
+
def __init__(self, in_channels, out_channels):
|
47 |
+
super(AttentionModule, self).__init__()
|
48 |
+
|
49 |
+
# Convolutional Layers
|
50 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
51 |
+
|
52 |
+
# Group Normalization
|
53 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
54 |
+
|
55 |
+
# ReLU Activation
|
56 |
+
self.relu = nn.ReLU()
|
57 |
+
|
58 |
+
# Spatial Attention
|
59 |
+
self.spatial_attention = nn.Sequential(
|
60 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
61 |
+
nn.Sigmoid()
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
# Apply spatial attention
|
66 |
+
spatial_attention = self.spatial_attention(x)
|
67 |
+
x = x * spatial_attention
|
68 |
+
|
69 |
+
# Apply convolutional layer
|
70 |
+
x = self.conv1(x)
|
71 |
+
x = self.group_norm(x)
|
72 |
+
x = self.relu(x)
|
73 |
+
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class AttentionDownsamplingModule(nn.Module):
|
78 |
+
def __init__(self, in_channels, out_channels, scale_factor=2):
|
79 |
+
super(AttentionDownsamplingModule, self).__init__()
|
80 |
+
|
81 |
+
# Spatial Attention
|
82 |
+
self.spatial_attention = nn.Sequential(
|
83 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
84 |
+
nn.Sigmoid()
|
85 |
+
)
|
86 |
+
|
87 |
+
# Channel Attention
|
88 |
+
self.channel_attention = nn.Sequential(
|
89 |
+
nn.AdaptiveAvgPool2d(1),
|
90 |
+
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
|
91 |
+
nn.ReLU(inplace=True),
|
92 |
+
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
|
93 |
+
nn.Sigmoid()
|
94 |
+
)
|
95 |
+
|
96 |
+
# Convolutional Layers
|
97 |
+
if scale_factor == 2:
|
98 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
99 |
+
elif scale_factor == 4:
|
100 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
101 |
+
|
102 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
103 |
+
|
104 |
+
# Group Normalization
|
105 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
106 |
+
|
107 |
+
# ReLU Activation
|
108 |
+
self.relu = nn.ReLU(inplace=True)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
# Apply spatial attention
|
112 |
+
spatial_attention = self.spatial_attention(x)
|
113 |
+
x = x * spatial_attention
|
114 |
+
|
115 |
+
# Apply channel attention
|
116 |
+
channel_attention = self.channel_attention(x)
|
117 |
+
x = x * channel_attention
|
118 |
+
|
119 |
+
# Apply convolutional layers
|
120 |
+
x = self.conv1(x)
|
121 |
+
x = self.group_norm(x)
|
122 |
+
x = self.relu(x)
|
123 |
+
x = self.conv2(x)
|
124 |
+
x = self.group_norm(x)
|
125 |
+
x = self.relu(x)
|
126 |
+
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class AttentionUpsamplingModule(nn.Module):
|
131 |
+
def __init__(self, in_channels, out_channels):
|
132 |
+
super(AttentionUpsamplingModule, self).__init__()
|
133 |
+
|
134 |
+
# Spatial Attention for outs[2]
|
135 |
+
self.spatial_attention = nn.Sequential(
|
136 |
+
nn.Conv2d(in_channels, 1, kernel_size=1),
|
137 |
+
nn.Sigmoid()
|
138 |
+
)
|
139 |
+
|
140 |
+
# Channel Attention for outs[2]
|
141 |
+
self.channel_attention = nn.Sequential(
|
142 |
+
nn.AdaptiveAvgPool2d(1),
|
143 |
+
nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
|
144 |
+
nn.ReLU(),
|
145 |
+
nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
|
146 |
+
nn.Sigmoid()
|
147 |
+
)
|
148 |
+
|
149 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
150 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
151 |
+
|
152 |
+
# Group Normalization
|
153 |
+
self.group_norm = nn.GroupNorm(20, out_channels)
|
154 |
+
|
155 |
+
# ReLU Activation
|
156 |
+
self.relu = nn.ReLU()
|
157 |
+
self.upscale = PixelShuffle(in_channels, 2)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
# Apply spatial attention
|
161 |
+
spatial_attention = self.spatial_attention(x)
|
162 |
+
x = x * spatial_attention
|
163 |
+
|
164 |
+
# Apply channel attention
|
165 |
+
channel_attention = self.channel_attention(x)
|
166 |
+
x = x * channel_attention
|
167 |
+
|
168 |
+
# Apply convolutional layers
|
169 |
+
x = self.conv1(x)
|
170 |
+
x = self.group_norm(x)
|
171 |
+
x = self.relu(x)
|
172 |
+
x = self.conv2(x)
|
173 |
+
x = self.group_norm(x)
|
174 |
+
x = self.relu(x)
|
175 |
+
|
176 |
+
# Upsample
|
177 |
+
x = self.upscale(x)
|
178 |
+
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class ConvLayer(nn.Module):
|
183 |
+
def __init__(self, in_channels, out_channels):
|
184 |
+
super(ConvLayer, self).__init__()
|
185 |
+
|
186 |
+
self.conv1 = nn.Sequential(
|
187 |
+
nn.Conv2d(in_channels, out_channels, 1),
|
188 |
+
nn.GroupNorm(20, out_channels),
|
189 |
+
nn.ReLU(),
|
190 |
+
)
|
191 |
+
|
192 |
+
def forward(self, x):
|
193 |
+
x = self.conv1(x)
|
194 |
+
|
195 |
+
return x
|
196 |
+
|
197 |
+
|
198 |
+
class InverseMultiAttentiveFeatureRefinement(nn.Module):
|
199 |
+
def __init__(self, in_channels_list):
|
200 |
+
super(InverseMultiAttentiveFeatureRefinement, self).__init__()
|
201 |
+
|
202 |
+
self.layer1 = AttentionModule(in_channels_list[0], in_channels_list[0])
|
203 |
+
self.layer2 = AttentionDownsamplingModule(in_channels_list[0], in_channels_list[0]//2, scale_factor = 2)
|
204 |
+
self.layer3 = ConvLayer(in_channels_list[0]//2 + in_channels_list[1], in_channels_list[1])
|
205 |
+
self.layer4 = AttentionDownsamplingModule(in_channels_list[1], in_channels_list[1]//2, scale_factor = 2)
|
206 |
+
self.layer5 = ConvLayer(in_channels_list[1]//2 + in_channels_list[2], in_channels_list[2])
|
207 |
+
self.layer6 = AttentionDownsamplingModule(in_channels_list[2], in_channels_list[2]//2, scale_factor = 2)
|
208 |
+
self.layer7 = ConvLayer(in_channels_list[2]//2 + in_channels_list[3], in_channels_list[3])
|
209 |
+
|
210 |
+
'''
|
211 |
+
self.layer8 = AttentionUpsamplingModule(in_channels_list[3], in_channels_list[3])
|
212 |
+
self.layer9 = ConvLayer(in_channels_list[2] + in_channels_list[3], in_channels_list[2])
|
213 |
+
self.layer10 = AttentionUpsamplingModule(in_channels_list[2], in_channels_list[2])
|
214 |
+
self.layer11 = ConvLayer(in_channels_list[1] + in_channels_list[2], in_channels_list[1])
|
215 |
+
self.layer12 = AttentionUpsamplingModule(in_channels_list[1], in_channels_list[1])
|
216 |
+
self.layer13 = ConvLayer(in_channels_list[0] + in_channels_list[1], in_channels_list[0])
|
217 |
+
'''
|
218 |
+
def forward(self, inputs):
|
219 |
+
x_c4, x_c3, x_c2, x_c1 = inputs
|
220 |
+
x_c4 = self.layer1(x_c4)
|
221 |
+
x_c4_3 = self.layer2(x_c4)
|
222 |
+
x_c3 = torch.cat([x_c4_3, x_c3], dim=1)
|
223 |
+
x_c3 = self.layer3(x_c3)
|
224 |
+
x_c3_2 = self.layer4(x_c3)
|
225 |
+
x_c2 = torch.cat([x_c3_2, x_c2], dim=1)
|
226 |
+
x_c2 = self.layer5(x_c2)
|
227 |
+
x_c2_1 = self.layer6(x_c2)
|
228 |
+
x_c1 = torch.cat([x_c2_1, x_c1], dim=1)
|
229 |
+
x_c1 = self.layer7(x_c1)
|
230 |
+
'''
|
231 |
+
x_c1_2 = self.layer8(x_c1)
|
232 |
+
x_c2 = torch.cat([x_c1_2, x_c2], dim=1)
|
233 |
+
x_c2 = self.layer9(x_c2)
|
234 |
+
x_c2_3 = self.layer10(x_c2)
|
235 |
+
x_c3 = torch.cat([x_c2_3, x_c3], dim=1)
|
236 |
+
x_c3 = self.layer11(x_c3)
|
237 |
+
x_c3_4 = self.layer12(x_c3)
|
238 |
+
x_c4 = torch.cat([x_c3_4, x_c4], dim=1)
|
239 |
+
x_c4 = self.layer13(x_c4)
|
240 |
+
'''
|
241 |
+
return [x_c4, x_c3, x_c2, x_c1]
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
class EVPRefer(nn.Module):
|
246 |
+
"""Encoder Decoder segmentors.
|
247 |
+
|
248 |
+
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
|
249 |
+
Note that auxiliary_head is only used for deep supervision during training,
|
250 |
+
which could be dumped during inference.
|
251 |
+
"""
|
252 |
+
|
253 |
+
def __init__(self,
|
254 |
+
sd_path=None,
|
255 |
+
base_size=512,
|
256 |
+
token_embed_dim=768,
|
257 |
+
neck_dim=[320,680,1320,1280],
|
258 |
+
**args):
|
259 |
+
super().__init__()
|
260 |
+
config = OmegaConf.load('./v1-inference.yaml')
|
261 |
+
config.model.params.ckpt_path = f'{sd_path}'
|
262 |
+
sd_model = instantiate_from_config(config.model)
|
263 |
+
self.encoder_vq = sd_model.first_stage_model
|
264 |
+
self.unet = UNetWrapper(sd_model.model, base_size=base_size)
|
265 |
+
del sd_model.cond_stage_model
|
266 |
+
del self.encoder_vq.decoder
|
267 |
+
for param in self.encoder_vq.parameters():
|
268 |
+
param.requires_grad = True
|
269 |
+
|
270 |
+
self.text_adapter = TextAdapterRefer(text_dim=token_embed_dim)
|
271 |
+
|
272 |
+
self.classifier = SimpleDecoding(dims=neck_dim)
|
273 |
+
|
274 |
+
self.gamma = nn.Parameter(torch.ones(token_embed_dim) * 1e-4)
|
275 |
+
self.aggregation = InverseMultiAttentiveFeatureRefinement([320,680,1320,1280])
|
276 |
+
self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
277 |
+
for param in self.clip_model.parameters():
|
278 |
+
param.requires_grad = True
|
279 |
+
|
280 |
+
|
281 |
+
def forward(self, img, sentences):
|
282 |
+
input_shape = img.shape[-2:]
|
283 |
+
|
284 |
+
latents = self.encoder_vq.encode(img).mode()
|
285 |
+
latents = latents / 4.7164
|
286 |
+
|
287 |
+
l_feats = self.clip_model(input_ids=sentences).last_hidden_state
|
288 |
+
c_crossattn = self.text_adapter(latents, l_feats, self.gamma) # NOTE: here the c_crossattn should be expand_dim as latents
|
289 |
+
t = torch.ones((img.shape[0],), device=img.device).long()
|
290 |
+
outs = self.unet(latents, t, c_crossattn=[c_crossattn])
|
291 |
+
|
292 |
+
outs = self.aggregation(outs)
|
293 |
+
|
294 |
+
x_c1, x_c2, x_c3, x_c4 = outs
|
295 |
+
x = self.classifier(x_c4, x_c3, x_c2, x_c1)
|
296 |
+
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
|
297 |
+
|
298 |
+
return x
|
299 |
+
|
300 |
+
def get_latent(self, x):
|
301 |
+
return self.encoder_vq.encode(x).mode()
|
refer/requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
requests
|
2 |
+
filelock
|
3 |
+
tqdm
|
4 |
+
timm
|
5 |
+
ftfy
|
6 |
+
regex
|
7 |
+
scipy
|
8 |
+
scikit-image
|
9 |
+
pycocotools==2.0.2
|
10 |
+
opencv-python==4.5.3.56
|
11 |
+
tokenizers
|
12 |
+
h5py
|
refer/transforms.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import random
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms as T
|
7 |
+
from torchvision.transforms import functional as F
|
8 |
+
|
9 |
+
import warnings
|
10 |
+
warnings.filterwarnings("ignore")
|
11 |
+
|
12 |
+
def pad_if_smaller(img, size, fill=0):
|
13 |
+
min_size = min(img.size)
|
14 |
+
if min_size < size:
|
15 |
+
ow, oh = img.size
|
16 |
+
padh = size - oh if oh < size else 0
|
17 |
+
padw = size - ow if ow < size else 0
|
18 |
+
img = F.pad(img, (0, 0, padw, padh), fill=fill)
|
19 |
+
return img
|
20 |
+
|
21 |
+
|
22 |
+
class Compose(object):
|
23 |
+
def __init__(self, transforms):
|
24 |
+
self.transforms = transforms
|
25 |
+
|
26 |
+
def __call__(self, image, target):
|
27 |
+
for t in self.transforms:
|
28 |
+
image, target = t(image, target)
|
29 |
+
return image, target
|
30 |
+
|
31 |
+
|
32 |
+
class Resize(object):
|
33 |
+
def __init__(self, h, w):
|
34 |
+
self.h = h
|
35 |
+
self.w = w
|
36 |
+
|
37 |
+
def __call__(self, image, target):
|
38 |
+
image = F.resize(image, (self.h, self.w))
|
39 |
+
# If size is a sequence like (h, w), the output size will be matched to this.
|
40 |
+
# If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio
|
41 |
+
target = F.resize(target, (self.h, self.w))
|
42 |
+
return image, target
|
43 |
+
|
44 |
+
|
45 |
+
class RandomResize(object):
|
46 |
+
def __init__(self, min_size, max_size=None):
|
47 |
+
self.min_size = min_size
|
48 |
+
if max_size is None:
|
49 |
+
max_size = min_size
|
50 |
+
self.max_size = max_size
|
51 |
+
|
52 |
+
def __call__(self, image, target):
|
53 |
+
size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1)
|
54 |
+
image = F.resize(image, size)
|
55 |
+
# If size is a sequence like (h, w), the output size will be matched to this.
|
56 |
+
# If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio
|
57 |
+
target = F.resize(target, size)
|
58 |
+
return image, target
|
59 |
+
|
60 |
+
|
61 |
+
class RandomHorizontalFlip(object):
|
62 |
+
def __init__(self, flip_prob):
|
63 |
+
self.flip_prob = flip_prob
|
64 |
+
|
65 |
+
def __call__(self, image, target):
|
66 |
+
if random.random() < self.flip_prob:
|
67 |
+
image = F.hflip(image)
|
68 |
+
target = F.hflip(target)
|
69 |
+
return image, target
|
70 |
+
|
71 |
+
|
72 |
+
class RandomCrop(object):
|
73 |
+
def __init__(self, size):
|
74 |
+
self.size = size
|
75 |
+
|
76 |
+
def __call__(self, image, target):
|
77 |
+
image = pad_if_smaller(image, self.size)
|
78 |
+
target = pad_if_smaller(target, self.size, fill=255)
|
79 |
+
crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
|
80 |
+
image = F.crop(image, *crop_params)
|
81 |
+
target = F.crop(target, *crop_params)
|
82 |
+
return image, target
|
83 |
+
|
84 |
+
|
85 |
+
class CenterCrop(object):
|
86 |
+
def __init__(self, size):
|
87 |
+
self.size = size
|
88 |
+
|
89 |
+
def __call__(self, image, target):
|
90 |
+
image = F.center_crop(image, self.size)
|
91 |
+
target = F.center_crop(target, self.size)
|
92 |
+
return image, target
|
93 |
+
|
94 |
+
|
95 |
+
class ToTensor(object):
|
96 |
+
def __call__(self, image, target):
|
97 |
+
image = F.to_tensor(image)
|
98 |
+
target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64)
|
99 |
+
return image, target
|
100 |
+
|
101 |
+
|
102 |
+
class RandomAffine(object):
|
103 |
+
def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None):
|
104 |
+
self.angle = angle
|
105 |
+
self.translate = translate
|
106 |
+
self.scale = scale
|
107 |
+
self.shear = shear
|
108 |
+
self.resample = resample
|
109 |
+
self.fillcolor = fillcolor
|
110 |
+
|
111 |
+
def __call__(self, image, target):
|
112 |
+
affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size)
|
113 |
+
image = F.affine(image, *affine_params)
|
114 |
+
target = F.affine(target, *affine_params)
|
115 |
+
return image, target
|
116 |
+
|
117 |
+
|
118 |
+
class Normalize(object):
|
119 |
+
def __init__(self, mean, std):
|
120 |
+
self.mean = mean
|
121 |
+
self.std = std
|
122 |
+
|
123 |
+
def __call__(self, image, target):
|
124 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
125 |
+
return image, target
|
126 |
+
|
refer/utils.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
from collections import defaultdict, deque
|
3 |
+
import datetime
|
4 |
+
import math
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
|
10 |
+
import errno
|
11 |
+
import os
|
12 |
+
|
13 |
+
import sys
|
14 |
+
|
15 |
+
|
16 |
+
class SmoothedValue(object):
|
17 |
+
"""Track a series of values and provide access to smoothed values over a
|
18 |
+
window or the global series average.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, window_size=20, fmt=None):
|
22 |
+
if fmt is None:
|
23 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
24 |
+
self.deque = deque(maxlen=window_size)
|
25 |
+
self.total = 0.0
|
26 |
+
self.count = 0
|
27 |
+
self.fmt = fmt
|
28 |
+
|
29 |
+
def update(self, value, n=1):
|
30 |
+
self.deque.append(value)
|
31 |
+
self.count += n
|
32 |
+
self.total += value * n
|
33 |
+
|
34 |
+
def synchronize_between_processes(self):
|
35 |
+
"""
|
36 |
+
Warning: does not synchronize the deque!
|
37 |
+
"""
|
38 |
+
if not is_dist_avail_and_initialized():
|
39 |
+
return
|
40 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
41 |
+
dist.barrier()
|
42 |
+
dist.all_reduce(t)
|
43 |
+
t = t.tolist()
|
44 |
+
self.count = int(t[0])
|
45 |
+
self.total = t[1]
|
46 |
+
|
47 |
+
@property
|
48 |
+
def median(self):
|
49 |
+
d = torch.tensor(list(self.deque))
|
50 |
+
return d.median().item()
|
51 |
+
|
52 |
+
@property
|
53 |
+
def avg(self):
|
54 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
55 |
+
return d.mean().item()
|
56 |
+
|
57 |
+
@property
|
58 |
+
def global_avg(self):
|
59 |
+
return self.total / self.count
|
60 |
+
|
61 |
+
@property
|
62 |
+
def max(self):
|
63 |
+
return max(self.deque)
|
64 |
+
|
65 |
+
@property
|
66 |
+
def value(self):
|
67 |
+
return self.deque[-1]
|
68 |
+
|
69 |
+
def __str__(self):
|
70 |
+
return self.fmt.format(
|
71 |
+
median=self.median,
|
72 |
+
avg=self.avg,
|
73 |
+
global_avg=self.global_avg,
|
74 |
+
max=self.max,
|
75 |
+
value=self.value)
|
76 |
+
|
77 |
+
|
78 |
+
class MetricLogger(object):
|
79 |
+
def __init__(self, delimiter="\t"):
|
80 |
+
self.meters = defaultdict(SmoothedValue)
|
81 |
+
self.delimiter = delimiter
|
82 |
+
|
83 |
+
def update(self, **kwargs):
|
84 |
+
for k, v in kwargs.items():
|
85 |
+
if isinstance(v, torch.Tensor):
|
86 |
+
v = v.item()
|
87 |
+
assert isinstance(v, (float, int))
|
88 |
+
self.meters[k].update(v)
|
89 |
+
|
90 |
+
def __getattr__(self, attr):
|
91 |
+
if attr in self.meters:
|
92 |
+
return self.meters[attr]
|
93 |
+
if attr in self.__dict__:
|
94 |
+
return self.__dict__[attr]
|
95 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
96 |
+
type(self).__name__, attr))
|
97 |
+
|
98 |
+
def __str__(self):
|
99 |
+
loss_str = []
|
100 |
+
for name, meter in self.meters.items():
|
101 |
+
loss_str.append(
|
102 |
+
"{}: {}".format(name, str(meter))
|
103 |
+
)
|
104 |
+
return self.delimiter.join(loss_str)
|
105 |
+
|
106 |
+
def synchronize_between_processes(self):
|
107 |
+
for meter in self.meters.values():
|
108 |
+
meter.synchronize_between_processes()
|
109 |
+
|
110 |
+
def add_meter(self, name, meter):
|
111 |
+
self.meters[name] = meter
|
112 |
+
|
113 |
+
def log_every(self, iterable, print_freq, header=None):
|
114 |
+
i = 0
|
115 |
+
if not header:
|
116 |
+
header = ''
|
117 |
+
start_time = time.time()
|
118 |
+
end = time.time()
|
119 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
120 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
121 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
122 |
+
log_msg = self.delimiter.join([
|
123 |
+
header,
|
124 |
+
'[{0' + space_fmt + '}/{1}]',
|
125 |
+
'eta: {eta}',
|
126 |
+
'{meters}',
|
127 |
+
'time: {time}',
|
128 |
+
'data: {data}',
|
129 |
+
'max mem: {memory:.0f}'
|
130 |
+
])
|
131 |
+
MB = 1024.0 * 1024.0
|
132 |
+
for obj in iterable:
|
133 |
+
data_time.update(time.time() - end)
|
134 |
+
yield obj
|
135 |
+
iter_time.update(time.time() - end)
|
136 |
+
if i % print_freq == 0:
|
137 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
138 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
139 |
+
print(log_msg.format(
|
140 |
+
i, len(iterable), eta=eta_string,
|
141 |
+
meters=str(self),
|
142 |
+
time=str(iter_time), data=str(data_time),
|
143 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
144 |
+
sys.stdout.flush()
|
145 |
+
|
146 |
+
i += 1
|
147 |
+
end = time.time()
|
148 |
+
total_time = time.time() - start_time
|
149 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
150 |
+
print('{} Total time: {}'.format(header, total_time_str))
|
151 |
+
|
152 |
+
|
153 |
+
def mkdir(path):
|
154 |
+
try:
|
155 |
+
os.makedirs(path)
|
156 |
+
except OSError as e:
|
157 |
+
if e.errno != errno.EEXIST:
|
158 |
+
raise
|
159 |
+
|
160 |
+
|
161 |
+
def setup_for_distributed(is_master):
|
162 |
+
"""
|
163 |
+
This function disables printing when not in master process
|
164 |
+
"""
|
165 |
+
import builtins as __builtin__
|
166 |
+
builtin_print = __builtin__.print
|
167 |
+
|
168 |
+
def print(*args, **kwargs):
|
169 |
+
force = kwargs.pop('force', False)
|
170 |
+
if is_master or force:
|
171 |
+
builtin_print(*args, **kwargs)
|
172 |
+
|
173 |
+
__builtin__.print = print
|
174 |
+
|
175 |
+
|
176 |
+
def is_dist_avail_and_initialized():
|
177 |
+
if not dist.is_available():
|
178 |
+
return False
|
179 |
+
if not dist.is_initialized():
|
180 |
+
return False
|
181 |
+
return True
|
182 |
+
|
183 |
+
|
184 |
+
def get_world_size():
|
185 |
+
if not is_dist_avail_and_initialized():
|
186 |
+
return 1
|
187 |
+
return dist.get_world_size()
|
188 |
+
|
189 |
+
|
190 |
+
def get_rank():
|
191 |
+
if not is_dist_avail_and_initialized():
|
192 |
+
return 0
|
193 |
+
return dist.get_rank()
|
194 |
+
|
195 |
+
|
196 |
+
def is_main_process():
|
197 |
+
return get_rank() == 0
|
198 |
+
|
199 |
+
|
200 |
+
def save_on_master(*args, **kwargs):
|
201 |
+
if is_main_process():
|
202 |
+
torch.save(*args, **kwargs)
|
203 |
+
|
204 |
+
|
205 |
+
def init_distributed_mode(args):
|
206 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
207 |
+
rank = int(os.environ["RANK"])
|
208 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
209 |
+
print(f"RANK and WORLD_SIZE in environment: {rank}/{world_size}")
|
210 |
+
else:
|
211 |
+
rank = -1
|
212 |
+
world_size = -1
|
213 |
+
|
214 |
+
torch.cuda.set_device(args.local_rank)
|
215 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
|
216 |
+
torch.distributed.barrier()
|
217 |
+
setup_for_distributed(is_main_process())
|
218 |
+
|
219 |
+
if args.output_dir:
|
220 |
+
mkdir(args.output_dir)
|
221 |
+
if args.model_id:
|
222 |
+
mkdir(os.path.join('./models/', args.model_id))
|
refer/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|