Training Code:cls/det
Browse files- training/.gitignore +250 -0
- training/Detection/README.md +47 -0
- training/Detection/configs/_base_/models/cascade_mask_rcnn_revcol_fpn.py +209 -0
- training/Detection/configs/revcol/cascade_mask_rcnn_revcol_base_3x_in1k.py +152 -0
- training/Detection/configs/revcol/cascade_mask_rcnn_revcol_base_3x_in22k.py +152 -0
- training/Detection/configs/revcol/cascade_mask_rcnn_revcol_large_3x_in22k.py +152 -0
- training/Detection/configs/revcol/cascade_mask_rcnn_revcol_small_3x_in1k.py +152 -0
- training/Detection/configs/revcol/cascade_mask_rcnn_revcol_tiny_3x_in1k.py +152 -0
- training/Detection/mmcv_custom/__init__.py +15 -0
- training/Detection/mmcv_custom/checkpoint.py +484 -0
- training/Detection/mmcv_custom/customized_text.py +130 -0
- training/Detection/mmcv_custom/layer_decay_optimizer_constructor.py +121 -0
- training/Detection/mmcv_custom/runner/checkpoint.py +85 -0
- training/Detection/mmdet/models/backbones/__init__.py +28 -0
- training/Detection/mmdet/models/backbones/revcol.py +187 -0
- training/Detection/mmdet/models/backbones/revcol_function.py +222 -0
- training/Detection/mmdet/models/backbones/revcol_module.py +85 -0
- training/Detection/mmdet/utils/__init__.py +12 -0
- training/Detection/mmdet/utils/optimizer.py +33 -0
- training/INSTRUCTIONS.md +158 -0
- training/LICENSE +190 -0
- training/README.md +79 -0
- training/config.py +243 -0
- training/configs/revcol_base_1k.yaml +48 -0
- training/configs/revcol_base_1k_224_finetune.yaml +50 -0
- training/configs/revcol_base_1k_384_finetune.yaml +50 -0
- training/configs/revcol_base_22k_pretrain.yaml +51 -0
- training/configs/revcol_large_1k_224_finetune.yaml +51 -0
- training/configs/revcol_large_1k_384_finetune.yaml +51 -0
- training/configs/revcol_large_22k_pretrain.yaml +50 -0
- training/configs/revcol_small_1k.yaml +48 -0
- training/configs/revcol_tiny_1k.yaml +48 -0
- training/configs/revcol_xlarge_1k_384_finetune.yaml +53 -0
- training/configs/revcol_xlarge_22k_pretrain.yaml +50 -0
- training/data/__init__.py +1 -0
- training/data/build_data.py +137 -0
- training/data/samplers.py +29 -0
- training/figures/title.png +0 -0
- training/logger.py +41 -0
- training/loss.py +35 -0
- training/lr_scheduler.py +96 -0
- training/main.py +422 -0
- training/models/__init__.py +1 -0
- training/models/build.py +48 -0
- training/models/modules.py +157 -0
- training/models/revcol.py +242 -0
- training/models/revcol_function.py +159 -0
- training/optimizer.py +145 -0
- training/requirements.txt +7 -0
- training/utils.py +179 -0
training/.gitignore
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/en/_build/
|
68 |
+
docs/zh_cn/_build/
|
69 |
+
|
70 |
+
# PyBuilder
|
71 |
+
target/
|
72 |
+
|
73 |
+
# Jupyter Notebook
|
74 |
+
.ipynb_checkpoints
|
75 |
+
|
76 |
+
# pyenv
|
77 |
+
.python-version
|
78 |
+
|
79 |
+
# celery beat schedule file
|
80 |
+
celerybeat-schedule
|
81 |
+
|
82 |
+
# SageMath parsed files
|
83 |
+
*.sage.py
|
84 |
+
|
85 |
+
# Environments
|
86 |
+
.env
|
87 |
+
.venv
|
88 |
+
env/
|
89 |
+
venv/
|
90 |
+
ENV/
|
91 |
+
env.bak/
|
92 |
+
venv.bak/
|
93 |
+
|
94 |
+
# Spyder project settings
|
95 |
+
.spyderproject
|
96 |
+
.spyproject
|
97 |
+
|
98 |
+
# Rope project settings
|
99 |
+
.ropeproject
|
100 |
+
|
101 |
+
# mkdocs documentation
|
102 |
+
/site
|
103 |
+
|
104 |
+
# mypy
|
105 |
+
.mypy_cache/
|
106 |
+
|
107 |
+
data/
|
108 |
+
data
|
109 |
+
.vscode
|
110 |
+
.idea
|
111 |
+
.DS_Store
|
112 |
+
|
113 |
+
# custom
|
114 |
+
*.pkl
|
115 |
+
*.pkl.json
|
116 |
+
*.log.json
|
117 |
+
docs/modelzoo_statistics.md
|
118 |
+
mmdet/.mim
|
119 |
+
work_dirs/
|
120 |
+
# Byte-compiled / optimized / DLL files
|
121 |
+
__pycache__/
|
122 |
+
*.py[cod]
|
123 |
+
*$py.class
|
124 |
+
|
125 |
+
# C extensions
|
126 |
+
*.so
|
127 |
+
|
128 |
+
# Distribution / packaging
|
129 |
+
.Python
|
130 |
+
build/
|
131 |
+
develop-eggs/
|
132 |
+
dist/
|
133 |
+
downloads/
|
134 |
+
eggs/
|
135 |
+
.eggs/
|
136 |
+
lib/
|
137 |
+
lib64/
|
138 |
+
parts/
|
139 |
+
sdist/
|
140 |
+
var/
|
141 |
+
wheels/
|
142 |
+
*.egg-info/
|
143 |
+
.installed.cfg
|
144 |
+
*.egg
|
145 |
+
MANIFEST
|
146 |
+
|
147 |
+
# PyInstaller
|
148 |
+
# Usually these files are written by a python script from a template
|
149 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
150 |
+
*.manifest
|
151 |
+
*.spec
|
152 |
+
|
153 |
+
# Installer logs
|
154 |
+
pip-log.txt
|
155 |
+
pip-delete-this-directory.txt
|
156 |
+
|
157 |
+
# Unit test / coverage reports
|
158 |
+
htmlcov/
|
159 |
+
.tox/
|
160 |
+
.coverage
|
161 |
+
.coverage.*
|
162 |
+
.cache
|
163 |
+
nosetests.xml
|
164 |
+
coverage.xml
|
165 |
+
*.cover
|
166 |
+
.hypothesis/
|
167 |
+
.pytest_cache/
|
168 |
+
|
169 |
+
# Translations
|
170 |
+
*.mo
|
171 |
+
*.pot
|
172 |
+
|
173 |
+
# Django stuff:
|
174 |
+
*.log
|
175 |
+
local_settings.py
|
176 |
+
db.sqlite3
|
177 |
+
|
178 |
+
# Flask stuff:
|
179 |
+
instance/
|
180 |
+
.webassets-cache
|
181 |
+
|
182 |
+
# Scrapy stuff:
|
183 |
+
.scrapy
|
184 |
+
|
185 |
+
# Sphinx documentation
|
186 |
+
docs/en/_build/
|
187 |
+
docs/zh_cn/_build/
|
188 |
+
|
189 |
+
# PyBuilder
|
190 |
+
target/
|
191 |
+
|
192 |
+
# Jupyter Notebook
|
193 |
+
.ipynb_checkpoints
|
194 |
+
|
195 |
+
# pyenv
|
196 |
+
.python-version
|
197 |
+
|
198 |
+
# celery beat schedule file
|
199 |
+
celerybeat-schedule
|
200 |
+
|
201 |
+
# SageMath parsed files
|
202 |
+
*.sage.py
|
203 |
+
|
204 |
+
# Environments
|
205 |
+
.env
|
206 |
+
.venv
|
207 |
+
env/
|
208 |
+
venv/
|
209 |
+
ENV/
|
210 |
+
env.bak/
|
211 |
+
venv.bak/
|
212 |
+
|
213 |
+
# Spyder project settings
|
214 |
+
.spyderproject
|
215 |
+
.spyproject
|
216 |
+
|
217 |
+
# Rope project settings
|
218 |
+
.ropeproject
|
219 |
+
|
220 |
+
# mkdocs documentation
|
221 |
+
/site
|
222 |
+
|
223 |
+
# mypy
|
224 |
+
.mypy_cache/
|
225 |
+
|
226 |
+
data/
|
227 |
+
data
|
228 |
+
.vscode
|
229 |
+
.idea
|
230 |
+
.DS_Store
|
231 |
+
|
232 |
+
# custom
|
233 |
+
*.pkl
|
234 |
+
*.pkl.json
|
235 |
+
*.log.json
|
236 |
+
docs/modelzoo_statistics.md
|
237 |
+
mmdet/.mim
|
238 |
+
work_dirs/
|
239 |
+
.DS_Store
|
240 |
+
|
241 |
+
# Pytorch
|
242 |
+
*.pth
|
243 |
+
*.py~
|
244 |
+
*.sh~
|
245 |
+
.DS_
|
246 |
+
# Pytorch
|
247 |
+
*.pth
|
248 |
+
*.py~
|
249 |
+
*.sh~
|
250 |
+
|
training/Detection/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# COCO Object detection with RevCol
|
2 |
+
|
3 |
+
## Getting started
|
4 |
+
|
5 |
+
We build RevCol object detection model based on [mmdetection](https://github.com/open-mmlab/mmdetection/tree/3e2693151add9b5d6db99b944da020cba837266b) commit `3e26931`. We add RevCol model and config files to [the original repo](https://github.com/open-mmlab/mmdetection/tree/3e2693151add9b5d6db99b944da020cba837266b). Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/3e2693151add9b5d6db99b944da020cba837266b/docs/en/get_started.md) for installation and dataset preparation instructions.
|
6 |
+
|
7 |
+
## Results and Fine-tuned Models
|
8 |
+
|
9 |
+
| name | Pretrained Model | Method | Lr Schd | box mAP | mask mAP | #params | FLOPs | Fine-tuned Model |
|
10 |
+
|:---:|:---:|:---:|:---:| :---:|:---:|:---:|:---:| :---:|
|
11 |
+
| RevCol-T | [ImageNet-1K]() | Cascade Mask R-CNN | 3x | 50.6 | 43.8 | 88M | 741G | [model]() |
|
12 |
+
| RevCol-S | [ImageNet-1K]() | Cascade Mask R-CNN | 3x | 52.6 | 45.5 | 118M | 833G | [model]() |
|
13 |
+
| RevCol-B | [ImageNet-1K]() | Cascade Mask R-CNN | 3x | 53.0 | 45.9 | 196M | 988G | [model]() |
|
14 |
+
| RevCol-B | [ImageNet-22K]() | Cascade Mask R-CNN | 3x | 55.0 | 47.5 | 196M | 988G | [model]() |
|
15 |
+
| RevCol-L | [ImageNet-22K]() | Cascade Mask R-CNN | 3x | 55.9 | 48.4 | 330M | 1453G | [model]() |
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
## Training
|
20 |
+
|
21 |
+
To train a detector with pre-trained models, run:
|
22 |
+
```
|
23 |
+
# single-gpu training
|
24 |
+
python tools/train.py <CONFIG_FILE> --cfg-options model.pretrained=<PRETRAIN_MODEL> [other optional arguments]
|
25 |
+
|
26 |
+
# multi-gpu training
|
27 |
+
tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [other optional arguments]
|
28 |
+
```
|
29 |
+
For example, to train a Cascade Mask R-CNN model with a `RevCol-T` backbone and 8 gpus, run:
|
30 |
+
```
|
31 |
+
tools/dist_train.sh configs/revcol/cascade_mask_rcnn_revcol_tiny_3x_in1k.py 8 --cfg-options pretrained=<PRETRAIN_MODEL>
|
32 |
+
```
|
33 |
+
|
34 |
+
More config files can be found at [`configs/revcol`](configs/revcol).
|
35 |
+
|
36 |
+
## Inference
|
37 |
+
```
|
38 |
+
# single-gpu testing
|
39 |
+
python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm
|
40 |
+
|
41 |
+
# multi-gpu testing
|
42 |
+
tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm
|
43 |
+
```
|
44 |
+
|
45 |
+
## Acknowledgment
|
46 |
+
|
47 |
+
This code is built using [mmdetection](https://github.com/open-mmlab/mmdetection), [timm](https://github.com/rwightman/pytorch-image-models) libraries, and [BeiT](https://github.com/microsoft/unilm/tree/f8f3df80c65eb5e5fc6d6d3c9bd3137621795d1e/beit), [Swin Transformer](https://github.com/microsoft/Swin-Transformer) repositories.
|
training/Detection/configs/_base_/models/cascade_mask_rcnn_revcol_fpn.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
# model settings
|
10 |
+
pretrained = None
|
11 |
+
model = dict(
|
12 |
+
type='CascadeRCNN',
|
13 |
+
backbone=dict(
|
14 |
+
type='RevCol',
|
15 |
+
channels=[48, 96, 192, 384],
|
16 |
+
layers=[3, 3, 9, 3],
|
17 |
+
num_subnet=4,
|
18 |
+
drop_path=0.2,
|
19 |
+
save_memory=False,
|
20 |
+
out_indices=[0, 1, 2, 3],
|
21 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)
|
22 |
+
),
|
23 |
+
neck=dict(
|
24 |
+
type='FPN',
|
25 |
+
in_channels=[128, 256, 512, 1024],
|
26 |
+
out_channels=256,
|
27 |
+
num_outs=5),
|
28 |
+
rpn_head=dict(
|
29 |
+
type='RPNHead',
|
30 |
+
in_channels=256,
|
31 |
+
feat_channels=256,
|
32 |
+
anchor_generator=dict(
|
33 |
+
type='AnchorGenerator',
|
34 |
+
scales=[8],
|
35 |
+
ratios=[0.5, 1.0, 2.0],
|
36 |
+
strides=[4, 8, 16, 32, 64]),
|
37 |
+
bbox_coder=dict(
|
38 |
+
type='DeltaXYWHBBoxCoder',
|
39 |
+
target_means=[.0, .0, .0, .0],
|
40 |
+
target_stds=[1.0, 1.0, 1.0, 1.0]),
|
41 |
+
loss_cls=dict(
|
42 |
+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
|
43 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
|
44 |
+
roi_head=dict(
|
45 |
+
type='CascadeRoIHead',
|
46 |
+
num_stages=3,
|
47 |
+
stage_loss_weights=[1, 0.5, 0.25],
|
48 |
+
bbox_roi_extractor=dict(
|
49 |
+
type='SingleRoIExtractor',
|
50 |
+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
|
51 |
+
out_channels=256,
|
52 |
+
featmap_strides=[4, 8, 16, 32]),
|
53 |
+
bbox_head=[
|
54 |
+
dict(
|
55 |
+
type='Shared2FCBBoxHead',
|
56 |
+
in_channels=256,
|
57 |
+
fc_out_channels=1024,
|
58 |
+
roi_feat_size=7,
|
59 |
+
num_classes=80,
|
60 |
+
bbox_coder=dict(
|
61 |
+
type='DeltaXYWHBBoxCoder',
|
62 |
+
target_means=[0., 0., 0., 0.],
|
63 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
64 |
+
reg_class_agnostic=True,
|
65 |
+
loss_cls=dict(
|
66 |
+
type='CrossEntropyLoss',
|
67 |
+
use_sigmoid=False,
|
68 |
+
loss_weight=1.0),
|
69 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
70 |
+
loss_weight=1.0)),
|
71 |
+
dict(
|
72 |
+
type='Shared2FCBBoxHead',
|
73 |
+
in_channels=256,
|
74 |
+
fc_out_channels=1024,
|
75 |
+
roi_feat_size=7,
|
76 |
+
num_classes=80,
|
77 |
+
bbox_coder=dict(
|
78 |
+
type='DeltaXYWHBBoxCoder',
|
79 |
+
target_means=[0., 0., 0., 0.],
|
80 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
81 |
+
reg_class_agnostic=True,
|
82 |
+
loss_cls=dict(
|
83 |
+
type='CrossEntropyLoss',
|
84 |
+
use_sigmoid=False,
|
85 |
+
loss_weight=1.0),
|
86 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
|
87 |
+
loss_weight=1.0)),
|
88 |
+
dict(
|
89 |
+
type='Shared2FCBBoxHead',
|
90 |
+
in_channels=256,
|
91 |
+
fc_out_channels=1024,
|
92 |
+
roi_feat_size=7,
|
93 |
+
num_classes=80,
|
94 |
+
bbox_coder=dict(
|
95 |
+
type='DeltaXYWHBBoxCoder',
|
96 |
+
target_means=[0., 0., 0., 0.],
|
97 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
98 |
+
reg_class_agnostic=True,
|
99 |
+
loss_cls=dict(
|
100 |
+
type='CrossEntropyLoss',
|
101 |
+
use_sigmoid=False,
|
102 |
+
loss_weight=1.0),
|
103 |
+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
|
104 |
+
],
|
105 |
+
mask_roi_extractor=dict(
|
106 |
+
type='SingleRoIExtractor',
|
107 |
+
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
|
108 |
+
out_channels=256,
|
109 |
+
featmap_strides=[4, 8, 16, 32]),
|
110 |
+
mask_head=dict(
|
111 |
+
type='FCNMaskHead',
|
112 |
+
num_convs=4,
|
113 |
+
in_channels=256,
|
114 |
+
conv_out_channels=256,
|
115 |
+
num_classes=80,
|
116 |
+
loss_mask=dict(
|
117 |
+
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
|
118 |
+
# model training and testing settings
|
119 |
+
train_cfg = dict(
|
120 |
+
rpn=dict(
|
121 |
+
assigner=dict(
|
122 |
+
type='MaxIoUAssigner',
|
123 |
+
pos_iou_thr=0.7,
|
124 |
+
neg_iou_thr=0.3,
|
125 |
+
min_pos_iou=0.3,
|
126 |
+
match_low_quality=True,
|
127 |
+
ignore_iof_thr=-1),
|
128 |
+
sampler=dict(
|
129 |
+
type='RandomSampler',
|
130 |
+
num=256,
|
131 |
+
pos_fraction=0.5,
|
132 |
+
neg_pos_ub=-1,
|
133 |
+
add_gt_as_proposals=False),
|
134 |
+
allowed_border=0,
|
135 |
+
pos_weight=-1,
|
136 |
+
debug=False),
|
137 |
+
rpn_proposal=dict(
|
138 |
+
nms_across_levels=False,
|
139 |
+
nms_pre=2000,
|
140 |
+
nms_post=2000,
|
141 |
+
max_per_img=2000,
|
142 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
143 |
+
min_bbox_size=0),
|
144 |
+
rcnn=[
|
145 |
+
dict(
|
146 |
+
assigner=dict(
|
147 |
+
type='MaxIoUAssigner',
|
148 |
+
pos_iou_thr=0.5,
|
149 |
+
neg_iou_thr=0.5,
|
150 |
+
min_pos_iou=0.5,
|
151 |
+
match_low_quality=False,
|
152 |
+
ignore_iof_thr=-1),
|
153 |
+
sampler=dict(
|
154 |
+
type='RandomSampler',
|
155 |
+
num=512,
|
156 |
+
pos_fraction=0.25,
|
157 |
+
neg_pos_ub=-1,
|
158 |
+
add_gt_as_proposals=True),
|
159 |
+
mask_size=28,
|
160 |
+
pos_weight=-1,
|
161 |
+
debug=False),
|
162 |
+
dict(
|
163 |
+
assigner=dict(
|
164 |
+
type='MaxIoUAssigner',
|
165 |
+
pos_iou_thr=0.6,
|
166 |
+
neg_iou_thr=0.6,
|
167 |
+
min_pos_iou=0.6,
|
168 |
+
match_low_quality=False,
|
169 |
+
ignore_iof_thr=-1),
|
170 |
+
sampler=dict(
|
171 |
+
type='RandomSampler',
|
172 |
+
num=512,
|
173 |
+
pos_fraction=0.25,
|
174 |
+
neg_pos_ub=-1,
|
175 |
+
add_gt_as_proposals=True),
|
176 |
+
mask_size=28,
|
177 |
+
pos_weight=-1,
|
178 |
+
debug=False),
|
179 |
+
dict(
|
180 |
+
assigner=dict(
|
181 |
+
type='MaxIoUAssigner',
|
182 |
+
pos_iou_thr=0.7,
|
183 |
+
neg_iou_thr=0.7,
|
184 |
+
min_pos_iou=0.7,
|
185 |
+
match_low_quality=False,
|
186 |
+
ignore_iof_thr=-1),
|
187 |
+
sampler=dict(
|
188 |
+
type='RandomSampler',
|
189 |
+
num=512,
|
190 |
+
pos_fraction=0.25,
|
191 |
+
neg_pos_ub=-1,
|
192 |
+
add_gt_as_proposals=True),
|
193 |
+
mask_size=28,
|
194 |
+
pos_weight=-1,
|
195 |
+
debug=False)
|
196 |
+
]),
|
197 |
+
test_cfg = dict(
|
198 |
+
rpn=dict(
|
199 |
+
nms_across_levels=False,
|
200 |
+
nms_pre=1000,
|
201 |
+
nms_post=1000,
|
202 |
+
max_per_img=1000,
|
203 |
+
nms=dict(type='nms', iou_threshold=0.7),
|
204 |
+
min_bbox_size=0),
|
205 |
+
rcnn=dict(
|
206 |
+
score_thr=0.05,
|
207 |
+
nms=dict(type='nms', iou_threshold=0.5),
|
208 |
+
max_per_img=100,
|
209 |
+
mask_thr_binary=0.5)))
|
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_base_3x_in1k.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
_base_ = [
|
10 |
+
'../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
|
11 |
+
'../_base_/datasets/coco_instance.py',
|
12 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
13 |
+
]
|
14 |
+
pretrained = './cls_model/revcol_base_1k.pth'
|
15 |
+
model = dict(
|
16 |
+
backbone=dict(
|
17 |
+
channels = [72, 144, 288, 576],
|
18 |
+
layers=[1, 1, 3, 2],
|
19 |
+
num_subnet=16,
|
20 |
+
drop_path = 0.4,
|
21 |
+
save_memory=False,
|
22 |
+
out_indices=[0, 1, 2, 3],
|
23 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)
|
24 |
+
),
|
25 |
+
neck=dict(in_channels=[72, 144, 288, 576]),
|
26 |
+
roi_head=dict(
|
27 |
+
bbox_head=[
|
28 |
+
dict(
|
29 |
+
type='ConvFCBBoxHead',
|
30 |
+
num_shared_convs=4,
|
31 |
+
num_shared_fcs=1,
|
32 |
+
in_channels=256,
|
33 |
+
conv_out_channels=256,
|
34 |
+
fc_out_channels=1024,
|
35 |
+
roi_feat_size=7,
|
36 |
+
num_classes=80,
|
37 |
+
bbox_coder=dict(
|
38 |
+
type='DeltaXYWHBBoxCoder',
|
39 |
+
target_means=[0., 0., 0., 0.],
|
40 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
41 |
+
reg_class_agnostic=False,
|
42 |
+
reg_decoded_bbox=True,
|
43 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
44 |
+
loss_cls=dict(
|
45 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
46 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
47 |
+
dict(
|
48 |
+
type='ConvFCBBoxHead',
|
49 |
+
num_shared_convs=4,
|
50 |
+
num_shared_fcs=1,
|
51 |
+
in_channels=256,
|
52 |
+
conv_out_channels=256,
|
53 |
+
fc_out_channels=1024,
|
54 |
+
roi_feat_size=7,
|
55 |
+
num_classes=80,
|
56 |
+
bbox_coder=dict(
|
57 |
+
type='DeltaXYWHBBoxCoder',
|
58 |
+
target_means=[0., 0., 0., 0.],
|
59 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
60 |
+
reg_class_agnostic=False,
|
61 |
+
reg_decoded_bbox=True,
|
62 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
63 |
+
loss_cls=dict(
|
64 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
65 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
66 |
+
dict(
|
67 |
+
type='ConvFCBBoxHead',
|
68 |
+
num_shared_convs=4,
|
69 |
+
num_shared_fcs=1,
|
70 |
+
in_channels=256,
|
71 |
+
conv_out_channels=256,
|
72 |
+
fc_out_channels=1024,
|
73 |
+
roi_feat_size=7,
|
74 |
+
num_classes=80,
|
75 |
+
bbox_coder=dict(
|
76 |
+
type='DeltaXYWHBBoxCoder',
|
77 |
+
target_means=[0., 0., 0., 0.],
|
78 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
79 |
+
reg_class_agnostic=False,
|
80 |
+
reg_decoded_bbox=True,
|
81 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
82 |
+
loss_cls=dict(
|
83 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
84 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
|
85 |
+
]))
|
86 |
+
|
87 |
+
img_norm_cfg = dict(
|
88 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
89 |
+
|
90 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
91 |
+
train_pipeline = [
|
92 |
+
dict(type='LoadImageFromFile'),
|
93 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
94 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
95 |
+
dict(type='AutoAugment',
|
96 |
+
policies=[
|
97 |
+
[
|
98 |
+
dict(type='Resize',
|
99 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
100 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
101 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
102 |
+
multiscale_mode='value',
|
103 |
+
keep_ratio=True)
|
104 |
+
],
|
105 |
+
[
|
106 |
+
dict(type='Resize',
|
107 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
108 |
+
multiscale_mode='value',
|
109 |
+
keep_ratio=True),
|
110 |
+
dict(type='RandomCrop',
|
111 |
+
crop_type='absolute_range',
|
112 |
+
crop_size=(384, 600),
|
113 |
+
allow_negative_crop=True),
|
114 |
+
dict(type='Resize',
|
115 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
116 |
+
(576, 1333), (608, 1333), (640, 1333),
|
117 |
+
(672, 1333), (704, 1333), (736, 1333),
|
118 |
+
(768, 1333), (800, 1333)],
|
119 |
+
multiscale_mode='value',
|
120 |
+
override=True,
|
121 |
+
keep_ratio=True)
|
122 |
+
]
|
123 |
+
]),
|
124 |
+
dict(type='Normalize', **img_norm_cfg),
|
125 |
+
dict(type='Pad', size_divisor=32),
|
126 |
+
dict(type='DefaultFormatBundle'),
|
127 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
128 |
+
]
|
129 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
130 |
+
|
131 |
+
|
132 |
+
optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
|
133 |
+
lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
|
134 |
+
paramwise_cfg={'decay_rate': 0.9,
|
135 |
+
'decay_type': 'layer_wise',
|
136 |
+
'layers': [1, 1, 3, 2],
|
137 |
+
'num_subnet': 16})
|
138 |
+
|
139 |
+
lr_config = dict(step=[27, 33])
|
140 |
+
runner = dict(type='EpochBasedRunner', max_epochs=36)
|
141 |
+
|
142 |
+
|
143 |
+
# fp16 = None
|
144 |
+
# optimizer_config = dict(
|
145 |
+
# type="DistOptimizerHook",
|
146 |
+
# update_interval=1,
|
147 |
+
# grad_clip=None,
|
148 |
+
# coalesce=True,
|
149 |
+
# bucket_size_mb=-1,
|
150 |
+
# use_fp16=True,
|
151 |
+
# )
|
152 |
+
fp16 = dict(loss_scale='dynamic')
|
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_base_3x_in22k.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
_base_ = [
|
10 |
+
'../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
|
11 |
+
'../_base_/datasets/coco_instance.py',
|
12 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
13 |
+
]
|
14 |
+
pretrained = './cls_model/revcol_base_22k.pth'
|
15 |
+
model = dict(
|
16 |
+
backbone=dict(
|
17 |
+
channels = [72, 144, 288, 576],
|
18 |
+
layers=[1, 1, 3, 2],
|
19 |
+
num_subnet=16,
|
20 |
+
drop_path = 0.4,
|
21 |
+
save_memory=False,
|
22 |
+
out_indices=[0, 1, 2, 3],
|
23 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)
|
24 |
+
),
|
25 |
+
neck=dict(in_channels=[72, 144, 288, 576]),
|
26 |
+
roi_head=dict(
|
27 |
+
bbox_head=[
|
28 |
+
dict(
|
29 |
+
type='ConvFCBBoxHead',
|
30 |
+
num_shared_convs=4,
|
31 |
+
num_shared_fcs=1,
|
32 |
+
in_channels=256,
|
33 |
+
conv_out_channels=256,
|
34 |
+
fc_out_channels=1024,
|
35 |
+
roi_feat_size=7,
|
36 |
+
num_classes=80,
|
37 |
+
bbox_coder=dict(
|
38 |
+
type='DeltaXYWHBBoxCoder',
|
39 |
+
target_means=[0., 0., 0., 0.],
|
40 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
41 |
+
reg_class_agnostic=False,
|
42 |
+
reg_decoded_bbox=True,
|
43 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
44 |
+
loss_cls=dict(
|
45 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
46 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
47 |
+
dict(
|
48 |
+
type='ConvFCBBoxHead',
|
49 |
+
num_shared_convs=4,
|
50 |
+
num_shared_fcs=1,
|
51 |
+
in_channels=256,
|
52 |
+
conv_out_channels=256,
|
53 |
+
fc_out_channels=1024,
|
54 |
+
roi_feat_size=7,
|
55 |
+
num_classes=80,
|
56 |
+
bbox_coder=dict(
|
57 |
+
type='DeltaXYWHBBoxCoder',
|
58 |
+
target_means=[0., 0., 0., 0.],
|
59 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
60 |
+
reg_class_agnostic=False,
|
61 |
+
reg_decoded_bbox=True,
|
62 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
63 |
+
loss_cls=dict(
|
64 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
65 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
66 |
+
dict(
|
67 |
+
type='ConvFCBBoxHead',
|
68 |
+
num_shared_convs=4,
|
69 |
+
num_shared_fcs=1,
|
70 |
+
in_channels=256,
|
71 |
+
conv_out_channels=256,
|
72 |
+
fc_out_channels=1024,
|
73 |
+
roi_feat_size=7,
|
74 |
+
num_classes=80,
|
75 |
+
bbox_coder=dict(
|
76 |
+
type='DeltaXYWHBBoxCoder',
|
77 |
+
target_means=[0., 0., 0., 0.],
|
78 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
79 |
+
reg_class_agnostic=False,
|
80 |
+
reg_decoded_bbox=True,
|
81 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
82 |
+
loss_cls=dict(
|
83 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
84 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
|
85 |
+
]))
|
86 |
+
|
87 |
+
img_norm_cfg = dict(
|
88 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
89 |
+
|
90 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
91 |
+
train_pipeline = [
|
92 |
+
dict(type='LoadImageFromFile'),
|
93 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
94 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
95 |
+
dict(type='AutoAugment',
|
96 |
+
policies=[
|
97 |
+
[
|
98 |
+
dict(type='Resize',
|
99 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
100 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
101 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
102 |
+
multiscale_mode='value',
|
103 |
+
keep_ratio=True)
|
104 |
+
],
|
105 |
+
[
|
106 |
+
dict(type='Resize',
|
107 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
108 |
+
multiscale_mode='value',
|
109 |
+
keep_ratio=True),
|
110 |
+
dict(type='RandomCrop',
|
111 |
+
crop_type='absolute_range',
|
112 |
+
crop_size=(384, 600),
|
113 |
+
allow_negative_crop=True),
|
114 |
+
dict(type='Resize',
|
115 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
116 |
+
(576, 1333), (608, 1333), (640, 1333),
|
117 |
+
(672, 1333), (704, 1333), (736, 1333),
|
118 |
+
(768, 1333), (800, 1333)],
|
119 |
+
multiscale_mode='value',
|
120 |
+
override=True,
|
121 |
+
keep_ratio=True)
|
122 |
+
]
|
123 |
+
]),
|
124 |
+
dict(type='Normalize', **img_norm_cfg),
|
125 |
+
dict(type='Pad', size_divisor=32),
|
126 |
+
dict(type='DefaultFormatBundle'),
|
127 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
128 |
+
]
|
129 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
130 |
+
|
131 |
+
|
132 |
+
optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
|
133 |
+
lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
|
134 |
+
paramwise_cfg={'decay_rate': 0.9,
|
135 |
+
'decay_type': 'layer_wise',
|
136 |
+
'layers': [1, 1, 3, 2],
|
137 |
+
'num_subnet': 16})
|
138 |
+
|
139 |
+
lr_config = dict(step=[27, 33])
|
140 |
+
runner = dict(type='EpochBasedRunner', max_epochs=36)
|
141 |
+
|
142 |
+
|
143 |
+
# fp16 = None
|
144 |
+
# optimizer_config = dict(
|
145 |
+
# type="DistOptimizerHook",
|
146 |
+
# update_interval=1,
|
147 |
+
# grad_clip=None,
|
148 |
+
# coalesce=True,
|
149 |
+
# bucket_size_mb=-1,
|
150 |
+
# use_fp16=True,
|
151 |
+
# )
|
152 |
+
fp16 = dict(loss_scale='dynamic')
|
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_large_3x_in22k.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
_base_ = [
|
10 |
+
'../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
|
11 |
+
'../_base_/datasets/coco_instance.py',
|
12 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
13 |
+
]
|
14 |
+
pretrained = './cls_model/revcol_large_22k.pth'
|
15 |
+
model = dict(
|
16 |
+
backbone=dict(
|
17 |
+
channels = [128, 256, 512, 1024],
|
18 |
+
layers=[1, 2, 6, 2],
|
19 |
+
num_subnet=8,
|
20 |
+
drop_path = 0.5,
|
21 |
+
save_memory=False,
|
22 |
+
out_indices=[0, 1, 2, 3],
|
23 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)
|
24 |
+
),
|
25 |
+
neck=dict(in_channels=[128, 256, 512, 1024]),
|
26 |
+
roi_head=dict(
|
27 |
+
bbox_head=[
|
28 |
+
dict(
|
29 |
+
type='ConvFCBBoxHead',
|
30 |
+
num_shared_convs=4,
|
31 |
+
num_shared_fcs=1,
|
32 |
+
in_channels=256,
|
33 |
+
conv_out_channels=256,
|
34 |
+
fc_out_channels=1024,
|
35 |
+
roi_feat_size=7,
|
36 |
+
num_classes=80,
|
37 |
+
bbox_coder=dict(
|
38 |
+
type='DeltaXYWHBBoxCoder',
|
39 |
+
target_means=[0., 0., 0., 0.],
|
40 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
41 |
+
reg_class_agnostic=False,
|
42 |
+
reg_decoded_bbox=True,
|
43 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
44 |
+
loss_cls=dict(
|
45 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
46 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
47 |
+
dict(
|
48 |
+
type='ConvFCBBoxHead',
|
49 |
+
num_shared_convs=4,
|
50 |
+
num_shared_fcs=1,
|
51 |
+
in_channels=256,
|
52 |
+
conv_out_channels=256,
|
53 |
+
fc_out_channels=1024,
|
54 |
+
roi_feat_size=7,
|
55 |
+
num_classes=80,
|
56 |
+
bbox_coder=dict(
|
57 |
+
type='DeltaXYWHBBoxCoder',
|
58 |
+
target_means=[0., 0., 0., 0.],
|
59 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
60 |
+
reg_class_agnostic=False,
|
61 |
+
reg_decoded_bbox=True,
|
62 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
63 |
+
loss_cls=dict(
|
64 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
65 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
66 |
+
dict(
|
67 |
+
type='ConvFCBBoxHead',
|
68 |
+
num_shared_convs=4,
|
69 |
+
num_shared_fcs=1,
|
70 |
+
in_channels=256,
|
71 |
+
conv_out_channels=256,
|
72 |
+
fc_out_channels=1024,
|
73 |
+
roi_feat_size=7,
|
74 |
+
num_classes=80,
|
75 |
+
bbox_coder=dict(
|
76 |
+
type='DeltaXYWHBBoxCoder',
|
77 |
+
target_means=[0., 0., 0., 0.],
|
78 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
79 |
+
reg_class_agnostic=False,
|
80 |
+
reg_decoded_bbox=True,
|
81 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
82 |
+
loss_cls=dict(
|
83 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
84 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
|
85 |
+
]))
|
86 |
+
|
87 |
+
img_norm_cfg = dict(
|
88 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
89 |
+
|
90 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
91 |
+
train_pipeline = [
|
92 |
+
dict(type='LoadImageFromFile'),
|
93 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
94 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
95 |
+
dict(type='AutoAugment',
|
96 |
+
policies=[
|
97 |
+
[
|
98 |
+
dict(type='Resize',
|
99 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
100 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
101 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
102 |
+
multiscale_mode='value',
|
103 |
+
keep_ratio=True)
|
104 |
+
],
|
105 |
+
[
|
106 |
+
dict(type='Resize',
|
107 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
108 |
+
multiscale_mode='value',
|
109 |
+
keep_ratio=True),
|
110 |
+
dict(type='RandomCrop',
|
111 |
+
crop_type='absolute_range',
|
112 |
+
crop_size=(384, 600),
|
113 |
+
allow_negative_crop=True),
|
114 |
+
dict(type='Resize',
|
115 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
116 |
+
(576, 1333), (608, 1333), (640, 1333),
|
117 |
+
(672, 1333), (704, 1333), (736, 1333),
|
118 |
+
(768, 1333), (800, 1333)],
|
119 |
+
multiscale_mode='value',
|
120 |
+
override=True,
|
121 |
+
keep_ratio=True)
|
122 |
+
]
|
123 |
+
]),
|
124 |
+
dict(type='Normalize', **img_norm_cfg),
|
125 |
+
dict(type='Pad', size_divisor=32),
|
126 |
+
dict(type='DefaultFormatBundle'),
|
127 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
128 |
+
]
|
129 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
130 |
+
|
131 |
+
|
132 |
+
optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
|
133 |
+
lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
|
134 |
+
paramwise_cfg={'decay_rate': 0.85,
|
135 |
+
'decay_type': 'layer_wise',
|
136 |
+
'layers': [1, 2, 6, 2],
|
137 |
+
'num_subnet': 8})
|
138 |
+
|
139 |
+
lr_config = dict(step=[27, 33])
|
140 |
+
runner = dict(type='EpochBasedRunner', max_epochs=36)
|
141 |
+
|
142 |
+
|
143 |
+
# fp16 = None
|
144 |
+
# optimizer_config = dict(
|
145 |
+
# type="DistOptimizerHook",
|
146 |
+
# update_interval=1,
|
147 |
+
# grad_clip=None,
|
148 |
+
# coalesce=True,
|
149 |
+
# bucket_size_mb=-1,
|
150 |
+
# use_fp16=True,
|
151 |
+
# )
|
152 |
+
fp16 = dict(loss_scale='dynamic')
|
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_small_3x_in1k.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
_base_ = [
|
10 |
+
'../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
|
11 |
+
'../_base_/datasets/coco_instance.py',
|
12 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
13 |
+
]
|
14 |
+
pretrained = './cls_model/revcol_small_1k.pth'
|
15 |
+
model = dict(
|
16 |
+
backbone=dict(
|
17 |
+
channels=[64, 128, 256, 512],
|
18 |
+
layers=[2, 2, 4, 2],
|
19 |
+
num_subnet=8,
|
20 |
+
drop_path = 0.4,
|
21 |
+
save_memory=False,
|
22 |
+
out_indices=[0, 1, 2, 3],
|
23 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)
|
24 |
+
),
|
25 |
+
neck=dict(in_channels=[64, 128, 256, 512]),
|
26 |
+
roi_head=dict(
|
27 |
+
bbox_head=[
|
28 |
+
dict(
|
29 |
+
type='ConvFCBBoxHead',
|
30 |
+
num_shared_convs=4,
|
31 |
+
num_shared_fcs=1,
|
32 |
+
in_channels=256,
|
33 |
+
conv_out_channels=256,
|
34 |
+
fc_out_channels=1024,
|
35 |
+
roi_feat_size=7,
|
36 |
+
num_classes=80,
|
37 |
+
bbox_coder=dict(
|
38 |
+
type='DeltaXYWHBBoxCoder',
|
39 |
+
target_means=[0., 0., 0., 0.],
|
40 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
41 |
+
reg_class_agnostic=False,
|
42 |
+
reg_decoded_bbox=True,
|
43 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
44 |
+
loss_cls=dict(
|
45 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
46 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
47 |
+
dict(
|
48 |
+
type='ConvFCBBoxHead',
|
49 |
+
num_shared_convs=4,
|
50 |
+
num_shared_fcs=1,
|
51 |
+
in_channels=256,
|
52 |
+
conv_out_channels=256,
|
53 |
+
fc_out_channels=1024,
|
54 |
+
roi_feat_size=7,
|
55 |
+
num_classes=80,
|
56 |
+
bbox_coder=dict(
|
57 |
+
type='DeltaXYWHBBoxCoder',
|
58 |
+
target_means=[0., 0., 0., 0.],
|
59 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
60 |
+
reg_class_agnostic=False,
|
61 |
+
reg_decoded_bbox=True,
|
62 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
63 |
+
loss_cls=dict(
|
64 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
65 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
66 |
+
dict(
|
67 |
+
type='ConvFCBBoxHead',
|
68 |
+
num_shared_convs=4,
|
69 |
+
num_shared_fcs=1,
|
70 |
+
in_channels=256,
|
71 |
+
conv_out_channels=256,
|
72 |
+
fc_out_channels=1024,
|
73 |
+
roi_feat_size=7,
|
74 |
+
num_classes=80,
|
75 |
+
bbox_coder=dict(
|
76 |
+
type='DeltaXYWHBBoxCoder',
|
77 |
+
target_means=[0., 0., 0., 0.],
|
78 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
79 |
+
reg_class_agnostic=False,
|
80 |
+
reg_decoded_bbox=True,
|
81 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
82 |
+
loss_cls=dict(
|
83 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
84 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
|
85 |
+
]))
|
86 |
+
|
87 |
+
img_norm_cfg = dict(
|
88 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
89 |
+
|
90 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
91 |
+
train_pipeline = [
|
92 |
+
dict(type='LoadImageFromFile'),
|
93 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
94 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
95 |
+
dict(type='AutoAugment',
|
96 |
+
policies=[
|
97 |
+
[
|
98 |
+
dict(type='Resize',
|
99 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
100 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
101 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
102 |
+
multiscale_mode='value',
|
103 |
+
keep_ratio=True)
|
104 |
+
],
|
105 |
+
[
|
106 |
+
dict(type='Resize',
|
107 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
108 |
+
multiscale_mode='value',
|
109 |
+
keep_ratio=True),
|
110 |
+
dict(type='RandomCrop',
|
111 |
+
crop_type='absolute_range',
|
112 |
+
crop_size=(384, 600),
|
113 |
+
allow_negative_crop=True),
|
114 |
+
dict(type='Resize',
|
115 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
116 |
+
(576, 1333), (608, 1333), (640, 1333),
|
117 |
+
(672, 1333), (704, 1333), (736, 1333),
|
118 |
+
(768, 1333), (800, 1333)],
|
119 |
+
multiscale_mode='value',
|
120 |
+
override=True,
|
121 |
+
keep_ratio=True)
|
122 |
+
]
|
123 |
+
]),
|
124 |
+
dict(type='Normalize', **img_norm_cfg),
|
125 |
+
dict(type='Pad', size_divisor=32),
|
126 |
+
dict(type='DefaultFormatBundle'),
|
127 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
128 |
+
]
|
129 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
130 |
+
|
131 |
+
|
132 |
+
optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
|
133 |
+
lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
|
134 |
+
paramwise_cfg={'decay_rate': 0.85,
|
135 |
+
'decay_type': 'layer_wise',
|
136 |
+
'layers': [2, 2, 4, 2],
|
137 |
+
'num_subnet': 8})
|
138 |
+
|
139 |
+
lr_config = dict(step=[27, 33])
|
140 |
+
runner = dict(type='EpochBasedRunner', max_epochs=36)
|
141 |
+
|
142 |
+
|
143 |
+
# fp16 = None
|
144 |
+
# optimizer_config = dict(
|
145 |
+
# type="DistOptimizerHook",
|
146 |
+
# update_interval=1,
|
147 |
+
# grad_clip=None,
|
148 |
+
# coalesce=True,
|
149 |
+
# bucket_size_mb=-1,
|
150 |
+
# use_fp16=True,
|
151 |
+
# )
|
152 |
+
fp16 = dict(loss_scale='dynamic')
|
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_tiny_3x_in1k.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
_base_ = [
|
10 |
+
'../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
|
11 |
+
'../_base_/datasets/coco_instance.py',
|
12 |
+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
|
13 |
+
]
|
14 |
+
pretrained = './cls_model/revcol_tiny_1k.pth'
|
15 |
+
model = dict(
|
16 |
+
backbone=dict(
|
17 |
+
channels=[64, 128, 256, 512],
|
18 |
+
layers=[2, 2, 4, 2],
|
19 |
+
num_subnet=4,
|
20 |
+
drop_path = 0.3,
|
21 |
+
save_memory=False,
|
22 |
+
out_indices=[0, 1, 2, 3],
|
23 |
+
init_cfg=dict(type='Pretrained', checkpoint=pretrained)
|
24 |
+
),
|
25 |
+
neck=dict(in_channels=[64, 128, 256, 512]),
|
26 |
+
roi_head=dict(
|
27 |
+
bbox_head=[
|
28 |
+
dict(
|
29 |
+
type='ConvFCBBoxHead',
|
30 |
+
num_shared_convs=4,
|
31 |
+
num_shared_fcs=1,
|
32 |
+
in_channels=256,
|
33 |
+
conv_out_channels=256,
|
34 |
+
fc_out_channels=1024,
|
35 |
+
roi_feat_size=7,
|
36 |
+
num_classes=80,
|
37 |
+
bbox_coder=dict(
|
38 |
+
type='DeltaXYWHBBoxCoder',
|
39 |
+
target_means=[0., 0., 0., 0.],
|
40 |
+
target_stds=[0.1, 0.1, 0.2, 0.2]),
|
41 |
+
reg_class_agnostic=False,
|
42 |
+
reg_decoded_bbox=True,
|
43 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
44 |
+
loss_cls=dict(
|
45 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
46 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
47 |
+
dict(
|
48 |
+
type='ConvFCBBoxHead',
|
49 |
+
num_shared_convs=4,
|
50 |
+
num_shared_fcs=1,
|
51 |
+
in_channels=256,
|
52 |
+
conv_out_channels=256,
|
53 |
+
fc_out_channels=1024,
|
54 |
+
roi_feat_size=7,
|
55 |
+
num_classes=80,
|
56 |
+
bbox_coder=dict(
|
57 |
+
type='DeltaXYWHBBoxCoder',
|
58 |
+
target_means=[0., 0., 0., 0.],
|
59 |
+
target_stds=[0.05, 0.05, 0.1, 0.1]),
|
60 |
+
reg_class_agnostic=False,
|
61 |
+
reg_decoded_bbox=True,
|
62 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
63 |
+
loss_cls=dict(
|
64 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
65 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
|
66 |
+
dict(
|
67 |
+
type='ConvFCBBoxHead',
|
68 |
+
num_shared_convs=4,
|
69 |
+
num_shared_fcs=1,
|
70 |
+
in_channels=256,
|
71 |
+
conv_out_channels=256,
|
72 |
+
fc_out_channels=1024,
|
73 |
+
roi_feat_size=7,
|
74 |
+
num_classes=80,
|
75 |
+
bbox_coder=dict(
|
76 |
+
type='DeltaXYWHBBoxCoder',
|
77 |
+
target_means=[0., 0., 0., 0.],
|
78 |
+
target_stds=[0.033, 0.033, 0.067, 0.067]),
|
79 |
+
reg_class_agnostic=False,
|
80 |
+
reg_decoded_bbox=True,
|
81 |
+
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
82 |
+
loss_cls=dict(
|
83 |
+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
|
84 |
+
loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
|
85 |
+
]))
|
86 |
+
|
87 |
+
img_norm_cfg = dict(
|
88 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
89 |
+
|
90 |
+
# augmentation strategy originates from DETR / Sparse RCNN
|
91 |
+
train_pipeline = [
|
92 |
+
dict(type='LoadImageFromFile'),
|
93 |
+
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
|
94 |
+
dict(type='RandomFlip', flip_ratio=0.5),
|
95 |
+
dict(type='AutoAugment',
|
96 |
+
policies=[
|
97 |
+
[
|
98 |
+
dict(type='Resize',
|
99 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
100 |
+
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
101 |
+
(736, 1333), (768, 1333), (800, 1333)],
|
102 |
+
multiscale_mode='value',
|
103 |
+
keep_ratio=True)
|
104 |
+
],
|
105 |
+
[
|
106 |
+
dict(type='Resize',
|
107 |
+
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
108 |
+
multiscale_mode='value',
|
109 |
+
keep_ratio=True),
|
110 |
+
dict(type='RandomCrop',
|
111 |
+
crop_type='absolute_range',
|
112 |
+
crop_size=(384, 600),
|
113 |
+
allow_negative_crop=True),
|
114 |
+
dict(type='Resize',
|
115 |
+
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
116 |
+
(576, 1333), (608, 1333), (640, 1333),
|
117 |
+
(672, 1333), (704, 1333), (736, 1333),
|
118 |
+
(768, 1333), (800, 1333)],
|
119 |
+
multiscale_mode='value',
|
120 |
+
override=True,
|
121 |
+
keep_ratio=True)
|
122 |
+
]
|
123 |
+
]),
|
124 |
+
dict(type='Normalize', **img_norm_cfg),
|
125 |
+
dict(type='Pad', size_divisor=32),
|
126 |
+
dict(type='DefaultFormatBundle'),
|
127 |
+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
|
128 |
+
]
|
129 |
+
data = dict(train=dict(pipeline=train_pipeline))
|
130 |
+
|
131 |
+
|
132 |
+
optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
|
133 |
+
lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
|
134 |
+
paramwise_cfg={'decay_rate': 0.85,
|
135 |
+
'decay_type': 'layer_wise',
|
136 |
+
'layers': [2, 2, 4, 2],
|
137 |
+
'num_subnet': 4})
|
138 |
+
|
139 |
+
lr_config = dict(step=[27, 33])
|
140 |
+
runner = dict(type='EpochBasedRunner', max_epochs=36)
|
141 |
+
|
142 |
+
|
143 |
+
# fp16 = None
|
144 |
+
# optimizer_config = dict(
|
145 |
+
# type="DistOptimizerHook",
|
146 |
+
# update_interval=1,
|
147 |
+
# grad_clip=None,
|
148 |
+
# coalesce=True,
|
149 |
+
# bucket_size_mb=-1,
|
150 |
+
# use_fp16=True,
|
151 |
+
# )
|
152 |
+
fp16 = dict(loss_scale='dynamic')
|
training/Detection/mmcv_custom/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
# -*- coding: utf-8 -*-
|
10 |
+
|
11 |
+
from .checkpoint import load_checkpoint
|
12 |
+
from .layer_decay_optimizer_constructor import LearningRateDecayOptimizerConstructor
|
13 |
+
from .customized_text import CustomizedTextLoggerHook
|
14 |
+
|
15 |
+
__all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'CustomizedTextLoggerHook']
|
training/Detection/mmcv_custom/checkpoint.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import pkgutil
|
6 |
+
import time
|
7 |
+
import warnings
|
8 |
+
from collections import OrderedDict
|
9 |
+
from importlib import import_module
|
10 |
+
from tempfile import TemporaryDirectory
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torchvision
|
14 |
+
from torch.optim import Optimizer
|
15 |
+
from torch.utils import model_zoo
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
import mmcv
|
19 |
+
from mmcv.fileio import FileClient
|
20 |
+
from mmcv.fileio import load as load_file
|
21 |
+
from mmcv.parallel import is_module_wrapper
|
22 |
+
from mmcv.utils import mkdir_or_exist
|
23 |
+
from mmcv.runner import get_dist_info
|
24 |
+
|
25 |
+
ENV_MMCV_HOME = 'MMCV_HOME'
|
26 |
+
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
27 |
+
DEFAULT_CACHE_DIR = '~/.cache'
|
28 |
+
|
29 |
+
|
30 |
+
def _get_mmcv_home():
|
31 |
+
mmcv_home = os.path.expanduser(
|
32 |
+
os.getenv(
|
33 |
+
ENV_MMCV_HOME,
|
34 |
+
os.path.join(
|
35 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
36 |
+
|
37 |
+
mkdir_or_exist(mmcv_home)
|
38 |
+
return mmcv_home
|
39 |
+
|
40 |
+
|
41 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
42 |
+
"""Load state_dict to a module.
|
43 |
+
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
44 |
+
Default value for ``strict`` is set to ``False`` and the message for
|
45 |
+
param mismatch will be shown even if strict is False.
|
46 |
+
Args:
|
47 |
+
module (Module): Module that receives the state_dict.
|
48 |
+
state_dict (OrderedDict): Weights.
|
49 |
+
strict (bool): whether to strictly enforce that the keys
|
50 |
+
in :attr:`state_dict` match the keys returned by this module's
|
51 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
52 |
+
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
53 |
+
message. If not specified, print function will be used.
|
54 |
+
"""
|
55 |
+
unexpected_keys = []
|
56 |
+
all_missing_keys = []
|
57 |
+
err_msg = []
|
58 |
+
|
59 |
+
metadata = getattr(state_dict, '_metadata', None)
|
60 |
+
state_dict = state_dict.copy()
|
61 |
+
if metadata is not None:
|
62 |
+
state_dict._metadata = metadata
|
63 |
+
|
64 |
+
# use _load_from_state_dict to enable checkpoint version control
|
65 |
+
def load(module, prefix=''):
|
66 |
+
# recursively check parallel module in case that the model has a
|
67 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
68 |
+
if is_module_wrapper(module):
|
69 |
+
module = module.module
|
70 |
+
local_metadata = {} if metadata is None else metadata.get(
|
71 |
+
prefix[:-1], {})
|
72 |
+
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
73 |
+
all_missing_keys, unexpected_keys,
|
74 |
+
err_msg)
|
75 |
+
for name, child in module._modules.items():
|
76 |
+
if child is not None:
|
77 |
+
load(child, prefix + name + '.')
|
78 |
+
|
79 |
+
load(module)
|
80 |
+
load = None # break load->load reference cycle
|
81 |
+
|
82 |
+
# ignore "num_batches_tracked" of BN layers
|
83 |
+
missing_keys = [
|
84 |
+
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
85 |
+
]
|
86 |
+
|
87 |
+
if unexpected_keys:
|
88 |
+
err_msg.append('unexpected key in source '
|
89 |
+
f'state_dict: {", ".join(unexpected_keys)}\n')
|
90 |
+
if missing_keys:
|
91 |
+
err_msg.append(
|
92 |
+
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
93 |
+
|
94 |
+
rank, _ = get_dist_info()
|
95 |
+
if len(err_msg) > 0 and rank == 0:
|
96 |
+
err_msg.insert(
|
97 |
+
0, 'The model and loaded state dict do not match exactly\n')
|
98 |
+
err_msg = '\n'.join(err_msg)
|
99 |
+
if strict:
|
100 |
+
raise RuntimeError(err_msg)
|
101 |
+
elif logger is not None:
|
102 |
+
logger.warning(err_msg)
|
103 |
+
else:
|
104 |
+
print(err_msg)
|
105 |
+
|
106 |
+
|
107 |
+
def load_url_dist(url, model_dir=None):
|
108 |
+
"""In distributed setting, this function only download checkpoint at local
|
109 |
+
rank 0."""
|
110 |
+
rank, world_size = get_dist_info()
|
111 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
112 |
+
if rank == 0:
|
113 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
114 |
+
if world_size > 1:
|
115 |
+
torch.distributed.barrier()
|
116 |
+
if rank > 0:
|
117 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
118 |
+
return checkpoint
|
119 |
+
|
120 |
+
|
121 |
+
def load_pavimodel_dist(model_path, map_location=None):
|
122 |
+
"""In distributed setting, this function only download checkpoint at local
|
123 |
+
rank 0."""
|
124 |
+
try:
|
125 |
+
from pavi import modelcloud
|
126 |
+
except ImportError:
|
127 |
+
raise ImportError(
|
128 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
129 |
+
rank, world_size = get_dist_info()
|
130 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
131 |
+
if rank == 0:
|
132 |
+
model = modelcloud.get(model_path)
|
133 |
+
with TemporaryDirectory() as tmp_dir:
|
134 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
135 |
+
model.download(downloaded_file)
|
136 |
+
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
137 |
+
if world_size > 1:
|
138 |
+
torch.distributed.barrier()
|
139 |
+
if rank > 0:
|
140 |
+
model = modelcloud.get(model_path)
|
141 |
+
with TemporaryDirectory() as tmp_dir:
|
142 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
143 |
+
model.download(downloaded_file)
|
144 |
+
checkpoint = torch.load(
|
145 |
+
downloaded_file, map_location=map_location)
|
146 |
+
return checkpoint
|
147 |
+
|
148 |
+
|
149 |
+
def load_fileclient_dist(filename, backend, map_location):
|
150 |
+
"""In distributed setting, this function only download checkpoint at local
|
151 |
+
rank 0."""
|
152 |
+
rank, world_size = get_dist_info()
|
153 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
154 |
+
allowed_backends = ['ceph']
|
155 |
+
if backend not in allowed_backends:
|
156 |
+
raise ValueError(f'Load from Backend {backend} is not supported.')
|
157 |
+
if rank == 0:
|
158 |
+
fileclient = FileClient(backend=backend)
|
159 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
160 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
161 |
+
if world_size > 1:
|
162 |
+
torch.distributed.barrier()
|
163 |
+
if rank > 0:
|
164 |
+
fileclient = FileClient(backend=backend)
|
165 |
+
buffer = io.BytesIO(fileclient.get(filename))
|
166 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
167 |
+
return checkpoint
|
168 |
+
|
169 |
+
|
170 |
+
def get_torchvision_models():
|
171 |
+
model_urls = dict()
|
172 |
+
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
173 |
+
if ispkg:
|
174 |
+
continue
|
175 |
+
_zoo = import_module(f'torchvision.models.{name}')
|
176 |
+
if hasattr(_zoo, 'model_urls'):
|
177 |
+
_urls = getattr(_zoo, 'model_urls')
|
178 |
+
model_urls.update(_urls)
|
179 |
+
return model_urls
|
180 |
+
|
181 |
+
|
182 |
+
def get_external_models():
|
183 |
+
mmcv_home = _get_mmcv_home()
|
184 |
+
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
185 |
+
default_urls = load_file(default_json_path)
|
186 |
+
assert isinstance(default_urls, dict)
|
187 |
+
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
188 |
+
if osp.exists(external_json_path):
|
189 |
+
external_urls = load_file(external_json_path)
|
190 |
+
assert isinstance(external_urls, dict)
|
191 |
+
default_urls.update(external_urls)
|
192 |
+
|
193 |
+
return default_urls
|
194 |
+
|
195 |
+
|
196 |
+
def get_mmcls_models():
|
197 |
+
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
198 |
+
mmcls_urls = load_file(mmcls_json_path)
|
199 |
+
|
200 |
+
return mmcls_urls
|
201 |
+
|
202 |
+
|
203 |
+
def get_deprecated_model_names():
|
204 |
+
deprecate_json_path = osp.join(mmcv.__path__[0],
|
205 |
+
'model_zoo/deprecated.json')
|
206 |
+
deprecate_urls = load_file(deprecate_json_path)
|
207 |
+
assert isinstance(deprecate_urls, dict)
|
208 |
+
|
209 |
+
return deprecate_urls
|
210 |
+
|
211 |
+
|
212 |
+
def _process_mmcls_checkpoint(checkpoint):
|
213 |
+
state_dict = checkpoint['state_dict']
|
214 |
+
new_state_dict = OrderedDict()
|
215 |
+
for k, v in state_dict.items():
|
216 |
+
if k.startswith('backbone.'):
|
217 |
+
new_state_dict[k[9:]] = v
|
218 |
+
new_checkpoint = dict(state_dict=new_state_dict)
|
219 |
+
|
220 |
+
return new_checkpoint
|
221 |
+
|
222 |
+
|
223 |
+
def _load_checkpoint(filename, map_location=None):
|
224 |
+
"""Load checkpoint from somewhere (modelzoo, file, url).
|
225 |
+
Args:
|
226 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
227 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
228 |
+
details.
|
229 |
+
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
230 |
+
Returns:
|
231 |
+
dict | OrderedDict: The loaded checkpoint. It can be either an
|
232 |
+
OrderedDict storing model weights or a dict containing other
|
233 |
+
information, which depends on the checkpoint.
|
234 |
+
"""
|
235 |
+
if filename.startswith('modelzoo://'):
|
236 |
+
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
237 |
+
'use "torchvision://" instead')
|
238 |
+
model_urls = get_torchvision_models()
|
239 |
+
model_name = filename[11:]
|
240 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
241 |
+
elif filename.startswith('torchvision://'):
|
242 |
+
model_urls = get_torchvision_models()
|
243 |
+
model_name = filename[14:]
|
244 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
245 |
+
elif filename.startswith('open-mmlab://'):
|
246 |
+
model_urls = get_external_models()
|
247 |
+
model_name = filename[13:]
|
248 |
+
deprecated_urls = get_deprecated_model_names()
|
249 |
+
if model_name in deprecated_urls:
|
250 |
+
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
|
251 |
+
f'of open-mmlab://{deprecated_urls[model_name]}')
|
252 |
+
model_name = deprecated_urls[model_name]
|
253 |
+
model_url = model_urls[model_name]
|
254 |
+
# check if is url
|
255 |
+
if model_url.startswith(('http://', 'https://')):
|
256 |
+
checkpoint = load_url_dist(model_url)
|
257 |
+
else:
|
258 |
+
filename = osp.join(_get_mmcv_home(), model_url)
|
259 |
+
if not osp.isfile(filename):
|
260 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
261 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
262 |
+
elif filename.startswith('mmcls://'):
|
263 |
+
model_urls = get_mmcls_models()
|
264 |
+
model_name = filename[8:]
|
265 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
266 |
+
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
267 |
+
elif filename.startswith(('http://', 'https://')):
|
268 |
+
checkpoint = load_url_dist(filename)
|
269 |
+
elif filename.startswith('pavi://'):
|
270 |
+
model_path = filename[7:]
|
271 |
+
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
|
272 |
+
elif filename.startswith('s3://'):
|
273 |
+
checkpoint = load_fileclient_dist(
|
274 |
+
filename, backend='ceph', map_location=map_location)
|
275 |
+
else:
|
276 |
+
if not osp.isfile(filename):
|
277 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
278 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
279 |
+
return checkpoint
|
280 |
+
|
281 |
+
|
282 |
+
def load_checkpoint(model,
|
283 |
+
filename,
|
284 |
+
map_location='cpu',
|
285 |
+
strict=False,
|
286 |
+
logger=None):
|
287 |
+
"""Load checkpoint from a file or URI.
|
288 |
+
Args:
|
289 |
+
model (Module): Module to load checkpoint.
|
290 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
291 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
292 |
+
details.
|
293 |
+
map_location (str): Same as :func:`torch.load`.
|
294 |
+
strict (bool): Whether to allow different params for the model and
|
295 |
+
checkpoint.
|
296 |
+
logger (:mod:`logging.Logger` or None): The logger for error message.
|
297 |
+
Returns:
|
298 |
+
dict or OrderedDict: The loaded checkpoint.
|
299 |
+
"""
|
300 |
+
checkpoint = _load_checkpoint(filename, map_location)
|
301 |
+
# OrderedDict is a subclass of dict
|
302 |
+
if not isinstance(checkpoint, dict):
|
303 |
+
raise RuntimeError(
|
304 |
+
f'No state_dict found in checkpoint file {filename}')
|
305 |
+
# get state_dict from checkpoint
|
306 |
+
if 'state_dict' in checkpoint:
|
307 |
+
state_dict = checkpoint['state_dict']
|
308 |
+
elif 'model' in checkpoint:
|
309 |
+
state_dict = checkpoint['model']
|
310 |
+
else:
|
311 |
+
state_dict = checkpoint
|
312 |
+
# strip prefix of state_dict
|
313 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
314 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
315 |
+
|
316 |
+
# for MoBY, load model of online branch
|
317 |
+
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
318 |
+
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
319 |
+
|
320 |
+
# reshape absolute position embedding
|
321 |
+
if state_dict.get('absolute_pos_embed') is not None:
|
322 |
+
absolute_pos_embed = state_dict['absolute_pos_embed']
|
323 |
+
N1, L, C1 = absolute_pos_embed.size()
|
324 |
+
N2, C2, H, W = model.absolute_pos_embed.size()
|
325 |
+
if N1 != N2 or C1 != C2 or L != H*W:
|
326 |
+
logger.warning("Error in loading absolute_pos_embed, pass")
|
327 |
+
else:
|
328 |
+
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
|
329 |
+
|
330 |
+
# interpolate position bias table if needed
|
331 |
+
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
|
332 |
+
for table_key in relative_position_bias_table_keys:
|
333 |
+
table_pretrained = state_dict[table_key]
|
334 |
+
table_current = model.state_dict()[table_key]
|
335 |
+
L1, nH1 = table_pretrained.size()
|
336 |
+
L2, nH2 = table_current.size()
|
337 |
+
if nH1 != nH2:
|
338 |
+
logger.warning(f"Error in loading {table_key}, pass")
|
339 |
+
else:
|
340 |
+
if L1 != L2:
|
341 |
+
S1 = int(L1 ** 0.5)
|
342 |
+
S2 = int(L2 ** 0.5)
|
343 |
+
table_pretrained_resized = F.interpolate(
|
344 |
+
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
345 |
+
size=(S2, S2), mode='bicubic')
|
346 |
+
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
347 |
+
|
348 |
+
# load state_dict
|
349 |
+
load_state_dict(model, state_dict, strict, logger)
|
350 |
+
return checkpoint
|
351 |
+
|
352 |
+
|
353 |
+
def weights_to_cpu(state_dict):
|
354 |
+
"""Copy a model state_dict to cpu.
|
355 |
+
Args:
|
356 |
+
state_dict (OrderedDict): Model weights on GPU.
|
357 |
+
Returns:
|
358 |
+
OrderedDict: Model weights on GPU.
|
359 |
+
"""
|
360 |
+
state_dict_cpu = OrderedDict()
|
361 |
+
for key, val in state_dict.items():
|
362 |
+
state_dict_cpu[key] = val.cpu()
|
363 |
+
return state_dict_cpu
|
364 |
+
|
365 |
+
|
366 |
+
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
367 |
+
"""Saves module state to `destination` dictionary.
|
368 |
+
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
369 |
+
Args:
|
370 |
+
module (nn.Module): The module to generate state_dict.
|
371 |
+
destination (dict): A dict where state will be stored.
|
372 |
+
prefix (str): The prefix for parameters and buffers used in this
|
373 |
+
module.
|
374 |
+
"""
|
375 |
+
for name, param in module._parameters.items():
|
376 |
+
if param is not None:
|
377 |
+
destination[prefix + name] = param if keep_vars else param.detach()
|
378 |
+
for name, buf in module._buffers.items():
|
379 |
+
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
380 |
+
if buf is not None:
|
381 |
+
destination[prefix + name] = buf if keep_vars else buf.detach()
|
382 |
+
|
383 |
+
|
384 |
+
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
385 |
+
"""Returns a dictionary containing a whole state of the module.
|
386 |
+
Both parameters and persistent buffers (e.g. running averages) are
|
387 |
+
included. Keys are corresponding parameter and buffer names.
|
388 |
+
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
389 |
+
recursively check parallel module in case that the model has a complicated
|
390 |
+
structure, e.g., nn.Module(nn.Module(DDP)).
|
391 |
+
Args:
|
392 |
+
module (nn.Module): The module to generate state_dict.
|
393 |
+
destination (OrderedDict): Returned dict for the state of the
|
394 |
+
module.
|
395 |
+
prefix (str): Prefix of the key.
|
396 |
+
keep_vars (bool): Whether to keep the variable property of the
|
397 |
+
parameters. Default: False.
|
398 |
+
Returns:
|
399 |
+
dict: A dictionary containing a whole state of the module.
|
400 |
+
"""
|
401 |
+
# recursively check parallel module in case that the model has a
|
402 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
403 |
+
if is_module_wrapper(module):
|
404 |
+
module = module.module
|
405 |
+
|
406 |
+
# below is the same as torch.nn.Module.state_dict()
|
407 |
+
if destination is None:
|
408 |
+
destination = OrderedDict()
|
409 |
+
destination._metadata = OrderedDict()
|
410 |
+
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
411 |
+
version=module._version)
|
412 |
+
_save_to_state_dict(module, destination, prefix, keep_vars)
|
413 |
+
for name, child in module._modules.items():
|
414 |
+
if child is not None:
|
415 |
+
get_state_dict(
|
416 |
+
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
417 |
+
for hook in module._state_dict_hooks.values():
|
418 |
+
hook_result = hook(module, destination, prefix, local_metadata)
|
419 |
+
if hook_result is not None:
|
420 |
+
destination = hook_result
|
421 |
+
return destination
|
422 |
+
|
423 |
+
|
424 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
425 |
+
"""Save checkpoint to file.
|
426 |
+
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
427 |
+
``optimizer``. By default ``meta`` will contain version and time info.
|
428 |
+
Args:
|
429 |
+
model (Module): Module whose params are to be saved.
|
430 |
+
filename (str): Checkpoint filename.
|
431 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
432 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
433 |
+
"""
|
434 |
+
if meta is None:
|
435 |
+
meta = {}
|
436 |
+
elif not isinstance(meta, dict):
|
437 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
438 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
439 |
+
|
440 |
+
if is_module_wrapper(model):
|
441 |
+
model = model.module
|
442 |
+
|
443 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
444 |
+
# save class name to the meta
|
445 |
+
meta.update(CLASSES=model.CLASSES)
|
446 |
+
|
447 |
+
checkpoint = {
|
448 |
+
'meta': meta,
|
449 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
450 |
+
}
|
451 |
+
# save optimizer state dict in the checkpoint
|
452 |
+
if isinstance(optimizer, Optimizer):
|
453 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
454 |
+
elif isinstance(optimizer, dict):
|
455 |
+
checkpoint['optimizer'] = {}
|
456 |
+
for name, optim in optimizer.items():
|
457 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
458 |
+
|
459 |
+
if filename.startswith('pavi://'):
|
460 |
+
try:
|
461 |
+
from pavi import modelcloud
|
462 |
+
from pavi.exception import NodeNotFoundError
|
463 |
+
except ImportError:
|
464 |
+
raise ImportError(
|
465 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
466 |
+
model_path = filename[7:]
|
467 |
+
root = modelcloud.Folder()
|
468 |
+
model_dir, model_name = osp.split(model_path)
|
469 |
+
try:
|
470 |
+
model = modelcloud.get(model_dir)
|
471 |
+
except NodeNotFoundError:
|
472 |
+
model = root.create_training_model(model_dir)
|
473 |
+
with TemporaryDirectory() as tmp_dir:
|
474 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
475 |
+
with open(checkpoint_file, 'wb') as f:
|
476 |
+
torch.save(checkpoint, f)
|
477 |
+
f.flush()
|
478 |
+
model.create_file(checkpoint_file, name=model_name)
|
479 |
+
else:
|
480 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
481 |
+
# immediately flush buffer
|
482 |
+
with open(filename, 'wb') as f:
|
483 |
+
torch.save(checkpoint, f)
|
484 |
+
f.flush()
|
training/Detection/mmcv_custom/customized_text.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
import datetime
|
10 |
+
from collections import OrderedDict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
import mmcv
|
15 |
+
from mmcv.runner import HOOKS
|
16 |
+
from mmcv.runner import TextLoggerHook
|
17 |
+
|
18 |
+
|
19 |
+
@HOOKS.register_module()
|
20 |
+
class CustomizedTextLoggerHook(TextLoggerHook):
|
21 |
+
"""Customized Text Logger hook.
|
22 |
+
|
23 |
+
This logger prints out both lr and layer_0_lr.
|
24 |
+
|
25 |
+
"""
|
26 |
+
|
27 |
+
def _log_info(self, log_dict, runner):
|
28 |
+
# print exp name for users to distinguish experiments
|
29 |
+
# at every ``interval_exp_name`` iterations and the end of each epoch
|
30 |
+
if runner.meta is not None and 'exp_name' in runner.meta:
|
31 |
+
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
32 |
+
self.by_epoch and self.end_of_epoch(runner)):
|
33 |
+
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
34 |
+
runner.logger.info(exp_info)
|
35 |
+
|
36 |
+
if log_dict['mode'] == 'train':
|
37 |
+
lr_str = {}
|
38 |
+
for lr_type in ['lr', 'layer_0_lr']:
|
39 |
+
if isinstance(log_dict[lr_type], dict):
|
40 |
+
lr_str[lr_type] = []
|
41 |
+
for k, val in log_dict[lr_type].items():
|
42 |
+
lr_str.append(f'{lr_type}_{k}: {val:.3e}')
|
43 |
+
lr_str[lr_type] = ' '.join(lr_str)
|
44 |
+
else:
|
45 |
+
lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}'
|
46 |
+
|
47 |
+
# by epoch: Epoch [4][100/1000]
|
48 |
+
# by iter: Iter [100/100000]
|
49 |
+
if self.by_epoch:
|
50 |
+
log_str = f'Epoch [{log_dict["epoch"]}]' \
|
51 |
+
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
|
52 |
+
else:
|
53 |
+
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
|
54 |
+
log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, '
|
55 |
+
|
56 |
+
if 'time' in log_dict.keys():
|
57 |
+
self.time_sec_tot += (log_dict['time'] * self.interval)
|
58 |
+
time_sec_avg = self.time_sec_tot / (
|
59 |
+
runner.iter - self.start_iter + 1)
|
60 |
+
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
61 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
62 |
+
log_str += f'eta: {eta_str}, '
|
63 |
+
log_str += f'time: {log_dict["time"]:.3f}, ' \
|
64 |
+
f'data_time: {log_dict["data_time"]:.3f}, '
|
65 |
+
# statistic memory
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
log_str += f'memory: {log_dict["memory"]}, '
|
68 |
+
else:
|
69 |
+
# val/test time
|
70 |
+
# here 1000 is the length of the val dataloader
|
71 |
+
# by epoch: Epoch[val] [4][1000]
|
72 |
+
# by iter: Iter[val] [1000]
|
73 |
+
if self.by_epoch:
|
74 |
+
log_str = f'Epoch({log_dict["mode"]}) ' \
|
75 |
+
f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
|
76 |
+
else:
|
77 |
+
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
|
78 |
+
|
79 |
+
log_items = []
|
80 |
+
for name, val in log_dict.items():
|
81 |
+
# TODO: resolve this hack
|
82 |
+
# these items have been in log_str
|
83 |
+
if name in [
|
84 |
+
'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time',
|
85 |
+
'memory', 'epoch'
|
86 |
+
]:
|
87 |
+
continue
|
88 |
+
if isinstance(val, float):
|
89 |
+
val = f'{val:.4f}'
|
90 |
+
log_items.append(f'{name}: {val}')
|
91 |
+
log_str += ', '.join(log_items)
|
92 |
+
|
93 |
+
runner.logger.info(log_str)
|
94 |
+
|
95 |
+
|
96 |
+
def log(self, runner):
|
97 |
+
if 'eval_iter_num' in runner.log_buffer.output:
|
98 |
+
# this doesn't modify runner.iter and is regardless of by_epoch
|
99 |
+
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
|
100 |
+
else:
|
101 |
+
cur_iter = self.get_iter(runner, inner_iter=True)
|
102 |
+
|
103 |
+
log_dict = OrderedDict(
|
104 |
+
mode=self.get_mode(runner),
|
105 |
+
epoch=self.get_epoch(runner),
|
106 |
+
iter=cur_iter)
|
107 |
+
|
108 |
+
# record lr and layer_0_lr
|
109 |
+
cur_lr = runner.current_lr()
|
110 |
+
if isinstance(cur_lr, list):
|
111 |
+
log_dict['layer_0_lr'] = min(cur_lr)
|
112 |
+
log_dict['lr'] = max(cur_lr)
|
113 |
+
else:
|
114 |
+
assert isinstance(cur_lr, dict)
|
115 |
+
log_dict['lr'], log_dict['layer_0_lr'] = {}, {}
|
116 |
+
for k, lr_ in cur_lr.items():
|
117 |
+
assert isinstance(lr_, list)
|
118 |
+
log_dict['layer_0_lr'].update({k: min(lr_)})
|
119 |
+
log_dict['lr'].update({k: max(lr_)})
|
120 |
+
|
121 |
+
if 'time' in runner.log_buffer.output:
|
122 |
+
# statistic memory
|
123 |
+
if torch.cuda.is_available():
|
124 |
+
log_dict['memory'] = self._get_max_memory(runner)
|
125 |
+
|
126 |
+
log_dict = dict(log_dict, **runner.log_buffer.output)
|
127 |
+
|
128 |
+
self._log_info(log_dict, runner)
|
129 |
+
self._dump_log(log_dict, runner)
|
130 |
+
return log_dict
|
training/Detection/mmcv_custom/layer_decay_optimizer_constructor.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
|
9 |
+
import json
|
10 |
+
import re
|
11 |
+
from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
|
12 |
+
from mmcv.runner import get_dist_info
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
def cal_model_depth(depth, num_subnet):
|
16 |
+
dp = np.zeros((depth, num_subnet))
|
17 |
+
dp[:,0]=np.linspace(0, depth-1, depth)
|
18 |
+
dp[0,:]=np.linspace(0, num_subnet-1, num_subnet)
|
19 |
+
for i in range(1, depth):
|
20 |
+
for j in range(1, num_subnet):
|
21 |
+
dp[i][j] = min(dp[i][j-1], dp[i-1][j])+1
|
22 |
+
dp = dp.astype(int)
|
23 |
+
# col = [x for x in np.linspace(0, sum(self.layers)-1, sum(self.layers))]
|
24 |
+
# dp = np.transpose(np.array([col]*self.num_subnet, dtype=int))
|
25 |
+
dp = dp+1 ## make layer id starts from 1
|
26 |
+
return dp
|
27 |
+
|
28 |
+
def get_num_layer_layer_wise(n, layers, num_subnet=12):
|
29 |
+
dp=cal_model_depth(sum(layers), num_subnet)
|
30 |
+
# def get_layer_id(n, dp, layers):
|
31 |
+
if n.startswith("backbone.subnet"):
|
32 |
+
n=n[9:]
|
33 |
+
name_part = n.split('.')
|
34 |
+
subnet = int(name_part[0][6:])
|
35 |
+
if name_part[1].startswith("alpha"):
|
36 |
+
id = dp[0][subnet]
|
37 |
+
else:
|
38 |
+
level = int(name_part[1][-1])
|
39 |
+
if name_part[2].startswith("blocks"):
|
40 |
+
sub = int(name_part[3])
|
41 |
+
if sub>layers[level]-1:
|
42 |
+
sub = layers[level]-1
|
43 |
+
block = sum(layers[:level])+sub
|
44 |
+
|
45 |
+
if name_part[2].startswith("fusion"):
|
46 |
+
block = sum(layers[:level])
|
47 |
+
id = dp[block][subnet]
|
48 |
+
elif n.startswith("backbone.stem"):
|
49 |
+
id = 0
|
50 |
+
else:
|
51 |
+
id = dp[-1][-1]+1
|
52 |
+
return id
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
@OPTIMIZER_BUILDERS.register_module()
|
57 |
+
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
58 |
+
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
59 |
+
"""Add all parameters of module to the params list.
|
60 |
+
The parameters of the given module will be added to the list of param
|
61 |
+
groups, with specific rules defined by paramwise_cfg.
|
62 |
+
Args:
|
63 |
+
params (list[dict]): A list of param groups, it will be modified
|
64 |
+
in place.
|
65 |
+
module (nn.Module): The module to be added.
|
66 |
+
prefix (str): The prefix of the module
|
67 |
+
is_dcn_module (int|float|None): If the current module is a
|
68 |
+
submodule of DCN, `is_dcn_module` will be passed to
|
69 |
+
control conv_offset layer's learning rate. Defaults to None.
|
70 |
+
"""
|
71 |
+
parameter_groups = {}
|
72 |
+
print(self.paramwise_cfg)
|
73 |
+
num_layers = cal_model_depth(sum(self.paramwise_cfg.get('layers')), self.paramwise_cfg.get('num_subnet'))[-1][-1]+2
|
74 |
+
# num_layers = self.paramwise_cfg.get('num_layers') + 2
|
75 |
+
decay_rate = self.paramwise_cfg.get('decay_rate')
|
76 |
+
decay_type = self.paramwise_cfg.get('decay_type', "layer_wise")
|
77 |
+
print("Build LearningRateDecayOptimizerConstructor %s %f - %d" % (decay_type, decay_rate, num_layers))
|
78 |
+
weight_decay = self.base_wd
|
79 |
+
|
80 |
+
for name, param in module.named_parameters():
|
81 |
+
if not param.requires_grad:
|
82 |
+
continue # frozen weights
|
83 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token') or re.match('(.*).alpha.$', name):
|
84 |
+
group_name = "no_decay"
|
85 |
+
this_weight_decay = 0.
|
86 |
+
else:
|
87 |
+
group_name = "decay"
|
88 |
+
this_weight_decay = weight_decay
|
89 |
+
|
90 |
+
if decay_type == "layer_wise":
|
91 |
+
layer_id = get_num_layer_layer_wise(name, self.paramwise_cfg.get('layers'), self.paramwise_cfg.get('num_subnet'))
|
92 |
+
|
93 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
94 |
+
|
95 |
+
if group_name not in parameter_groups:
|
96 |
+
scale = decay_rate ** (num_layers - layer_id - 1)
|
97 |
+
|
98 |
+
parameter_groups[group_name] = {
|
99 |
+
"weight_decay": this_weight_decay,
|
100 |
+
"params": [],
|
101 |
+
"param_names": [],
|
102 |
+
"lr_scale": scale,
|
103 |
+
"group_name": group_name,
|
104 |
+
"lr": scale * self.base_lr,
|
105 |
+
}
|
106 |
+
|
107 |
+
parameter_groups[group_name]["params"].append(param)
|
108 |
+
parameter_groups[group_name]["param_names"].append(name)
|
109 |
+
rank, _ = get_dist_info()
|
110 |
+
if rank == 0:
|
111 |
+
to_display = {}
|
112 |
+
for key in parameter_groups:
|
113 |
+
to_display[key] = {
|
114 |
+
"param_names": parameter_groups[key]["param_names"],
|
115 |
+
"lr_scale": parameter_groups[key]["lr_scale"],
|
116 |
+
"lr": parameter_groups[key]["lr"],
|
117 |
+
"weight_decay": parameter_groups[key]["weight_decay"],
|
118 |
+
}
|
119 |
+
print("Param groups = %s" % json.dumps(to_display, indent=2))
|
120 |
+
|
121 |
+
params.extend(parameter_groups.values())
|
training/Detection/mmcv_custom/runner/checkpoint.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import time
|
4 |
+
from tempfile import TemporaryDirectory
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.optim import Optimizer
|
8 |
+
|
9 |
+
import mmcv
|
10 |
+
from mmcv.parallel import is_module_wrapper
|
11 |
+
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
|
12 |
+
|
13 |
+
try:
|
14 |
+
import apex
|
15 |
+
except:
|
16 |
+
print('apex is not installed')
|
17 |
+
|
18 |
+
|
19 |
+
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
20 |
+
"""Save checkpoint to file.
|
21 |
+
|
22 |
+
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
|
23 |
+
``optimizer``, ``amp``. By default ``meta`` will contain version
|
24 |
+
and time info.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model (Module): Module whose params are to be saved.
|
28 |
+
filename (str): Checkpoint filename.
|
29 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
30 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
31 |
+
"""
|
32 |
+
if meta is None:
|
33 |
+
meta = {}
|
34 |
+
elif not isinstance(meta, dict):
|
35 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
36 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
37 |
+
|
38 |
+
if is_module_wrapper(model):
|
39 |
+
model = model.module
|
40 |
+
|
41 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
42 |
+
# save class name to the meta
|
43 |
+
meta.update(CLASSES=model.CLASSES)
|
44 |
+
|
45 |
+
checkpoint = {
|
46 |
+
'meta': meta,
|
47 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
48 |
+
}
|
49 |
+
# save optimizer state dict in the checkpoint
|
50 |
+
if isinstance(optimizer, Optimizer):
|
51 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
52 |
+
elif isinstance(optimizer, dict):
|
53 |
+
checkpoint['optimizer'] = {}
|
54 |
+
for name, optim in optimizer.items():
|
55 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
56 |
+
|
57 |
+
# save amp state dict in the checkpoint
|
58 |
+
# checkpoint['amp'] = apex.amp.state_dict()
|
59 |
+
|
60 |
+
if filename.startswith('pavi://'):
|
61 |
+
try:
|
62 |
+
from pavi import modelcloud
|
63 |
+
from pavi.exception import NodeNotFoundError
|
64 |
+
except ImportError:
|
65 |
+
raise ImportError(
|
66 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
67 |
+
model_path = filename[7:]
|
68 |
+
root = modelcloud.Folder()
|
69 |
+
model_dir, model_name = osp.split(model_path)
|
70 |
+
try:
|
71 |
+
model = modelcloud.get(model_dir)
|
72 |
+
except NodeNotFoundError:
|
73 |
+
model = root.create_training_model(model_dir)
|
74 |
+
with TemporaryDirectory() as tmp_dir:
|
75 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
76 |
+
with open(checkpoint_file, 'wb') as f:
|
77 |
+
torch.save(checkpoint, f)
|
78 |
+
f.flush()
|
79 |
+
model.create_file(checkpoint_file, name=model_name)
|
80 |
+
else:
|
81 |
+
mmcv.mkdir_or_exist(osp.dirname(filename))
|
82 |
+
# immediately flush buffer
|
83 |
+
with open(filename, 'wb') as f:
|
84 |
+
torch.save(checkpoint, f)
|
85 |
+
f.flush()
|
training/Detection/mmdet/models/backbones/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .csp_darknet import CSPDarknet
|
3 |
+
from .darknet import Darknet
|
4 |
+
from .detectors_resnet import DetectoRS_ResNet
|
5 |
+
from .detectors_resnext import DetectoRS_ResNeXt
|
6 |
+
from .efficientnet import EfficientNet
|
7 |
+
from .hourglass import HourglassNet
|
8 |
+
from .hrnet import HRNet
|
9 |
+
from .mobilenet_v2 import MobileNetV2
|
10 |
+
from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2
|
11 |
+
from .regnet import RegNet
|
12 |
+
from .res2net import Res2Net
|
13 |
+
from .resnest import ResNeSt
|
14 |
+
from .resnet import ResNet, ResNetV1d
|
15 |
+
from .resnext import ResNeXt
|
16 |
+
from .ssd_vgg import SSDVGG
|
17 |
+
from .swin import SwinTransformer
|
18 |
+
from .trident_resnet import TridentResNet
|
19 |
+
from .revcol_huge import RevCol_Huge
|
20 |
+
from .revcol import RevCol
|
21 |
+
|
22 |
+
__all__ = [
|
23 |
+
'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet',
|
24 |
+
'MobileNetV2', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet',
|
25 |
+
'DetectoRS_ResNeXt', 'Darknet', 'ResNeSt', 'TridentResNet', 'CSPDarknet',
|
26 |
+
'SwinTransformer', 'PyramidVisionTransformer',
|
27 |
+
'PyramidVisionTransformerV2', 'EfficientNet', 'RevCol'
|
28 |
+
]
|
training/Detection/mmdet/models/backbones/revcol.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .revcol_module import ConvNextBlock, LayerNorm, UpSampleConvnext
|
6 |
+
from mmdet.utils import get_root_logger
|
7 |
+
from ..builder import BACKBONES
|
8 |
+
from .revcol_function import ReverseFunction
|
9 |
+
from mmcv.cnn import constant_init, trunc_normal_init
|
10 |
+
from mmcv.runner import BaseModule, _load_checkpoint
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
class Fusion(nn.Module):
|
14 |
+
def __init__(self, level, channels, first_col) -> None:
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.level = level
|
18 |
+
self.first_col = first_col
|
19 |
+
self.down = nn.Sequential(
|
20 |
+
nn.Conv2d(channels[level-1], channels[level], kernel_size=2, stride=2),
|
21 |
+
LayerNorm(channels[level], eps=1e-6, data_format="channels_first"),
|
22 |
+
) if level in [1, 2, 3] else nn.Identity()
|
23 |
+
if not first_col:
|
24 |
+
self.up = UpSampleConvnext(1, channels[level+1], channels[level]) if level in [0, 1, 2] else nn.Identity()
|
25 |
+
|
26 |
+
|
27 |
+
def forward(self, *args):
|
28 |
+
c_down, c_up = args
|
29 |
+
|
30 |
+
if self.first_col:
|
31 |
+
x = self.down(c_down)
|
32 |
+
return x
|
33 |
+
|
34 |
+
if self.level == 3:
|
35 |
+
x = self.down(c_down)
|
36 |
+
else:
|
37 |
+
x = self.up(c_up) + self.down(c_down)
|
38 |
+
return x
|
39 |
+
|
40 |
+
class Level(nn.Module):
|
41 |
+
def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None:
|
42 |
+
super().__init__()
|
43 |
+
countlayer = sum(layers[:level])
|
44 |
+
expansion = 4
|
45 |
+
self.fusion = Fusion(level, channels, first_col)
|
46 |
+
modules = [ConvNextBlock(channels[level], expansion*channels[level], channels[level], kernel_size = kernel_size, layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer+i]) for i in range(layers[level])]
|
47 |
+
self.blocks = nn.Sequential(*modules)
|
48 |
+
def forward(self, *args):
|
49 |
+
x = self.fusion(*args)
|
50 |
+
x = self.blocks(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
class SubNet(nn.Module):
|
54 |
+
def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None:
|
55 |
+
super().__init__()
|
56 |
+
shortcut_scale_init_value = 0.5
|
57 |
+
self.save_memory = save_memory
|
58 |
+
self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)),
|
59 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
60 |
+
self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)),
|
61 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
62 |
+
self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)),
|
63 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
64 |
+
self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)),
|
65 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
66 |
+
|
67 |
+
self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates)
|
68 |
+
|
69 |
+
self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates)
|
70 |
+
|
71 |
+
self.level2 = Level(2, channels, layers, kernel_size,first_col, dp_rates)
|
72 |
+
|
73 |
+
self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates)
|
74 |
+
|
75 |
+
def _forward_nonreverse(self, *args):
|
76 |
+
x, c0, c1, c2, c3= args
|
77 |
+
|
78 |
+
c0 = (self.alpha0)*c0 + self.level0(x, c1)
|
79 |
+
c1 = (self.alpha1)*c1 + self.level1(c0, c2)
|
80 |
+
c2 = (self.alpha2)*c2 + self.level2(c1, c3)
|
81 |
+
c3 = (self.alpha3)*c3 + self.level3(c2, None)
|
82 |
+
|
83 |
+
return c0, c1, c2, c3
|
84 |
+
|
85 |
+
def _forward_reverse(self, *args):
|
86 |
+
|
87 |
+
local_funs = [self.level0, self.level1, self.level2, self.level3]
|
88 |
+
alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3]
|
89 |
+
_, c0, c1, c2, c3 = ReverseFunction.apply(
|
90 |
+
local_funs, alpha, *args)
|
91 |
+
|
92 |
+
return c0, c1, c2, c3
|
93 |
+
|
94 |
+
def forward(self, *args):
|
95 |
+
|
96 |
+
self._clamp_abs(self.alpha0.data, 1e-3)
|
97 |
+
self._clamp_abs(self.alpha1.data, 1e-3)
|
98 |
+
self._clamp_abs(self.alpha2.data, 1e-3)
|
99 |
+
self._clamp_abs(self.alpha3.data, 1e-3)
|
100 |
+
|
101 |
+
if self.save_memory:
|
102 |
+
return self._forward_reverse(*args)
|
103 |
+
else:
|
104 |
+
return self._forward_nonreverse(*args)
|
105 |
+
|
106 |
+
def _clamp_abs(self, data, value):
|
107 |
+
with torch.no_grad():
|
108 |
+
sign=data.sign()
|
109 |
+
data.abs_().clamp_(value)
|
110 |
+
data*=sign
|
111 |
+
|
112 |
+
@BACKBONES.register_module()
|
113 |
+
class RevCol(BaseModule):
|
114 |
+
def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size = 3, num_classes=1000, drop_path = 0.0, save_memory=True, single_head=True, out_indices=[0, 1, 2, 3], init_cfg=None) -> None:
|
115 |
+
super().__init__(init_cfg)
|
116 |
+
self.num_subnet = num_subnet
|
117 |
+
self.single_head = single_head
|
118 |
+
self.out_indices = out_indices
|
119 |
+
self.init_cfg = init_cfg
|
120 |
+
|
121 |
+
self.stem = nn.Sequential(
|
122 |
+
nn.Conv2d(3, channels[0], kernel_size=4, stride=4),
|
123 |
+
LayerNorm(channels[0], eps=1e-6, data_format="channels_first")
|
124 |
+
)
|
125 |
+
|
126 |
+
# dp_rate = self.cal_dp_rate(sum(layers), num_subnet, drop_path)
|
127 |
+
|
128 |
+
dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))]
|
129 |
+
for i in range(num_subnet):
|
130 |
+
first_col = True if i == 0 else False
|
131 |
+
self.add_module(f'subnet{str(i)}', SubNet(
|
132 |
+
channels,layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory))
|
133 |
+
|
134 |
+
def init_weights(self):
|
135 |
+
logger = get_root_logger()
|
136 |
+
if self.init_cfg is None:
|
137 |
+
logger.warn(f'No pre-trained weights for '
|
138 |
+
f'{self.__class__.__name__}, '
|
139 |
+
f'training start from scratch')
|
140 |
+
for m in self.modules():
|
141 |
+
if isinstance(m, nn.Linear):
|
142 |
+
trunc_normal_init(m, std=.02, bias=0.)
|
143 |
+
elif isinstance(m, nn.LayerNorm):
|
144 |
+
constant_init(m, 1.0)
|
145 |
+
else:
|
146 |
+
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
147 |
+
f'specify `Pretrained` in ' \
|
148 |
+
f'`init_cfg` in ' \
|
149 |
+
f'{self.__class__.__name__} '
|
150 |
+
ckpt = _load_checkpoint(
|
151 |
+
self.init_cfg.checkpoint, logger=logger, map_location='cpu')
|
152 |
+
if 'state_dict' in ckpt:
|
153 |
+
_state_dict = ckpt['state_dict']
|
154 |
+
elif 'model' in ckpt:
|
155 |
+
_state_dict = ckpt['model']
|
156 |
+
else:
|
157 |
+
_state_dict = ckpt
|
158 |
+
|
159 |
+
|
160 |
+
state_dict = _state_dict
|
161 |
+
# print(state_dict.keys())
|
162 |
+
# strip prefix of state_dict
|
163 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
164 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
165 |
+
|
166 |
+
# load state_dict
|
167 |
+
self.load_state_dict(state_dict, False)
|
168 |
+
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
x = self.stem(x)
|
172 |
+
c0, c1, c2, c3 = 0, 0, 0, 0
|
173 |
+
for i in range(self.num_subnet):
|
174 |
+
# c0, c1, c2, c3 = checkpoint(getattr(self, f'subnet{str(i)}'), x, c0, c1, c2, c3 )
|
175 |
+
c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
|
176 |
+
return c0, c1, c2, c3
|
177 |
+
|
178 |
+
def cal_dp_rate(self, depth, num_subnet, drop_path):
|
179 |
+
dp = np.zeros((depth, num_subnet))
|
180 |
+
dp[:,0]=np.linspace(0, depth-1, depth)
|
181 |
+
dp[0,:]=np.linspace(0, num_subnet-1, num_subnet)
|
182 |
+
for i in range(1, depth):
|
183 |
+
for j in range(1, num_subnet):
|
184 |
+
dp[i][j] = min(dp[i][j-1], dp[i-1][j])+1
|
185 |
+
ratio = dp[-1][-1]/drop_path
|
186 |
+
dp_matrix = dp/ratio
|
187 |
+
return dp_matrix
|
training/Detection/mmdet/models/backbones/revcol_function.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Any, Iterable, List, Tuple, Callable
|
3 |
+
import torch.distributed as dist
|
4 |
+
|
5 |
+
def get_gpu_states(fwd_gpu_devices) -> Tuple[List[int], List[torch.Tensor]]:
|
6 |
+
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
|
7 |
+
# the conditionals short-circuit.
|
8 |
+
fwd_gpu_states = []
|
9 |
+
for device in fwd_gpu_devices:
|
10 |
+
with torch.cuda.device(device):
|
11 |
+
fwd_gpu_states.append(torch.cuda.get_rng_state())
|
12 |
+
|
13 |
+
return fwd_gpu_states
|
14 |
+
|
15 |
+
def get_gpu_device(*args):
|
16 |
+
|
17 |
+
fwd_gpu_devices = list(set(arg.get_device() for arg in args
|
18 |
+
if isinstance(arg, torch.Tensor) and arg.is_cuda))
|
19 |
+
return fwd_gpu_devices
|
20 |
+
|
21 |
+
def set_device_states(fwd_cpu_state, devices, states) -> None:
|
22 |
+
torch.set_rng_state(fwd_cpu_state)
|
23 |
+
for device, state in zip(devices, states):
|
24 |
+
with torch.cuda.device(device):
|
25 |
+
torch.cuda.set_rng_state(state)
|
26 |
+
|
27 |
+
def detach_and_grad(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
28 |
+
if isinstance(inputs, tuple):
|
29 |
+
out = []
|
30 |
+
for inp in inputs:
|
31 |
+
if not isinstance(inp, torch.Tensor):
|
32 |
+
out.append(inp)
|
33 |
+
continue
|
34 |
+
|
35 |
+
x = inp.detach()
|
36 |
+
x.requires_grad = True
|
37 |
+
out.append(x)
|
38 |
+
return tuple(out)
|
39 |
+
else:
|
40 |
+
raise RuntimeError(
|
41 |
+
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
42 |
+
|
43 |
+
def get_cpu_and_gpu_states(gpu_devices):
|
44 |
+
return torch.get_rng_state(), get_gpu_states(gpu_devices)
|
45 |
+
|
46 |
+
class ReverseFunction(torch.autograd.Function):
|
47 |
+
@staticmethod
|
48 |
+
def forward(ctx, run_functions, alpha, *args):
|
49 |
+
l0, l1, l2, l3 = run_functions
|
50 |
+
alpha0, alpha1, alpha2, alpha3 = alpha
|
51 |
+
ctx.run_functions = run_functions
|
52 |
+
ctx.alpha = alpha
|
53 |
+
ctx.preserve_rng_state = True
|
54 |
+
|
55 |
+
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
56 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
57 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
58 |
+
ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
|
59 |
+
"dtype": torch.get_autocast_cpu_dtype(),
|
60 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
61 |
+
|
62 |
+
assert len(args) == 5
|
63 |
+
[x, c0, c1, c2, c3] = args
|
64 |
+
if type(c0) == int:
|
65 |
+
ctx.first_col = True
|
66 |
+
else:
|
67 |
+
ctx.first_col = False
|
68 |
+
with torch.no_grad():
|
69 |
+
if ctx.preserve_rng_state:
|
70 |
+
gpu_devices = get_gpu_device(*args)
|
71 |
+
ctx.gpu_devices = gpu_devices
|
72 |
+
ctx.cpu_states_0, ctx.gpu_states_0 = get_cpu_and_gpu_states(gpu_devices)
|
73 |
+
c0 = l0(x, c1, c3) + c0*alpha0
|
74 |
+
ctx.cpu_states_1, ctx.gpu_states_1 = get_cpu_and_gpu_states(gpu_devices)
|
75 |
+
c1 = l1(c0, c2) + c1*alpha1
|
76 |
+
ctx.cpu_states_2, ctx.gpu_states_2 = get_cpu_and_gpu_states(gpu_devices)
|
77 |
+
c2 = l2(c1, c3) + c2*alpha2
|
78 |
+
ctx.cpu_states_3, ctx.gpu_states_3 = get_cpu_and_gpu_states(gpu_devices)
|
79 |
+
c3 = l3(c2) + c3*alpha3
|
80 |
+
else:
|
81 |
+
c0 = l0(x, c1, c3) + c0*alpha0
|
82 |
+
c1 = l1(c0, c2) + c1*alpha1
|
83 |
+
c2 = l2(c1, c3) + c2*alpha2
|
84 |
+
c3 = l3(c2) + c3*alpha3
|
85 |
+
ctx.save_for_backward(x, c0, c1, c2, c3)
|
86 |
+
return x, c0, c1 ,c2, c3
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def backward(ctx, *grad_outputs):
|
90 |
+
x, c0, c1, c2, c3 = ctx.saved_tensors
|
91 |
+
l0, l1, l2, l3 = ctx.run_functions
|
92 |
+
alpha0, alpha1, alpha2, alpha3 = ctx.alpha
|
93 |
+
gx_right, g0_right, g1_right, g2_right, g3_right = grad_outputs
|
94 |
+
(x, c0, c1, c2, c3) = detach_and_grad((x, c0, c1, c2, c3))
|
95 |
+
|
96 |
+
if ctx.preserve_rng_state:
|
97 |
+
with torch.enable_grad(), \
|
98 |
+
torch.random.fork_rng(devices=ctx.gpu_devices, enabled=ctx.preserve_rng_state), \
|
99 |
+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
|
100 |
+
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
101 |
+
|
102 |
+
g3_up = g3_right
|
103 |
+
g3_left = g3_up*alpha3 ##shortcut
|
104 |
+
set_device_states(ctx.cpu_states_3, ctx.gpu_devices, ctx.gpu_states_3)
|
105 |
+
oup3 = l3(c2)
|
106 |
+
torch.autograd.backward(oup3, g3_up, retain_graph=True)
|
107 |
+
with torch.no_grad():
|
108 |
+
c3_left = (1/alpha3)*(c3 - oup3) ## feature reverse
|
109 |
+
g2_up = g2_right+ c2.grad
|
110 |
+
g2_left = g2_up*alpha2 ##shortcut
|
111 |
+
|
112 |
+
(c3_left,) = detach_and_grad((c3_left,))
|
113 |
+
set_device_states(ctx.cpu_states_2, ctx.gpu_devices, ctx.gpu_states_2)
|
114 |
+
oup2 = l2(c1, c3_left)
|
115 |
+
torch.autograd.backward(oup2, g2_up, retain_graph=True)
|
116 |
+
c3_left.requires_grad = False
|
117 |
+
cout3 = c3_left*alpha3 ##alpha3 update
|
118 |
+
torch.autograd.backward(cout3, g3_up)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
c2_left = (1/alpha2)*(c2 - oup2) ## feature reverse
|
122 |
+
g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
|
123 |
+
g1_up = g1_right+c1.grad
|
124 |
+
g1_left = g1_up*alpha1 ##shortcut
|
125 |
+
|
126 |
+
(c2_left,) = detach_and_grad((c2_left,))
|
127 |
+
set_device_states(ctx.cpu_states_1, ctx.gpu_devices, ctx.gpu_states_1)
|
128 |
+
oup1 = l1(c0, c2_left)
|
129 |
+
torch.autograd.backward(oup1, g1_up, retain_graph=True)
|
130 |
+
c2_left.requires_grad = False
|
131 |
+
cout2 = c2_left*alpha2 ##alpha3 update
|
132 |
+
torch.autograd.backward(cout2, g2_up)
|
133 |
+
|
134 |
+
with torch.no_grad():
|
135 |
+
c1_left = (1/alpha1)*(c1 - oup1) ## feature reverse
|
136 |
+
g0_up = g0_right + c0.grad
|
137 |
+
g0_left = g0_up*alpha0 ##shortcut
|
138 |
+
g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left ## Fusion
|
139 |
+
|
140 |
+
(c1_left,c3_left) = detach_and_grad((c1_left,c3_left))
|
141 |
+
set_device_states(ctx.cpu_states_0, ctx.gpu_devices, ctx.gpu_states_0)
|
142 |
+
oup0 = l0(x, c1_left, c3_left)
|
143 |
+
torch.autograd.backward(oup0, g0_up, retain_graph=True)
|
144 |
+
c1_left.requires_grad = False
|
145 |
+
cout1 = c1_left*alpha1 ##alpha3 update
|
146 |
+
torch.autograd.backward(cout1, g1_up)
|
147 |
+
|
148 |
+
with torch.no_grad():
|
149 |
+
c0_left = (1/alpha0)*(c0 - oup0) ## feature reverse
|
150 |
+
gx_up = x.grad ## Fusion
|
151 |
+
g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left ## Fusion
|
152 |
+
g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left ## Fusion
|
153 |
+
c0_left.requires_grad = False
|
154 |
+
cout0 = c0_left*alpha0 ##alpha3 update
|
155 |
+
torch.autograd.backward(cout0, g0_up)
|
156 |
+
else:
|
157 |
+
with torch.enable_grad():
|
158 |
+
|
159 |
+
g3_up = g3_right
|
160 |
+
g3_left = g3_up*alpha3 ##shortcut
|
161 |
+
oup3 = l3(c2)
|
162 |
+
torch.autograd.backward(oup3, g3_up, retain_graph=True)
|
163 |
+
with torch.no_grad():
|
164 |
+
c3_left = (1/alpha3)*(c3 - oup3) ## feature reverse
|
165 |
+
g2_up = g2_right+ c2.grad
|
166 |
+
g2_left = g2_up*alpha2 ##shortcut
|
167 |
+
|
168 |
+
(c3_left,) = detach_and_grad((c3_left,))
|
169 |
+
oup2 = l2(c1, c3_left)
|
170 |
+
torch.autograd.backward(oup2, g2_up, retain_graph=True)
|
171 |
+
c3_left.requires_grad = False
|
172 |
+
cout3 = c3_left*alpha3 ##alpha3 update
|
173 |
+
torch.autograd.backward(cout3, g3_up)
|
174 |
+
|
175 |
+
with torch.no_grad():
|
176 |
+
c2_left = (1/alpha2)*(c2 - oup2) ## feature reverse
|
177 |
+
g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
|
178 |
+
g1_up = g1_right+c1.grad
|
179 |
+
g1_left = g1_up*alpha1 ##shortcut
|
180 |
+
|
181 |
+
(c2_left,) = detach_and_grad((c2_left,))
|
182 |
+
oup1 = l1(c0, c2_left)
|
183 |
+
torch.autograd.backward(oup1, g1_up, retain_graph=True)
|
184 |
+
c2_left.requires_grad = False
|
185 |
+
cout2 = c2_left*alpha2 ##alpha3 update
|
186 |
+
torch.autograd.backward(cout2, g2_up)
|
187 |
+
|
188 |
+
with torch.no_grad():
|
189 |
+
c1_left = (1/alpha1)*(c1 - oup1) ## feature reverse
|
190 |
+
g0_up = g0_right + c0.grad
|
191 |
+
g0_left = g0_up*alpha0 ##shortcut
|
192 |
+
g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left ## Fusion
|
193 |
+
|
194 |
+
(c1_left,c3_left) = detach_and_grad((c1_left,c3_left))
|
195 |
+
oup0 = l0(x, c1_left, c3_left)
|
196 |
+
torch.autograd.backward(oup0, g0_up, retain_graph=True)
|
197 |
+
c1_left.requires_grad = False
|
198 |
+
cout1 = c1_left*alpha1 ##alpha3 update
|
199 |
+
torch.autograd.backward(cout1, g1_up)
|
200 |
+
|
201 |
+
with torch.no_grad():
|
202 |
+
c0_left = (1/alpha0)*(c0 - oup0) ## feature reverse
|
203 |
+
gx_up = x.grad ## Fusion
|
204 |
+
g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left ## Fusion
|
205 |
+
g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left ## Fusion
|
206 |
+
c0_left.requires_grad = False
|
207 |
+
cout0 = c0_left*alpha0 ##alpha3 update
|
208 |
+
torch.autograd.backward(cout0, g0_up)
|
209 |
+
# if dist.get_rank()==0:
|
210 |
+
# print(c0_left.mean().data)
|
211 |
+
# print(f'c0: {c0_left.max()}, c1: {c1_left.max()}, c2: {c2_left.max()}, c3: {c3_left.max()}')
|
212 |
+
# print(f'x.grad: {gx_up.mean()}, c0.grad: {g0_left.mean()}, c1.grad: {g1_left.mean()}, c2.grad: {g2_left.mean()}, c3.grad: {g3_left.mean()}')
|
213 |
+
# import pdb;pdb.set_trace()
|
214 |
+
|
215 |
+
if ctx.first_col:
|
216 |
+
# print(f'c0: {c0_left.max()}, c1: {c1_left.max()}, c2: {c2_left.max()}, c3: {c3_left.max()}')
|
217 |
+
return None, None, gx_up, None, None, None, None
|
218 |
+
else:
|
219 |
+
return None, None, gx_up, g0_left, g1_left, g2_left, g3_left
|
220 |
+
|
221 |
+
|
222 |
+
|
training/Detection/mmdet/models/backbones/revcol_module.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imp
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from timm.models.layers import DropPath
|
6 |
+
|
7 |
+
class UpSampleConvnext(nn.Module):
|
8 |
+
def __init__(self, ratio, inchannel, outchannel):
|
9 |
+
super().__init__()
|
10 |
+
self.ratio = ratio
|
11 |
+
self.channel_reschedule = nn.Sequential(
|
12 |
+
# LayerNorm(inchannel, eps=1e-6, data_format="channels_last"),
|
13 |
+
nn.Linear(inchannel, outchannel),
|
14 |
+
LayerNorm(outchannel, eps=1e-6, data_format="channels_last"))
|
15 |
+
self.upsample = nn.Upsample(scale_factor=2**ratio, mode='nearest')
|
16 |
+
def forward(self, x):
|
17 |
+
x = x.permute(0, 2, 3, 1)
|
18 |
+
x = self.channel_reschedule(x)
|
19 |
+
x = x = x.permute(0, 3, 1, 2)
|
20 |
+
|
21 |
+
return self.upsample(x)
|
22 |
+
|
23 |
+
class LayerNorm(nn.Module):
|
24 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
25 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
26 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
27 |
+
with shape (batch_size, channels, height, width).
|
28 |
+
"""
|
29 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
|
30 |
+
super().__init__()
|
31 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
32 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
33 |
+
self.eps = eps
|
34 |
+
self.data_format = data_format
|
35 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
36 |
+
raise NotImplementedError
|
37 |
+
self.normalized_shape = (normalized_shape, )
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
if self.data_format == "channels_last":
|
41 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
42 |
+
elif self.data_format == "channels_first":
|
43 |
+
u = x.mean(1, keepdim=True)
|
44 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
45 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
46 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
class ConvNextBlock(nn.Module):
|
51 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
52 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
53 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
54 |
+
We use (2) as we find it slightly faster in PyTorch
|
55 |
+
|
56 |
+
Args:
|
57 |
+
dim (int): Number of input channels.
|
58 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
59 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
60 |
+
"""
|
61 |
+
def __init__(self, in_channel, hidden_dim, out_channel, kernel_size=3, layer_scale_init_value=1e-6, drop_path= 0.0):
|
62 |
+
super().__init__()
|
63 |
+
self.dwconv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=in_channel) # depthwise conv
|
64 |
+
self.norm = nn.LayerNorm(in_channel, eps=1e-6)
|
65 |
+
self.pwconv1 = nn.Linear(in_channel, hidden_dim) # pointwise/1x1 convs, implemented with linear layers
|
66 |
+
self.act = nn.GELU()
|
67 |
+
self.pwconv2 = nn.Linear(hidden_dim, out_channel)
|
68 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channel)),
|
69 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
70 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
input = x
|
74 |
+
x = self.dwconv(x)
|
75 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
76 |
+
x = self.norm(x)
|
77 |
+
x = self.pwconv1(x)
|
78 |
+
x = self.act(x)
|
79 |
+
x = self.pwconv2(x)
|
80 |
+
if self.gamma is not None:
|
81 |
+
x = self.gamma * x
|
82 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
83 |
+
|
84 |
+
x = input + self.drop_path(x)
|
85 |
+
return x
|
training/Detection/mmdet/utils/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .collect_env import collect_env
|
3 |
+
from .logger import get_caller_name, get_root_logger, log_img_scale
|
4 |
+
from .misc import find_latest_checkpoint, update_data_root
|
5 |
+
from .setup_env import setup_multi_processes
|
6 |
+
from .optimizer import DistOptimizerHook
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
|
10 |
+
'update_data_root', 'setup_multi_processes', 'get_caller_name',
|
11 |
+
'log_img_scale', 'DistOptimizerHook'
|
12 |
+
]
|
training/Detection/mmdet/utils/optimizer.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv.runner import OptimizerHook, HOOKS
|
2 |
+
try:
|
3 |
+
import apex
|
4 |
+
except:
|
5 |
+
print('apex is not installed')
|
6 |
+
|
7 |
+
|
8 |
+
@HOOKS.register_module()
|
9 |
+
class DistOptimizerHook(OptimizerHook):
|
10 |
+
"""Optimizer hook for distributed training."""
|
11 |
+
|
12 |
+
def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
|
13 |
+
self.grad_clip = grad_clip
|
14 |
+
self.coalesce = coalesce
|
15 |
+
self.bucket_size_mb = bucket_size_mb
|
16 |
+
self.update_interval = update_interval
|
17 |
+
self.use_fp16 = use_fp16
|
18 |
+
|
19 |
+
def before_run(self, runner):
|
20 |
+
runner.optimizer.zero_grad()
|
21 |
+
|
22 |
+
def after_train_iter(self, runner):
|
23 |
+
runner.outputs['loss'] /= self.update_interval
|
24 |
+
if self.use_fp16:
|
25 |
+
with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss:
|
26 |
+
scaled_loss.backward()
|
27 |
+
else:
|
28 |
+
runner.outputs['loss'].backward()
|
29 |
+
if self.every_n_iters(runner, self.update_interval):
|
30 |
+
if self.grad_clip is not None:
|
31 |
+
self.clip_grads(runner.model.parameters())
|
32 |
+
runner.optimizer.step()
|
33 |
+
runner.optimizer.zero_grad()
|
training/INSTRUCTIONS.md
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Installation, Training and Evaluation Instructions for Image Classification
|
2 |
+
|
3 |
+
We provide installation, training and evaluation instructions for image classification here.
|
4 |
+
|
5 |
+
## Installation Instructions
|
6 |
+
|
7 |
+
- Clone this repo:
|
8 |
+
|
9 |
+
```bash
|
10 |
+
git clone https://github.com/megvii-research/RevCol.git
|
11 |
+
cd RevCol
|
12 |
+
```
|
13 |
+
|
14 |
+
- Create a conda virtual environment and activate it:
|
15 |
+
|
16 |
+
```bash
|
17 |
+
conda create --name revcol python=3.7 -y
|
18 |
+
conda activate revcol
|
19 |
+
```
|
20 |
+
|
21 |
+
- Install `CUDA>=11.3` with `cudnn>=8` following
|
22 |
+
the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
|
23 |
+
- Install `PyTorch>=1.11.0` and `torchvision>=0.12.0` with `CUDA>=11.3`:
|
24 |
+
|
25 |
+
```bash
|
26 |
+
conda install pytorch=1.11.0 torchvision=0.12.0 torchaudio=0.11.0 cudatoolkit=11.3 -c pytorch
|
27 |
+
```
|
28 |
+
|
29 |
+
- Install `timm==0.5.4`:
|
30 |
+
|
31 |
+
```bash
|
32 |
+
pip install timm==0.5.4
|
33 |
+
```
|
34 |
+
|
35 |
+
- Install other requirements:
|
36 |
+
|
37 |
+
```bash
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
## Data preparation
|
42 |
+
|
43 |
+
We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
|
44 |
+
load data:
|
45 |
+
|
46 |
+
- For standard imagenet-1k dataset, the file structure should look like:
|
47 |
+
```bash
|
48 |
+
path-to-imagenet-1k
|
49 |
+
├── train
|
50 |
+
│ ├── class1
|
51 |
+
│ │ ├── img1.jpeg
|
52 |
+
│ │ ├── img2.jpeg
|
53 |
+
│ │ └── ...
|
54 |
+
│ ├── class2
|
55 |
+
│ │ ├── img3.jpeg
|
56 |
+
│ │ └── ...
|
57 |
+
│ └── ...
|
58 |
+
└── val
|
59 |
+
├── class1
|
60 |
+
│ ├── img4.jpeg
|
61 |
+
│ ├── img5.jpeg
|
62 |
+
│ └── ...
|
63 |
+
├── class2
|
64 |
+
│ ├── img6.jpeg
|
65 |
+
│ └── ...
|
66 |
+
└── ...
|
67 |
+
|
68 |
+
```
|
69 |
+
|
70 |
+
- For ImageNet-22K dataset, the file structure should look like:
|
71 |
+
|
72 |
+
```bash
|
73 |
+
path-to-imagenet-22k
|
74 |
+
├── class1
|
75 |
+
│ ├── img1.jpeg
|
76 |
+
│ ├── img2.jpeg
|
77 |
+
│ └── ...
|
78 |
+
├── class2
|
79 |
+
│ ├── img3.jpeg
|
80 |
+
│ └── ...
|
81 |
+
└── ...
|
82 |
+
```
|
83 |
+
|
84 |
+
- As imagenet-22k has no val set, one way is to use imagenet-1k val set as the evaluation for imagenet 22k dataset. Please remember to map the imagenet-1k label to imagenet-22k.
|
85 |
+
```bash
|
86 |
+
path-to-imagenet-22k-custom-eval-set
|
87 |
+
├── class1
|
88 |
+
│ ├── img1.jpeg
|
89 |
+
│ ├── img2.jpeg
|
90 |
+
│ └── ...
|
91 |
+
├── class2
|
92 |
+
│ ├── img3.jpeg
|
93 |
+
│ └── ...
|
94 |
+
└── ...
|
95 |
+
```
|
96 |
+
|
97 |
+
## Evaluation
|
98 |
+
|
99 |
+
To evaluate a pre-trained `RevCol` on ImageNet validation set, run:
|
100 |
+
|
101 |
+
```bash
|
102 |
+
torchrun --nproc_per_node=<num-of-gpus-to-use> --master_port=23456 main.py --cfg <config-file.yaml> --resume <checkpoint_path> --data-path <imagenet-path> --eval
|
103 |
+
```
|
104 |
+
|
105 |
+
For example, to evaluate the `RevCol-T` with a single GPU:
|
106 |
+
|
107 |
+
```bash
|
108 |
+
torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_tiny_1k.yaml --resume path_to_your_model.pth --eval
|
109 |
+
```
|
110 |
+
|
111 |
+
## Training from scratch on ImageNet-1K
|
112 |
+
|
113 |
+
To train a `RevCol` on ImageNet from scratch, run:
|
114 |
+
|
115 |
+
```bash
|
116 |
+
torchrun --nproc_per_node=<num-of-gpus-to-use> --master_port=23456 main.py \
|
117 |
+
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
|
118 |
+
```
|
119 |
+
|
120 |
+
**Notes**:
|
121 |
+
|
122 |
+
For example, to train `RevCol` with 8 GPU on a single node for 300 epochs, run:
|
123 |
+
|
124 |
+
`RevCol-T`:
|
125 |
+
|
126 |
+
```bash
|
127 |
+
torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_tiny_1k.yaml --batch-size 128 --data-path <imagenet-path>
|
128 |
+
```
|
129 |
+
|
130 |
+
`RevCol-S`:
|
131 |
+
|
132 |
+
```bash
|
133 |
+
torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_small_1k.yaml --batch-size 128 --data-path <imagenet-path>
|
134 |
+
```
|
135 |
+
|
136 |
+
`RevCol-B`:
|
137 |
+
|
138 |
+
```bash
|
139 |
+
torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_base_1k.yaml --batch-size 128 --data-path <imagenet-path>
|
140 |
+
```
|
141 |
+
|
142 |
+
## Pre-training on ImageNet-22K
|
143 |
+
|
144 |
+
For example, to pre-train a `RevCol-B` model on ImageNet-22K:
|
145 |
+
|
146 |
+
```bash
|
147 |
+
torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_large_22k_pretrain.yaml --batch-size 128 --data-path <imagenet-22k-path> --opt DATA.EVAL_DATA_PATH <imagenet-22k-custom-eval-path>
|
148 |
+
```
|
149 |
+
|
150 |
+
|
151 |
+
## Fine-tuning from a ImageNet-22K(21K) pre-trained model
|
152 |
+
|
153 |
+
For example, to fine-tune a `RevCol-B` model pre-trained on ImageNet-22K(21K):
|
154 |
+
|
155 |
+
```bashs
|
156 |
+
torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_base_1k_384_finetune.yaml --batch-size 64 --data-path <imagenet-22k-path> --finetune revcol_base_22k_pretrained.pth
|
157 |
+
```
|
158 |
+
|
training/LICENSE
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
Copyright (c) 2022 Megvii Inc. All rights reserved.
|
179 |
+
|
180 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
181 |
+
you may not use this file except in compliance with the License.
|
182 |
+
You may obtain a copy of the License at
|
183 |
+
|
184 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
185 |
+
|
186 |
+
Unless required by applicable law or agreed to in writing, software
|
187 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
188 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
189 |
+
See the License for the specific language governing permissions and
|
190 |
+
limitations under the License.
|
training/README.md
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reversible Column Networks
|
2 |
+
This repo is the official implementation of:
|
3 |
+
|
4 |
+
### [Reversible Column Networks](https://arxiv.org/abs/2212.11696)
|
5 |
+
[Yuxuan Cai](https://nightsnack.github.io), [Yizhuang Zhou](https://scholar.google.com/citations?user=VRSGDDEAAAAJ), [Qi Han](https://hanqer.github.io), Jianjian Sun, Xiangwen Kong, Jun Li, [Xiangyu Zhang](https://scholar.google.com/citations?user=yuB-cfoAAAAJ) \
|
6 |
+
[MEGVII Technology](https://en.megvii.com)\
|
7 |
+
International Conference on Learning Representations (ICLR) 2023\
|
8 |
+
[\[arxiv\]](https://arxiv.org/abs/2212.11696)
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Updates
|
13 |
+
***2/10/2023***\
|
14 |
+
RevCol model weights released.
|
15 |
+
|
16 |
+
***1/21/2023***\
|
17 |
+
RevCol was accepted by ICLR 2023!
|
18 |
+
|
19 |
+
***12/23/2022***\
|
20 |
+
Initial commits: codes for ImageNet-1k and ImageNet-22k classification are released.
|
21 |
+
|
22 |
+
|
23 |
+
## To Do List
|
24 |
+
|
25 |
+
|
26 |
+
- [x] ImageNet-1K and 22k Training Code
|
27 |
+
- [x] ImageNet-1K and 22k Model Weights
|
28 |
+
- [ ] Cascade Mask R-CNN COCO Object Detection Code & Model Weights
|
29 |
+
- [ ] ADE20k Semantic Segmentation Code & Model Weights
|
30 |
+
|
31 |
+
|
32 |
+
## Introduction
|
33 |
+
RevCol is composed of multiple copies of subnetworks, named columns respectively, between which multi-level reversible connections are employed. RevCol coud serves as a foundation model backbone for various tasks in computer vision including classification, detection and segmentation.
|
34 |
+
|
35 |
+
<p align="center">
|
36 |
+
<img src="figures/title.png" width=100% height=100%
|
37 |
+
class="center">
|
38 |
+
</p>
|
39 |
+
|
40 |
+
## Main Results on ImageNet with Pre-trained Models
|
41 |
+
|
42 |
+
| name | pretrain | resolution | #params |FLOPs | acc@1 | pretrained model | finetuned model |
|
43 |
+
|:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
44 |
+
| RevCol-T | ImageNet-1K | 224x224 | 30M | 4.5G | 82.2 | [baidu](https://pan.baidu.com/s/1iGsbdmFcDpwviCHaajeUnA?pwd=h4tj)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_tiny_1k.pth) | - |
|
45 |
+
| RevCol-S | ImageNet-1K | 224x224 | 60M | 9.0G | 83.5 | [baidu](https://pan.baidu.com/s/1hpHfdFrTZIPB5NTwqDMLag?pwd=mxuk)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_small_1k.pth) | - |
|
46 |
+
| RevCol-B | ImageNet-1K | 224x224 | 138M | 16.6G | 84.1 | [baidu](https://pan.baidu.com/s/16XIJ1n8pXPD2cXwnFX6b9w?pwd=j6x9)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_1k.pth) | - |
|
47 |
+
| RevCol-B<sup>\*</sup> | ImageNet-22K | 224x224 | 138M | 16.6G | 85.6 |[baidu](https://pan.baidu.com/s/1l8zOFifgC8fZtBpHK2ZQHg?pwd=rh58)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_22k.pth)| [baidu](https://pan.baidu.com/s/1HqhDXL6OIQdn1LeM2pewYQ?pwd=1bp3)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_22k_1kft_224.pth)|
|
48 |
+
| RevCol-B<sup>\*</sup> | ImageNet-22K | 384x384 | 138M | 48.9G | 86.7 |[baidu](https://pan.baidu.com/s/1l8zOFifgC8fZtBpHK2ZQHg?pwd=rh58)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_22k.pth)| [baidu](https://pan.baidu.com/s/18G0zAUygKgu58s2AjCBpsw?pwd=rv86)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_22k_1kft_384.pth)|
|
49 |
+
| RevCol-L<sup>\*</sup> | ImageNet-22K | 224x224 | 273M | 39G | 86.6 |[baidu](https://pan.baidu.com/s/1ueKqh3lFAAgC-vVU34ChYA?pwd=qv5m)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_large_22k.pth)| [baidu](https://pan.baidu.com/s/1CsWmcPcwieMzXE8pVmHh7w?pwd=qd9n)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_large_22k_1kft_224.pth)|
|
50 |
+
| RevCol-L<sup>\*</sup> | ImageNet-22K | 384x384 | 273M | 116G | 87.6 |[baidu](https://pan.baidu.com/s/1ueKqh3lFAAgC-vVU34ChYA?pwd=qv5m)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_large_22k.pth)| [baidu](https://pan.baidu.com/s/1VmCE3W3Xw6-Lo4rWrj9Xzg?pwd=x69r)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_large_22k_1kft_384.pth)|
|
51 |
+
|
52 |
+
## Getting Started
|
53 |
+
Please refer to [INSTRUCTIONS.md](INSTRUCTIONS.md) for setting up, training and evaluation details.
|
54 |
+
|
55 |
+
|
56 |
+
## Acknowledgement
|
57 |
+
This repo was inspired by several open source projects. We are grateful for these excellent projects and list them as follows:
|
58 |
+
- [timm](https://github.com/rwightman/pytorch-image-models)
|
59 |
+
- [Swin Transformer](https://github.com/microsoft/Swin-Transformer)
|
60 |
+
- [ConvNeXt](https://github.com/facebookresearch/ConvNeXt)
|
61 |
+
- [beit](https://github.com/microsoft/unilm/tree/master/beit)
|
62 |
+
|
63 |
+
## License
|
64 |
+
RevCol is released under the [Apache 2.0 license](LICENSE).
|
65 |
+
|
66 |
+
## Contact Us
|
67 |
+
If you have any questions about this repo or the original paper, please contact Yuxuan at [email protected].
|
68 |
+
|
69 |
+
|
70 |
+
## Citation
|
71 |
+
```
|
72 |
+
@inproceedings{cai2022reversible,
|
73 |
+
title={Reversible Column Networks},
|
74 |
+
author={Cai, Yuxuan and Zhou, Yizhuang and Han, Qi and Sun, Jianjian and Kong, Xiangwen and Li, Jun and Zhang, Xiangyu},
|
75 |
+
booktitle={International Conference on Learning Representations},
|
76 |
+
year={2023},
|
77 |
+
url={https://openreview.net/forum?id=Oc2vlWU0jFY}
|
78 |
+
}
|
79 |
+
```
|
training/config.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under TheApache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import os
|
8 |
+
import yaml
|
9 |
+
from yacs.config import CfgNode as CN
|
10 |
+
|
11 |
+
_C = CN()
|
12 |
+
|
13 |
+
# Base config files
|
14 |
+
_C.BASE = ['']
|
15 |
+
|
16 |
+
# -----------------------------------------------------------------------------
|
17 |
+
# Data settings
|
18 |
+
# -----------------------------------------------------------------------------
|
19 |
+
_C.DATA = CN()
|
20 |
+
# Batch size for a single GPU, could be overwritten by command line argument
|
21 |
+
_C.DATA.BATCH_SIZE = 128
|
22 |
+
# Path to dataset, could be overwritten by command line argument
|
23 |
+
_C.DATA.DATA_PATH = 'path/to/imagenet'
|
24 |
+
# Dataset name
|
25 |
+
_C.DATA.DATASET = 'imagenet'
|
26 |
+
# Input image size
|
27 |
+
_C.DATA.IMG_SIZE = 224
|
28 |
+
# Interpolation to resize image (random, bilinear, bicubic)
|
29 |
+
_C.DATA.INTERPOLATION = 'bicubic'
|
30 |
+
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
|
31 |
+
_C.DATA.PIN_MEMORY = True
|
32 |
+
# Number of data loading threads
|
33 |
+
_C.DATA.NUM_WORKERS = 8
|
34 |
+
# Path to evaluation dataset for ImageNet 22k
|
35 |
+
_C.DATA.EVAL_DATA_PATH = 'path/to/eval/data'
|
36 |
+
|
37 |
+
# -----------------------------------------------------------------------------
|
38 |
+
# Model settings
|
39 |
+
# -----------------------------------------------------------------------------
|
40 |
+
_C.MODEL = CN()
|
41 |
+
# Model type
|
42 |
+
_C.MODEL.TYPE = ''
|
43 |
+
# Model name
|
44 |
+
_C.MODEL.NAME = ''
|
45 |
+
# Checkpoint to resume, could be overwritten by command line argument
|
46 |
+
_C.MODEL.RESUME = ''
|
47 |
+
# Checkpoint to finetune, could be overwritten by command line argument
|
48 |
+
_C.MODEL.FINETUNE = ''
|
49 |
+
# Number of classes, overwritten in data preparation
|
50 |
+
_C.MODEL.NUM_CLASSES = 1000
|
51 |
+
# Label Smoothing
|
52 |
+
_C.MODEL.LABEL_SMOOTHING = 0.0
|
53 |
+
|
54 |
+
# -----------------------------------------------------------------------------
|
55 |
+
# Specific Model settings
|
56 |
+
# -----------------------------------------------------------------------------
|
57 |
+
|
58 |
+
_C.REVCOL = CN()
|
59 |
+
|
60 |
+
_C.REVCOL.INTER_SUPV = True
|
61 |
+
|
62 |
+
_C.REVCOL.SAVEMM = True
|
63 |
+
|
64 |
+
_C.REVCOL.FCOE = 4.0
|
65 |
+
|
66 |
+
_C.REVCOL.CCOE = 0.8
|
67 |
+
|
68 |
+
_C.REVCOL.KERNEL_SIZE = 3
|
69 |
+
|
70 |
+
_C.REVCOL.DROP_PATH = 0.1
|
71 |
+
|
72 |
+
_C.REVCOL.HEAD_INIT_SCALE = None
|
73 |
+
|
74 |
+
# -----------------------------------------------------------------------------
|
75 |
+
# Training settings
|
76 |
+
# -----------------------------------------------------------------------------
|
77 |
+
_C.TRAIN = CN()
|
78 |
+
_C.TRAIN.START_EPOCH = 0
|
79 |
+
_C.TRAIN.EPOCHS = 300
|
80 |
+
_C.TRAIN.WARMUP_EPOCHS = 5
|
81 |
+
_C.TRAIN.WEIGHT_DECAY = 4e-5
|
82 |
+
_C.TRAIN.BASE_LR = 0.4
|
83 |
+
|
84 |
+
_C.TRAIN.WARMUP_LR = 0.05
|
85 |
+
_C.TRAIN.MIN_LR = 1e-5
|
86 |
+
# Clip gradient norm
|
87 |
+
_C.TRAIN.CLIP_GRAD = 10.0
|
88 |
+
# Auto resume from latest checkpoint
|
89 |
+
_C.TRAIN.AUTO_RESUME = True
|
90 |
+
# Check point
|
91 |
+
_C.TRAIN.USE_CHECKPOINT = False
|
92 |
+
|
93 |
+
_C.TRAIN.AMP = True
|
94 |
+
|
95 |
+
# LR scheduler
|
96 |
+
_C.TRAIN.LR_SCHEDULER = CN()
|
97 |
+
# LR scheduler
|
98 |
+
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
|
99 |
+
# Epoch interval to decay LR, used in StepLRScheduler
|
100 |
+
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
|
101 |
+
# LR decay rate, used in StepLRScheduler
|
102 |
+
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
|
103 |
+
|
104 |
+
# Optimizer
|
105 |
+
_C.TRAIN.OPTIMIZER = CN()
|
106 |
+
_C.TRAIN.OPTIMIZER.NAME = 'sgd'
|
107 |
+
# Optimizer Epsilon fow adamw
|
108 |
+
_C.TRAIN.OPTIMIZER.EPS = 1e-8
|
109 |
+
# Optimizer Betas fow adamw
|
110 |
+
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
|
111 |
+
# SGD momentum
|
112 |
+
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
|
113 |
+
# Layer Decay
|
114 |
+
_C.TRAIN.OPTIMIZER.LAYER_DECAY = 1.0
|
115 |
+
|
116 |
+
# -----------------------------------------------------------------------------
|
117 |
+
# Augmentation settings
|
118 |
+
# -----------------------------------------------------------------------------
|
119 |
+
_C.AUG = CN()
|
120 |
+
# Color jitter factor
|
121 |
+
_C.AUG.COLOR_JITTER = 0.4
|
122 |
+
# Use AutoAugment policy. "v0" or "original"
|
123 |
+
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
|
124 |
+
# Random erase prob
|
125 |
+
_C.AUG.REPROB = 0.25
|
126 |
+
# Random erase mode
|
127 |
+
_C.AUG.REMODE = 'pixel'
|
128 |
+
# Random erase count
|
129 |
+
_C.AUG.RECOUNT = 1
|
130 |
+
# Mixup alpha, mixup enabled if > 0
|
131 |
+
_C.AUG.MIXUP = 0.8
|
132 |
+
# Cutmix alpha, cutmix enabled if > 0
|
133 |
+
_C.AUG.CUTMIX = 1.0
|
134 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
135 |
+
_C.AUG.CUTMIX_MINMAX = None
|
136 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
137 |
+
_C.AUG.MIXUP_PROB = 1.0
|
138 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
139 |
+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
140 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
141 |
+
_C.AUG.MIXUP_MODE = 'batch'
|
142 |
+
|
143 |
+
# -----------------------------------------------------------------------------
|
144 |
+
# Testing settings
|
145 |
+
# -----------------------------------------------------------------------------
|
146 |
+
_C.TEST = CN()
|
147 |
+
# Whether to use center crop when testing
|
148 |
+
_C.TEST.CROP = True
|
149 |
+
|
150 |
+
# -----------------------------------------------------------------------------
|
151 |
+
# Misc
|
152 |
+
# -----------------------------------------------------------------------------
|
153 |
+
# Path to output folder, overwritten by command line argument
|
154 |
+
_C.OUTPUT = 'outputs/'
|
155 |
+
# Tag of experiment, overwritten by command line argument
|
156 |
+
_C.TAG = 'default'
|
157 |
+
# Frequency to save checkpoint
|
158 |
+
_C.SAVE_FREQ = 1
|
159 |
+
# Frequency to logging info
|
160 |
+
_C.PRINT_FREQ = 100
|
161 |
+
# Fixed random seed
|
162 |
+
_C.SEED = 0
|
163 |
+
# Perform evaluation only, overwritten by command line argument
|
164 |
+
_C.EVAL_MODE = False
|
165 |
+
# Test throughput only, overwritten by command line argument
|
166 |
+
_C.THROUGHPUT_MODE = False
|
167 |
+
# local rank for DistributedDataParallel, given by command line argument
|
168 |
+
_C.LOCAL_RANK = 0
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
# EMA
|
173 |
+
_C.MODEL_EMA = False
|
174 |
+
_C.MODEL_EMA_DECAY = 0.9999
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
# Machine
|
180 |
+
_C.MACHINE = CN()
|
181 |
+
_C.MACHINE.MACHINE_WORLD_SIZE = None
|
182 |
+
_C.MACHINE.MACHINE_RANK = None
|
183 |
+
|
184 |
+
def _update_config_from_file(config, cfg_file):
|
185 |
+
config.defrost()
|
186 |
+
with open(cfg_file, 'r') as f:
|
187 |
+
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
188 |
+
|
189 |
+
for cfg in yaml_cfg.setdefault('BASE', ['']):
|
190 |
+
if cfg:
|
191 |
+
_update_config_from_file(
|
192 |
+
config, os.path.join(os.path.dirname(cfg_file), cfg)
|
193 |
+
)
|
194 |
+
print('=> merge config from {}'.format(cfg_file))
|
195 |
+
config.merge_from_file(cfg_file)
|
196 |
+
config.freeze()
|
197 |
+
|
198 |
+
|
199 |
+
def update_config(config, args):
|
200 |
+
_update_config_from_file(config, args.cfg)
|
201 |
+
|
202 |
+
config.defrost()
|
203 |
+
if args.opts:
|
204 |
+
config.merge_from_list(args.opts)
|
205 |
+
|
206 |
+
# merge from specific arguments
|
207 |
+
if args.batch_size:
|
208 |
+
config.DATA.BATCH_SIZE = args.batch_size
|
209 |
+
if args.data_path:
|
210 |
+
config.DATA.DATA_PATH = args.data_path
|
211 |
+
if args.resume:
|
212 |
+
config.MODEL.RESUME = args.resume
|
213 |
+
if args.finetune:
|
214 |
+
config.MODEL.FINETUNE = args.finetune
|
215 |
+
if args.use_checkpoint:
|
216 |
+
config.TRAIN.USE_CHECKPOINT = True
|
217 |
+
if args.output:
|
218 |
+
config.OUTPUT = args.output
|
219 |
+
if args.tag:
|
220 |
+
config.TAG = args.tag
|
221 |
+
if args.eval:
|
222 |
+
config.EVAL_MODE = True
|
223 |
+
if args.model_ema:
|
224 |
+
config.MODEL_EMA = True
|
225 |
+
|
226 |
+
config.dist_url = args.dist_url
|
227 |
+
# set local rank for distributed training
|
228 |
+
config.LOCAL_RANK = args.local_rank
|
229 |
+
|
230 |
+
# output folder
|
231 |
+
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
232 |
+
|
233 |
+
config.freeze()
|
234 |
+
|
235 |
+
|
236 |
+
def get_config(args):
|
237 |
+
"""Get a yacs CfgNode object with default values."""
|
238 |
+
# Return a clone so that the defaults will not be altered
|
239 |
+
# This is for the "local variable" use pattern
|
240 |
+
config = _C.clone()
|
241 |
+
update_config(config, args)
|
242 |
+
|
243 |
+
return config
|
training/configs/revcol_base_1k.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 100
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: True
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 224
|
6 |
+
NUM_WORKERS: 6
|
7 |
+
MODEL:
|
8 |
+
TYPE: revcol_base
|
9 |
+
NAME: revcol_base
|
10 |
+
LABEL_SMOOTHING: 0.1
|
11 |
+
REVCOL:
|
12 |
+
INTER_SUPV: True
|
13 |
+
SAVEMM: True
|
14 |
+
FCOE: 3.0
|
15 |
+
CCOE: 0.7
|
16 |
+
DROP_PATH: 0.4
|
17 |
+
TRAIN:
|
18 |
+
EPOCHS: 300
|
19 |
+
BASE_LR: 1e-3
|
20 |
+
WARMUP_EPOCHS: 20
|
21 |
+
WEIGHT_DECAY: 0.05
|
22 |
+
WARMUP_LR: 1e-5
|
23 |
+
MIN_LR: 1e-8
|
24 |
+
OPTIMIZER:
|
25 |
+
NAME: 'adamw'
|
26 |
+
CLIP_GRAD: 5.0
|
27 |
+
AUG:
|
28 |
+
COLOR_JITTER: 0.0
|
29 |
+
# Use AutoAugment policy. "v0" or "original"
|
30 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
31 |
+
# Random erase prob
|
32 |
+
REPROB: 0.25
|
33 |
+
# Random erase mode
|
34 |
+
REMODE: 'pixel'
|
35 |
+
# Random erase count
|
36 |
+
RECOUNT: 1
|
37 |
+
# Mixup alpha, mixup enabled if > 0
|
38 |
+
MIXUP: 0.8
|
39 |
+
# Cutmix alpha, cutmix enabled if > 0
|
40 |
+
CUTMIX: 1.0
|
41 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
42 |
+
CUTMIX_MINMAX: None
|
43 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
44 |
+
MIXUP_PROB: 1.0
|
45 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
46 |
+
MIXUP_SWITCH_PROB: 0.5
|
47 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
48 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_base_1k_224_finetune.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 100
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: False
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 224
|
6 |
+
DATASET: imagenet
|
7 |
+
MODEL:
|
8 |
+
TYPE: revcol_base
|
9 |
+
NAME: revcol_base_1k_Finetune_224
|
10 |
+
LABEL_SMOOTHING: 0.1
|
11 |
+
NUM_CLASSES: 1000
|
12 |
+
REVCOL:
|
13 |
+
INTER_SUPV: False
|
14 |
+
SAVEMM: True
|
15 |
+
FCOE: 3.0
|
16 |
+
CCOE: 0.7
|
17 |
+
DROP_PATH: 0.2
|
18 |
+
HEAD_INIT_SCALE: 0.001
|
19 |
+
TRAIN:
|
20 |
+
EPOCHS: 30
|
21 |
+
BASE_LR: 1e-4
|
22 |
+
WARMUP_EPOCHS: 0
|
23 |
+
WEIGHT_DECAY: 1e-8
|
24 |
+
WARMUP_LR: 4e-6
|
25 |
+
MIN_LR: 2e-7
|
26 |
+
OPTIMIZER:
|
27 |
+
NAME: 'adamw'
|
28 |
+
LAYER_DECAY: 0.9
|
29 |
+
AUG:
|
30 |
+
COLOR_JITTER: 0.0
|
31 |
+
# Use AutoAugment policy. "v0" or "original"
|
32 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
33 |
+
# Random erase prob
|
34 |
+
REPROB: 0.25
|
35 |
+
# Random erase mode
|
36 |
+
REMODE: 'pixel'
|
37 |
+
# Random erase count
|
38 |
+
RECOUNT: 1
|
39 |
+
# Mixup alpha, mixup enabled if > 0
|
40 |
+
MIXUP: 0.0
|
41 |
+
# Cutmix alpha, cutmix enabled if > 0
|
42 |
+
CUTMIX: 0.0
|
43 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
44 |
+
CUTMIX_MINMAX: None
|
45 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
46 |
+
MIXUP_PROB: 0.0
|
47 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
48 |
+
MIXUP_SWITCH_PROB: 0.0
|
49 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
50 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_base_1k_384_finetune.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 100
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: False
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 384
|
6 |
+
DATASET: imagenet
|
7 |
+
MODEL:
|
8 |
+
TYPE: revcol_base
|
9 |
+
NAME: revcol_base_1k_Finetune_384
|
10 |
+
LABEL_SMOOTHING: 0.1
|
11 |
+
NUM_CLASSES: 1000
|
12 |
+
REVCOL:
|
13 |
+
INTER_SUPV: False
|
14 |
+
SAVEMM: True
|
15 |
+
FCOE: 3.0
|
16 |
+
CCOE: 0.7
|
17 |
+
DROP_PATH: 0.2
|
18 |
+
HEAD_INIT_SCALE: 0.001
|
19 |
+
TRAIN:
|
20 |
+
EPOCHS: 30
|
21 |
+
BASE_LR: 1e-4
|
22 |
+
WARMUP_EPOCHS: 0
|
23 |
+
WEIGHT_DECAY: 1e-8
|
24 |
+
WARMUP_LR: 4e-6
|
25 |
+
MIN_LR: 2e-7
|
26 |
+
OPTIMIZER:
|
27 |
+
NAME: 'adamw'
|
28 |
+
LAYER_DECAY: 0.9
|
29 |
+
AUG:
|
30 |
+
COLOR_JITTER: 0.0
|
31 |
+
# Use AutoAugment policy. "v0" or "original"
|
32 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
33 |
+
# Random erase prob
|
34 |
+
REPROB: 0.25
|
35 |
+
# Random erase mode
|
36 |
+
REMODE: 'pixel'
|
37 |
+
# Random erase count
|
38 |
+
RECOUNT: 1
|
39 |
+
# Mixup alpha, mixup enabled if > 0
|
40 |
+
MIXUP: 0.0
|
41 |
+
# Cutmix alpha, cutmix enabled if > 0
|
42 |
+
CUTMIX: 0.0
|
43 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
44 |
+
CUTMIX_MINMAX: None
|
45 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
46 |
+
MIXUP_PROB: 0.0
|
47 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
48 |
+
MIXUP_SWITCH_PROB: 0.0
|
49 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
50 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_base_22k_pretrain.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 100
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: False
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 224
|
6 |
+
DATASET: imagenet22k
|
7 |
+
NUM_WORKERS: 6
|
8 |
+
MODEL:
|
9 |
+
TYPE: revcol_base
|
10 |
+
NAME: revcol_base_22k_Pretrain
|
11 |
+
LABEL_SMOOTHING: 0.1
|
12 |
+
NUM_CLASSES: 21841
|
13 |
+
REVCOL:
|
14 |
+
INTER_SUPV: True
|
15 |
+
SAVEMM: True
|
16 |
+
FCOE: 3.0
|
17 |
+
CCOE: 0.7
|
18 |
+
DROP_PATH: 0.3
|
19 |
+
TRAIN:
|
20 |
+
EPOCHS: 90
|
21 |
+
BASE_LR: 1.25e-4
|
22 |
+
WARMUP_EPOCHS: 5
|
23 |
+
WEIGHT_DECAY: 0.1
|
24 |
+
WARMUP_LR: 1e-5
|
25 |
+
MIN_LR: 1e-7
|
26 |
+
OPTIMIZER:
|
27 |
+
NAME: 'adamw'
|
28 |
+
# BN:
|
29 |
+
# USE_PRECISE_STATS: True
|
30 |
+
AUG:
|
31 |
+
COLOR_JITTER: 0.4
|
32 |
+
# Use AutoAugment policy. "v0" or "original"
|
33 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
34 |
+
# Random erase prob
|
35 |
+
REPROB: 0.25
|
36 |
+
# Random erase mode
|
37 |
+
REMODE: 'pixel'
|
38 |
+
# Random erase count
|
39 |
+
RECOUNT: 1
|
40 |
+
# Mixup alpha, mixup enabled if > 0
|
41 |
+
MIXUP: 0.8
|
42 |
+
# Cutmix alpha, cutmix enabled if > 0
|
43 |
+
CUTMIX: 1.0
|
44 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
45 |
+
CUTMIX_MINMAX: None
|
46 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
47 |
+
MIXUP_PROB: 1.0
|
48 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
49 |
+
MIXUP_SWITCH_PROB: 0.5
|
50 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
51 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_large_1k_224_finetune.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 20
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: False
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 224
|
6 |
+
DATASET: imagenet
|
7 |
+
MODEL:
|
8 |
+
TYPE: revcol_large
|
9 |
+
NAME: revcol_large_1k_Finetune_224
|
10 |
+
LABEL_SMOOTHING: 0.1
|
11 |
+
NUM_CLASSES: 1000
|
12 |
+
REVCOL:
|
13 |
+
INTER_SUPV: False
|
14 |
+
SAVEMM: True
|
15 |
+
FCOE: 3.0
|
16 |
+
CCOE: 0.7
|
17 |
+
DROP_PATH: 0.3
|
18 |
+
HEAD_INIT_SCALE: 0.001
|
19 |
+
TRAIN:
|
20 |
+
EPOCHS: 30
|
21 |
+
BASE_LR: 5e-5
|
22 |
+
WARMUP_EPOCHS: 0
|
23 |
+
WEIGHT_DECAY: 1e-8
|
24 |
+
WARMUP_LR: 4e-6
|
25 |
+
MIN_LR: 2e-7
|
26 |
+
OPTIMIZER:
|
27 |
+
NAME: 'adamw'
|
28 |
+
LAYER_DECAY: 0.8
|
29 |
+
|
30 |
+
AUG:
|
31 |
+
COLOR_JITTER: 0.0
|
32 |
+
# Use AutoAugment policy. "v0" or "original"
|
33 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
34 |
+
# Random erase prob
|
35 |
+
REPROB: 0.25
|
36 |
+
# Random erase mode
|
37 |
+
REMODE: 'pixel'
|
38 |
+
# Random erase count
|
39 |
+
RECOUNT: 1
|
40 |
+
# Mixup alpha, mixup enabled if > 0
|
41 |
+
MIXUP: 0.0
|
42 |
+
# Cutmix alpha, cutmix enabled if > 0
|
43 |
+
CUTMIX: 0.0
|
44 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
45 |
+
CUTMIX_MINMAX: None
|
46 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
47 |
+
MIXUP_PROB: 0.0
|
48 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
49 |
+
MIXUP_SWITCH_PROB: 0.0
|
50 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
51 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_large_1k_384_finetune.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 20
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: False
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 384
|
6 |
+
DATASET: imagenet
|
7 |
+
MODEL:
|
8 |
+
TYPE: revcol_large
|
9 |
+
NAME: revcol_large_1k_Finetune_384
|
10 |
+
LABEL_SMOOTHING: 0.1
|
11 |
+
NUM_CLASSES: 1000
|
12 |
+
REVCOL:
|
13 |
+
INTER_SUPV: False
|
14 |
+
SAVEMM: True
|
15 |
+
FCOE: 3.0
|
16 |
+
CCOE: 0.7
|
17 |
+
DROP_PATH: 0.3
|
18 |
+
HEAD_INIT_SCALE: 0.001
|
19 |
+
TRAIN:
|
20 |
+
EPOCHS: 30
|
21 |
+
BASE_LR: 5e-5
|
22 |
+
WARMUP_EPOCHS: 0
|
23 |
+
WEIGHT_DECAY: 1e-8
|
24 |
+
WARMUP_LR: 4e-6
|
25 |
+
MIN_LR: 2e-7
|
26 |
+
OPTIMIZER:
|
27 |
+
NAME: 'adamw'
|
28 |
+
LAYER_DECAY: 0.8
|
29 |
+
|
30 |
+
AUG:
|
31 |
+
COLOR_JITTER: 0.0
|
32 |
+
# Use AutoAugment policy. "v0" or "original"
|
33 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
34 |
+
# Random erase prob
|
35 |
+
REPROB: 0.25
|
36 |
+
# Random erase mode
|
37 |
+
REMODE: 'pixel'
|
38 |
+
# Random erase count
|
39 |
+
RECOUNT: 1
|
40 |
+
# Mixup alpha, mixup enabled if > 0
|
41 |
+
MIXUP: 0.0
|
42 |
+
# Cutmix alpha, cutmix enabled if > 0
|
43 |
+
CUTMIX: 0.0
|
44 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
45 |
+
CUTMIX_MINMAX: None
|
46 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
47 |
+
MIXUP_PROB: 0.0
|
48 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
49 |
+
MIXUP_SWITCH_PROB: 0.0
|
50 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
51 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_large_22k_pretrain.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 100
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: False
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 224
|
6 |
+
DATASET: imagenet22k
|
7 |
+
NUM_WORKERS: 6
|
8 |
+
MODEL:
|
9 |
+
TYPE: revcol_large
|
10 |
+
NAME: revcol_large_22k_Pretrain
|
11 |
+
LABEL_SMOOTHING: 0.1
|
12 |
+
NUM_CLASSES: 21841
|
13 |
+
REVCOL:
|
14 |
+
INTER_SUPV: True
|
15 |
+
SAVEMM: True
|
16 |
+
FCOE: 3.0
|
17 |
+
CCOE: 0.7
|
18 |
+
DROP_PATH: 0.3
|
19 |
+
TRAIN:
|
20 |
+
EPOCHS: 90
|
21 |
+
BASE_LR: 1.25e-4
|
22 |
+
WARMUP_EPOCHS: 5
|
23 |
+
WEIGHT_DECAY: 0.1
|
24 |
+
WARMUP_LR: 1e-5
|
25 |
+
MIN_LR: 1e-7
|
26 |
+
OPTIMIZER:
|
27 |
+
NAME: 'adamw'
|
28 |
+
|
29 |
+
AUG:
|
30 |
+
COLOR_JITTER: 0.4
|
31 |
+
# Use AutoAugment policy. "v0" or "original"
|
32 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
33 |
+
# Random erase prob
|
34 |
+
REPROB: 0.25
|
35 |
+
# Random erase mode
|
36 |
+
REMODE: 'pixel'
|
37 |
+
# Random erase count
|
38 |
+
RECOUNT: 1
|
39 |
+
# Mixup alpha, mixup enabled if > 0
|
40 |
+
MIXUP: 0.8
|
41 |
+
# Cutmix alpha, cutmix enabled if > 0
|
42 |
+
CUTMIX: 1.0
|
43 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
44 |
+
CUTMIX_MINMAX: None
|
45 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
46 |
+
MIXUP_PROB: 1.0
|
47 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
48 |
+
MIXUP_SWITCH_PROB: 0.5
|
49 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
50 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_small_1k.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 100
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: True
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 224
|
6 |
+
NUM_WORKERS: 6
|
7 |
+
MODEL:
|
8 |
+
TYPE: revcol_small
|
9 |
+
NAME: revcol_small
|
10 |
+
LABEL_SMOOTHING: 0.1
|
11 |
+
REVCOL:
|
12 |
+
INTER_SUPV: True
|
13 |
+
SAVEMM: True
|
14 |
+
FCOE: 3.0
|
15 |
+
CCOE: 0.7
|
16 |
+
DROP_PATH: 0.3
|
17 |
+
TRAIN:
|
18 |
+
EPOCHS: 300
|
19 |
+
BASE_LR: 1e-3
|
20 |
+
WARMUP_EPOCHS: 20
|
21 |
+
WEIGHT_DECAY: 0.05
|
22 |
+
WARMUP_LR: 1e-5
|
23 |
+
MIN_LR: 1e-6
|
24 |
+
OPTIMIZER:
|
25 |
+
NAME: 'adamw'
|
26 |
+
CLIP_GRAD: 0.0
|
27 |
+
AUG:
|
28 |
+
COLOR_JITTER: 0.4
|
29 |
+
# Use AutoAugment policy. "v0" or "original"
|
30 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
31 |
+
# Random erase prob
|
32 |
+
REPROB: 0.25
|
33 |
+
# Random erase mode
|
34 |
+
REMODE: 'pixel'
|
35 |
+
# Random erase count
|
36 |
+
RECOUNT: 1
|
37 |
+
# Mixup alpha, mixup enabled if > 0
|
38 |
+
MIXUP: 0.8
|
39 |
+
# Cutmix alpha, cutmix enabled if > 0
|
40 |
+
CUTMIX: 1.0
|
41 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
42 |
+
CUTMIX_MINMAX: None
|
43 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
44 |
+
MIXUP_PROB: 1.0
|
45 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
46 |
+
MIXUP_SWITCH_PROB: 0.5
|
47 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
48 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_tiny_1k.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 100
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: True
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 224
|
6 |
+
NUM_WORKERS: 6
|
7 |
+
MODEL:
|
8 |
+
TYPE: revcol_tiny
|
9 |
+
NAME: revcol_tiny
|
10 |
+
LABEL_SMOOTHING: 0.1
|
11 |
+
REVCOL:
|
12 |
+
INTER_SUPV: True
|
13 |
+
SAVEMM: True
|
14 |
+
FCOE: 3.0
|
15 |
+
CCOE: 0.7
|
16 |
+
DROP_PATH: 0.1
|
17 |
+
TRAIN:
|
18 |
+
EPOCHS: 300
|
19 |
+
BASE_LR: 1e-3
|
20 |
+
WARMUP_EPOCHS: 20
|
21 |
+
WEIGHT_DECAY: 0.05
|
22 |
+
WARMUP_LR: 1e-5
|
23 |
+
MIN_LR: 1e-6
|
24 |
+
OPTIMIZER:
|
25 |
+
NAME: 'adamw'
|
26 |
+
CLIP_GRAD: 0.0
|
27 |
+
AUG:
|
28 |
+
COLOR_JITTER: 0.4
|
29 |
+
# Use AutoAugment policy. "v0" or "original"
|
30 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
31 |
+
# Random erase prob
|
32 |
+
REPROB: 0.25
|
33 |
+
# Random erase mode
|
34 |
+
REMODE: 'pixel'
|
35 |
+
# Random erase count
|
36 |
+
RECOUNT: 1
|
37 |
+
# Mixup alpha, mixup enabled if > 0
|
38 |
+
MIXUP: 0.8
|
39 |
+
# Cutmix alpha, cutmix enabled if > 0
|
40 |
+
CUTMIX: 1.0
|
41 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
42 |
+
CUTMIX_MINMAX: None
|
43 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
44 |
+
MIXUP_PROB: 1.0
|
45 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
46 |
+
MIXUP_SWITCH_PROB: 0.5
|
47 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
48 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_xlarge_1k_384_finetune.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 30
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: True
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 384
|
6 |
+
DATASET: imagenet
|
7 |
+
PIPE_NAME: 'dpflow.silvia.imagenet.train.rand-re-jitt.384'
|
8 |
+
NUM_WORKERS: 8
|
9 |
+
MODEL:
|
10 |
+
TYPE: revcol_xlarge
|
11 |
+
NAME: revcol_xlarge_1k_Finetune
|
12 |
+
LABEL_SMOOTHING: 0.1
|
13 |
+
NUM_CLASSES: 1000
|
14 |
+
REVCOL:
|
15 |
+
INTER_SUPV: False
|
16 |
+
SAVEMM: True
|
17 |
+
FCOE: 3.0
|
18 |
+
CCOE: 0.7
|
19 |
+
DROP_PATH: 0.4
|
20 |
+
HEAD_INIT_SCALE: 0.001
|
21 |
+
TRAIN:
|
22 |
+
EPOCHS: 30
|
23 |
+
BASE_LR: 2e-5
|
24 |
+
WARMUP_EPOCHS: 0
|
25 |
+
WEIGHT_DECAY: 1e-8
|
26 |
+
WARMUP_LR: 4e-6
|
27 |
+
MIN_LR: 2e-7
|
28 |
+
OPTIMIZER:
|
29 |
+
NAME: 'adamw'
|
30 |
+
LAYER_DECAY: 0.8
|
31 |
+
|
32 |
+
AUG:
|
33 |
+
COLOR_JITTER: 0.0
|
34 |
+
# Use AutoAugment policy. "v0" or "original"
|
35 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
36 |
+
# Random erase prob
|
37 |
+
REPROB: 0.25
|
38 |
+
# Random erase mode
|
39 |
+
REMODE: 'pixel'
|
40 |
+
# Random erase count
|
41 |
+
RECOUNT: 1
|
42 |
+
# Mixup alpha, mixup enabled if > 0
|
43 |
+
MIXUP: 0.0
|
44 |
+
# Cutmix alpha, cutmix enabled if > 0
|
45 |
+
CUTMIX: 0.0
|
46 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
47 |
+
CUTMIX_MINMAX: None
|
48 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
49 |
+
MIXUP_PROB: 0.0
|
50 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
51 |
+
MIXUP_SWITCH_PROB: 0.0
|
52 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
53 |
+
MIXUP_MODE: 'batch'
|
training/configs/revcol_xlarge_22k_pretrain.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PRINT_FREQ: 100
|
2 |
+
SAVE_FREQ: 1
|
3 |
+
MODEL_EMA: False
|
4 |
+
DATA:
|
5 |
+
IMG_SIZE: 224
|
6 |
+
DATASET: imagenet22k
|
7 |
+
NUM_WORKERS: 8
|
8 |
+
MODEL:
|
9 |
+
TYPE: revcol_xlarge
|
10 |
+
NAME: revcol_xlarge_22k_Pretrain
|
11 |
+
LABEL_SMOOTHING: 0.1
|
12 |
+
NUM_CLASSES: 21841
|
13 |
+
REVCOL:
|
14 |
+
INTER_SUPV: True
|
15 |
+
SAVEMM: True
|
16 |
+
FCOE: 3.0
|
17 |
+
CCOE: 0.7
|
18 |
+
DROP_PATH: 0.4
|
19 |
+
TRAIN:
|
20 |
+
EPOCHS: 90
|
21 |
+
BASE_LR: 1.25e-4
|
22 |
+
WARMUP_EPOCHS: 5
|
23 |
+
WEIGHT_DECAY: 0.1
|
24 |
+
WARMUP_LR: 1e-5
|
25 |
+
MIN_LR: 1e-7
|
26 |
+
OPTIMIZER:
|
27 |
+
NAME: 'adamw'
|
28 |
+
|
29 |
+
AUG:
|
30 |
+
COLOR_JITTER: 0.0
|
31 |
+
# Use AutoAugment policy. "v0" or "original"
|
32 |
+
AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
|
33 |
+
# Random erase prob
|
34 |
+
REPROB: 0.25
|
35 |
+
# Random erase mode
|
36 |
+
REMODE: 'pixel'
|
37 |
+
# Random erase count
|
38 |
+
RECOUNT: 1
|
39 |
+
# Mixup alpha, mixup enabled if > 0
|
40 |
+
MIXUP: 0.8
|
41 |
+
# Cutmix alpha, cutmix enabled if > 0
|
42 |
+
CUTMIX: 1.0
|
43 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
44 |
+
CUTMIX_MINMAX: None
|
45 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
46 |
+
MIXUP_PROB: 1.0
|
47 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
48 |
+
MIXUP_SWITCH_PROB: 0.5
|
49 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
50 |
+
MIXUP_MODE: 'batch'
|
training/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .build_data import build_loader
|
training/data/build_data.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import queue
|
9 |
+
from typing import Dict, Sequence
|
10 |
+
import warnings
|
11 |
+
import os
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import torch.distributed as dist
|
15 |
+
from torchvision import datasets, transforms
|
16 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
17 |
+
from timm.data import Mixup
|
18 |
+
from timm.data import create_transform
|
19 |
+
|
20 |
+
|
21 |
+
from .samplers import SubsetRandomSampler
|
22 |
+
|
23 |
+
def build_loader(config):
|
24 |
+
|
25 |
+
config.defrost()
|
26 |
+
dataset_train, _ = build_dataset(is_train=True, config=config)
|
27 |
+
config.freeze()
|
28 |
+
print(f"global rank {dist.get_rank()} successfully build train dataset")
|
29 |
+
|
30 |
+
|
31 |
+
sampler_train = torch.utils.data.DistributedSampler(
|
32 |
+
dataset_train, shuffle=True
|
33 |
+
)
|
34 |
+
|
35 |
+
data_loader_train = torch.utils.data.DataLoader(
|
36 |
+
dataset_train, sampler=sampler_train,
|
37 |
+
batch_size=config.DATA.BATCH_SIZE,
|
38 |
+
num_workers=config.DATA.NUM_WORKERS,
|
39 |
+
pin_memory=config.DATA.PIN_MEMORY,
|
40 |
+
drop_last=True,
|
41 |
+
persistent_workers=True
|
42 |
+
)
|
43 |
+
|
44 |
+
#-----------------------------------Val Dataset-----------------------------------
|
45 |
+
|
46 |
+
dataset_val, _ = build_dataset(is_train=False, config=config)
|
47 |
+
print(f"global rank {dist.get_rank()} successfully build val dataset")
|
48 |
+
|
49 |
+
indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
|
50 |
+
sampler_val = SubsetRandomSampler(indices)
|
51 |
+
|
52 |
+
data_loader_val = torch.utils.data.DataLoader(
|
53 |
+
dataset_val, sampler=sampler_val,
|
54 |
+
batch_size=config.DATA.BATCH_SIZE,
|
55 |
+
shuffle=False,
|
56 |
+
num_workers=config.DATA.NUM_WORKERS,
|
57 |
+
pin_memory=config.DATA.PIN_MEMORY,
|
58 |
+
drop_last=False,
|
59 |
+
persistent_workers=True
|
60 |
+
)
|
61 |
+
|
62 |
+
# setup mixup / cutmix
|
63 |
+
mixup_fn = None
|
64 |
+
mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
|
65 |
+
if mixup_active:
|
66 |
+
mixup_fn = Mixup(
|
67 |
+
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
|
68 |
+
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
|
69 |
+
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
|
70 |
+
|
71 |
+
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
|
72 |
+
|
73 |
+
|
74 |
+
def build_dataset(is_train, config):
|
75 |
+
transform = build_transform(is_train, config)
|
76 |
+
if config.DATA.DATASET == 'imagenet':
|
77 |
+
prefix = 'train' if is_train else 'val'
|
78 |
+
root = os.path.join(config.DATA.DATA_PATH, prefix)
|
79 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
80 |
+
nb_classes = 1000
|
81 |
+
elif config.DATA.DATASET == 'imagenet22K':
|
82 |
+
if is_train:
|
83 |
+
root = config.DATA.DATA_PATH
|
84 |
+
else:
|
85 |
+
root = config.DATA.EVAL_DATA_PATH
|
86 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
87 |
+
nb_classes = 21841
|
88 |
+
else:
|
89 |
+
raise NotImplementedError("We only support ImageNet Now.")
|
90 |
+
|
91 |
+
return dataset, nb_classes
|
92 |
+
|
93 |
+
|
94 |
+
def build_transform(is_train, config):
|
95 |
+
resize_im = config.DATA.IMG_SIZE > 32
|
96 |
+
if is_train:
|
97 |
+
# this should always dispatch to transforms_imagenet_train
|
98 |
+
transform = create_transform(
|
99 |
+
input_size=config.DATA.IMG_SIZE,
|
100 |
+
is_training=True,
|
101 |
+
color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
|
102 |
+
auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
|
103 |
+
re_prob=config.AUG.REPROB,
|
104 |
+
re_mode=config.AUG.REMODE,
|
105 |
+
re_count=config.AUG.RECOUNT,
|
106 |
+
interpolation=config.DATA.INTERPOLATION,
|
107 |
+
)
|
108 |
+
if not resize_im:
|
109 |
+
# replace RandomResizedCropAndInterpolation with
|
110 |
+
# RandomCrop
|
111 |
+
transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
|
112 |
+
return transform
|
113 |
+
|
114 |
+
t = []
|
115 |
+
if resize_im:
|
116 |
+
if config.DATA.IMG_SIZE > 224:
|
117 |
+
t.append(
|
118 |
+
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
119 |
+
interpolation=transforms.InterpolationMode.BICUBIC),
|
120 |
+
)
|
121 |
+
print(f"Warping {config.DATA.IMG_SIZE} size input images...")
|
122 |
+
elif config.TEST.CROP:
|
123 |
+
size = int((256 / 224) * config.DATA.IMG_SIZE)
|
124 |
+
t.append(
|
125 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
126 |
+
# to maintain same ratio w.r.t. 224 images
|
127 |
+
)
|
128 |
+
t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
|
129 |
+
else:
|
130 |
+
t.append(
|
131 |
+
transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
132 |
+
interpolation=transforms.InterpolationMode.BICUBIC)
|
133 |
+
)
|
134 |
+
|
135 |
+
t.append(transforms.ToTensor())
|
136 |
+
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
137 |
+
return transforms.Compose(t)
|
training/data/samplers.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class SubsetRandomSampler(torch.utils.data.Sampler):
|
12 |
+
r"""Samples elements randomly from a given list of indices, without replacement.
|
13 |
+
|
14 |
+
Arguments:
|
15 |
+
indices (sequence): a sequence of indices
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, indices):
|
19 |
+
self.epoch = 0
|
20 |
+
self.indices = indices
|
21 |
+
|
22 |
+
def __iter__(self):
|
23 |
+
return (self.indices[i] for i in torch.randperm(len(self.indices)))
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.indices)
|
27 |
+
|
28 |
+
def set_epoch(self, epoch):
|
29 |
+
self.epoch = epoch
|
training/figures/title.png
ADDED
training/logger.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import logging
|
11 |
+
import functools
|
12 |
+
from termcolor import colored
|
13 |
+
|
14 |
+
|
15 |
+
@functools.lru_cache()
|
16 |
+
def create_logger(output_dir, dist_rank=0, name=''):
|
17 |
+
# create logger
|
18 |
+
logger = logging.getLogger(name)
|
19 |
+
logger.setLevel(logging.DEBUG)
|
20 |
+
logger.propagate = False
|
21 |
+
|
22 |
+
# create formatter
|
23 |
+
fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
|
24 |
+
color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
|
25 |
+
colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
|
26 |
+
|
27 |
+
# create console handlers for master process
|
28 |
+
if dist_rank == 0:
|
29 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
30 |
+
console_handler.setLevel(logging.DEBUG)
|
31 |
+
console_handler.setFormatter(
|
32 |
+
logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
|
33 |
+
logger.addHandler(console_handler)
|
34 |
+
|
35 |
+
# create file handlers
|
36 |
+
file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
|
37 |
+
file_handler.setLevel(logging.DEBUG)
|
38 |
+
file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
|
39 |
+
logger.addHandler(file_handler)
|
40 |
+
|
41 |
+
return logger
|
training/loss.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from dis import dis
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torch.functional import Tensor
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def compound_loss(coe, output_feature, image:Tensor, output_label, targets, criterion_bce, criterion_ce, epoch):
|
18 |
+
f_coe, c_coe = coe
|
19 |
+
image.clamp_(0.01, 0.99)
|
20 |
+
multi_loss = []
|
21 |
+
for i, feature in enumerate(output_feature):
|
22 |
+
ratio_f = 1 - i / len(output_feature)
|
23 |
+
ratio_c = (i+1) / (len(output_label))
|
24 |
+
|
25 |
+
ihx = criterion_bce(feature, image) * ratio_f * f_coe
|
26 |
+
ihy = criterion_ce(output_label[i], targets) * ratio_c * c_coe
|
27 |
+
# if dist.get_rank() == 0:
|
28 |
+
# print(f'ihx: {ihx}, ihy: {ihy}')
|
29 |
+
multi_loss.append(ihx + ihy)
|
30 |
+
# feature_loss.append(torch.dist(output_feature[i], teacher_feature) * feature_coe)
|
31 |
+
multi_loss.append(criterion_ce(output_label[-1], targets))
|
32 |
+
# print(feature_loss)
|
33 |
+
loss = torch.sum(torch.stack(multi_loss), dim=0)
|
34 |
+
# +torch.mean(torch.stack(feature_loss), dim=0)
|
35 |
+
return loss, multi_loss
|
training/lr_scheduler.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from timm.scheduler.cosine_lr import CosineLRScheduler
|
10 |
+
from timm.scheduler.step_lr import StepLRScheduler
|
11 |
+
from timm.scheduler.scheduler import Scheduler
|
12 |
+
|
13 |
+
|
14 |
+
def build_scheduler(config, optimizer=None):
|
15 |
+
|
16 |
+
lr_scheduler = None
|
17 |
+
if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
|
18 |
+
lr_scheduler = CosLRScheduler()
|
19 |
+
elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
|
20 |
+
lr_scheduler = StepLRScheduler()
|
21 |
+
else:
|
22 |
+
raise NotImplementedError(f"Unkown lr scheduler: {config.TRAIN.LR_SCHEDULER.NAME}")
|
23 |
+
|
24 |
+
return lr_scheduler
|
25 |
+
|
26 |
+
import math
|
27 |
+
|
28 |
+
|
29 |
+
class CosLRScheduler():
|
30 |
+
def __init__(self) -> None:
|
31 |
+
pass
|
32 |
+
|
33 |
+
def step_update(self, optimizer, epoch, config):
|
34 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
35 |
+
if epoch < config.TRAIN.WARMUP_EPOCHS:
|
36 |
+
lr = (config.TRAIN.BASE_LR-config.TRAIN.WARMUP_LR) * epoch / config.TRAIN.WARMUP_EPOCHS + config.TRAIN.WARMUP_LR
|
37 |
+
else:
|
38 |
+
lr = config.TRAIN.MIN_LR + (config.TRAIN.BASE_LR - config.TRAIN.MIN_LR) * 0.5 * \
|
39 |
+
(1. + math.cos(math.pi * (epoch - config.TRAIN.WARMUP_EPOCHS ) / (config.TRAIN.EPOCHS - config.TRAIN.WARMUP_EPOCHS )))
|
40 |
+
for param_group in optimizer.param_groups:
|
41 |
+
if "lr_scale" in param_group:
|
42 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
43 |
+
else:
|
44 |
+
param_group["lr"] = lr
|
45 |
+
return lr
|
46 |
+
|
47 |
+
class LinearLRScheduler(Scheduler):
|
48 |
+
def __init__(self,
|
49 |
+
optimizer: torch.optim.Optimizer,
|
50 |
+
t_initial: int,
|
51 |
+
lr_min_rate: float,
|
52 |
+
warmup_t=0,
|
53 |
+
warmup_lr_init=0.,
|
54 |
+
t_in_epochs=True,
|
55 |
+
noise_range_t=None,
|
56 |
+
noise_pct=0.67,
|
57 |
+
noise_std=1.0,
|
58 |
+
noise_seed=42,
|
59 |
+
initialize=True,
|
60 |
+
) -> None:
|
61 |
+
super().__init__(
|
62 |
+
optimizer, param_group_field="lr",
|
63 |
+
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
64 |
+
initialize=initialize)
|
65 |
+
|
66 |
+
self.t_initial = t_initial
|
67 |
+
self.lr_min_rate = lr_min_rate
|
68 |
+
self.warmup_t = warmup_t
|
69 |
+
self.warmup_lr_init = warmup_lr_init
|
70 |
+
self.t_in_epochs = t_in_epochs
|
71 |
+
if self.warmup_t:
|
72 |
+
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
73 |
+
super().update_groups(self.warmup_lr_init)
|
74 |
+
else:
|
75 |
+
self.warmup_steps = [1 for _ in self.base_values]
|
76 |
+
|
77 |
+
def _get_lr(self, t):
|
78 |
+
if t < self.warmup_t:
|
79 |
+
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
80 |
+
else:
|
81 |
+
t = t - self.warmup_t
|
82 |
+
total_t = self.t_initial - self.warmup_t
|
83 |
+
lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
|
84 |
+
return lrs
|
85 |
+
|
86 |
+
def get_epoch_values(self, epoch: int):
|
87 |
+
if self.t_in_epochs:
|
88 |
+
return self._get_lr(epoch)
|
89 |
+
else:
|
90 |
+
return None
|
91 |
+
|
92 |
+
def get_update_values(self, num_updates: int):
|
93 |
+
if not self.t_in_epochs:
|
94 |
+
return self._get_lr(num_updates)
|
95 |
+
else:
|
96 |
+
return None
|
training/main.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import sys
|
12 |
+
import time
|
13 |
+
import argparse
|
14 |
+
import datetime
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.backends.cudnn as cudnn
|
19 |
+
import torch.distributed as dist
|
20 |
+
import torch.multiprocessing as mp
|
21 |
+
import torchvision.transforms.functional as visionF
|
22 |
+
import torch.cuda.amp as amp
|
23 |
+
from typing import Optional
|
24 |
+
|
25 |
+
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
26 |
+
from timm.utils import accuracy, AverageMeter
|
27 |
+
from timm.utils import ModelEma as ModelEma
|
28 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
29 |
+
|
30 |
+
from config import get_config
|
31 |
+
from models import *
|
32 |
+
from loss import *
|
33 |
+
from data import build_loader
|
34 |
+
from lr_scheduler import build_scheduler
|
35 |
+
from optimizer import build_optimizer
|
36 |
+
from logger import create_logger
|
37 |
+
from utils import denormalize, load_checkpoint, load_checkpoint_finetune, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
|
38 |
+
from torch.utils.tensorboard import SummaryWriter
|
39 |
+
|
40 |
+
scaler = amp.GradScaler()
|
41 |
+
logger = None
|
42 |
+
|
43 |
+
def parse_option():
|
44 |
+
parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
|
45 |
+
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
|
46 |
+
parser.add_argument(
|
47 |
+
"--opts",
|
48 |
+
help="Modify config options by adding 'KEY VALUE' pairs. ",
|
49 |
+
default=None,
|
50 |
+
nargs='+',
|
51 |
+
)
|
52 |
+
|
53 |
+
# easy config modification
|
54 |
+
parser.add_argument('--batch-size', type=int, default=128, help="batch size for single GPU")
|
55 |
+
parser.add_argument('--data-path', type=str, default='data', help='path to dataset')
|
56 |
+
parser.add_argument('--resume', help='resume from checkpoint')
|
57 |
+
parser.add_argument('--finetune', help='finetune from checkpoint')
|
58 |
+
|
59 |
+
parser.add_argument('--use-checkpoint', action='store_true',
|
60 |
+
help="whether to use gradient checkpointing to save memory")
|
61 |
+
|
62 |
+
parser.add_argument('--output', default='outputs/', type=str, metavar='PATH',
|
63 |
+
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
|
64 |
+
parser.add_argument('--tag', help='tag of experiment')
|
65 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
66 |
+
|
67 |
+
# ema
|
68 |
+
parser.add_argument('--model-ema', action='store_true')
|
69 |
+
|
70 |
+
# distributed training
|
71 |
+
parser.add_argument("--local_rank", type=int, required=False, help='local rank for DistributedDataParallel')
|
72 |
+
|
73 |
+
parser.add_argument('--dist-url', default='env://', type=str,
|
74 |
+
help='url used to set up distributed training')
|
75 |
+
|
76 |
+
args, unparsed = parser.parse_known_args()
|
77 |
+
# print(args)
|
78 |
+
config = get_config(args)
|
79 |
+
|
80 |
+
return args, config
|
81 |
+
|
82 |
+
|
83 |
+
def main(config):
|
84 |
+
|
85 |
+
config.defrost()
|
86 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
87 |
+
rank = int(os.environ["RANK"])
|
88 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
89 |
+
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
|
90 |
+
|
91 |
+
else:
|
92 |
+
rank = -1
|
93 |
+
world_size = -1
|
94 |
+
return
|
95 |
+
|
96 |
+
# linear scale the learning rate according to total batch size, base bs 1024
|
97 |
+
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * world_size / 1024.0
|
98 |
+
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * world_size / 1024.0
|
99 |
+
linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * world_size / 1024.0
|
100 |
+
|
101 |
+
config.TRAIN.BASE_LR = linear_scaled_lr
|
102 |
+
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
|
103 |
+
config.TRAIN.MIN_LR = linear_scaled_min_lr
|
104 |
+
config.freeze()
|
105 |
+
|
106 |
+
dist.init_process_group(
|
107 |
+
backend='nccl', init_method=config.dist_url,
|
108 |
+
world_size=world_size, rank=rank,
|
109 |
+
)
|
110 |
+
seed = config.SEED + dist.get_rank()
|
111 |
+
torch.manual_seed(seed)
|
112 |
+
np.random.seed(seed)
|
113 |
+
torch.cuda.set_device(rank)
|
114 |
+
global logger
|
115 |
+
|
116 |
+
logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
|
117 |
+
logger.info(config.dump())
|
118 |
+
writer = None
|
119 |
+
if dist.get_rank() == 0:
|
120 |
+
writer = SummaryWriter(config.OUTPUT)
|
121 |
+
|
122 |
+
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
|
123 |
+
|
124 |
+
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
|
125 |
+
model = build_model(config)
|
126 |
+
|
127 |
+
model.cuda()
|
128 |
+
logger.info(str(model)[:10000])
|
129 |
+
|
130 |
+
model_ema = None
|
131 |
+
if config.MODEL_EMA:
|
132 |
+
# Important to create EMA model after cuda(), DP wrapper, and AMP but
|
133 |
+
# before SyncBN and DDP wrapper
|
134 |
+
logger.info(f"Using EMA...")
|
135 |
+
model_ema = ModelEma(
|
136 |
+
model,
|
137 |
+
decay=config.MODEL_EMA_DECAY,
|
138 |
+
)
|
139 |
+
|
140 |
+
optimizer = build_optimizer(config, model)
|
141 |
+
if config.TRAIN.AMP:
|
142 |
+
logger.info(f"-------------------------------Using Pytorch AMP...--------------------------------")
|
143 |
+
|
144 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False, find_unused_parameters=False)
|
145 |
+
# model._set_static_graph()
|
146 |
+
model_without_ddp = model.module
|
147 |
+
|
148 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
149 |
+
logger.info(f"number of params: {n_parameters}")
|
150 |
+
|
151 |
+
lr_scheduler = build_scheduler(config)
|
152 |
+
|
153 |
+
if config.AUG.MIXUP > 0.:
|
154 |
+
# smoothing is handled with mixup label transform
|
155 |
+
criterion = SoftTargetCrossEntropy()
|
156 |
+
elif config.MODEL.LABEL_SMOOTHING > 0.:
|
157 |
+
criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
|
158 |
+
else:
|
159 |
+
criterion = torch.nn.CrossEntropyLoss()
|
160 |
+
criterion_bce = torch.nn.BCEWithLogitsLoss()
|
161 |
+
max_accuracy = 0.0
|
162 |
+
|
163 |
+
if config.TRAIN.AUTO_RESUME:
|
164 |
+
resume_file = auto_resume_helper(config.OUTPUT, logger)
|
165 |
+
if resume_file:
|
166 |
+
if config.MODEL.RESUME:
|
167 |
+
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
|
168 |
+
config.defrost()
|
169 |
+
config.MODEL.RESUME = resume_file
|
170 |
+
config.freeze()
|
171 |
+
logger.info(f'auto resuming from {resume_file}')
|
172 |
+
else:
|
173 |
+
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
|
174 |
+
|
175 |
+
if config.MODEL.RESUME:
|
176 |
+
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, logger, model_ema)
|
177 |
+
logger.info(f"Start validation")
|
178 |
+
acc1, acc5, loss = validate(config, data_loader_val, model, writer, epoch=config.TRAIN.START_EPOCH)
|
179 |
+
logger.info(f"Accuracy of the network on the 50000 test images: {acc1:.1f}, {acc5:.1f}%")
|
180 |
+
|
181 |
+
if config.EVAL_MODE:
|
182 |
+
return
|
183 |
+
|
184 |
+
if config.MODEL.FINETUNE:
|
185 |
+
load_checkpoint_finetune(config, model_without_ddp, logger)
|
186 |
+
|
187 |
+
|
188 |
+
logger.info("Start training")
|
189 |
+
start_time = time.time()
|
190 |
+
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
|
191 |
+
data_loader_train.sampler.set_epoch(epoch)
|
192 |
+
|
193 |
+
train_one_epoch(config, model, criterion, criterion_bce, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, writer, model_ema)
|
194 |
+
|
195 |
+
acc1, acc5, _ = validate(config, data_loader_val, model, writer, epoch)
|
196 |
+
logger.info(f"Accuracy of the network on the 5000 test images: {acc1:.2f}, {acc5:.2f}%")
|
197 |
+
|
198 |
+
|
199 |
+
if config.MODEL_EMA:
|
200 |
+
acc1_ema, acc5_ema, _ = validate_ema(config, data_loader_val, model_ema.ema, writer, epoch)
|
201 |
+
logger.info(f"Accuracy of the EMA network on the 5000 test images: {acc1_ema:.1f}, {acc5_ema:.1f}%")
|
202 |
+
# acc1 = max(acc1, acc1_ema)
|
203 |
+
|
204 |
+
if dist.get_rank() == 0 and epoch % config.SAVE_FREQ == 0:
|
205 |
+
save_checkpoint(config, epoch, model_without_ddp, acc1, max_accuracy, optimizer, logger, model_ema)
|
206 |
+
|
207 |
+
max_accuracy = max(max_accuracy, acc1)
|
208 |
+
logger.info(f'Max accuracy: {max_accuracy:.2f}%')
|
209 |
+
|
210 |
+
total_time = time.time() - start_time
|
211 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
212 |
+
logger.info('Training time {}'.format(total_time_str))
|
213 |
+
|
214 |
+
|
215 |
+
def train_one_epoch(config, model, criterion_ce, criterion_bce, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, writer, model_ema: Optional[ModelEma] = None):
|
216 |
+
global logger
|
217 |
+
model.train()
|
218 |
+
optimizer.zero_grad()
|
219 |
+
|
220 |
+
num_steps = len(data_loader)
|
221 |
+
batch_time = AverageMeter()
|
222 |
+
loss_meter = AverageMeter()
|
223 |
+
norm_meter = AverageMeter()
|
224 |
+
data_time = AverageMeter()
|
225 |
+
|
226 |
+
start = time.time()
|
227 |
+
end = time.time()
|
228 |
+
|
229 |
+
|
230 |
+
for idx, (samples, targets) in enumerate(data_loader):
|
231 |
+
|
232 |
+
samples = samples.cuda(non_blocking=True)
|
233 |
+
targets = targets.cuda(non_blocking=True)
|
234 |
+
|
235 |
+
data_time.update(time.time()-end)
|
236 |
+
lr_scheduler.step_update(optimizer, idx / num_steps + epoch, config)
|
237 |
+
if mixup_fn is not None:
|
238 |
+
samples, targets = mixup_fn(samples, targets)
|
239 |
+
with amp.autocast(enabled=config.TRAIN.AMP):
|
240 |
+
output_label, output_feature = model(samples)
|
241 |
+
if len(output_label) == 1:
|
242 |
+
loss = criterion_ce(output_label[0], targets)
|
243 |
+
multi_loss = []
|
244 |
+
else:
|
245 |
+
loss, multi_loss = compound_loss((config.REVCOL.FCOE, config.REVCOL.CCOE), output_feature, denormalize(samples, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), output_label, targets, criterion_bce, criterion_ce, epoch)
|
246 |
+
|
247 |
+
|
248 |
+
if not math.isfinite(loss.item()):
|
249 |
+
print("Loss is {} in iteration {}, multiloss {}, !".format(loss.item(), idx, multi_loss))
|
250 |
+
|
251 |
+
scaler.scale(loss).backward()
|
252 |
+
if config.TRAIN.CLIP_GRAD:
|
253 |
+
scaler.unscale_(optimizer)
|
254 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
|
255 |
+
else:
|
256 |
+
scaler.unscale_(optimizer)
|
257 |
+
grad_norm = get_grad_norm(model.parameters())
|
258 |
+
scaler.step(optimizer)
|
259 |
+
scaler.update()
|
260 |
+
|
261 |
+
optimizer.zero_grad()
|
262 |
+
|
263 |
+
if model_ema is not None:
|
264 |
+
model_ema.update(model)
|
265 |
+
|
266 |
+
loss_meter.update(loss.item(), targets.size(0))
|
267 |
+
norm_meter.update(grad_norm)
|
268 |
+
batch_time.update(time.time() - end)
|
269 |
+
end = time.time()
|
270 |
+
|
271 |
+
if dist.get_rank() == 0 and idx%10 == 0:
|
272 |
+
writer.add_scalar('Train/train_loss',loss_meter.val, epoch * num_steps + idx )
|
273 |
+
writer.add_scalar('Train/grad_norm',norm_meter.val, epoch * num_steps + idx )
|
274 |
+
for i, subloss in enumerate(multi_loss):
|
275 |
+
writer.add_scalar(f'Train/sub_loss{i}', subloss, epoch * num_steps + idx)
|
276 |
+
|
277 |
+
if idx % config.PRINT_FREQ == 0:
|
278 |
+
lr = optimizer.param_groups[-1]['lr']
|
279 |
+
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
280 |
+
etas = batch_time.avg * (num_steps - idx)
|
281 |
+
logger.info(
|
282 |
+
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
|
283 |
+
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
|
284 |
+
f'datatime {data_time.val:.4f} ({data_time.avg:.4f})\t'
|
285 |
+
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
|
286 |
+
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
|
287 |
+
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
|
288 |
+
f'mem {memory_used:.0f}MB')
|
289 |
+
|
290 |
+
epoch_time = time.time() - start
|
291 |
+
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
@torch.no_grad()
|
297 |
+
def validate(config, data_loader, model, writer, epoch):
|
298 |
+
global logger
|
299 |
+
criterion = torch.nn.CrossEntropyLoss()
|
300 |
+
model.eval()
|
301 |
+
|
302 |
+
batch_time = AverageMeter()
|
303 |
+
loss_meter = AverageMeter()
|
304 |
+
acc1_meter_list = []
|
305 |
+
acc5_meter_list = []
|
306 |
+
for i in range(4):
|
307 |
+
acc1_meter_list.append(AverageMeter())
|
308 |
+
acc5_meter_list.append(AverageMeter())
|
309 |
+
|
310 |
+
end = time.time()
|
311 |
+
for idx, (images, target) in enumerate(data_loader):
|
312 |
+
|
313 |
+
images = images.cuda(non_blocking=True)
|
314 |
+
target = target.cuda(non_blocking=True)
|
315 |
+
|
316 |
+
# compute output
|
317 |
+
outputs,_ = model(images)
|
318 |
+
|
319 |
+
|
320 |
+
if len(acc1_meter_list) != len(outputs):
|
321 |
+
acc1_meter_list = acc1_meter_list[:len(outputs)]
|
322 |
+
acc5_meter_list = acc5_meter_list[:len(outputs)]
|
323 |
+
|
324 |
+
output_last = outputs[-1]
|
325 |
+
loss = criterion(output_last, target)
|
326 |
+
loss = reduce_tensor(loss)
|
327 |
+
loss_meter.update(loss.item(), target.size(0))
|
328 |
+
|
329 |
+
for i, subnet_out in enumerate(outputs):
|
330 |
+
acc1, acc5 = accuracy(subnet_out, target, topk=(1, 5))
|
331 |
+
|
332 |
+
|
333 |
+
acc1 = reduce_tensor(acc1)
|
334 |
+
acc5 = reduce_tensor(acc5)
|
335 |
+
|
336 |
+
acc1_meter_list[i].update(acc1.item(), target.size(0))
|
337 |
+
acc5_meter_list[i].update(acc5.item(), target.size(0))
|
338 |
+
|
339 |
+
# measure elapsed time
|
340 |
+
batch_time.update(time.time() - end)
|
341 |
+
end = time.time()
|
342 |
+
|
343 |
+
|
344 |
+
if idx % config.PRINT_FREQ == 0:
|
345 |
+
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
346 |
+
logger.info(
|
347 |
+
f'Test: [{idx}/{len(data_loader)}]\t'
|
348 |
+
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
349 |
+
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
|
350 |
+
f'Acc@1 {acc1_meter_list[-1].val:.3f} ({acc1_meter_list[-1].avg:.3f})\t'
|
351 |
+
f'Acc@5 {acc5_meter_list[-1].val:.3f} ({acc5_meter_list[-1].avg:.3f})\t'
|
352 |
+
f'Mem {memory_used:.0f}MB')
|
353 |
+
|
354 |
+
logger.info(f' * Acc@1 {acc1_meter_list[-1].avg:.3f} Acc@5 {acc5_meter_list[-1].avg:.3f}')
|
355 |
+
if dist.get_rank() == 0:
|
356 |
+
for i in range(len(acc1_meter_list)):
|
357 |
+
writer.add_scalar(f'Val_top1/acc_{i}', acc1_meter_list[i].avg, epoch)
|
358 |
+
writer.add_scalar(f'Val_top5/acc_{i}', acc5_meter_list[i].avg, epoch)
|
359 |
+
return acc1_meter_list[-1].avg, acc5_meter_list[-1].avg, loss_meter.avg
|
360 |
+
|
361 |
+
@torch.no_grad()
|
362 |
+
def validate_ema(config, data_loader, model, writer, epoch):
|
363 |
+
global logger
|
364 |
+
criterion = torch.nn.CrossEntropyLoss()
|
365 |
+
model.eval()
|
366 |
+
|
367 |
+
batch_time = AverageMeter()
|
368 |
+
loss_meter = AverageMeter()
|
369 |
+
acc1_meter = AverageMeter()
|
370 |
+
acc5_meter = AverageMeter()
|
371 |
+
|
372 |
+
end = time.time()
|
373 |
+
for idx, (images, target) in enumerate(data_loader):
|
374 |
+
images = images.cuda(non_blocking=True)
|
375 |
+
target = target.cuda(non_blocking=True)
|
376 |
+
|
377 |
+
outputs,_ = model(images)
|
378 |
+
|
379 |
+
output_last = outputs[-1]
|
380 |
+
loss = criterion(output_last, target)
|
381 |
+
loss = reduce_tensor(loss)
|
382 |
+
loss_meter.update(loss.item(), target.size(0))
|
383 |
+
|
384 |
+
|
385 |
+
acc1, acc5 = accuracy(output_last, target, topk=(1, 5))
|
386 |
+
|
387 |
+
|
388 |
+
acc1 = reduce_tensor(acc1)
|
389 |
+
acc5 = reduce_tensor(acc5)
|
390 |
+
|
391 |
+
acc1_meter.update(acc1.item(), target.size(0))
|
392 |
+
acc5_meter.update(acc5.item(), target.size(0))
|
393 |
+
|
394 |
+
# measure elapsed time
|
395 |
+
batch_time.update(time.time() - end)
|
396 |
+
end = time.time()
|
397 |
+
|
398 |
+
if idx % config.PRINT_FREQ == 0:
|
399 |
+
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
|
400 |
+
logger.info(
|
401 |
+
f'Test: [{idx}/{len(data_loader)}]\t'
|
402 |
+
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
403 |
+
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
|
404 |
+
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
|
405 |
+
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
|
406 |
+
f'Mem {memory_used:.0f}MB')
|
407 |
+
|
408 |
+
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
|
409 |
+
|
410 |
+
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == '__main__':
|
414 |
+
_, config = parse_option()
|
415 |
+
|
416 |
+
cudnn.benchmark = True
|
417 |
+
|
418 |
+
os.makedirs(config.OUTPUT, exist_ok=True)
|
419 |
+
|
420 |
+
ngpus_per_node = torch.cuda.device_count()
|
421 |
+
|
422 |
+
main(None, config, ngpus_per_node)
|
training/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .build import build_model
|
training/models/build.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from models.revcol import *
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def build_model(config):
|
14 |
+
model_type = config.MODEL.TYPE
|
15 |
+
|
16 |
+
##-------------------------------------- revcol tiny ----------------------------------------------------------------------------------------------------------------------#
|
17 |
+
|
18 |
+
if model_type == "revcol_tiny":
|
19 |
+
model = revcol_tiny(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES, kernel_size = config.REVCOL.KERNEL_SIZE)
|
20 |
+
|
21 |
+
##-------------------------------------- revcol small ----------------------------------------------------------------------------------------------------------------------#
|
22 |
+
|
23 |
+
elif model_type == "revcol_small":
|
24 |
+
model = revcol_small(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES, kernel_size = config.REVCOL.KERNEL_SIZE)
|
25 |
+
|
26 |
+
##-------------------------------------- revcol base ----------------------------------------------------------------------------------------------------------------------#
|
27 |
+
|
28 |
+
elif model_type == "revcol_base":
|
29 |
+
model = revcol_base(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , kernel_size = config.REVCOL.KERNEL_SIZE)
|
30 |
+
|
31 |
+
##-------------------------------------- revcol large ----------------------------------------------------------------------------------------------------------------------#
|
32 |
+
|
33 |
+
elif model_type == "revcol_large":
|
34 |
+
model = revcol_large(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , head_init_scale=config.REVCOL.HEAD_INIT_SCALE, kernel_size = config.REVCOL.KERNEL_SIZE)
|
35 |
+
|
36 |
+
##-------------------------------------- revcol xlarge ----------------------------------------------------------------------------------------------------------------------#
|
37 |
+
|
38 |
+
elif model_type == "revcol_xlarge":
|
39 |
+
model = revcol_xlarge(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , head_init_scale=config.REVCOL.HEAD_INIT_SCALE, kernel_size = config.REVCOL.KERNEL_SIZE)
|
40 |
+
|
41 |
+
else:
|
42 |
+
raise NotImplementedError(f"Unkown model: {model_type}")
|
43 |
+
|
44 |
+
return model
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
training/models/modules.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import imp
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from timm.models.layers import DropPath
|
13 |
+
|
14 |
+
|
15 |
+
class UpSampleConvnext(nn.Module):
|
16 |
+
def __init__(self, ratio, inchannel, outchannel):
|
17 |
+
super().__init__()
|
18 |
+
self.ratio = ratio
|
19 |
+
self.channel_reschedule = nn.Sequential(
|
20 |
+
# LayerNorm(inchannel, eps=1e-6, data_format="channels_last"),
|
21 |
+
nn.Linear(inchannel, outchannel),
|
22 |
+
LayerNorm(outchannel, eps=1e-6, data_format="channels_last"))
|
23 |
+
self.upsample = nn.Upsample(scale_factor=2**ratio, mode='nearest')
|
24 |
+
def forward(self, x):
|
25 |
+
x = x.permute(0, 2, 3, 1)
|
26 |
+
x = self.channel_reschedule(x)
|
27 |
+
x = x = x.permute(0, 3, 1, 2)
|
28 |
+
|
29 |
+
return self.upsample(x)
|
30 |
+
|
31 |
+
class LayerNorm(nn.Module):
|
32 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
33 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
34 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
35 |
+
with shape (batch_size, channels, height, width).
|
36 |
+
"""
|
37 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first", elementwise_affine = True):
|
38 |
+
super().__init__()
|
39 |
+
self.elementwise_affine = elementwise_affine
|
40 |
+
if elementwise_affine:
|
41 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
42 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
43 |
+
self.eps = eps
|
44 |
+
self.data_format = data_format
|
45 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
46 |
+
raise NotImplementedError
|
47 |
+
self.normalized_shape = (normalized_shape, )
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
if self.data_format == "channels_last":
|
51 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
52 |
+
elif self.data_format == "channels_first":
|
53 |
+
u = x.mean(1, keepdim=True)
|
54 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
55 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
56 |
+
if self.elementwise_affine:
|
57 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
class ConvNextBlock(nn.Module):
|
62 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
63 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
64 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
65 |
+
We use (2) as we find it slightly faster in PyTorch
|
66 |
+
|
67 |
+
Args:
|
68 |
+
dim (int): Number of input channels.
|
69 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
70 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
71 |
+
"""
|
72 |
+
def __init__(self, in_channel, hidden_dim, out_channel, kernel_size=3, layer_scale_init_value=1e-6, drop_path= 0.0):
|
73 |
+
super().__init__()
|
74 |
+
self.dwconv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=in_channel) # depthwise conv
|
75 |
+
self.norm = nn.LayerNorm(in_channel, eps=1e-6)
|
76 |
+
self.pwconv1 = nn.Linear(in_channel, hidden_dim) # pointwise/1x1 convs, implemented with linear layers
|
77 |
+
self.act = nn.GELU()
|
78 |
+
self.pwconv2 = nn.Linear(hidden_dim, out_channel)
|
79 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channel)),
|
80 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
81 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
input = x
|
85 |
+
x = self.dwconv(x)
|
86 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
87 |
+
x = self.norm(x)
|
88 |
+
x = self.pwconv1(x)
|
89 |
+
x = self.act(x)
|
90 |
+
# print(f"x min: {x.min()}, x max: {x.max()}, input min: {input.min()}, input max: {input.max()}, x mean: {x.mean()}, x var: {x.var()}, ratio: {torch.sum(x>8)/x.numel()}")
|
91 |
+
x = self.pwconv2(x)
|
92 |
+
if self.gamma is not None:
|
93 |
+
x = self.gamma * x
|
94 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
95 |
+
|
96 |
+
x = input + self.drop_path(x)
|
97 |
+
return x
|
98 |
+
|
99 |
+
class Decoder(nn.Module):
|
100 |
+
def __init__(self, depth=[2,2,2,2], dim=[112, 72, 40, 24], block_type = None, kernel_size = 3) -> None:
|
101 |
+
super().__init__()
|
102 |
+
self.depth = depth
|
103 |
+
self.dim = dim
|
104 |
+
self.block_type = block_type
|
105 |
+
self._build_decode_layer(dim, depth, kernel_size)
|
106 |
+
self.projback = nn.Sequential(
|
107 |
+
nn.Conv2d(
|
108 |
+
in_channels=dim[-1],
|
109 |
+
out_channels=4 ** 2 * 3, kernel_size=1),
|
110 |
+
nn.PixelShuffle(4),
|
111 |
+
)
|
112 |
+
|
113 |
+
def _build_decode_layer(self, dim, depth, kernel_size):
|
114 |
+
normal_layers = nn.ModuleList()
|
115 |
+
upsample_layers = nn.ModuleList()
|
116 |
+
proj_layers = nn.ModuleList()
|
117 |
+
|
118 |
+
norm_layer = LayerNorm
|
119 |
+
|
120 |
+
for i in range(1, len(dim)):
|
121 |
+
module = [self.block_type(dim[i], dim[i], dim[i], kernel_size) for _ in range(depth[i])]
|
122 |
+
normal_layers.append(nn.Sequential(*module))
|
123 |
+
upsample_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
|
124 |
+
proj_layers.append(nn.Sequential(
|
125 |
+
nn.Conv2d(dim[i-1], dim[i], 1, 1),
|
126 |
+
norm_layer(dim[i]),
|
127 |
+
nn.GELU()
|
128 |
+
))
|
129 |
+
self.normal_layers = normal_layers
|
130 |
+
self.upsample_layers = upsample_layers
|
131 |
+
self.proj_layers = proj_layers
|
132 |
+
|
133 |
+
def _forward_stage(self, stage, x):
|
134 |
+
x = self.proj_layers[stage](x)
|
135 |
+
x = self.upsample_layers[stage](x)
|
136 |
+
return self.normal_layers[stage](x)
|
137 |
+
|
138 |
+
def forward(self, c3):
|
139 |
+
x = self._forward_stage(0, c3) #14
|
140 |
+
x = self._forward_stage(1, x) #28
|
141 |
+
x = self._forward_stage(2, x) #56
|
142 |
+
x = self.projback(x)
|
143 |
+
return x
|
144 |
+
|
145 |
+
class SimDecoder(nn.Module):
|
146 |
+
def __init__(self, in_channel, encoder_stride) -> None:
|
147 |
+
super().__init__()
|
148 |
+
self.projback = nn.Sequential(
|
149 |
+
LayerNorm(in_channel),
|
150 |
+
nn.Conv2d(
|
151 |
+
in_channels=in_channel,
|
152 |
+
out_channels=encoder_stride ** 2 * 3, kernel_size=1),
|
153 |
+
nn.PixelShuffle(encoder_stride),
|
154 |
+
)
|
155 |
+
|
156 |
+
def forward(self, c3):
|
157 |
+
return self.projback(c3)
|
training/models/revcol.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from models.modules import ConvNextBlock, Decoder, LayerNorm, SimDecoder, UpSampleConvnext
|
11 |
+
import torch.distributed as dist
|
12 |
+
from models.revcol_function import ReverseFunction
|
13 |
+
from timm.models.layers import trunc_normal_
|
14 |
+
|
15 |
+
class Fusion(nn.Module):
|
16 |
+
def __init__(self, level, channels, first_col) -> None:
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.level = level
|
20 |
+
self.first_col = first_col
|
21 |
+
self.down = nn.Sequential(
|
22 |
+
nn.Conv2d(channels[level-1], channels[level], kernel_size=2, stride=2),
|
23 |
+
LayerNorm(channels[level], eps=1e-6, data_format="channels_first"),
|
24 |
+
) if level in [1, 2, 3] else nn.Identity()
|
25 |
+
if not first_col:
|
26 |
+
self.up = UpSampleConvnext(1, channels[level+1], channels[level]) if level in [0, 1, 2] else nn.Identity()
|
27 |
+
|
28 |
+
def forward(self, *args):
|
29 |
+
|
30 |
+
c_down, c_up = args
|
31 |
+
|
32 |
+
if self.first_col:
|
33 |
+
x = self.down(c_down)
|
34 |
+
return x
|
35 |
+
|
36 |
+
if self.level == 3:
|
37 |
+
x = self.down(c_down)
|
38 |
+
else:
|
39 |
+
x = self.up(c_up) + self.down(c_down)
|
40 |
+
return x
|
41 |
+
|
42 |
+
class Level(nn.Module):
|
43 |
+
def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None:
|
44 |
+
super().__init__()
|
45 |
+
countlayer = sum(layers[:level])
|
46 |
+
expansion = 4
|
47 |
+
self.fusion = Fusion(level, channels, first_col)
|
48 |
+
modules = [ConvNextBlock(channels[level], expansion*channels[level], channels[level], kernel_size = kernel_size, layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer+i]) for i in range(layers[level])]
|
49 |
+
self.blocks = nn.Sequential(*modules)
|
50 |
+
def forward(self, *args):
|
51 |
+
x = self.fusion(*args)
|
52 |
+
x = self.blocks(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
class SubNet(nn.Module):
|
56 |
+
def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None:
|
57 |
+
super().__init__()
|
58 |
+
shortcut_scale_init_value = 0.5
|
59 |
+
self.save_memory = save_memory
|
60 |
+
self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)),
|
61 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
62 |
+
self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)),
|
63 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
64 |
+
self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)),
|
65 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
66 |
+
self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)),
|
67 |
+
requires_grad=True) if shortcut_scale_init_value > 0 else None
|
68 |
+
|
69 |
+
self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates)
|
70 |
+
|
71 |
+
self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates)
|
72 |
+
|
73 |
+
self.level2 = Level(2, channels, layers, kernel_size,first_col, dp_rates)
|
74 |
+
|
75 |
+
self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates)
|
76 |
+
|
77 |
+
def _forward_nonreverse(self, *args):
|
78 |
+
x, c0, c1, c2, c3= args
|
79 |
+
|
80 |
+
c0 = (self.alpha0)*c0 + self.level0(x, c1)
|
81 |
+
c1 = (self.alpha1)*c1 + self.level1(c0, c2)
|
82 |
+
c2 = (self.alpha2)*c2 + self.level2(c1, c3)
|
83 |
+
c3 = (self.alpha3)*c3 + self.level3(c2, None)
|
84 |
+
return c0, c1, c2, c3
|
85 |
+
|
86 |
+
def _forward_reverse(self, *args):
|
87 |
+
|
88 |
+
local_funs = [self.level0, self.level1, self.level2, self.level3]
|
89 |
+
alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3]
|
90 |
+
_, c0, c1, c2, c3 = ReverseFunction.apply(
|
91 |
+
local_funs, alpha, *args)
|
92 |
+
|
93 |
+
return c0, c1, c2, c3
|
94 |
+
|
95 |
+
def forward(self, *args):
|
96 |
+
|
97 |
+
self._clamp_abs(self.alpha0.data, 1e-3)
|
98 |
+
self._clamp_abs(self.alpha1.data, 1e-3)
|
99 |
+
self._clamp_abs(self.alpha2.data, 1e-3)
|
100 |
+
self._clamp_abs(self.alpha3.data, 1e-3)
|
101 |
+
|
102 |
+
if self.save_memory:
|
103 |
+
return self._forward_reverse(*args)
|
104 |
+
else:
|
105 |
+
return self._forward_nonreverse(*args)
|
106 |
+
|
107 |
+
def _clamp_abs(self, data, value):
|
108 |
+
with torch.no_grad():
|
109 |
+
sign=data.sign()
|
110 |
+
data.abs_().clamp_(value)
|
111 |
+
data*=sign
|
112 |
+
|
113 |
+
|
114 |
+
class Classifier(nn.Module):
|
115 |
+
def __init__(self, in_channels, num_classes):
|
116 |
+
super().__init__()
|
117 |
+
|
118 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
119 |
+
self.classifier = nn.Sequential(
|
120 |
+
nn.LayerNorm(in_channels, eps=1e-6), # final norm layer
|
121 |
+
nn.Linear(in_channels, num_classes),
|
122 |
+
)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
x = self.avgpool(x)
|
126 |
+
x = x.view(x.size(0), -1)
|
127 |
+
x = self.classifier(x)
|
128 |
+
return x
|
129 |
+
|
130 |
+
class FullNet(nn.Module):
|
131 |
+
def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size = 3, num_classes=1000, drop_path = 0.0, save_memory=True, inter_supv=True, head_init_scale=None) -> None:
|
132 |
+
super().__init__()
|
133 |
+
self.num_subnet = num_subnet
|
134 |
+
self.inter_supv = inter_supv
|
135 |
+
self.channels = channels
|
136 |
+
self.layers = layers
|
137 |
+
|
138 |
+
self.stem = nn.Sequential(
|
139 |
+
nn.Conv2d(3, channels[0], kernel_size=4, stride=4),
|
140 |
+
LayerNorm(channels[0], eps=1e-6, data_format="channels_first")
|
141 |
+
)
|
142 |
+
|
143 |
+
dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))]
|
144 |
+
for i in range(num_subnet):
|
145 |
+
first_col = True if i == 0 else False
|
146 |
+
self.add_module(f'subnet{str(i)}', SubNet(
|
147 |
+
channels,layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory))
|
148 |
+
|
149 |
+
if not inter_supv:
|
150 |
+
self.cls = Classifier(in_channels=channels[-1], num_classes=num_classes)
|
151 |
+
else:
|
152 |
+
self.cls_blocks = nn.ModuleList([Classifier(in_channels=channels[-1], num_classes=num_classes) for _ in range(4) ])
|
153 |
+
if num_classes<=1000:
|
154 |
+
channels.reverse()
|
155 |
+
self.decoder_blocks = nn.ModuleList([Decoder(depth=[1,1,1,1], dim=channels, block_type=ConvNextBlock, kernel_size = 3) for _ in range(3) ])
|
156 |
+
else:
|
157 |
+
self.decoder_blocks = nn.ModuleList([SimDecoder(in_channel=channels[-1], encoder_stride=32) for _ in range(3) ])
|
158 |
+
|
159 |
+
self.apply(self._init_weights)
|
160 |
+
|
161 |
+
if head_init_scale:
|
162 |
+
print(f'Head_init_scale: {head_init_scale}')
|
163 |
+
self.cls.classifier._modules['1'].weight.data.mul_(head_init_scale)
|
164 |
+
self.cls.classifier._modules['1'].bias.data.mul_(head_init_scale)
|
165 |
+
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
|
169 |
+
if self.inter_supv:
|
170 |
+
return self._forward_intermediate_supervision(x)
|
171 |
+
else:
|
172 |
+
c0, c1, c2, c3 = 0, 0, 0, 0
|
173 |
+
x = self.stem(x)
|
174 |
+
for i in range(self.num_subnet):
|
175 |
+
c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
|
176 |
+
return [self.cls(c3)], None
|
177 |
+
|
178 |
+
def _forward_intermediate_supervision(self, x):
|
179 |
+
x_cls_out = []
|
180 |
+
x_img_out = []
|
181 |
+
c0, c1, c2, c3 = 0, 0, 0, 0
|
182 |
+
interval = self.num_subnet//4
|
183 |
+
|
184 |
+
x = self.stem(x)
|
185 |
+
for i in range(self.num_subnet):
|
186 |
+
c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
|
187 |
+
if (i+1) % interval == 0:
|
188 |
+
x_cls_out.append(self.cls_blocks[i//interval](c3))
|
189 |
+
if i != self.num_subnet-1:
|
190 |
+
x_img_out.append(self.decoder_blocks[i//interval](c3))
|
191 |
+
|
192 |
+
return x_cls_out, x_img_out
|
193 |
+
|
194 |
+
|
195 |
+
def _init_weights(self, module):
|
196 |
+
if isinstance(module, nn.Conv2d):
|
197 |
+
trunc_normal_(module.weight, std=.02)
|
198 |
+
nn.init.constant_(module.bias, 0)
|
199 |
+
elif isinstance(module, nn.Linear):
|
200 |
+
trunc_normal_(module.weight, std=.02)
|
201 |
+
nn.init.constant_(module.bias, 0)
|
202 |
+
|
203 |
+
##-------------------------------------- Tiny -----------------------------------------
|
204 |
+
|
205 |
+
def revcol_tiny(save_memory, inter_supv=True, drop_path=0.1, num_classes=1000, kernel_size = 3):
|
206 |
+
channels = [64, 128, 256, 512]
|
207 |
+
layers = [2, 2, 4, 2]
|
208 |
+
num_subnet = 4
|
209 |
+
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size)
|
210 |
+
|
211 |
+
##-------------------------------------- Small -----------------------------------------
|
212 |
+
|
213 |
+
def revcol_small(save_memory, inter_supv=True, drop_path=0.3, num_classes=1000, kernel_size = 3):
|
214 |
+
channels = [64, 128, 256, 512]
|
215 |
+
layers = [2, 2, 4, 2]
|
216 |
+
num_subnet = 8
|
217 |
+
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size)
|
218 |
+
|
219 |
+
##-------------------------------------- Base -----------------------------------------
|
220 |
+
|
221 |
+
def revcol_base(save_memory, inter_supv=True, drop_path=0.4, num_classes=1000, kernel_size = 3, head_init_scale=None):
|
222 |
+
channels = [72, 144, 288, 576]
|
223 |
+
layers = [1, 1, 3, 2]
|
224 |
+
num_subnet = 16
|
225 |
+
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)
|
226 |
+
|
227 |
+
|
228 |
+
##-------------------------------------- Large -----------------------------------------
|
229 |
+
|
230 |
+
def revcol_large(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None):
|
231 |
+
channels = [128, 256, 512, 1024]
|
232 |
+
layers = [1, 2, 6, 2]
|
233 |
+
num_subnet = 8
|
234 |
+
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)
|
235 |
+
|
236 |
+
##--------------------------------------Extra-Large -----------------------------------------
|
237 |
+
def revcol_xlarge(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None):
|
238 |
+
channels = [224, 448, 896, 1792]
|
239 |
+
layers = [1, 2, 6, 2]
|
240 |
+
num_subnet = 8
|
241 |
+
return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)
|
242 |
+
|
training/models/revcol_function.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from typing import Any, Iterable, List, Tuple, Callable
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
def get_gpu_states(fwd_gpu_devices) -> Tuple[List[int], List[torch.Tensor]]:
|
13 |
+
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
|
14 |
+
# the conditionals short-circuit.
|
15 |
+
fwd_gpu_states = []
|
16 |
+
for device in fwd_gpu_devices:
|
17 |
+
with torch.cuda.device(device):
|
18 |
+
fwd_gpu_states.append(torch.cuda.get_rng_state())
|
19 |
+
|
20 |
+
return fwd_gpu_states
|
21 |
+
|
22 |
+
def get_gpu_device(*args):
|
23 |
+
|
24 |
+
fwd_gpu_devices = list(set(arg.get_device() for arg in args
|
25 |
+
if isinstance(arg, torch.Tensor) and arg.is_cuda))
|
26 |
+
return fwd_gpu_devices
|
27 |
+
|
28 |
+
def set_device_states(fwd_cpu_state, devices, states) -> None:
|
29 |
+
torch.set_rng_state(fwd_cpu_state)
|
30 |
+
for device, state in zip(devices, states):
|
31 |
+
with torch.cuda.device(device):
|
32 |
+
torch.cuda.set_rng_state(state)
|
33 |
+
|
34 |
+
def detach_and_grad(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
|
35 |
+
if isinstance(inputs, tuple):
|
36 |
+
out = []
|
37 |
+
for inp in inputs:
|
38 |
+
if not isinstance(inp, torch.Tensor):
|
39 |
+
out.append(inp)
|
40 |
+
continue
|
41 |
+
|
42 |
+
x = inp.detach()
|
43 |
+
x.requires_grad = True
|
44 |
+
out.append(x)
|
45 |
+
return tuple(out)
|
46 |
+
else:
|
47 |
+
raise RuntimeError(
|
48 |
+
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
|
49 |
+
|
50 |
+
def get_cpu_and_gpu_states(gpu_devices):
|
51 |
+
return torch.get_rng_state(), get_gpu_states(gpu_devices)
|
52 |
+
|
53 |
+
class ReverseFunction(torch.autograd.Function):
|
54 |
+
@staticmethod
|
55 |
+
def forward(ctx, run_functions, alpha, *args):
|
56 |
+
l0, l1, l2, l3 = run_functions
|
57 |
+
alpha0, alpha1, alpha2, alpha3 = alpha
|
58 |
+
ctx.run_functions = run_functions
|
59 |
+
ctx.alpha = alpha
|
60 |
+
ctx.preserve_rng_state = True
|
61 |
+
|
62 |
+
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
63 |
+
"dtype": torch.get_autocast_gpu_dtype(),
|
64 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
65 |
+
ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
|
66 |
+
"dtype": torch.get_autocast_cpu_dtype(),
|
67 |
+
"cache_enabled": torch.is_autocast_cache_enabled()}
|
68 |
+
|
69 |
+
assert len(args) == 5
|
70 |
+
[x, c0, c1, c2, c3] = args
|
71 |
+
if type(c0) == int:
|
72 |
+
ctx.first_col = True
|
73 |
+
else:
|
74 |
+
ctx.first_col = False
|
75 |
+
with torch.no_grad():
|
76 |
+
gpu_devices = get_gpu_device(*args)
|
77 |
+
ctx.gpu_devices = gpu_devices
|
78 |
+
ctx.cpu_states_0, ctx.gpu_states_0 = get_cpu_and_gpu_states(gpu_devices)
|
79 |
+
c0 = l0(x, c1) + c0*alpha0
|
80 |
+
ctx.cpu_states_1, ctx.gpu_states_1 = get_cpu_and_gpu_states(gpu_devices)
|
81 |
+
c1 = l1(c0, c2) + c1*alpha1
|
82 |
+
ctx.cpu_states_2, ctx.gpu_states_2 = get_cpu_and_gpu_states(gpu_devices)
|
83 |
+
c2 = l2(c1, c3) + c2*alpha2
|
84 |
+
ctx.cpu_states_3, ctx.gpu_states_3 = get_cpu_and_gpu_states(gpu_devices)
|
85 |
+
c3 = l3(c2, None) + c3*alpha3
|
86 |
+
ctx.save_for_backward(x, c0, c1, c2, c3)
|
87 |
+
return x, c0, c1 ,c2, c3
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def backward(ctx, *grad_outputs):
|
91 |
+
x, c0, c1, c2, c3 = ctx.saved_tensors
|
92 |
+
l0, l1, l2, l3 = ctx.run_functions
|
93 |
+
alpha0, alpha1, alpha2, alpha3 = ctx.alpha
|
94 |
+
gx_right, g0_right, g1_right, g2_right, g3_right = grad_outputs
|
95 |
+
(x, c0, c1, c2, c3) = detach_and_grad((x, c0, c1, c2, c3))
|
96 |
+
|
97 |
+
with torch.enable_grad(), \
|
98 |
+
torch.random.fork_rng(devices=ctx.gpu_devices, enabled=ctx.preserve_rng_state), \
|
99 |
+
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
|
100 |
+
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
101 |
+
|
102 |
+
g3_up = g3_right
|
103 |
+
g3_left = g3_up*alpha3 ##shortcut
|
104 |
+
set_device_states(ctx.cpu_states_3, ctx.gpu_devices, ctx.gpu_states_3)
|
105 |
+
oup3 = l3(c2, None)
|
106 |
+
torch.autograd.backward(oup3, g3_up, retain_graph=True)
|
107 |
+
with torch.no_grad():
|
108 |
+
c3_left = (1/alpha3)*(c3 - oup3) ## feature reverse
|
109 |
+
g2_up = g2_right+ c2.grad
|
110 |
+
g2_left = g2_up*alpha2 ##shortcut
|
111 |
+
|
112 |
+
(c3_left,) = detach_and_grad((c3_left,))
|
113 |
+
set_device_states(ctx.cpu_states_2, ctx.gpu_devices, ctx.gpu_states_2)
|
114 |
+
oup2 = l2(c1, c3_left)
|
115 |
+
torch.autograd.backward(oup2, g2_up, retain_graph=True)
|
116 |
+
c3_left.requires_grad = False
|
117 |
+
cout3 = c3_left*alpha3 ##alpha3 update
|
118 |
+
torch.autograd.backward(cout3, g3_up)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
c2_left = (1/alpha2)*(c2 - oup2) ## feature reverse
|
122 |
+
g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
|
123 |
+
g1_up = g1_right+c1.grad
|
124 |
+
g1_left = g1_up*alpha1 ##shortcut
|
125 |
+
|
126 |
+
(c2_left,) = detach_and_grad((c2_left,))
|
127 |
+
set_device_states(ctx.cpu_states_1, ctx.gpu_devices, ctx.gpu_states_1)
|
128 |
+
oup1 = l1(c0, c2_left)
|
129 |
+
torch.autograd.backward(oup1, g1_up, retain_graph=True)
|
130 |
+
c2_left.requires_grad = False
|
131 |
+
cout2 = c2_left*alpha2 ##alpha2 update
|
132 |
+
torch.autograd.backward(cout2, g2_up)
|
133 |
+
|
134 |
+
with torch.no_grad():
|
135 |
+
c1_left = (1/alpha1)*(c1 - oup1) ## feature reverse
|
136 |
+
g0_up = g0_right + c0.grad
|
137 |
+
g0_left = g0_up*alpha0 ##shortcut
|
138 |
+
g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left ## Fusion
|
139 |
+
|
140 |
+
(c1_left,) = detach_and_grad((c1_left,))
|
141 |
+
set_device_states(ctx.cpu_states_0, ctx.gpu_devices, ctx.gpu_states_0)
|
142 |
+
oup0 = l0(x, c1_left)
|
143 |
+
torch.autograd.backward(oup0, g0_up, retain_graph=True)
|
144 |
+
c1_left.requires_grad = False
|
145 |
+
cout1 = c1_left*alpha1 ##alpha1 update
|
146 |
+
torch.autograd.backward(cout1, g1_up)
|
147 |
+
|
148 |
+
with torch.no_grad():
|
149 |
+
c0_left = (1/alpha0)*(c0 - oup0) ## feature reverse
|
150 |
+
gx_up = x.grad ## Fusion
|
151 |
+
g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left ## Fusion
|
152 |
+
c0_left.requires_grad = False
|
153 |
+
cout0 = c0_left*alpha0 ##alpha0 update
|
154 |
+
torch.autograd.backward(cout0, g0_up)
|
155 |
+
|
156 |
+
if ctx.first_col:
|
157 |
+
return None, None, gx_up, None, None, None, None
|
158 |
+
else:
|
159 |
+
return None, None, gx_up, g0_left, g1_left, g2_left, g3_left
|
training/optimizer.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from torch import optim as optim
|
10 |
+
|
11 |
+
def build_optimizer(config, model):
|
12 |
+
"""
|
13 |
+
Build optimizer, set weight decay of normalization to 0 by default.
|
14 |
+
"""
|
15 |
+
skip = {}
|
16 |
+
skip_keywords = {}
|
17 |
+
if hasattr(model, 'no_weight_decay'):
|
18 |
+
skip = model.no_weight_decay()
|
19 |
+
if hasattr(model, 'no_weight_decay_keywords'):
|
20 |
+
skip_keywords = model.no_weight_decay_keywords()
|
21 |
+
|
22 |
+
elif config.MODEL.TYPE.startswith("revcol"):
|
23 |
+
parameters = param_groups_lrd(model, weight_decay=config.TRAIN.WEIGHT_DECAY, no_weight_decay_list=[], layer_decay=config.TRAIN.OPTIMIZER.LAYER_DECAY)
|
24 |
+
else:
|
25 |
+
parameters = set_weight_decay(model, skip, skip_keywords)
|
26 |
+
|
27 |
+
|
28 |
+
opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
|
29 |
+
optimizer = None
|
30 |
+
if opt_lower == 'sgd':
|
31 |
+
optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
|
32 |
+
lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
|
33 |
+
elif opt_lower == 'adamw':
|
34 |
+
optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
|
35 |
+
lr=config.TRAIN.BASE_LR)
|
36 |
+
|
37 |
+
return optimizer
|
38 |
+
|
39 |
+
|
40 |
+
def set_weight_decay(model, skip_list=(), skip_keywords=()):
|
41 |
+
has_decay = []
|
42 |
+
no_decay = []
|
43 |
+
|
44 |
+
for name, param in model.named_parameters():
|
45 |
+
if not param.requires_grad or name in ["linear_eval.weight", "linear_eval.bias"]:
|
46 |
+
continue # frozen weights
|
47 |
+
if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
|
48 |
+
check_keywords_in_name(name, skip_keywords):
|
49 |
+
no_decay.append(param)
|
50 |
+
# print(f"{name} has no weight decay")
|
51 |
+
else:
|
52 |
+
has_decay.append(param)
|
53 |
+
return [{'params': has_decay},
|
54 |
+
{'params': no_decay, 'weight_decay': 0.}]
|
55 |
+
|
56 |
+
|
57 |
+
def check_keywords_in_name(name, keywords=()):
|
58 |
+
isin = False
|
59 |
+
for keyword in keywords:
|
60 |
+
if keyword in name:
|
61 |
+
isin = True
|
62 |
+
return isin
|
63 |
+
|
64 |
+
def cal_model_depth(columns, layers):
|
65 |
+
depth = sum(layers)
|
66 |
+
dp = np.zeros((depth, columns))
|
67 |
+
dp[:,0]=np.linspace(0, depth-1, depth)
|
68 |
+
dp[0,:]=np.linspace(0, columns-1, columns)
|
69 |
+
for i in range(1, depth):
|
70 |
+
for j in range(1, columns):
|
71 |
+
dp[i][j] = min(dp[i][j-1], dp[i-1][j])+1
|
72 |
+
dp = dp.astype(int)
|
73 |
+
return dp
|
74 |
+
|
75 |
+
|
76 |
+
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
|
77 |
+
"""
|
78 |
+
Parameter groups for layer-wise lr decay
|
79 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
80 |
+
"""
|
81 |
+
param_group_names = {}
|
82 |
+
param_groups = {}
|
83 |
+
dp = cal_model_depth(model.num_subnet, model.layers)+1
|
84 |
+
num_layers = dp[-1][-1] + 1
|
85 |
+
|
86 |
+
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
|
87 |
+
|
88 |
+
for n, p in model.named_parameters():
|
89 |
+
if not p.requires_grad:
|
90 |
+
continue
|
91 |
+
|
92 |
+
# no decay: all 1D parameters and model specific ones
|
93 |
+
if p.ndim == 1 or n in no_weight_decay_list:# or re.match('(.*).alpha.$', n):
|
94 |
+
g_decay = "no_decay"
|
95 |
+
this_decay = 0.
|
96 |
+
else:
|
97 |
+
g_decay = "decay"
|
98 |
+
this_decay = weight_decay
|
99 |
+
|
100 |
+
layer_id = get_layer_id(n, dp, model.layers)
|
101 |
+
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
102 |
+
|
103 |
+
if group_name not in param_group_names:
|
104 |
+
this_scale = layer_scales[layer_id]
|
105 |
+
|
106 |
+
param_group_names[group_name] = {
|
107 |
+
"lr_scale": this_scale,
|
108 |
+
"weight_decay": this_decay,
|
109 |
+
"params": [],
|
110 |
+
}
|
111 |
+
param_groups[group_name] = {
|
112 |
+
"lr_scale": this_scale,
|
113 |
+
"weight_decay": this_decay,
|
114 |
+
"params": [],
|
115 |
+
}
|
116 |
+
|
117 |
+
param_group_names[group_name]["params"].append(n)
|
118 |
+
param_groups[group_name]["params"].append(p)
|
119 |
+
|
120 |
+
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
121 |
+
|
122 |
+
return list(param_groups.values())
|
123 |
+
|
124 |
+
def get_layer_id(n, dp, layers):
|
125 |
+
if n.startswith("subnet"):
|
126 |
+
name_part = n.split('.')
|
127 |
+
subnet = int(name_part[0][6:])
|
128 |
+
if name_part[1].startswith("alpha"):
|
129 |
+
id = dp[0][subnet]
|
130 |
+
else:
|
131 |
+
level = int(name_part[1][-1])
|
132 |
+
if name_part[2].startswith("blocks"):
|
133 |
+
sub = int(name_part[3])
|
134 |
+
if sub>layers[level]-1:
|
135 |
+
sub = layers[level]-1
|
136 |
+
block = sum(layers[:level])+sub
|
137 |
+
|
138 |
+
if name_part[2].startswith("fusion"):
|
139 |
+
block = sum(layers[:level])
|
140 |
+
id = dp[block][subnet]
|
141 |
+
elif n.startswith("stem"):
|
142 |
+
id = 0
|
143 |
+
else:
|
144 |
+
id = dp[-1][-1]+1
|
145 |
+
return id
|
training/requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fvcore==0.1.5.post20211023
|
2 |
+
numpy==1.20.3
|
3 |
+
opencv_python==4.4.0.46
|
4 |
+
termcolor==1.1.0
|
5 |
+
timm==0.5.4
|
6 |
+
yacs==0.1.8
|
7 |
+
tensorboard
|
training/utils.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Reversible Column Networks
|
3 |
+
# Copyright (c) 2022 Megvii Inc.
|
4 |
+
# Licensed under The Apache License 2.0 [see LICENSE for details]
|
5 |
+
# Written by Yuxuan Cai
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
from typing import List
|
12 |
+
from timm.utils.model_ema import ModelEma
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
from timm.utils import get_state_dict
|
16 |
+
import subprocess
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
def load_checkpoint(config, model, optimizer, logger, model_ema=None):
|
22 |
+
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
|
23 |
+
if config.MODEL.RESUME.startswith('https'):
|
24 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
25 |
+
config.MODEL.RESUME, map_location='cpu', check_hash=True)
|
26 |
+
else:
|
27 |
+
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
|
28 |
+
logger.info("Already loaded checkpoint to memory..")
|
29 |
+
msg = model.load_state_dict(checkpoint['model'], strict=False)
|
30 |
+
logger.info(msg)
|
31 |
+
max_accuracy = 0.0
|
32 |
+
if config.MODEL_EMA:
|
33 |
+
if 'state_dict_ema' in checkpoint.keys():
|
34 |
+
model_ema.ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
|
35 |
+
|
36 |
+
logger.info("Loaded state_dict_ema")
|
37 |
+
else:
|
38 |
+
model_ema.ema.load_state_dict(checkpoint['model'], strict=False)
|
39 |
+
logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
|
40 |
+
|
41 |
+
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'epoch' in checkpoint:
|
42 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
43 |
+
config.defrost()
|
44 |
+
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
|
45 |
+
config.freeze()
|
46 |
+
|
47 |
+
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
|
48 |
+
if 'max_accuracy' in checkpoint:
|
49 |
+
max_accuracy = checkpoint['max_accuracy']
|
50 |
+
# del checkpoint
|
51 |
+
# torch.cuda.empty_cache()
|
52 |
+
return max_accuracy
|
53 |
+
|
54 |
+
def load_checkpoint_finetune(config, model, logger, model_ema=None):
|
55 |
+
logger.info(f"==============> Finetune {config.MODEL.FINETUNE}....................")
|
56 |
+
checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')['model']
|
57 |
+
converted_weights = {}
|
58 |
+
keys = list(checkpoint.keys())
|
59 |
+
for key in keys:
|
60 |
+
if re.match(r'cls.*', key):
|
61 |
+
# if re.match(r'cls.classifier.1.*', key):
|
62 |
+
print(f'key: {key} is used for pretrain, discarded.')
|
63 |
+
continue
|
64 |
+
else:
|
65 |
+
converted_weights[key] = checkpoint[key]
|
66 |
+
msg = model.load_state_dict(converted_weights, strict=False)
|
67 |
+
logger.info(msg)
|
68 |
+
if model_ema is not None:
|
69 |
+
ema_msg = model_ema.ema.load_state_dict(converted_weights, strict=False)
|
70 |
+
logger.info(f"==============> Loaded Pretraind statedict into EMA....................")
|
71 |
+
logger.info(ema_msg)
|
72 |
+
del checkpoint
|
73 |
+
torch.cuda.empty_cache()
|
74 |
+
|
75 |
+
|
76 |
+
def save_checkpoint(config, epoch, model, epoch_accuracy, max_accuracy, optimizer, logger, model_ema=None):
|
77 |
+
if model_ema is not None:
|
78 |
+
logger.info("Model EMA is not None...")
|
79 |
+
save_state = {'model': model.state_dict(),
|
80 |
+
'optimizer': optimizer.state_dict(),
|
81 |
+
'max_accuracy': max(max_accuracy, epoch_accuracy),
|
82 |
+
'epoch': epoch,
|
83 |
+
'state_dict_ema': get_state_dict(model_ema),
|
84 |
+
'input': input,
|
85 |
+
'config': config}
|
86 |
+
else:
|
87 |
+
save_state = {'model': model.state_dict(),
|
88 |
+
'optimizer': optimizer.state_dict(),
|
89 |
+
'max_accuracy': max(max_accuracy, epoch_accuracy),
|
90 |
+
'epoch': epoch,
|
91 |
+
'state_dict_ema': None,
|
92 |
+
'input': input,
|
93 |
+
'config': config}
|
94 |
+
|
95 |
+
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
|
96 |
+
best_path = os.path.join(config.OUTPUT, f'best.pth')
|
97 |
+
|
98 |
+
logger.info(f"{save_path} saving......")
|
99 |
+
torch.save(save_state, save_path)
|
100 |
+
if epoch_accuracy>max_accuracy:
|
101 |
+
torch.save(save_state, best_path)
|
102 |
+
logger.info(f"{save_path} saved !!!")
|
103 |
+
|
104 |
+
|
105 |
+
def get_grad_norm(parameters, norm_type=2):
|
106 |
+
if isinstance(parameters, torch.Tensor):
|
107 |
+
parameters = [parameters]
|
108 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
109 |
+
norm_type = float(norm_type)
|
110 |
+
total_norm = 0
|
111 |
+
for p in parameters:
|
112 |
+
param_norm = p.grad.data.norm(norm_type)
|
113 |
+
total_norm += param_norm.item() ** norm_type
|
114 |
+
total_norm = total_norm ** (1. / norm_type)
|
115 |
+
return total_norm
|
116 |
+
|
117 |
+
|
118 |
+
def auto_resume_helper(output_dir,logger):
|
119 |
+
checkpoints = os.listdir(output_dir)
|
120 |
+
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and ckpt.startswith('ckpt_')]
|
121 |
+
logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}")
|
122 |
+
if len(checkpoints) > 0:
|
123 |
+
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
|
124 |
+
logger.info(f"The latest checkpoint founded: {latest_checkpoint}")
|
125 |
+
resume_file = latest_checkpoint
|
126 |
+
else:
|
127 |
+
resume_file = None
|
128 |
+
return resume_file
|
129 |
+
|
130 |
+
|
131 |
+
def reduce_tensor(tensor):
|
132 |
+
rt = tensor.clone()
|
133 |
+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
134 |
+
rt /= dist.get_world_size()
|
135 |
+
return rt
|
136 |
+
|
137 |
+
def denormalize(tensor: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
|
138 |
+
"""Denormalize a float tensor image with mean and standard deviation.
|
139 |
+
This transform does not support PIL Image.
|
140 |
+
|
141 |
+
.. note::
|
142 |
+
This transform acts out of place by default, i.e., it does not mutates the input tensor.
|
143 |
+
|
144 |
+
See :class:`~torchvision.transforms.Normalize` for more details.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
|
148 |
+
mean (sequence): Sequence of means for each channel.
|
149 |
+
std (sequence): Sequence of standard deviations for each channel.
|
150 |
+
inplace(bool,optional): Bool to make this operation inplace.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
Tensor: Denormalized Tensor image.
|
154 |
+
"""
|
155 |
+
if not isinstance(tensor, torch.Tensor):
|
156 |
+
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
|
157 |
+
|
158 |
+
if not tensor.is_floating_point():
|
159 |
+
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype))
|
160 |
+
|
161 |
+
if tensor.ndim < 3:
|
162 |
+
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
|
163 |
+
'{}.'.format(tensor.size()))
|
164 |
+
|
165 |
+
if not inplace:
|
166 |
+
tensor = tensor.clone()
|
167 |
+
|
168 |
+
dtype = tensor.dtype
|
169 |
+
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
|
170 |
+
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
|
171 |
+
if (std == 0).any():
|
172 |
+
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
|
173 |
+
if mean.ndim == 1:
|
174 |
+
mean = mean.view(-1, 1, 1)
|
175 |
+
if std.ndim == 1:
|
176 |
+
std = std.view(-1, 1, 1)
|
177 |
+
tensor.mul_(std).add_(mean).clip_(0.0, 1.0)
|
178 |
+
return tensor
|
179 |
+
|