LarryTsai commited on
Commit
b9425fd
1 Parent(s): f74a553

Training Code:cls/det

Browse files
Files changed (50) hide show
  1. training/.gitignore +250 -0
  2. training/Detection/README.md +47 -0
  3. training/Detection/configs/_base_/models/cascade_mask_rcnn_revcol_fpn.py +209 -0
  4. training/Detection/configs/revcol/cascade_mask_rcnn_revcol_base_3x_in1k.py +152 -0
  5. training/Detection/configs/revcol/cascade_mask_rcnn_revcol_base_3x_in22k.py +152 -0
  6. training/Detection/configs/revcol/cascade_mask_rcnn_revcol_large_3x_in22k.py +152 -0
  7. training/Detection/configs/revcol/cascade_mask_rcnn_revcol_small_3x_in1k.py +152 -0
  8. training/Detection/configs/revcol/cascade_mask_rcnn_revcol_tiny_3x_in1k.py +152 -0
  9. training/Detection/mmcv_custom/__init__.py +15 -0
  10. training/Detection/mmcv_custom/checkpoint.py +484 -0
  11. training/Detection/mmcv_custom/customized_text.py +130 -0
  12. training/Detection/mmcv_custom/layer_decay_optimizer_constructor.py +121 -0
  13. training/Detection/mmcv_custom/runner/checkpoint.py +85 -0
  14. training/Detection/mmdet/models/backbones/__init__.py +28 -0
  15. training/Detection/mmdet/models/backbones/revcol.py +187 -0
  16. training/Detection/mmdet/models/backbones/revcol_function.py +222 -0
  17. training/Detection/mmdet/models/backbones/revcol_module.py +85 -0
  18. training/Detection/mmdet/utils/__init__.py +12 -0
  19. training/Detection/mmdet/utils/optimizer.py +33 -0
  20. training/INSTRUCTIONS.md +158 -0
  21. training/LICENSE +190 -0
  22. training/README.md +79 -0
  23. training/config.py +243 -0
  24. training/configs/revcol_base_1k.yaml +48 -0
  25. training/configs/revcol_base_1k_224_finetune.yaml +50 -0
  26. training/configs/revcol_base_1k_384_finetune.yaml +50 -0
  27. training/configs/revcol_base_22k_pretrain.yaml +51 -0
  28. training/configs/revcol_large_1k_224_finetune.yaml +51 -0
  29. training/configs/revcol_large_1k_384_finetune.yaml +51 -0
  30. training/configs/revcol_large_22k_pretrain.yaml +50 -0
  31. training/configs/revcol_small_1k.yaml +48 -0
  32. training/configs/revcol_tiny_1k.yaml +48 -0
  33. training/configs/revcol_xlarge_1k_384_finetune.yaml +53 -0
  34. training/configs/revcol_xlarge_22k_pretrain.yaml +50 -0
  35. training/data/__init__.py +1 -0
  36. training/data/build_data.py +137 -0
  37. training/data/samplers.py +29 -0
  38. training/figures/title.png +0 -0
  39. training/logger.py +41 -0
  40. training/loss.py +35 -0
  41. training/lr_scheduler.py +96 -0
  42. training/main.py +422 -0
  43. training/models/__init__.py +1 -0
  44. training/models/build.py +48 -0
  45. training/models/modules.py +157 -0
  46. training/models/revcol.py +242 -0
  47. training/models/revcol_function.py +159 -0
  48. training/optimizer.py +145 -0
  49. training/requirements.txt +7 -0
  50. training/utils.py +179 -0
training/.gitignore ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/en/_build/
68
+ docs/zh_cn/_build/
69
+
70
+ # PyBuilder
71
+ target/
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+
76
+ # pyenv
77
+ .python-version
78
+
79
+ # celery beat schedule file
80
+ celerybeat-schedule
81
+
82
+ # SageMath parsed files
83
+ *.sage.py
84
+
85
+ # Environments
86
+ .env
87
+ .venv
88
+ env/
89
+ venv/
90
+ ENV/
91
+ env.bak/
92
+ venv.bak/
93
+
94
+ # Spyder project settings
95
+ .spyderproject
96
+ .spyproject
97
+
98
+ # Rope project settings
99
+ .ropeproject
100
+
101
+ # mkdocs documentation
102
+ /site
103
+
104
+ # mypy
105
+ .mypy_cache/
106
+
107
+ data/
108
+ data
109
+ .vscode
110
+ .idea
111
+ .DS_Store
112
+
113
+ # custom
114
+ *.pkl
115
+ *.pkl.json
116
+ *.log.json
117
+ docs/modelzoo_statistics.md
118
+ mmdet/.mim
119
+ work_dirs/
120
+ # Byte-compiled / optimized / DLL files
121
+ __pycache__/
122
+ *.py[cod]
123
+ *$py.class
124
+
125
+ # C extensions
126
+ *.so
127
+
128
+ # Distribution / packaging
129
+ .Python
130
+ build/
131
+ develop-eggs/
132
+ dist/
133
+ downloads/
134
+ eggs/
135
+ .eggs/
136
+ lib/
137
+ lib64/
138
+ parts/
139
+ sdist/
140
+ var/
141
+ wheels/
142
+ *.egg-info/
143
+ .installed.cfg
144
+ *.egg
145
+ MANIFEST
146
+
147
+ # PyInstaller
148
+ # Usually these files are written by a python script from a template
149
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
150
+ *.manifest
151
+ *.spec
152
+
153
+ # Installer logs
154
+ pip-log.txt
155
+ pip-delete-this-directory.txt
156
+
157
+ # Unit test / coverage reports
158
+ htmlcov/
159
+ .tox/
160
+ .coverage
161
+ .coverage.*
162
+ .cache
163
+ nosetests.xml
164
+ coverage.xml
165
+ *.cover
166
+ .hypothesis/
167
+ .pytest_cache/
168
+
169
+ # Translations
170
+ *.mo
171
+ *.pot
172
+
173
+ # Django stuff:
174
+ *.log
175
+ local_settings.py
176
+ db.sqlite3
177
+
178
+ # Flask stuff:
179
+ instance/
180
+ .webassets-cache
181
+
182
+ # Scrapy stuff:
183
+ .scrapy
184
+
185
+ # Sphinx documentation
186
+ docs/en/_build/
187
+ docs/zh_cn/_build/
188
+
189
+ # PyBuilder
190
+ target/
191
+
192
+ # Jupyter Notebook
193
+ .ipynb_checkpoints
194
+
195
+ # pyenv
196
+ .python-version
197
+
198
+ # celery beat schedule file
199
+ celerybeat-schedule
200
+
201
+ # SageMath parsed files
202
+ *.sage.py
203
+
204
+ # Environments
205
+ .env
206
+ .venv
207
+ env/
208
+ venv/
209
+ ENV/
210
+ env.bak/
211
+ venv.bak/
212
+
213
+ # Spyder project settings
214
+ .spyderproject
215
+ .spyproject
216
+
217
+ # Rope project settings
218
+ .ropeproject
219
+
220
+ # mkdocs documentation
221
+ /site
222
+
223
+ # mypy
224
+ .mypy_cache/
225
+
226
+ data/
227
+ data
228
+ .vscode
229
+ .idea
230
+ .DS_Store
231
+
232
+ # custom
233
+ *.pkl
234
+ *.pkl.json
235
+ *.log.json
236
+ docs/modelzoo_statistics.md
237
+ mmdet/.mim
238
+ work_dirs/
239
+ .DS_Store
240
+
241
+ # Pytorch
242
+ *.pth
243
+ *.py~
244
+ *.sh~
245
+ .DS_
246
+ # Pytorch
247
+ *.pth
248
+ *.py~
249
+ *.sh~
250
+
training/Detection/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # COCO Object detection with RevCol
2
+
3
+ ## Getting started
4
+
5
+ We build RevCol object detection model based on [mmdetection](https://github.com/open-mmlab/mmdetection/tree/3e2693151add9b5d6db99b944da020cba837266b) commit `3e26931`. We add RevCol model and config files to [the original repo](https://github.com/open-mmlab/mmdetection/tree/3e2693151add9b5d6db99b944da020cba837266b). Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/3e2693151add9b5d6db99b944da020cba837266b/docs/en/get_started.md) for installation and dataset preparation instructions.
6
+
7
+ ## Results and Fine-tuned Models
8
+
9
+ | name | Pretrained Model | Method | Lr Schd | box mAP | mask mAP | #params | FLOPs | Fine-tuned Model |
10
+ |:---:|:---:|:---:|:---:| :---:|:---:|:---:|:---:| :---:|
11
+ | RevCol-T | [ImageNet-1K]() | Cascade Mask R-CNN | 3x | 50.6 | 43.8 | 88M | 741G | [model]() |
12
+ | RevCol-S | [ImageNet-1K]() | Cascade Mask R-CNN | 3x | 52.6 | 45.5 | 118M | 833G | [model]() |
13
+ | RevCol-B | [ImageNet-1K]() | Cascade Mask R-CNN | 3x | 53.0 | 45.9 | 196M | 988G | [model]() |
14
+ | RevCol-B | [ImageNet-22K]() | Cascade Mask R-CNN | 3x | 55.0 | 47.5 | 196M | 988G | [model]() |
15
+ | RevCol-L | [ImageNet-22K]() | Cascade Mask R-CNN | 3x | 55.9 | 48.4 | 330M | 1453G | [model]() |
16
+
17
+
18
+
19
+ ## Training
20
+
21
+ To train a detector with pre-trained models, run:
22
+ ```
23
+ # single-gpu training
24
+ python tools/train.py <CONFIG_FILE> --cfg-options model.pretrained=<PRETRAIN_MODEL> [other optional arguments]
25
+
26
+ # multi-gpu training
27
+ tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [other optional arguments]
28
+ ```
29
+ For example, to train a Cascade Mask R-CNN model with a `RevCol-T` backbone and 8 gpus, run:
30
+ ```
31
+ tools/dist_train.sh configs/revcol/cascade_mask_rcnn_revcol_tiny_3x_in1k.py 8 --cfg-options pretrained=<PRETRAIN_MODEL>
32
+ ```
33
+
34
+ More config files can be found at [`configs/revcol`](configs/revcol).
35
+
36
+ ## Inference
37
+ ```
38
+ # single-gpu testing
39
+ python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm
40
+
41
+ # multi-gpu testing
42
+ tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm
43
+ ```
44
+
45
+ ## Acknowledgment
46
+
47
+ This code is built using [mmdetection](https://github.com/open-mmlab/mmdetection), [timm](https://github.com/rwightman/pytorch-image-models) libraries, and [BeiT](https://github.com/microsoft/unilm/tree/f8f3df80c65eb5e5fc6d6d3c9bd3137621795d1e/beit), [Swin Transformer](https://github.com/microsoft/Swin-Transformer) repositories.
training/Detection/configs/_base_/models/cascade_mask_rcnn_revcol_fpn.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ # model settings
10
+ pretrained = None
11
+ model = dict(
12
+ type='CascadeRCNN',
13
+ backbone=dict(
14
+ type='RevCol',
15
+ channels=[48, 96, 192, 384],
16
+ layers=[3, 3, 9, 3],
17
+ num_subnet=4,
18
+ drop_path=0.2,
19
+ save_memory=False,
20
+ out_indices=[0, 1, 2, 3],
21
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)
22
+ ),
23
+ neck=dict(
24
+ type='FPN',
25
+ in_channels=[128, 256, 512, 1024],
26
+ out_channels=256,
27
+ num_outs=5),
28
+ rpn_head=dict(
29
+ type='RPNHead',
30
+ in_channels=256,
31
+ feat_channels=256,
32
+ anchor_generator=dict(
33
+ type='AnchorGenerator',
34
+ scales=[8],
35
+ ratios=[0.5, 1.0, 2.0],
36
+ strides=[4, 8, 16, 32, 64]),
37
+ bbox_coder=dict(
38
+ type='DeltaXYWHBBoxCoder',
39
+ target_means=[.0, .0, .0, .0],
40
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
41
+ loss_cls=dict(
42
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
43
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
44
+ roi_head=dict(
45
+ type='CascadeRoIHead',
46
+ num_stages=3,
47
+ stage_loss_weights=[1, 0.5, 0.25],
48
+ bbox_roi_extractor=dict(
49
+ type='SingleRoIExtractor',
50
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
51
+ out_channels=256,
52
+ featmap_strides=[4, 8, 16, 32]),
53
+ bbox_head=[
54
+ dict(
55
+ type='Shared2FCBBoxHead',
56
+ in_channels=256,
57
+ fc_out_channels=1024,
58
+ roi_feat_size=7,
59
+ num_classes=80,
60
+ bbox_coder=dict(
61
+ type='DeltaXYWHBBoxCoder',
62
+ target_means=[0., 0., 0., 0.],
63
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
64
+ reg_class_agnostic=True,
65
+ loss_cls=dict(
66
+ type='CrossEntropyLoss',
67
+ use_sigmoid=False,
68
+ loss_weight=1.0),
69
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
70
+ loss_weight=1.0)),
71
+ dict(
72
+ type='Shared2FCBBoxHead',
73
+ in_channels=256,
74
+ fc_out_channels=1024,
75
+ roi_feat_size=7,
76
+ num_classes=80,
77
+ bbox_coder=dict(
78
+ type='DeltaXYWHBBoxCoder',
79
+ target_means=[0., 0., 0., 0.],
80
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
81
+ reg_class_agnostic=True,
82
+ loss_cls=dict(
83
+ type='CrossEntropyLoss',
84
+ use_sigmoid=False,
85
+ loss_weight=1.0),
86
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
87
+ loss_weight=1.0)),
88
+ dict(
89
+ type='Shared2FCBBoxHead',
90
+ in_channels=256,
91
+ fc_out_channels=1024,
92
+ roi_feat_size=7,
93
+ num_classes=80,
94
+ bbox_coder=dict(
95
+ type='DeltaXYWHBBoxCoder',
96
+ target_means=[0., 0., 0., 0.],
97
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
98
+ reg_class_agnostic=True,
99
+ loss_cls=dict(
100
+ type='CrossEntropyLoss',
101
+ use_sigmoid=False,
102
+ loss_weight=1.0),
103
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
104
+ ],
105
+ mask_roi_extractor=dict(
106
+ type='SingleRoIExtractor',
107
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
108
+ out_channels=256,
109
+ featmap_strides=[4, 8, 16, 32]),
110
+ mask_head=dict(
111
+ type='FCNMaskHead',
112
+ num_convs=4,
113
+ in_channels=256,
114
+ conv_out_channels=256,
115
+ num_classes=80,
116
+ loss_mask=dict(
117
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
118
+ # model training and testing settings
119
+ train_cfg = dict(
120
+ rpn=dict(
121
+ assigner=dict(
122
+ type='MaxIoUAssigner',
123
+ pos_iou_thr=0.7,
124
+ neg_iou_thr=0.3,
125
+ min_pos_iou=0.3,
126
+ match_low_quality=True,
127
+ ignore_iof_thr=-1),
128
+ sampler=dict(
129
+ type='RandomSampler',
130
+ num=256,
131
+ pos_fraction=0.5,
132
+ neg_pos_ub=-1,
133
+ add_gt_as_proposals=False),
134
+ allowed_border=0,
135
+ pos_weight=-1,
136
+ debug=False),
137
+ rpn_proposal=dict(
138
+ nms_across_levels=False,
139
+ nms_pre=2000,
140
+ nms_post=2000,
141
+ max_per_img=2000,
142
+ nms=dict(type='nms', iou_threshold=0.7),
143
+ min_bbox_size=0),
144
+ rcnn=[
145
+ dict(
146
+ assigner=dict(
147
+ type='MaxIoUAssigner',
148
+ pos_iou_thr=0.5,
149
+ neg_iou_thr=0.5,
150
+ min_pos_iou=0.5,
151
+ match_low_quality=False,
152
+ ignore_iof_thr=-1),
153
+ sampler=dict(
154
+ type='RandomSampler',
155
+ num=512,
156
+ pos_fraction=0.25,
157
+ neg_pos_ub=-1,
158
+ add_gt_as_proposals=True),
159
+ mask_size=28,
160
+ pos_weight=-1,
161
+ debug=False),
162
+ dict(
163
+ assigner=dict(
164
+ type='MaxIoUAssigner',
165
+ pos_iou_thr=0.6,
166
+ neg_iou_thr=0.6,
167
+ min_pos_iou=0.6,
168
+ match_low_quality=False,
169
+ ignore_iof_thr=-1),
170
+ sampler=dict(
171
+ type='RandomSampler',
172
+ num=512,
173
+ pos_fraction=0.25,
174
+ neg_pos_ub=-1,
175
+ add_gt_as_proposals=True),
176
+ mask_size=28,
177
+ pos_weight=-1,
178
+ debug=False),
179
+ dict(
180
+ assigner=dict(
181
+ type='MaxIoUAssigner',
182
+ pos_iou_thr=0.7,
183
+ neg_iou_thr=0.7,
184
+ min_pos_iou=0.7,
185
+ match_low_quality=False,
186
+ ignore_iof_thr=-1),
187
+ sampler=dict(
188
+ type='RandomSampler',
189
+ num=512,
190
+ pos_fraction=0.25,
191
+ neg_pos_ub=-1,
192
+ add_gt_as_proposals=True),
193
+ mask_size=28,
194
+ pos_weight=-1,
195
+ debug=False)
196
+ ]),
197
+ test_cfg = dict(
198
+ rpn=dict(
199
+ nms_across_levels=False,
200
+ nms_pre=1000,
201
+ nms_post=1000,
202
+ max_per_img=1000,
203
+ nms=dict(type='nms', iou_threshold=0.7),
204
+ min_bbox_size=0),
205
+ rcnn=dict(
206
+ score_thr=0.05,
207
+ nms=dict(type='nms', iou_threshold=0.5),
208
+ max_per_img=100,
209
+ mask_thr_binary=0.5)))
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_base_3x_in1k.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ _base_ = [
10
+ '../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
11
+ '../_base_/datasets/coco_instance.py',
12
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
13
+ ]
14
+ pretrained = './cls_model/revcol_base_1k.pth'
15
+ model = dict(
16
+ backbone=dict(
17
+ channels = [72, 144, 288, 576],
18
+ layers=[1, 1, 3, 2],
19
+ num_subnet=16,
20
+ drop_path = 0.4,
21
+ save_memory=False,
22
+ out_indices=[0, 1, 2, 3],
23
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)
24
+ ),
25
+ neck=dict(in_channels=[72, 144, 288, 576]),
26
+ roi_head=dict(
27
+ bbox_head=[
28
+ dict(
29
+ type='ConvFCBBoxHead',
30
+ num_shared_convs=4,
31
+ num_shared_fcs=1,
32
+ in_channels=256,
33
+ conv_out_channels=256,
34
+ fc_out_channels=1024,
35
+ roi_feat_size=7,
36
+ num_classes=80,
37
+ bbox_coder=dict(
38
+ type='DeltaXYWHBBoxCoder',
39
+ target_means=[0., 0., 0., 0.],
40
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
41
+ reg_class_agnostic=False,
42
+ reg_decoded_bbox=True,
43
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
44
+ loss_cls=dict(
45
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
46
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
47
+ dict(
48
+ type='ConvFCBBoxHead',
49
+ num_shared_convs=4,
50
+ num_shared_fcs=1,
51
+ in_channels=256,
52
+ conv_out_channels=256,
53
+ fc_out_channels=1024,
54
+ roi_feat_size=7,
55
+ num_classes=80,
56
+ bbox_coder=dict(
57
+ type='DeltaXYWHBBoxCoder',
58
+ target_means=[0., 0., 0., 0.],
59
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
60
+ reg_class_agnostic=False,
61
+ reg_decoded_bbox=True,
62
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
63
+ loss_cls=dict(
64
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
65
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
66
+ dict(
67
+ type='ConvFCBBoxHead',
68
+ num_shared_convs=4,
69
+ num_shared_fcs=1,
70
+ in_channels=256,
71
+ conv_out_channels=256,
72
+ fc_out_channels=1024,
73
+ roi_feat_size=7,
74
+ num_classes=80,
75
+ bbox_coder=dict(
76
+ type='DeltaXYWHBBoxCoder',
77
+ target_means=[0., 0., 0., 0.],
78
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
79
+ reg_class_agnostic=False,
80
+ reg_decoded_bbox=True,
81
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
82
+ loss_cls=dict(
83
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
84
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
85
+ ]))
86
+
87
+ img_norm_cfg = dict(
88
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
89
+
90
+ # augmentation strategy originates from DETR / Sparse RCNN
91
+ train_pipeline = [
92
+ dict(type='LoadImageFromFile'),
93
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
94
+ dict(type='RandomFlip', flip_ratio=0.5),
95
+ dict(type='AutoAugment',
96
+ policies=[
97
+ [
98
+ dict(type='Resize',
99
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
100
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
101
+ (736, 1333), (768, 1333), (800, 1333)],
102
+ multiscale_mode='value',
103
+ keep_ratio=True)
104
+ ],
105
+ [
106
+ dict(type='Resize',
107
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
108
+ multiscale_mode='value',
109
+ keep_ratio=True),
110
+ dict(type='RandomCrop',
111
+ crop_type='absolute_range',
112
+ crop_size=(384, 600),
113
+ allow_negative_crop=True),
114
+ dict(type='Resize',
115
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
116
+ (576, 1333), (608, 1333), (640, 1333),
117
+ (672, 1333), (704, 1333), (736, 1333),
118
+ (768, 1333), (800, 1333)],
119
+ multiscale_mode='value',
120
+ override=True,
121
+ keep_ratio=True)
122
+ ]
123
+ ]),
124
+ dict(type='Normalize', **img_norm_cfg),
125
+ dict(type='Pad', size_divisor=32),
126
+ dict(type='DefaultFormatBundle'),
127
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
128
+ ]
129
+ data = dict(train=dict(pipeline=train_pipeline))
130
+
131
+
132
+ optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
133
+ lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
134
+ paramwise_cfg={'decay_rate': 0.9,
135
+ 'decay_type': 'layer_wise',
136
+ 'layers': [1, 1, 3, 2],
137
+ 'num_subnet': 16})
138
+
139
+ lr_config = dict(step=[27, 33])
140
+ runner = dict(type='EpochBasedRunner', max_epochs=36)
141
+
142
+
143
+ # fp16 = None
144
+ # optimizer_config = dict(
145
+ # type="DistOptimizerHook",
146
+ # update_interval=1,
147
+ # grad_clip=None,
148
+ # coalesce=True,
149
+ # bucket_size_mb=-1,
150
+ # use_fp16=True,
151
+ # )
152
+ fp16 = dict(loss_scale='dynamic')
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_base_3x_in22k.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ _base_ = [
10
+ '../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
11
+ '../_base_/datasets/coco_instance.py',
12
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
13
+ ]
14
+ pretrained = './cls_model/revcol_base_22k.pth'
15
+ model = dict(
16
+ backbone=dict(
17
+ channels = [72, 144, 288, 576],
18
+ layers=[1, 1, 3, 2],
19
+ num_subnet=16,
20
+ drop_path = 0.4,
21
+ save_memory=False,
22
+ out_indices=[0, 1, 2, 3],
23
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)
24
+ ),
25
+ neck=dict(in_channels=[72, 144, 288, 576]),
26
+ roi_head=dict(
27
+ bbox_head=[
28
+ dict(
29
+ type='ConvFCBBoxHead',
30
+ num_shared_convs=4,
31
+ num_shared_fcs=1,
32
+ in_channels=256,
33
+ conv_out_channels=256,
34
+ fc_out_channels=1024,
35
+ roi_feat_size=7,
36
+ num_classes=80,
37
+ bbox_coder=dict(
38
+ type='DeltaXYWHBBoxCoder',
39
+ target_means=[0., 0., 0., 0.],
40
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
41
+ reg_class_agnostic=False,
42
+ reg_decoded_bbox=True,
43
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
44
+ loss_cls=dict(
45
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
46
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
47
+ dict(
48
+ type='ConvFCBBoxHead',
49
+ num_shared_convs=4,
50
+ num_shared_fcs=1,
51
+ in_channels=256,
52
+ conv_out_channels=256,
53
+ fc_out_channels=1024,
54
+ roi_feat_size=7,
55
+ num_classes=80,
56
+ bbox_coder=dict(
57
+ type='DeltaXYWHBBoxCoder',
58
+ target_means=[0., 0., 0., 0.],
59
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
60
+ reg_class_agnostic=False,
61
+ reg_decoded_bbox=True,
62
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
63
+ loss_cls=dict(
64
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
65
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
66
+ dict(
67
+ type='ConvFCBBoxHead',
68
+ num_shared_convs=4,
69
+ num_shared_fcs=1,
70
+ in_channels=256,
71
+ conv_out_channels=256,
72
+ fc_out_channels=1024,
73
+ roi_feat_size=7,
74
+ num_classes=80,
75
+ bbox_coder=dict(
76
+ type='DeltaXYWHBBoxCoder',
77
+ target_means=[0., 0., 0., 0.],
78
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
79
+ reg_class_agnostic=False,
80
+ reg_decoded_bbox=True,
81
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
82
+ loss_cls=dict(
83
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
84
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
85
+ ]))
86
+
87
+ img_norm_cfg = dict(
88
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
89
+
90
+ # augmentation strategy originates from DETR / Sparse RCNN
91
+ train_pipeline = [
92
+ dict(type='LoadImageFromFile'),
93
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
94
+ dict(type='RandomFlip', flip_ratio=0.5),
95
+ dict(type='AutoAugment',
96
+ policies=[
97
+ [
98
+ dict(type='Resize',
99
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
100
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
101
+ (736, 1333), (768, 1333), (800, 1333)],
102
+ multiscale_mode='value',
103
+ keep_ratio=True)
104
+ ],
105
+ [
106
+ dict(type='Resize',
107
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
108
+ multiscale_mode='value',
109
+ keep_ratio=True),
110
+ dict(type='RandomCrop',
111
+ crop_type='absolute_range',
112
+ crop_size=(384, 600),
113
+ allow_negative_crop=True),
114
+ dict(type='Resize',
115
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
116
+ (576, 1333), (608, 1333), (640, 1333),
117
+ (672, 1333), (704, 1333), (736, 1333),
118
+ (768, 1333), (800, 1333)],
119
+ multiscale_mode='value',
120
+ override=True,
121
+ keep_ratio=True)
122
+ ]
123
+ ]),
124
+ dict(type='Normalize', **img_norm_cfg),
125
+ dict(type='Pad', size_divisor=32),
126
+ dict(type='DefaultFormatBundle'),
127
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
128
+ ]
129
+ data = dict(train=dict(pipeline=train_pipeline))
130
+
131
+
132
+ optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
133
+ lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
134
+ paramwise_cfg={'decay_rate': 0.9,
135
+ 'decay_type': 'layer_wise',
136
+ 'layers': [1, 1, 3, 2],
137
+ 'num_subnet': 16})
138
+
139
+ lr_config = dict(step=[27, 33])
140
+ runner = dict(type='EpochBasedRunner', max_epochs=36)
141
+
142
+
143
+ # fp16 = None
144
+ # optimizer_config = dict(
145
+ # type="DistOptimizerHook",
146
+ # update_interval=1,
147
+ # grad_clip=None,
148
+ # coalesce=True,
149
+ # bucket_size_mb=-1,
150
+ # use_fp16=True,
151
+ # )
152
+ fp16 = dict(loss_scale='dynamic')
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_large_3x_in22k.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ _base_ = [
10
+ '../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
11
+ '../_base_/datasets/coco_instance.py',
12
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
13
+ ]
14
+ pretrained = './cls_model/revcol_large_22k.pth'
15
+ model = dict(
16
+ backbone=dict(
17
+ channels = [128, 256, 512, 1024],
18
+ layers=[1, 2, 6, 2],
19
+ num_subnet=8,
20
+ drop_path = 0.5,
21
+ save_memory=False,
22
+ out_indices=[0, 1, 2, 3],
23
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)
24
+ ),
25
+ neck=dict(in_channels=[128, 256, 512, 1024]),
26
+ roi_head=dict(
27
+ bbox_head=[
28
+ dict(
29
+ type='ConvFCBBoxHead',
30
+ num_shared_convs=4,
31
+ num_shared_fcs=1,
32
+ in_channels=256,
33
+ conv_out_channels=256,
34
+ fc_out_channels=1024,
35
+ roi_feat_size=7,
36
+ num_classes=80,
37
+ bbox_coder=dict(
38
+ type='DeltaXYWHBBoxCoder',
39
+ target_means=[0., 0., 0., 0.],
40
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
41
+ reg_class_agnostic=False,
42
+ reg_decoded_bbox=True,
43
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
44
+ loss_cls=dict(
45
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
46
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
47
+ dict(
48
+ type='ConvFCBBoxHead',
49
+ num_shared_convs=4,
50
+ num_shared_fcs=1,
51
+ in_channels=256,
52
+ conv_out_channels=256,
53
+ fc_out_channels=1024,
54
+ roi_feat_size=7,
55
+ num_classes=80,
56
+ bbox_coder=dict(
57
+ type='DeltaXYWHBBoxCoder',
58
+ target_means=[0., 0., 0., 0.],
59
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
60
+ reg_class_agnostic=False,
61
+ reg_decoded_bbox=True,
62
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
63
+ loss_cls=dict(
64
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
65
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
66
+ dict(
67
+ type='ConvFCBBoxHead',
68
+ num_shared_convs=4,
69
+ num_shared_fcs=1,
70
+ in_channels=256,
71
+ conv_out_channels=256,
72
+ fc_out_channels=1024,
73
+ roi_feat_size=7,
74
+ num_classes=80,
75
+ bbox_coder=dict(
76
+ type='DeltaXYWHBBoxCoder',
77
+ target_means=[0., 0., 0., 0.],
78
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
79
+ reg_class_agnostic=False,
80
+ reg_decoded_bbox=True,
81
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
82
+ loss_cls=dict(
83
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
84
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
85
+ ]))
86
+
87
+ img_norm_cfg = dict(
88
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
89
+
90
+ # augmentation strategy originates from DETR / Sparse RCNN
91
+ train_pipeline = [
92
+ dict(type='LoadImageFromFile'),
93
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
94
+ dict(type='RandomFlip', flip_ratio=0.5),
95
+ dict(type='AutoAugment',
96
+ policies=[
97
+ [
98
+ dict(type='Resize',
99
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
100
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
101
+ (736, 1333), (768, 1333), (800, 1333)],
102
+ multiscale_mode='value',
103
+ keep_ratio=True)
104
+ ],
105
+ [
106
+ dict(type='Resize',
107
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
108
+ multiscale_mode='value',
109
+ keep_ratio=True),
110
+ dict(type='RandomCrop',
111
+ crop_type='absolute_range',
112
+ crop_size=(384, 600),
113
+ allow_negative_crop=True),
114
+ dict(type='Resize',
115
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
116
+ (576, 1333), (608, 1333), (640, 1333),
117
+ (672, 1333), (704, 1333), (736, 1333),
118
+ (768, 1333), (800, 1333)],
119
+ multiscale_mode='value',
120
+ override=True,
121
+ keep_ratio=True)
122
+ ]
123
+ ]),
124
+ dict(type='Normalize', **img_norm_cfg),
125
+ dict(type='Pad', size_divisor=32),
126
+ dict(type='DefaultFormatBundle'),
127
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
128
+ ]
129
+ data = dict(train=dict(pipeline=train_pipeline))
130
+
131
+
132
+ optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
133
+ lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
134
+ paramwise_cfg={'decay_rate': 0.85,
135
+ 'decay_type': 'layer_wise',
136
+ 'layers': [1, 2, 6, 2],
137
+ 'num_subnet': 8})
138
+
139
+ lr_config = dict(step=[27, 33])
140
+ runner = dict(type='EpochBasedRunner', max_epochs=36)
141
+
142
+
143
+ # fp16 = None
144
+ # optimizer_config = dict(
145
+ # type="DistOptimizerHook",
146
+ # update_interval=1,
147
+ # grad_clip=None,
148
+ # coalesce=True,
149
+ # bucket_size_mb=-1,
150
+ # use_fp16=True,
151
+ # )
152
+ fp16 = dict(loss_scale='dynamic')
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_small_3x_in1k.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ _base_ = [
10
+ '../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
11
+ '../_base_/datasets/coco_instance.py',
12
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
13
+ ]
14
+ pretrained = './cls_model/revcol_small_1k.pth'
15
+ model = dict(
16
+ backbone=dict(
17
+ channels=[64, 128, 256, 512],
18
+ layers=[2, 2, 4, 2],
19
+ num_subnet=8,
20
+ drop_path = 0.4,
21
+ save_memory=False,
22
+ out_indices=[0, 1, 2, 3],
23
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)
24
+ ),
25
+ neck=dict(in_channels=[64, 128, 256, 512]),
26
+ roi_head=dict(
27
+ bbox_head=[
28
+ dict(
29
+ type='ConvFCBBoxHead',
30
+ num_shared_convs=4,
31
+ num_shared_fcs=1,
32
+ in_channels=256,
33
+ conv_out_channels=256,
34
+ fc_out_channels=1024,
35
+ roi_feat_size=7,
36
+ num_classes=80,
37
+ bbox_coder=dict(
38
+ type='DeltaXYWHBBoxCoder',
39
+ target_means=[0., 0., 0., 0.],
40
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
41
+ reg_class_agnostic=False,
42
+ reg_decoded_bbox=True,
43
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
44
+ loss_cls=dict(
45
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
46
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
47
+ dict(
48
+ type='ConvFCBBoxHead',
49
+ num_shared_convs=4,
50
+ num_shared_fcs=1,
51
+ in_channels=256,
52
+ conv_out_channels=256,
53
+ fc_out_channels=1024,
54
+ roi_feat_size=7,
55
+ num_classes=80,
56
+ bbox_coder=dict(
57
+ type='DeltaXYWHBBoxCoder',
58
+ target_means=[0., 0., 0., 0.],
59
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
60
+ reg_class_agnostic=False,
61
+ reg_decoded_bbox=True,
62
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
63
+ loss_cls=dict(
64
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
65
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
66
+ dict(
67
+ type='ConvFCBBoxHead',
68
+ num_shared_convs=4,
69
+ num_shared_fcs=1,
70
+ in_channels=256,
71
+ conv_out_channels=256,
72
+ fc_out_channels=1024,
73
+ roi_feat_size=7,
74
+ num_classes=80,
75
+ bbox_coder=dict(
76
+ type='DeltaXYWHBBoxCoder',
77
+ target_means=[0., 0., 0., 0.],
78
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
79
+ reg_class_agnostic=False,
80
+ reg_decoded_bbox=True,
81
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
82
+ loss_cls=dict(
83
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
84
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
85
+ ]))
86
+
87
+ img_norm_cfg = dict(
88
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
89
+
90
+ # augmentation strategy originates from DETR / Sparse RCNN
91
+ train_pipeline = [
92
+ dict(type='LoadImageFromFile'),
93
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
94
+ dict(type='RandomFlip', flip_ratio=0.5),
95
+ dict(type='AutoAugment',
96
+ policies=[
97
+ [
98
+ dict(type='Resize',
99
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
100
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
101
+ (736, 1333), (768, 1333), (800, 1333)],
102
+ multiscale_mode='value',
103
+ keep_ratio=True)
104
+ ],
105
+ [
106
+ dict(type='Resize',
107
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
108
+ multiscale_mode='value',
109
+ keep_ratio=True),
110
+ dict(type='RandomCrop',
111
+ crop_type='absolute_range',
112
+ crop_size=(384, 600),
113
+ allow_negative_crop=True),
114
+ dict(type='Resize',
115
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
116
+ (576, 1333), (608, 1333), (640, 1333),
117
+ (672, 1333), (704, 1333), (736, 1333),
118
+ (768, 1333), (800, 1333)],
119
+ multiscale_mode='value',
120
+ override=True,
121
+ keep_ratio=True)
122
+ ]
123
+ ]),
124
+ dict(type='Normalize', **img_norm_cfg),
125
+ dict(type='Pad', size_divisor=32),
126
+ dict(type='DefaultFormatBundle'),
127
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
128
+ ]
129
+ data = dict(train=dict(pipeline=train_pipeline))
130
+
131
+
132
+ optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
133
+ lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
134
+ paramwise_cfg={'decay_rate': 0.85,
135
+ 'decay_type': 'layer_wise',
136
+ 'layers': [2, 2, 4, 2],
137
+ 'num_subnet': 8})
138
+
139
+ lr_config = dict(step=[27, 33])
140
+ runner = dict(type='EpochBasedRunner', max_epochs=36)
141
+
142
+
143
+ # fp16 = None
144
+ # optimizer_config = dict(
145
+ # type="DistOptimizerHook",
146
+ # update_interval=1,
147
+ # grad_clip=None,
148
+ # coalesce=True,
149
+ # bucket_size_mb=-1,
150
+ # use_fp16=True,
151
+ # )
152
+ fp16 = dict(loss_scale='dynamic')
training/Detection/configs/revcol/cascade_mask_rcnn_revcol_tiny_3x_in1k.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ _base_ = [
10
+ '../_base_/models/cascade_mask_rcnn_revcol_fpn.py',
11
+ '../_base_/datasets/coco_instance.py',
12
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
13
+ ]
14
+ pretrained = './cls_model/revcol_tiny_1k.pth'
15
+ model = dict(
16
+ backbone=dict(
17
+ channels=[64, 128, 256, 512],
18
+ layers=[2, 2, 4, 2],
19
+ num_subnet=4,
20
+ drop_path = 0.3,
21
+ save_memory=False,
22
+ out_indices=[0, 1, 2, 3],
23
+ init_cfg=dict(type='Pretrained', checkpoint=pretrained)
24
+ ),
25
+ neck=dict(in_channels=[64, 128, 256, 512]),
26
+ roi_head=dict(
27
+ bbox_head=[
28
+ dict(
29
+ type='ConvFCBBoxHead',
30
+ num_shared_convs=4,
31
+ num_shared_fcs=1,
32
+ in_channels=256,
33
+ conv_out_channels=256,
34
+ fc_out_channels=1024,
35
+ roi_feat_size=7,
36
+ num_classes=80,
37
+ bbox_coder=dict(
38
+ type='DeltaXYWHBBoxCoder',
39
+ target_means=[0., 0., 0., 0.],
40
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
41
+ reg_class_agnostic=False,
42
+ reg_decoded_bbox=True,
43
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
44
+ loss_cls=dict(
45
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
46
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
47
+ dict(
48
+ type='ConvFCBBoxHead',
49
+ num_shared_convs=4,
50
+ num_shared_fcs=1,
51
+ in_channels=256,
52
+ conv_out_channels=256,
53
+ fc_out_channels=1024,
54
+ roi_feat_size=7,
55
+ num_classes=80,
56
+ bbox_coder=dict(
57
+ type='DeltaXYWHBBoxCoder',
58
+ target_means=[0., 0., 0., 0.],
59
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
60
+ reg_class_agnostic=False,
61
+ reg_decoded_bbox=True,
62
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
63
+ loss_cls=dict(
64
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
65
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0)),
66
+ dict(
67
+ type='ConvFCBBoxHead',
68
+ num_shared_convs=4,
69
+ num_shared_fcs=1,
70
+ in_channels=256,
71
+ conv_out_channels=256,
72
+ fc_out_channels=1024,
73
+ roi_feat_size=7,
74
+ num_classes=80,
75
+ bbox_coder=dict(
76
+ type='DeltaXYWHBBoxCoder',
77
+ target_means=[0., 0., 0., 0.],
78
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
79
+ reg_class_agnostic=False,
80
+ reg_decoded_bbox=True,
81
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
82
+ loss_cls=dict(
83
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
84
+ loss_bbox=dict(type='GIoULoss', loss_weight=10.0))
85
+ ]))
86
+
87
+ img_norm_cfg = dict(
88
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
89
+
90
+ # augmentation strategy originates from DETR / Sparse RCNN
91
+ train_pipeline = [
92
+ dict(type='LoadImageFromFile'),
93
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
94
+ dict(type='RandomFlip', flip_ratio=0.5),
95
+ dict(type='AutoAugment',
96
+ policies=[
97
+ [
98
+ dict(type='Resize',
99
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
100
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
101
+ (736, 1333), (768, 1333), (800, 1333)],
102
+ multiscale_mode='value',
103
+ keep_ratio=True)
104
+ ],
105
+ [
106
+ dict(type='Resize',
107
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
108
+ multiscale_mode='value',
109
+ keep_ratio=True),
110
+ dict(type='RandomCrop',
111
+ crop_type='absolute_range',
112
+ crop_size=(384, 600),
113
+ allow_negative_crop=True),
114
+ dict(type='Resize',
115
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
116
+ (576, 1333), (608, 1333), (640, 1333),
117
+ (672, 1333), (704, 1333), (736, 1333),
118
+ (768, 1333), (800, 1333)],
119
+ multiscale_mode='value',
120
+ override=True,
121
+ keep_ratio=True)
122
+ ]
123
+ ]),
124
+ dict(type='Normalize', **img_norm_cfg),
125
+ dict(type='Pad', size_divisor=32),
126
+ dict(type='DefaultFormatBundle'),
127
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
128
+ ]
129
+ data = dict(train=dict(pipeline=train_pipeline))
130
+
131
+
132
+ optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
133
+ lr=0.0002, betas=(0.9, 0.999), weight_decay=0.05,
134
+ paramwise_cfg={'decay_rate': 0.85,
135
+ 'decay_type': 'layer_wise',
136
+ 'layers': [2, 2, 4, 2],
137
+ 'num_subnet': 4})
138
+
139
+ lr_config = dict(step=[27, 33])
140
+ runner = dict(type='EpochBasedRunner', max_epochs=36)
141
+
142
+
143
+ # fp16 = None
144
+ # optimizer_config = dict(
145
+ # type="DistOptimizerHook",
146
+ # update_interval=1,
147
+ # grad_clip=None,
148
+ # coalesce=True,
149
+ # bucket_size_mb=-1,
150
+ # use_fp16=True,
151
+ # )
152
+ fp16 = dict(loss_scale='dynamic')
training/Detection/mmcv_custom/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ # -*- coding: utf-8 -*-
10
+
11
+ from .checkpoint import load_checkpoint
12
+ from .layer_decay_optimizer_constructor import LearningRateDecayOptimizerConstructor
13
+ from .customized_text import CustomizedTextLoggerHook
14
+
15
+ __all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'CustomizedTextLoggerHook']
training/Detection/mmcv_custom/checkpoint.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import io
3
+ import os
4
+ import os.path as osp
5
+ import pkgutil
6
+ import time
7
+ import warnings
8
+ from collections import OrderedDict
9
+ from importlib import import_module
10
+ from tempfile import TemporaryDirectory
11
+
12
+ import torch
13
+ import torchvision
14
+ from torch.optim import Optimizer
15
+ from torch.utils import model_zoo
16
+ from torch.nn import functional as F
17
+
18
+ import mmcv
19
+ from mmcv.fileio import FileClient
20
+ from mmcv.fileio import load as load_file
21
+ from mmcv.parallel import is_module_wrapper
22
+ from mmcv.utils import mkdir_or_exist
23
+ from mmcv.runner import get_dist_info
24
+
25
+ ENV_MMCV_HOME = 'MMCV_HOME'
26
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
27
+ DEFAULT_CACHE_DIR = '~/.cache'
28
+
29
+
30
+ def _get_mmcv_home():
31
+ mmcv_home = os.path.expanduser(
32
+ os.getenv(
33
+ ENV_MMCV_HOME,
34
+ os.path.join(
35
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
36
+
37
+ mkdir_or_exist(mmcv_home)
38
+ return mmcv_home
39
+
40
+
41
+ def load_state_dict(module, state_dict, strict=False, logger=None):
42
+ """Load state_dict to a module.
43
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
44
+ Default value for ``strict`` is set to ``False`` and the message for
45
+ param mismatch will be shown even if strict is False.
46
+ Args:
47
+ module (Module): Module that receives the state_dict.
48
+ state_dict (OrderedDict): Weights.
49
+ strict (bool): whether to strictly enforce that the keys
50
+ in :attr:`state_dict` match the keys returned by this module's
51
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
52
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
53
+ message. If not specified, print function will be used.
54
+ """
55
+ unexpected_keys = []
56
+ all_missing_keys = []
57
+ err_msg = []
58
+
59
+ metadata = getattr(state_dict, '_metadata', None)
60
+ state_dict = state_dict.copy()
61
+ if metadata is not None:
62
+ state_dict._metadata = metadata
63
+
64
+ # use _load_from_state_dict to enable checkpoint version control
65
+ def load(module, prefix=''):
66
+ # recursively check parallel module in case that the model has a
67
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
68
+ if is_module_wrapper(module):
69
+ module = module.module
70
+ local_metadata = {} if metadata is None else metadata.get(
71
+ prefix[:-1], {})
72
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
73
+ all_missing_keys, unexpected_keys,
74
+ err_msg)
75
+ for name, child in module._modules.items():
76
+ if child is not None:
77
+ load(child, prefix + name + '.')
78
+
79
+ load(module)
80
+ load = None # break load->load reference cycle
81
+
82
+ # ignore "num_batches_tracked" of BN layers
83
+ missing_keys = [
84
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
85
+ ]
86
+
87
+ if unexpected_keys:
88
+ err_msg.append('unexpected key in source '
89
+ f'state_dict: {", ".join(unexpected_keys)}\n')
90
+ if missing_keys:
91
+ err_msg.append(
92
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
93
+
94
+ rank, _ = get_dist_info()
95
+ if len(err_msg) > 0 and rank == 0:
96
+ err_msg.insert(
97
+ 0, 'The model and loaded state dict do not match exactly\n')
98
+ err_msg = '\n'.join(err_msg)
99
+ if strict:
100
+ raise RuntimeError(err_msg)
101
+ elif logger is not None:
102
+ logger.warning(err_msg)
103
+ else:
104
+ print(err_msg)
105
+
106
+
107
+ def load_url_dist(url, model_dir=None):
108
+ """In distributed setting, this function only download checkpoint at local
109
+ rank 0."""
110
+ rank, world_size = get_dist_info()
111
+ rank = int(os.environ.get('LOCAL_RANK', rank))
112
+ if rank == 0:
113
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
114
+ if world_size > 1:
115
+ torch.distributed.barrier()
116
+ if rank > 0:
117
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
118
+ return checkpoint
119
+
120
+
121
+ def load_pavimodel_dist(model_path, map_location=None):
122
+ """In distributed setting, this function only download checkpoint at local
123
+ rank 0."""
124
+ try:
125
+ from pavi import modelcloud
126
+ except ImportError:
127
+ raise ImportError(
128
+ 'Please install pavi to load checkpoint from modelcloud.')
129
+ rank, world_size = get_dist_info()
130
+ rank = int(os.environ.get('LOCAL_RANK', rank))
131
+ if rank == 0:
132
+ model = modelcloud.get(model_path)
133
+ with TemporaryDirectory() as tmp_dir:
134
+ downloaded_file = osp.join(tmp_dir, model.name)
135
+ model.download(downloaded_file)
136
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
137
+ if world_size > 1:
138
+ torch.distributed.barrier()
139
+ if rank > 0:
140
+ model = modelcloud.get(model_path)
141
+ with TemporaryDirectory() as tmp_dir:
142
+ downloaded_file = osp.join(tmp_dir, model.name)
143
+ model.download(downloaded_file)
144
+ checkpoint = torch.load(
145
+ downloaded_file, map_location=map_location)
146
+ return checkpoint
147
+
148
+
149
+ def load_fileclient_dist(filename, backend, map_location):
150
+ """In distributed setting, this function only download checkpoint at local
151
+ rank 0."""
152
+ rank, world_size = get_dist_info()
153
+ rank = int(os.environ.get('LOCAL_RANK', rank))
154
+ allowed_backends = ['ceph']
155
+ if backend not in allowed_backends:
156
+ raise ValueError(f'Load from Backend {backend} is not supported.')
157
+ if rank == 0:
158
+ fileclient = FileClient(backend=backend)
159
+ buffer = io.BytesIO(fileclient.get(filename))
160
+ checkpoint = torch.load(buffer, map_location=map_location)
161
+ if world_size > 1:
162
+ torch.distributed.barrier()
163
+ if rank > 0:
164
+ fileclient = FileClient(backend=backend)
165
+ buffer = io.BytesIO(fileclient.get(filename))
166
+ checkpoint = torch.load(buffer, map_location=map_location)
167
+ return checkpoint
168
+
169
+
170
+ def get_torchvision_models():
171
+ model_urls = dict()
172
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
173
+ if ispkg:
174
+ continue
175
+ _zoo = import_module(f'torchvision.models.{name}')
176
+ if hasattr(_zoo, 'model_urls'):
177
+ _urls = getattr(_zoo, 'model_urls')
178
+ model_urls.update(_urls)
179
+ return model_urls
180
+
181
+
182
+ def get_external_models():
183
+ mmcv_home = _get_mmcv_home()
184
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
185
+ default_urls = load_file(default_json_path)
186
+ assert isinstance(default_urls, dict)
187
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
188
+ if osp.exists(external_json_path):
189
+ external_urls = load_file(external_json_path)
190
+ assert isinstance(external_urls, dict)
191
+ default_urls.update(external_urls)
192
+
193
+ return default_urls
194
+
195
+
196
+ def get_mmcls_models():
197
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
198
+ mmcls_urls = load_file(mmcls_json_path)
199
+
200
+ return mmcls_urls
201
+
202
+
203
+ def get_deprecated_model_names():
204
+ deprecate_json_path = osp.join(mmcv.__path__[0],
205
+ 'model_zoo/deprecated.json')
206
+ deprecate_urls = load_file(deprecate_json_path)
207
+ assert isinstance(deprecate_urls, dict)
208
+
209
+ return deprecate_urls
210
+
211
+
212
+ def _process_mmcls_checkpoint(checkpoint):
213
+ state_dict = checkpoint['state_dict']
214
+ new_state_dict = OrderedDict()
215
+ for k, v in state_dict.items():
216
+ if k.startswith('backbone.'):
217
+ new_state_dict[k[9:]] = v
218
+ new_checkpoint = dict(state_dict=new_state_dict)
219
+
220
+ return new_checkpoint
221
+
222
+
223
+ def _load_checkpoint(filename, map_location=None):
224
+ """Load checkpoint from somewhere (modelzoo, file, url).
225
+ Args:
226
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
227
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
228
+ details.
229
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
230
+ Returns:
231
+ dict | OrderedDict: The loaded checkpoint. It can be either an
232
+ OrderedDict storing model weights or a dict containing other
233
+ information, which depends on the checkpoint.
234
+ """
235
+ if filename.startswith('modelzoo://'):
236
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
237
+ 'use "torchvision://" instead')
238
+ model_urls = get_torchvision_models()
239
+ model_name = filename[11:]
240
+ checkpoint = load_url_dist(model_urls[model_name])
241
+ elif filename.startswith('torchvision://'):
242
+ model_urls = get_torchvision_models()
243
+ model_name = filename[14:]
244
+ checkpoint = load_url_dist(model_urls[model_name])
245
+ elif filename.startswith('open-mmlab://'):
246
+ model_urls = get_external_models()
247
+ model_name = filename[13:]
248
+ deprecated_urls = get_deprecated_model_names()
249
+ if model_name in deprecated_urls:
250
+ warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
251
+ f'of open-mmlab://{deprecated_urls[model_name]}')
252
+ model_name = deprecated_urls[model_name]
253
+ model_url = model_urls[model_name]
254
+ # check if is url
255
+ if model_url.startswith(('http://', 'https://')):
256
+ checkpoint = load_url_dist(model_url)
257
+ else:
258
+ filename = osp.join(_get_mmcv_home(), model_url)
259
+ if not osp.isfile(filename):
260
+ raise IOError(f'{filename} is not a checkpoint file')
261
+ checkpoint = torch.load(filename, map_location=map_location)
262
+ elif filename.startswith('mmcls://'):
263
+ model_urls = get_mmcls_models()
264
+ model_name = filename[8:]
265
+ checkpoint = load_url_dist(model_urls[model_name])
266
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
267
+ elif filename.startswith(('http://', 'https://')):
268
+ checkpoint = load_url_dist(filename)
269
+ elif filename.startswith('pavi://'):
270
+ model_path = filename[7:]
271
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
272
+ elif filename.startswith('s3://'):
273
+ checkpoint = load_fileclient_dist(
274
+ filename, backend='ceph', map_location=map_location)
275
+ else:
276
+ if not osp.isfile(filename):
277
+ raise IOError(f'{filename} is not a checkpoint file')
278
+ checkpoint = torch.load(filename, map_location=map_location)
279
+ return checkpoint
280
+
281
+
282
+ def load_checkpoint(model,
283
+ filename,
284
+ map_location='cpu',
285
+ strict=False,
286
+ logger=None):
287
+ """Load checkpoint from a file or URI.
288
+ Args:
289
+ model (Module): Module to load checkpoint.
290
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
291
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
292
+ details.
293
+ map_location (str): Same as :func:`torch.load`.
294
+ strict (bool): Whether to allow different params for the model and
295
+ checkpoint.
296
+ logger (:mod:`logging.Logger` or None): The logger for error message.
297
+ Returns:
298
+ dict or OrderedDict: The loaded checkpoint.
299
+ """
300
+ checkpoint = _load_checkpoint(filename, map_location)
301
+ # OrderedDict is a subclass of dict
302
+ if not isinstance(checkpoint, dict):
303
+ raise RuntimeError(
304
+ f'No state_dict found in checkpoint file {filename}')
305
+ # get state_dict from checkpoint
306
+ if 'state_dict' in checkpoint:
307
+ state_dict = checkpoint['state_dict']
308
+ elif 'model' in checkpoint:
309
+ state_dict = checkpoint['model']
310
+ else:
311
+ state_dict = checkpoint
312
+ # strip prefix of state_dict
313
+ if list(state_dict.keys())[0].startswith('module.'):
314
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
315
+
316
+ # for MoBY, load model of online branch
317
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
318
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
319
+
320
+ # reshape absolute position embedding
321
+ if state_dict.get('absolute_pos_embed') is not None:
322
+ absolute_pos_embed = state_dict['absolute_pos_embed']
323
+ N1, L, C1 = absolute_pos_embed.size()
324
+ N2, C2, H, W = model.absolute_pos_embed.size()
325
+ if N1 != N2 or C1 != C2 or L != H*W:
326
+ logger.warning("Error in loading absolute_pos_embed, pass")
327
+ else:
328
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
329
+
330
+ # interpolate position bias table if needed
331
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
332
+ for table_key in relative_position_bias_table_keys:
333
+ table_pretrained = state_dict[table_key]
334
+ table_current = model.state_dict()[table_key]
335
+ L1, nH1 = table_pretrained.size()
336
+ L2, nH2 = table_current.size()
337
+ if nH1 != nH2:
338
+ logger.warning(f"Error in loading {table_key}, pass")
339
+ else:
340
+ if L1 != L2:
341
+ S1 = int(L1 ** 0.5)
342
+ S2 = int(L2 ** 0.5)
343
+ table_pretrained_resized = F.interpolate(
344
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
345
+ size=(S2, S2), mode='bicubic')
346
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
347
+
348
+ # load state_dict
349
+ load_state_dict(model, state_dict, strict, logger)
350
+ return checkpoint
351
+
352
+
353
+ def weights_to_cpu(state_dict):
354
+ """Copy a model state_dict to cpu.
355
+ Args:
356
+ state_dict (OrderedDict): Model weights on GPU.
357
+ Returns:
358
+ OrderedDict: Model weights on GPU.
359
+ """
360
+ state_dict_cpu = OrderedDict()
361
+ for key, val in state_dict.items():
362
+ state_dict_cpu[key] = val.cpu()
363
+ return state_dict_cpu
364
+
365
+
366
+ def _save_to_state_dict(module, destination, prefix, keep_vars):
367
+ """Saves module state to `destination` dictionary.
368
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
369
+ Args:
370
+ module (nn.Module): The module to generate state_dict.
371
+ destination (dict): A dict where state will be stored.
372
+ prefix (str): The prefix for parameters and buffers used in this
373
+ module.
374
+ """
375
+ for name, param in module._parameters.items():
376
+ if param is not None:
377
+ destination[prefix + name] = param if keep_vars else param.detach()
378
+ for name, buf in module._buffers.items():
379
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
380
+ if buf is not None:
381
+ destination[prefix + name] = buf if keep_vars else buf.detach()
382
+
383
+
384
+ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
385
+ """Returns a dictionary containing a whole state of the module.
386
+ Both parameters and persistent buffers (e.g. running averages) are
387
+ included. Keys are corresponding parameter and buffer names.
388
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
389
+ recursively check parallel module in case that the model has a complicated
390
+ structure, e.g., nn.Module(nn.Module(DDP)).
391
+ Args:
392
+ module (nn.Module): The module to generate state_dict.
393
+ destination (OrderedDict): Returned dict for the state of the
394
+ module.
395
+ prefix (str): Prefix of the key.
396
+ keep_vars (bool): Whether to keep the variable property of the
397
+ parameters. Default: False.
398
+ Returns:
399
+ dict: A dictionary containing a whole state of the module.
400
+ """
401
+ # recursively check parallel module in case that the model has a
402
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
403
+ if is_module_wrapper(module):
404
+ module = module.module
405
+
406
+ # below is the same as torch.nn.Module.state_dict()
407
+ if destination is None:
408
+ destination = OrderedDict()
409
+ destination._metadata = OrderedDict()
410
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
411
+ version=module._version)
412
+ _save_to_state_dict(module, destination, prefix, keep_vars)
413
+ for name, child in module._modules.items():
414
+ if child is not None:
415
+ get_state_dict(
416
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
417
+ for hook in module._state_dict_hooks.values():
418
+ hook_result = hook(module, destination, prefix, local_metadata)
419
+ if hook_result is not None:
420
+ destination = hook_result
421
+ return destination
422
+
423
+
424
+ def save_checkpoint(model, filename, optimizer=None, meta=None):
425
+ """Save checkpoint to file.
426
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
427
+ ``optimizer``. By default ``meta`` will contain version and time info.
428
+ Args:
429
+ model (Module): Module whose params are to be saved.
430
+ filename (str): Checkpoint filename.
431
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
432
+ meta (dict, optional): Metadata to be saved in checkpoint.
433
+ """
434
+ if meta is None:
435
+ meta = {}
436
+ elif not isinstance(meta, dict):
437
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
438
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
439
+
440
+ if is_module_wrapper(model):
441
+ model = model.module
442
+
443
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
444
+ # save class name to the meta
445
+ meta.update(CLASSES=model.CLASSES)
446
+
447
+ checkpoint = {
448
+ 'meta': meta,
449
+ 'state_dict': weights_to_cpu(get_state_dict(model))
450
+ }
451
+ # save optimizer state dict in the checkpoint
452
+ if isinstance(optimizer, Optimizer):
453
+ checkpoint['optimizer'] = optimizer.state_dict()
454
+ elif isinstance(optimizer, dict):
455
+ checkpoint['optimizer'] = {}
456
+ for name, optim in optimizer.items():
457
+ checkpoint['optimizer'][name] = optim.state_dict()
458
+
459
+ if filename.startswith('pavi://'):
460
+ try:
461
+ from pavi import modelcloud
462
+ from pavi.exception import NodeNotFoundError
463
+ except ImportError:
464
+ raise ImportError(
465
+ 'Please install pavi to load checkpoint from modelcloud.')
466
+ model_path = filename[7:]
467
+ root = modelcloud.Folder()
468
+ model_dir, model_name = osp.split(model_path)
469
+ try:
470
+ model = modelcloud.get(model_dir)
471
+ except NodeNotFoundError:
472
+ model = root.create_training_model(model_dir)
473
+ with TemporaryDirectory() as tmp_dir:
474
+ checkpoint_file = osp.join(tmp_dir, model_name)
475
+ with open(checkpoint_file, 'wb') as f:
476
+ torch.save(checkpoint, f)
477
+ f.flush()
478
+ model.create_file(checkpoint_file, name=model_name)
479
+ else:
480
+ mmcv.mkdir_or_exist(osp.dirname(filename))
481
+ # immediately flush buffer
482
+ with open(filename, 'wb') as f:
483
+ torch.save(checkpoint, f)
484
+ f.flush()
training/Detection/mmcv_custom/customized_text.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import datetime
10
+ from collections import OrderedDict
11
+
12
+ import torch
13
+
14
+ import mmcv
15
+ from mmcv.runner import HOOKS
16
+ from mmcv.runner import TextLoggerHook
17
+
18
+
19
+ @HOOKS.register_module()
20
+ class CustomizedTextLoggerHook(TextLoggerHook):
21
+ """Customized Text Logger hook.
22
+
23
+ This logger prints out both lr and layer_0_lr.
24
+
25
+ """
26
+
27
+ def _log_info(self, log_dict, runner):
28
+ # print exp name for users to distinguish experiments
29
+ # at every ``interval_exp_name`` iterations and the end of each epoch
30
+ if runner.meta is not None and 'exp_name' in runner.meta:
31
+ if (self.every_n_iters(runner, self.interval_exp_name)) or (
32
+ self.by_epoch and self.end_of_epoch(runner)):
33
+ exp_info = f'Exp name: {runner.meta["exp_name"]}'
34
+ runner.logger.info(exp_info)
35
+
36
+ if log_dict['mode'] == 'train':
37
+ lr_str = {}
38
+ for lr_type in ['lr', 'layer_0_lr']:
39
+ if isinstance(log_dict[lr_type], dict):
40
+ lr_str[lr_type] = []
41
+ for k, val in log_dict[lr_type].items():
42
+ lr_str.append(f'{lr_type}_{k}: {val:.3e}')
43
+ lr_str[lr_type] = ' '.join(lr_str)
44
+ else:
45
+ lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}'
46
+
47
+ # by epoch: Epoch [4][100/1000]
48
+ # by iter: Iter [100/100000]
49
+ if self.by_epoch:
50
+ log_str = f'Epoch [{log_dict["epoch"]}]' \
51
+ f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
52
+ else:
53
+ log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
54
+ log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, '
55
+
56
+ if 'time' in log_dict.keys():
57
+ self.time_sec_tot += (log_dict['time'] * self.interval)
58
+ time_sec_avg = self.time_sec_tot / (
59
+ runner.iter - self.start_iter + 1)
60
+ eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
61
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
62
+ log_str += f'eta: {eta_str}, '
63
+ log_str += f'time: {log_dict["time"]:.3f}, ' \
64
+ f'data_time: {log_dict["data_time"]:.3f}, '
65
+ # statistic memory
66
+ if torch.cuda.is_available():
67
+ log_str += f'memory: {log_dict["memory"]}, '
68
+ else:
69
+ # val/test time
70
+ # here 1000 is the length of the val dataloader
71
+ # by epoch: Epoch[val] [4][1000]
72
+ # by iter: Iter[val] [1000]
73
+ if self.by_epoch:
74
+ log_str = f'Epoch({log_dict["mode"]}) ' \
75
+ f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
76
+ else:
77
+ log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
78
+
79
+ log_items = []
80
+ for name, val in log_dict.items():
81
+ # TODO: resolve this hack
82
+ # these items have been in log_str
83
+ if name in [
84
+ 'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time',
85
+ 'memory', 'epoch'
86
+ ]:
87
+ continue
88
+ if isinstance(val, float):
89
+ val = f'{val:.4f}'
90
+ log_items.append(f'{name}: {val}')
91
+ log_str += ', '.join(log_items)
92
+
93
+ runner.logger.info(log_str)
94
+
95
+
96
+ def log(self, runner):
97
+ if 'eval_iter_num' in runner.log_buffer.output:
98
+ # this doesn't modify runner.iter and is regardless of by_epoch
99
+ cur_iter = runner.log_buffer.output.pop('eval_iter_num')
100
+ else:
101
+ cur_iter = self.get_iter(runner, inner_iter=True)
102
+
103
+ log_dict = OrderedDict(
104
+ mode=self.get_mode(runner),
105
+ epoch=self.get_epoch(runner),
106
+ iter=cur_iter)
107
+
108
+ # record lr and layer_0_lr
109
+ cur_lr = runner.current_lr()
110
+ if isinstance(cur_lr, list):
111
+ log_dict['layer_0_lr'] = min(cur_lr)
112
+ log_dict['lr'] = max(cur_lr)
113
+ else:
114
+ assert isinstance(cur_lr, dict)
115
+ log_dict['lr'], log_dict['layer_0_lr'] = {}, {}
116
+ for k, lr_ in cur_lr.items():
117
+ assert isinstance(lr_, list)
118
+ log_dict['layer_0_lr'].update({k: min(lr_)})
119
+ log_dict['lr'].update({k: max(lr_)})
120
+
121
+ if 'time' in runner.log_buffer.output:
122
+ # statistic memory
123
+ if torch.cuda.is_available():
124
+ log_dict['memory'] = self._get_max_memory(runner)
125
+
126
+ log_dict = dict(log_dict, **runner.log_buffer.output)
127
+
128
+ self._log_info(log_dict, runner)
129
+ self._dump_log(log_dict, runner)
130
+ return log_dict
training/Detection/mmcv_custom/layer_decay_optimizer_constructor.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import json
10
+ import re
11
+ from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
12
+ from mmcv.runner import get_dist_info
13
+ import numpy as np
14
+
15
+ def cal_model_depth(depth, num_subnet):
16
+ dp = np.zeros((depth, num_subnet))
17
+ dp[:,0]=np.linspace(0, depth-1, depth)
18
+ dp[0,:]=np.linspace(0, num_subnet-1, num_subnet)
19
+ for i in range(1, depth):
20
+ for j in range(1, num_subnet):
21
+ dp[i][j] = min(dp[i][j-1], dp[i-1][j])+1
22
+ dp = dp.astype(int)
23
+ # col = [x for x in np.linspace(0, sum(self.layers)-1, sum(self.layers))]
24
+ # dp = np.transpose(np.array([col]*self.num_subnet, dtype=int))
25
+ dp = dp+1 ## make layer id starts from 1
26
+ return dp
27
+
28
+ def get_num_layer_layer_wise(n, layers, num_subnet=12):
29
+ dp=cal_model_depth(sum(layers), num_subnet)
30
+ # def get_layer_id(n, dp, layers):
31
+ if n.startswith("backbone.subnet"):
32
+ n=n[9:]
33
+ name_part = n.split('.')
34
+ subnet = int(name_part[0][6:])
35
+ if name_part[1].startswith("alpha"):
36
+ id = dp[0][subnet]
37
+ else:
38
+ level = int(name_part[1][-1])
39
+ if name_part[2].startswith("blocks"):
40
+ sub = int(name_part[3])
41
+ if sub>layers[level]-1:
42
+ sub = layers[level]-1
43
+ block = sum(layers[:level])+sub
44
+
45
+ if name_part[2].startswith("fusion"):
46
+ block = sum(layers[:level])
47
+ id = dp[block][subnet]
48
+ elif n.startswith("backbone.stem"):
49
+ id = 0
50
+ else:
51
+ id = dp[-1][-1]+1
52
+ return id
53
+
54
+
55
+
56
+ @OPTIMIZER_BUILDERS.register_module()
57
+ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
58
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
59
+ """Add all parameters of module to the params list.
60
+ The parameters of the given module will be added to the list of param
61
+ groups, with specific rules defined by paramwise_cfg.
62
+ Args:
63
+ params (list[dict]): A list of param groups, it will be modified
64
+ in place.
65
+ module (nn.Module): The module to be added.
66
+ prefix (str): The prefix of the module
67
+ is_dcn_module (int|float|None): If the current module is a
68
+ submodule of DCN, `is_dcn_module` will be passed to
69
+ control conv_offset layer's learning rate. Defaults to None.
70
+ """
71
+ parameter_groups = {}
72
+ print(self.paramwise_cfg)
73
+ num_layers = cal_model_depth(sum(self.paramwise_cfg.get('layers')), self.paramwise_cfg.get('num_subnet'))[-1][-1]+2
74
+ # num_layers = self.paramwise_cfg.get('num_layers') + 2
75
+ decay_rate = self.paramwise_cfg.get('decay_rate')
76
+ decay_type = self.paramwise_cfg.get('decay_type', "layer_wise")
77
+ print("Build LearningRateDecayOptimizerConstructor %s %f - %d" % (decay_type, decay_rate, num_layers))
78
+ weight_decay = self.base_wd
79
+
80
+ for name, param in module.named_parameters():
81
+ if not param.requires_grad:
82
+ continue # frozen weights
83
+ if len(param.shape) == 1 or name.endswith(".bias") or name in ('pos_embed', 'cls_token') or re.match('(.*).alpha.$', name):
84
+ group_name = "no_decay"
85
+ this_weight_decay = 0.
86
+ else:
87
+ group_name = "decay"
88
+ this_weight_decay = weight_decay
89
+
90
+ if decay_type == "layer_wise":
91
+ layer_id = get_num_layer_layer_wise(name, self.paramwise_cfg.get('layers'), self.paramwise_cfg.get('num_subnet'))
92
+
93
+ group_name = "layer_%d_%s" % (layer_id, group_name)
94
+
95
+ if group_name not in parameter_groups:
96
+ scale = decay_rate ** (num_layers - layer_id - 1)
97
+
98
+ parameter_groups[group_name] = {
99
+ "weight_decay": this_weight_decay,
100
+ "params": [],
101
+ "param_names": [],
102
+ "lr_scale": scale,
103
+ "group_name": group_name,
104
+ "lr": scale * self.base_lr,
105
+ }
106
+
107
+ parameter_groups[group_name]["params"].append(param)
108
+ parameter_groups[group_name]["param_names"].append(name)
109
+ rank, _ = get_dist_info()
110
+ if rank == 0:
111
+ to_display = {}
112
+ for key in parameter_groups:
113
+ to_display[key] = {
114
+ "param_names": parameter_groups[key]["param_names"],
115
+ "lr_scale": parameter_groups[key]["lr_scale"],
116
+ "lr": parameter_groups[key]["lr"],
117
+ "weight_decay": parameter_groups[key]["weight_decay"],
118
+ }
119
+ print("Param groups = %s" % json.dumps(to_display, indent=2))
120
+
121
+ params.extend(parameter_groups.values())
training/Detection/mmcv_custom/runner/checkpoint.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import os.path as osp
3
+ import time
4
+ from tempfile import TemporaryDirectory
5
+
6
+ import torch
7
+ from torch.optim import Optimizer
8
+
9
+ import mmcv
10
+ from mmcv.parallel import is_module_wrapper
11
+ from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
12
+
13
+ try:
14
+ import apex
15
+ except:
16
+ print('apex is not installed')
17
+
18
+
19
+ def save_checkpoint(model, filename, optimizer=None, meta=None):
20
+ """Save checkpoint to file.
21
+
22
+ The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
23
+ ``optimizer``, ``amp``. By default ``meta`` will contain version
24
+ and time info.
25
+
26
+ Args:
27
+ model (Module): Module whose params are to be saved.
28
+ filename (str): Checkpoint filename.
29
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
30
+ meta (dict, optional): Metadata to be saved in checkpoint.
31
+ """
32
+ if meta is None:
33
+ meta = {}
34
+ elif not isinstance(meta, dict):
35
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
36
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
37
+
38
+ if is_module_wrapper(model):
39
+ model = model.module
40
+
41
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
42
+ # save class name to the meta
43
+ meta.update(CLASSES=model.CLASSES)
44
+
45
+ checkpoint = {
46
+ 'meta': meta,
47
+ 'state_dict': weights_to_cpu(get_state_dict(model))
48
+ }
49
+ # save optimizer state dict in the checkpoint
50
+ if isinstance(optimizer, Optimizer):
51
+ checkpoint['optimizer'] = optimizer.state_dict()
52
+ elif isinstance(optimizer, dict):
53
+ checkpoint['optimizer'] = {}
54
+ for name, optim in optimizer.items():
55
+ checkpoint['optimizer'][name] = optim.state_dict()
56
+
57
+ # save amp state dict in the checkpoint
58
+ # checkpoint['amp'] = apex.amp.state_dict()
59
+
60
+ if filename.startswith('pavi://'):
61
+ try:
62
+ from pavi import modelcloud
63
+ from pavi.exception import NodeNotFoundError
64
+ except ImportError:
65
+ raise ImportError(
66
+ 'Please install pavi to load checkpoint from modelcloud.')
67
+ model_path = filename[7:]
68
+ root = modelcloud.Folder()
69
+ model_dir, model_name = osp.split(model_path)
70
+ try:
71
+ model = modelcloud.get(model_dir)
72
+ except NodeNotFoundError:
73
+ model = root.create_training_model(model_dir)
74
+ with TemporaryDirectory() as tmp_dir:
75
+ checkpoint_file = osp.join(tmp_dir, model_name)
76
+ with open(checkpoint_file, 'wb') as f:
77
+ torch.save(checkpoint, f)
78
+ f.flush()
79
+ model.create_file(checkpoint_file, name=model_name)
80
+ else:
81
+ mmcv.mkdir_or_exist(osp.dirname(filename))
82
+ # immediately flush buffer
83
+ with open(filename, 'wb') as f:
84
+ torch.save(checkpoint, f)
85
+ f.flush()
training/Detection/mmdet/models/backbones/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .csp_darknet import CSPDarknet
3
+ from .darknet import Darknet
4
+ from .detectors_resnet import DetectoRS_ResNet
5
+ from .detectors_resnext import DetectoRS_ResNeXt
6
+ from .efficientnet import EfficientNet
7
+ from .hourglass import HourglassNet
8
+ from .hrnet import HRNet
9
+ from .mobilenet_v2 import MobileNetV2
10
+ from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2
11
+ from .regnet import RegNet
12
+ from .res2net import Res2Net
13
+ from .resnest import ResNeSt
14
+ from .resnet import ResNet, ResNetV1d
15
+ from .resnext import ResNeXt
16
+ from .ssd_vgg import SSDVGG
17
+ from .swin import SwinTransformer
18
+ from .trident_resnet import TridentResNet
19
+ from .revcol_huge import RevCol_Huge
20
+ from .revcol import RevCol
21
+
22
+ __all__ = [
23
+ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet',
24
+ 'MobileNetV2', 'Res2Net', 'HourglassNet', 'DetectoRS_ResNet',
25
+ 'DetectoRS_ResNeXt', 'Darknet', 'ResNeSt', 'TridentResNet', 'CSPDarknet',
26
+ 'SwinTransformer', 'PyramidVisionTransformer',
27
+ 'PyramidVisionTransformerV2', 'EfficientNet', 'RevCol'
28
+ ]
training/Detection/mmdet/models/backbones/revcol.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .revcol_module import ConvNextBlock, LayerNorm, UpSampleConvnext
6
+ from mmdet.utils import get_root_logger
7
+ from ..builder import BACKBONES
8
+ from .revcol_function import ReverseFunction
9
+ from mmcv.cnn import constant_init, trunc_normal_init
10
+ from mmcv.runner import BaseModule, _load_checkpoint
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ class Fusion(nn.Module):
14
+ def __init__(self, level, channels, first_col) -> None:
15
+ super().__init__()
16
+
17
+ self.level = level
18
+ self.first_col = first_col
19
+ self.down = nn.Sequential(
20
+ nn.Conv2d(channels[level-1], channels[level], kernel_size=2, stride=2),
21
+ LayerNorm(channels[level], eps=1e-6, data_format="channels_first"),
22
+ ) if level in [1, 2, 3] else nn.Identity()
23
+ if not first_col:
24
+ self.up = UpSampleConvnext(1, channels[level+1], channels[level]) if level in [0, 1, 2] else nn.Identity()
25
+
26
+
27
+ def forward(self, *args):
28
+ c_down, c_up = args
29
+
30
+ if self.first_col:
31
+ x = self.down(c_down)
32
+ return x
33
+
34
+ if self.level == 3:
35
+ x = self.down(c_down)
36
+ else:
37
+ x = self.up(c_up) + self.down(c_down)
38
+ return x
39
+
40
+ class Level(nn.Module):
41
+ def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None:
42
+ super().__init__()
43
+ countlayer = sum(layers[:level])
44
+ expansion = 4
45
+ self.fusion = Fusion(level, channels, first_col)
46
+ modules = [ConvNextBlock(channels[level], expansion*channels[level], channels[level], kernel_size = kernel_size, layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer+i]) for i in range(layers[level])]
47
+ self.blocks = nn.Sequential(*modules)
48
+ def forward(self, *args):
49
+ x = self.fusion(*args)
50
+ x = self.blocks(x)
51
+ return x
52
+
53
+ class SubNet(nn.Module):
54
+ def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None:
55
+ super().__init__()
56
+ shortcut_scale_init_value = 0.5
57
+ self.save_memory = save_memory
58
+ self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)),
59
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
60
+ self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)),
61
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
62
+ self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)),
63
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
64
+ self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)),
65
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
66
+
67
+ self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates)
68
+
69
+ self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates)
70
+
71
+ self.level2 = Level(2, channels, layers, kernel_size,first_col, dp_rates)
72
+
73
+ self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates)
74
+
75
+ def _forward_nonreverse(self, *args):
76
+ x, c0, c1, c2, c3= args
77
+
78
+ c0 = (self.alpha0)*c0 + self.level0(x, c1)
79
+ c1 = (self.alpha1)*c1 + self.level1(c0, c2)
80
+ c2 = (self.alpha2)*c2 + self.level2(c1, c3)
81
+ c3 = (self.alpha3)*c3 + self.level3(c2, None)
82
+
83
+ return c0, c1, c2, c3
84
+
85
+ def _forward_reverse(self, *args):
86
+
87
+ local_funs = [self.level0, self.level1, self.level2, self.level3]
88
+ alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3]
89
+ _, c0, c1, c2, c3 = ReverseFunction.apply(
90
+ local_funs, alpha, *args)
91
+
92
+ return c0, c1, c2, c3
93
+
94
+ def forward(self, *args):
95
+
96
+ self._clamp_abs(self.alpha0.data, 1e-3)
97
+ self._clamp_abs(self.alpha1.data, 1e-3)
98
+ self._clamp_abs(self.alpha2.data, 1e-3)
99
+ self._clamp_abs(self.alpha3.data, 1e-3)
100
+
101
+ if self.save_memory:
102
+ return self._forward_reverse(*args)
103
+ else:
104
+ return self._forward_nonreverse(*args)
105
+
106
+ def _clamp_abs(self, data, value):
107
+ with torch.no_grad():
108
+ sign=data.sign()
109
+ data.abs_().clamp_(value)
110
+ data*=sign
111
+
112
+ @BACKBONES.register_module()
113
+ class RevCol(BaseModule):
114
+ def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size = 3, num_classes=1000, drop_path = 0.0, save_memory=True, single_head=True, out_indices=[0, 1, 2, 3], init_cfg=None) -> None:
115
+ super().__init__(init_cfg)
116
+ self.num_subnet = num_subnet
117
+ self.single_head = single_head
118
+ self.out_indices = out_indices
119
+ self.init_cfg = init_cfg
120
+
121
+ self.stem = nn.Sequential(
122
+ nn.Conv2d(3, channels[0], kernel_size=4, stride=4),
123
+ LayerNorm(channels[0], eps=1e-6, data_format="channels_first")
124
+ )
125
+
126
+ # dp_rate = self.cal_dp_rate(sum(layers), num_subnet, drop_path)
127
+
128
+ dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))]
129
+ for i in range(num_subnet):
130
+ first_col = True if i == 0 else False
131
+ self.add_module(f'subnet{str(i)}', SubNet(
132
+ channels,layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory))
133
+
134
+ def init_weights(self):
135
+ logger = get_root_logger()
136
+ if self.init_cfg is None:
137
+ logger.warn(f'No pre-trained weights for '
138
+ f'{self.__class__.__name__}, '
139
+ f'training start from scratch')
140
+ for m in self.modules():
141
+ if isinstance(m, nn.Linear):
142
+ trunc_normal_init(m, std=.02, bias=0.)
143
+ elif isinstance(m, nn.LayerNorm):
144
+ constant_init(m, 1.0)
145
+ else:
146
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
147
+ f'specify `Pretrained` in ' \
148
+ f'`init_cfg` in ' \
149
+ f'{self.__class__.__name__} '
150
+ ckpt = _load_checkpoint(
151
+ self.init_cfg.checkpoint, logger=logger, map_location='cpu')
152
+ if 'state_dict' in ckpt:
153
+ _state_dict = ckpt['state_dict']
154
+ elif 'model' in ckpt:
155
+ _state_dict = ckpt['model']
156
+ else:
157
+ _state_dict = ckpt
158
+
159
+
160
+ state_dict = _state_dict
161
+ # print(state_dict.keys())
162
+ # strip prefix of state_dict
163
+ if list(state_dict.keys())[0].startswith('module.'):
164
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
165
+
166
+ # load state_dict
167
+ self.load_state_dict(state_dict, False)
168
+
169
+
170
+ def forward(self, x):
171
+ x = self.stem(x)
172
+ c0, c1, c2, c3 = 0, 0, 0, 0
173
+ for i in range(self.num_subnet):
174
+ # c0, c1, c2, c3 = checkpoint(getattr(self, f'subnet{str(i)}'), x, c0, c1, c2, c3 )
175
+ c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
176
+ return c0, c1, c2, c3
177
+
178
+ def cal_dp_rate(self, depth, num_subnet, drop_path):
179
+ dp = np.zeros((depth, num_subnet))
180
+ dp[:,0]=np.linspace(0, depth-1, depth)
181
+ dp[0,:]=np.linspace(0, num_subnet-1, num_subnet)
182
+ for i in range(1, depth):
183
+ for j in range(1, num_subnet):
184
+ dp[i][j] = min(dp[i][j-1], dp[i-1][j])+1
185
+ ratio = dp[-1][-1]/drop_path
186
+ dp_matrix = dp/ratio
187
+ return dp_matrix
training/Detection/mmdet/models/backbones/revcol_function.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any, Iterable, List, Tuple, Callable
3
+ import torch.distributed as dist
4
+
5
+ def get_gpu_states(fwd_gpu_devices) -> Tuple[List[int], List[torch.Tensor]]:
6
+ # This will not error out if "arg" is a CPU tensor or a non-tensor type because
7
+ # the conditionals short-circuit.
8
+ fwd_gpu_states = []
9
+ for device in fwd_gpu_devices:
10
+ with torch.cuda.device(device):
11
+ fwd_gpu_states.append(torch.cuda.get_rng_state())
12
+
13
+ return fwd_gpu_states
14
+
15
+ def get_gpu_device(*args):
16
+
17
+ fwd_gpu_devices = list(set(arg.get_device() for arg in args
18
+ if isinstance(arg, torch.Tensor) and arg.is_cuda))
19
+ return fwd_gpu_devices
20
+
21
+ def set_device_states(fwd_cpu_state, devices, states) -> None:
22
+ torch.set_rng_state(fwd_cpu_state)
23
+ for device, state in zip(devices, states):
24
+ with torch.cuda.device(device):
25
+ torch.cuda.set_rng_state(state)
26
+
27
+ def detach_and_grad(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
28
+ if isinstance(inputs, tuple):
29
+ out = []
30
+ for inp in inputs:
31
+ if not isinstance(inp, torch.Tensor):
32
+ out.append(inp)
33
+ continue
34
+
35
+ x = inp.detach()
36
+ x.requires_grad = True
37
+ out.append(x)
38
+ return tuple(out)
39
+ else:
40
+ raise RuntimeError(
41
+ "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
42
+
43
+ def get_cpu_and_gpu_states(gpu_devices):
44
+ return torch.get_rng_state(), get_gpu_states(gpu_devices)
45
+
46
+ class ReverseFunction(torch.autograd.Function):
47
+ @staticmethod
48
+ def forward(ctx, run_functions, alpha, *args):
49
+ l0, l1, l2, l3 = run_functions
50
+ alpha0, alpha1, alpha2, alpha3 = alpha
51
+ ctx.run_functions = run_functions
52
+ ctx.alpha = alpha
53
+ ctx.preserve_rng_state = True
54
+
55
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
56
+ "dtype": torch.get_autocast_gpu_dtype(),
57
+ "cache_enabled": torch.is_autocast_cache_enabled()}
58
+ ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
59
+ "dtype": torch.get_autocast_cpu_dtype(),
60
+ "cache_enabled": torch.is_autocast_cache_enabled()}
61
+
62
+ assert len(args) == 5
63
+ [x, c0, c1, c2, c3] = args
64
+ if type(c0) == int:
65
+ ctx.first_col = True
66
+ else:
67
+ ctx.first_col = False
68
+ with torch.no_grad():
69
+ if ctx.preserve_rng_state:
70
+ gpu_devices = get_gpu_device(*args)
71
+ ctx.gpu_devices = gpu_devices
72
+ ctx.cpu_states_0, ctx.gpu_states_0 = get_cpu_and_gpu_states(gpu_devices)
73
+ c0 = l0(x, c1, c3) + c0*alpha0
74
+ ctx.cpu_states_1, ctx.gpu_states_1 = get_cpu_and_gpu_states(gpu_devices)
75
+ c1 = l1(c0, c2) + c1*alpha1
76
+ ctx.cpu_states_2, ctx.gpu_states_2 = get_cpu_and_gpu_states(gpu_devices)
77
+ c2 = l2(c1, c3) + c2*alpha2
78
+ ctx.cpu_states_3, ctx.gpu_states_3 = get_cpu_and_gpu_states(gpu_devices)
79
+ c3 = l3(c2) + c3*alpha3
80
+ else:
81
+ c0 = l0(x, c1, c3) + c0*alpha0
82
+ c1 = l1(c0, c2) + c1*alpha1
83
+ c2 = l2(c1, c3) + c2*alpha2
84
+ c3 = l3(c2) + c3*alpha3
85
+ ctx.save_for_backward(x, c0, c1, c2, c3)
86
+ return x, c0, c1 ,c2, c3
87
+
88
+ @staticmethod
89
+ def backward(ctx, *grad_outputs):
90
+ x, c0, c1, c2, c3 = ctx.saved_tensors
91
+ l0, l1, l2, l3 = ctx.run_functions
92
+ alpha0, alpha1, alpha2, alpha3 = ctx.alpha
93
+ gx_right, g0_right, g1_right, g2_right, g3_right = grad_outputs
94
+ (x, c0, c1, c2, c3) = detach_and_grad((x, c0, c1, c2, c3))
95
+
96
+ if ctx.preserve_rng_state:
97
+ with torch.enable_grad(), \
98
+ torch.random.fork_rng(devices=ctx.gpu_devices, enabled=ctx.preserve_rng_state), \
99
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
100
+ torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
101
+
102
+ g3_up = g3_right
103
+ g3_left = g3_up*alpha3 ##shortcut
104
+ set_device_states(ctx.cpu_states_3, ctx.gpu_devices, ctx.gpu_states_3)
105
+ oup3 = l3(c2)
106
+ torch.autograd.backward(oup3, g3_up, retain_graph=True)
107
+ with torch.no_grad():
108
+ c3_left = (1/alpha3)*(c3 - oup3) ## feature reverse
109
+ g2_up = g2_right+ c2.grad
110
+ g2_left = g2_up*alpha2 ##shortcut
111
+
112
+ (c3_left,) = detach_and_grad((c3_left,))
113
+ set_device_states(ctx.cpu_states_2, ctx.gpu_devices, ctx.gpu_states_2)
114
+ oup2 = l2(c1, c3_left)
115
+ torch.autograd.backward(oup2, g2_up, retain_graph=True)
116
+ c3_left.requires_grad = False
117
+ cout3 = c3_left*alpha3 ##alpha3 update
118
+ torch.autograd.backward(cout3, g3_up)
119
+
120
+ with torch.no_grad():
121
+ c2_left = (1/alpha2)*(c2 - oup2) ## feature reverse
122
+ g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
123
+ g1_up = g1_right+c1.grad
124
+ g1_left = g1_up*alpha1 ##shortcut
125
+
126
+ (c2_left,) = detach_and_grad((c2_left,))
127
+ set_device_states(ctx.cpu_states_1, ctx.gpu_devices, ctx.gpu_states_1)
128
+ oup1 = l1(c0, c2_left)
129
+ torch.autograd.backward(oup1, g1_up, retain_graph=True)
130
+ c2_left.requires_grad = False
131
+ cout2 = c2_left*alpha2 ##alpha3 update
132
+ torch.autograd.backward(cout2, g2_up)
133
+
134
+ with torch.no_grad():
135
+ c1_left = (1/alpha1)*(c1 - oup1) ## feature reverse
136
+ g0_up = g0_right + c0.grad
137
+ g0_left = g0_up*alpha0 ##shortcut
138
+ g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left ## Fusion
139
+
140
+ (c1_left,c3_left) = detach_and_grad((c1_left,c3_left))
141
+ set_device_states(ctx.cpu_states_0, ctx.gpu_devices, ctx.gpu_states_0)
142
+ oup0 = l0(x, c1_left, c3_left)
143
+ torch.autograd.backward(oup0, g0_up, retain_graph=True)
144
+ c1_left.requires_grad = False
145
+ cout1 = c1_left*alpha1 ##alpha3 update
146
+ torch.autograd.backward(cout1, g1_up)
147
+
148
+ with torch.no_grad():
149
+ c0_left = (1/alpha0)*(c0 - oup0) ## feature reverse
150
+ gx_up = x.grad ## Fusion
151
+ g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left ## Fusion
152
+ g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left ## Fusion
153
+ c0_left.requires_grad = False
154
+ cout0 = c0_left*alpha0 ##alpha3 update
155
+ torch.autograd.backward(cout0, g0_up)
156
+ else:
157
+ with torch.enable_grad():
158
+
159
+ g3_up = g3_right
160
+ g3_left = g3_up*alpha3 ##shortcut
161
+ oup3 = l3(c2)
162
+ torch.autograd.backward(oup3, g3_up, retain_graph=True)
163
+ with torch.no_grad():
164
+ c3_left = (1/alpha3)*(c3 - oup3) ## feature reverse
165
+ g2_up = g2_right+ c2.grad
166
+ g2_left = g2_up*alpha2 ##shortcut
167
+
168
+ (c3_left,) = detach_and_grad((c3_left,))
169
+ oup2 = l2(c1, c3_left)
170
+ torch.autograd.backward(oup2, g2_up, retain_graph=True)
171
+ c3_left.requires_grad = False
172
+ cout3 = c3_left*alpha3 ##alpha3 update
173
+ torch.autograd.backward(cout3, g3_up)
174
+
175
+ with torch.no_grad():
176
+ c2_left = (1/alpha2)*(c2 - oup2) ## feature reverse
177
+ g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
178
+ g1_up = g1_right+c1.grad
179
+ g1_left = g1_up*alpha1 ##shortcut
180
+
181
+ (c2_left,) = detach_and_grad((c2_left,))
182
+ oup1 = l1(c0, c2_left)
183
+ torch.autograd.backward(oup1, g1_up, retain_graph=True)
184
+ c2_left.requires_grad = False
185
+ cout2 = c2_left*alpha2 ##alpha3 update
186
+ torch.autograd.backward(cout2, g2_up)
187
+
188
+ with torch.no_grad():
189
+ c1_left = (1/alpha1)*(c1 - oup1) ## feature reverse
190
+ g0_up = g0_right + c0.grad
191
+ g0_left = g0_up*alpha0 ##shortcut
192
+ g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left ## Fusion
193
+
194
+ (c1_left,c3_left) = detach_and_grad((c1_left,c3_left))
195
+ oup0 = l0(x, c1_left, c3_left)
196
+ torch.autograd.backward(oup0, g0_up, retain_graph=True)
197
+ c1_left.requires_grad = False
198
+ cout1 = c1_left*alpha1 ##alpha3 update
199
+ torch.autograd.backward(cout1, g1_up)
200
+
201
+ with torch.no_grad():
202
+ c0_left = (1/alpha0)*(c0 - oup0) ## feature reverse
203
+ gx_up = x.grad ## Fusion
204
+ g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left ## Fusion
205
+ g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left ## Fusion
206
+ c0_left.requires_grad = False
207
+ cout0 = c0_left*alpha0 ##alpha3 update
208
+ torch.autograd.backward(cout0, g0_up)
209
+ # if dist.get_rank()==0:
210
+ # print(c0_left.mean().data)
211
+ # print(f'c0: {c0_left.max()}, c1: {c1_left.max()}, c2: {c2_left.max()}, c3: {c3_left.max()}')
212
+ # print(f'x.grad: {gx_up.mean()}, c0.grad: {g0_left.mean()}, c1.grad: {g1_left.mean()}, c2.grad: {g2_left.mean()}, c3.grad: {g3_left.mean()}')
213
+ # import pdb;pdb.set_trace()
214
+
215
+ if ctx.first_col:
216
+ # print(f'c0: {c0_left.max()}, c1: {c1_left.max()}, c2: {c2_left.max()}, c3: {c3_left.max()}')
217
+ return None, None, gx_up, None, None, None, None
218
+ else:
219
+ return None, None, gx_up, g0_left, g1_left, g2_left, g3_left
220
+
221
+
222
+
training/Detection/mmdet/models/backbones/revcol_module.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imp
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from timm.models.layers import DropPath
6
+
7
+ class UpSampleConvnext(nn.Module):
8
+ def __init__(self, ratio, inchannel, outchannel):
9
+ super().__init__()
10
+ self.ratio = ratio
11
+ self.channel_reschedule = nn.Sequential(
12
+ # LayerNorm(inchannel, eps=1e-6, data_format="channels_last"),
13
+ nn.Linear(inchannel, outchannel),
14
+ LayerNorm(outchannel, eps=1e-6, data_format="channels_last"))
15
+ self.upsample = nn.Upsample(scale_factor=2**ratio, mode='nearest')
16
+ def forward(self, x):
17
+ x = x.permute(0, 2, 3, 1)
18
+ x = self.channel_reschedule(x)
19
+ x = x = x.permute(0, 3, 1, 2)
20
+
21
+ return self.upsample(x)
22
+
23
+ class LayerNorm(nn.Module):
24
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
25
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
26
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
27
+ with shape (batch_size, channels, height, width).
28
+ """
29
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first"):
30
+ super().__init__()
31
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
32
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
33
+ self.eps = eps
34
+ self.data_format = data_format
35
+ if self.data_format not in ["channels_last", "channels_first"]:
36
+ raise NotImplementedError
37
+ self.normalized_shape = (normalized_shape, )
38
+
39
+ def forward(self, x):
40
+ if self.data_format == "channels_last":
41
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
42
+ elif self.data_format == "channels_first":
43
+ u = x.mean(1, keepdim=True)
44
+ s = (x - u).pow(2).mean(1, keepdim=True)
45
+ x = (x - u) / torch.sqrt(s + self.eps)
46
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
47
+ return x
48
+
49
+
50
+ class ConvNextBlock(nn.Module):
51
+ r""" ConvNeXt Block. There are two equivalent implementations:
52
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
53
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
54
+ We use (2) as we find it slightly faster in PyTorch
55
+
56
+ Args:
57
+ dim (int): Number of input channels.
58
+ drop_path (float): Stochastic depth rate. Default: 0.0
59
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
60
+ """
61
+ def __init__(self, in_channel, hidden_dim, out_channel, kernel_size=3, layer_scale_init_value=1e-6, drop_path= 0.0):
62
+ super().__init__()
63
+ self.dwconv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=in_channel) # depthwise conv
64
+ self.norm = nn.LayerNorm(in_channel, eps=1e-6)
65
+ self.pwconv1 = nn.Linear(in_channel, hidden_dim) # pointwise/1x1 convs, implemented with linear layers
66
+ self.act = nn.GELU()
67
+ self.pwconv2 = nn.Linear(hidden_dim, out_channel)
68
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channel)),
69
+ requires_grad=True) if layer_scale_init_value > 0 else None
70
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
71
+
72
+ def forward(self, x):
73
+ input = x
74
+ x = self.dwconv(x)
75
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
76
+ x = self.norm(x)
77
+ x = self.pwconv1(x)
78
+ x = self.act(x)
79
+ x = self.pwconv2(x)
80
+ if self.gamma is not None:
81
+ x = self.gamma * x
82
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
83
+
84
+ x = input + self.drop_path(x)
85
+ return x
training/Detection/mmdet/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .collect_env import collect_env
3
+ from .logger import get_caller_name, get_root_logger, log_img_scale
4
+ from .misc import find_latest_checkpoint, update_data_root
5
+ from .setup_env import setup_multi_processes
6
+ from .optimizer import DistOptimizerHook
7
+
8
+ __all__ = [
9
+ 'get_root_logger', 'collect_env', 'find_latest_checkpoint',
10
+ 'update_data_root', 'setup_multi_processes', 'get_caller_name',
11
+ 'log_img_scale', 'DistOptimizerHook'
12
+ ]
training/Detection/mmdet/utils/optimizer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcv.runner import OptimizerHook, HOOKS
2
+ try:
3
+ import apex
4
+ except:
5
+ print('apex is not installed')
6
+
7
+
8
+ @HOOKS.register_module()
9
+ class DistOptimizerHook(OptimizerHook):
10
+ """Optimizer hook for distributed training."""
11
+
12
+ def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
13
+ self.grad_clip = grad_clip
14
+ self.coalesce = coalesce
15
+ self.bucket_size_mb = bucket_size_mb
16
+ self.update_interval = update_interval
17
+ self.use_fp16 = use_fp16
18
+
19
+ def before_run(self, runner):
20
+ runner.optimizer.zero_grad()
21
+
22
+ def after_train_iter(self, runner):
23
+ runner.outputs['loss'] /= self.update_interval
24
+ if self.use_fp16:
25
+ with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss:
26
+ scaled_loss.backward()
27
+ else:
28
+ runner.outputs['loss'].backward()
29
+ if self.every_n_iters(runner, self.update_interval):
30
+ if self.grad_clip is not None:
31
+ self.clip_grads(runner.model.parameters())
32
+ runner.optimizer.step()
33
+ runner.optimizer.zero_grad()
training/INSTRUCTIONS.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation, Training and Evaluation Instructions for Image Classification
2
+
3
+ We provide installation, training and evaluation instructions for image classification here.
4
+
5
+ ## Installation Instructions
6
+
7
+ - Clone this repo:
8
+
9
+ ```bash
10
+ git clone https://github.com/megvii-research/RevCol.git
11
+ cd RevCol
12
+ ```
13
+
14
+ - Create a conda virtual environment and activate it:
15
+
16
+ ```bash
17
+ conda create --name revcol python=3.7 -y
18
+ conda activate revcol
19
+ ```
20
+
21
+ - Install `CUDA>=11.3` with `cudnn>=8` following
22
+ the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
23
+ - Install `PyTorch>=1.11.0` and `torchvision>=0.12.0` with `CUDA>=11.3`:
24
+
25
+ ```bash
26
+ conda install pytorch=1.11.0 torchvision=0.12.0 torchaudio=0.11.0 cudatoolkit=11.3 -c pytorch
27
+ ```
28
+
29
+ - Install `timm==0.5.4`:
30
+
31
+ ```bash
32
+ pip install timm==0.5.4
33
+ ```
34
+
35
+ - Install other requirements:
36
+
37
+ ```bash
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ ## Data preparation
42
+
43
+ We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
44
+ load data:
45
+
46
+ - For standard imagenet-1k dataset, the file structure should look like:
47
+ ```bash
48
+ path-to-imagenet-1k
49
+ ├── train
50
+ │ ├── class1
51
+ │ │ ├── img1.jpeg
52
+ │ │ ├── img2.jpeg
53
+ │ │ └── ...
54
+ │ ├── class2
55
+ │ │ ├── img3.jpeg
56
+ │ │ └── ...
57
+ │ └── ...
58
+ └── val
59
+ ├── class1
60
+ │ ├── img4.jpeg
61
+ │ ├── img5.jpeg
62
+ │ └── ...
63
+ ├── class2
64
+ │ ├── img6.jpeg
65
+ │ └── ...
66
+ └── ...
67
+
68
+ ```
69
+
70
+ - For ImageNet-22K dataset, the file structure should look like:
71
+
72
+ ```bash
73
+ path-to-imagenet-22k
74
+ ├── class1
75
+ │ ├── img1.jpeg
76
+ │ ├── img2.jpeg
77
+ │ └── ...
78
+ ├── class2
79
+ │ ├── img3.jpeg
80
+ │ └── ...
81
+ └── ...
82
+ ```
83
+
84
+ - As imagenet-22k has no val set, one way is to use imagenet-1k val set as the evaluation for imagenet 22k dataset. Please remember to map the imagenet-1k label to imagenet-22k.
85
+ ```bash
86
+ path-to-imagenet-22k-custom-eval-set
87
+ ├── class1
88
+ │ ├── img1.jpeg
89
+ │ ├── img2.jpeg
90
+ │ └── ...
91
+ ├── class2
92
+ │ ├── img3.jpeg
93
+ │ └── ...
94
+ └── ...
95
+ ```
96
+
97
+ ## Evaluation
98
+
99
+ To evaluate a pre-trained `RevCol` on ImageNet validation set, run:
100
+
101
+ ```bash
102
+ torchrun --nproc_per_node=<num-of-gpus-to-use> --master_port=23456 main.py --cfg <config-file.yaml> --resume <checkpoint_path> --data-path <imagenet-path> --eval
103
+ ```
104
+
105
+ For example, to evaluate the `RevCol-T` with a single GPU:
106
+
107
+ ```bash
108
+ torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_tiny_1k.yaml --resume path_to_your_model.pth --eval
109
+ ```
110
+
111
+ ## Training from scratch on ImageNet-1K
112
+
113
+ To train a `RevCol` on ImageNet from scratch, run:
114
+
115
+ ```bash
116
+ torchrun --nproc_per_node=<num-of-gpus-to-use> --master_port=23456 main.py \
117
+ --cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
118
+ ```
119
+
120
+ **Notes**:
121
+
122
+ For example, to train `RevCol` with 8 GPU on a single node for 300 epochs, run:
123
+
124
+ `RevCol-T`:
125
+
126
+ ```bash
127
+ torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_tiny_1k.yaml --batch-size 128 --data-path <imagenet-path>
128
+ ```
129
+
130
+ `RevCol-S`:
131
+
132
+ ```bash
133
+ torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_small_1k.yaml --batch-size 128 --data-path <imagenet-path>
134
+ ```
135
+
136
+ `RevCol-B`:
137
+
138
+ ```bash
139
+ torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_base_1k.yaml --batch-size 128 --data-path <imagenet-path>
140
+ ```
141
+
142
+ ## Pre-training on ImageNet-22K
143
+
144
+ For example, to pre-train a `RevCol-B` model on ImageNet-22K:
145
+
146
+ ```bash
147
+ torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_large_22k_pretrain.yaml --batch-size 128 --data-path <imagenet-22k-path> --opt DATA.EVAL_DATA_PATH <imagenet-22k-custom-eval-path>
148
+ ```
149
+
150
+
151
+ ## Fine-tuning from a ImageNet-22K(21K) pre-trained model
152
+
153
+ For example, to fine-tune a `RevCol-B` model pre-trained on ImageNet-22K(21K):
154
+
155
+ ```bashs
156
+ torchrun --nproc_per_node=8 --master_port=23456 main.py --cfg configs/revcol_base_1k_384_finetune.yaml --batch-size 64 --data-path <imagenet-22k-path> --finetune revcol_base_22k_pretrained.pth
157
+ ```
158
+
training/LICENSE ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ Copyright (c) 2022 Megvii Inc. All rights reserved.
179
+
180
+ Licensed under the Apache License, Version 2.0 (the "License");
181
+ you may not use this file except in compliance with the License.
182
+ You may obtain a copy of the License at
183
+
184
+ http://www.apache.org/licenses/LICENSE-2.0
185
+
186
+ Unless required by applicable law or agreed to in writing, software
187
+ distributed under the License is distributed on an "AS IS" BASIS,
188
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189
+ See the License for the specific language governing permissions and
190
+ limitations under the License.
training/README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reversible Column Networks
2
+ This repo is the official implementation of:
3
+
4
+ ### [Reversible Column Networks](https://arxiv.org/abs/2212.11696)
5
+ [Yuxuan Cai](https://nightsnack.github.io), [Yizhuang Zhou](https://scholar.google.com/citations?user=VRSGDDEAAAAJ), [Qi Han](https://hanqer.github.io), Jianjian Sun, Xiangwen Kong, Jun Li, [Xiangyu Zhang](https://scholar.google.com/citations?user=yuB-cfoAAAAJ) \
6
+ [MEGVII Technology](https://en.megvii.com)\
7
+ International Conference on Learning Representations (ICLR) 2023\
8
+ [\[arxiv\]](https://arxiv.org/abs/2212.11696)
9
+
10
+
11
+
12
+ ## Updates
13
+ ***2/10/2023***\
14
+ RevCol model weights released.
15
+
16
+ ***1/21/2023***\
17
+ RevCol was accepted by ICLR 2023!
18
+
19
+ ***12/23/2022***\
20
+ Initial commits: codes for ImageNet-1k and ImageNet-22k classification are released.
21
+
22
+
23
+ ## To Do List
24
+
25
+
26
+ - [x] ImageNet-1K and 22k Training Code
27
+ - [x] ImageNet-1K and 22k Model Weights
28
+ - [ ] Cascade Mask R-CNN COCO Object Detection Code & Model Weights
29
+ - [ ] ADE20k Semantic Segmentation Code & Model Weights
30
+
31
+
32
+ ## Introduction
33
+ RevCol is composed of multiple copies of subnetworks, named columns respectively, between which multi-level reversible connections are employed. RevCol coud serves as a foundation model backbone for various tasks in computer vision including classification, detection and segmentation.
34
+
35
+ <p align="center">
36
+ <img src="figures/title.png" width=100% height=100%
37
+ class="center">
38
+ </p>
39
+
40
+ ## Main Results on ImageNet with Pre-trained Models
41
+
42
+ | name | pretrain | resolution | #params |FLOPs | acc@1 | pretrained model | finetuned model |
43
+ |:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
44
+ | RevCol-T | ImageNet-1K | 224x224 | 30M | 4.5G | 82.2 | [baidu](https://pan.baidu.com/s/1iGsbdmFcDpwviCHaajeUnA?pwd=h4tj)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_tiny_1k.pth) | - |
45
+ | RevCol-S | ImageNet-1K | 224x224 | 60M | 9.0G | 83.5 | [baidu](https://pan.baidu.com/s/1hpHfdFrTZIPB5NTwqDMLag?pwd=mxuk)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_small_1k.pth) | - |
46
+ | RevCol-B | ImageNet-1K | 224x224 | 138M | 16.6G | 84.1 | [baidu](https://pan.baidu.com/s/16XIJ1n8pXPD2cXwnFX6b9w?pwd=j6x9)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_1k.pth) | - |
47
+ | RevCol-B<sup>\*</sup> | ImageNet-22K | 224x224 | 138M | 16.6G | 85.6 |[baidu](https://pan.baidu.com/s/1l8zOFifgC8fZtBpHK2ZQHg?pwd=rh58)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_22k.pth)| [baidu](https://pan.baidu.com/s/1HqhDXL6OIQdn1LeM2pewYQ?pwd=1bp3)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_22k_1kft_224.pth)|
48
+ | RevCol-B<sup>\*</sup> | ImageNet-22K | 384x384 | 138M | 48.9G | 86.7 |[baidu](https://pan.baidu.com/s/1l8zOFifgC8fZtBpHK2ZQHg?pwd=rh58)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_22k.pth)| [baidu](https://pan.baidu.com/s/18G0zAUygKgu58s2AjCBpsw?pwd=rv86)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_base_22k_1kft_384.pth)|
49
+ | RevCol-L<sup>\*</sup> | ImageNet-22K | 224x224 | 273M | 39G | 86.6 |[baidu](https://pan.baidu.com/s/1ueKqh3lFAAgC-vVU34ChYA?pwd=qv5m)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_large_22k.pth)| [baidu](https://pan.baidu.com/s/1CsWmcPcwieMzXE8pVmHh7w?pwd=qd9n)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_large_22k_1kft_224.pth)|
50
+ | RevCol-L<sup>\*</sup> | ImageNet-22K | 384x384 | 273M | 116G | 87.6 |[baidu](https://pan.baidu.com/s/1ueKqh3lFAAgC-vVU34ChYA?pwd=qv5m)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_large_22k.pth)| [baidu](https://pan.baidu.com/s/1VmCE3W3Xw6-Lo4rWrj9Xzg?pwd=x69r)/[github](https://github.com/megvii-research/RevCol/releases/download/checkpoint/revcol_large_22k_1kft_384.pth)|
51
+
52
+ ## Getting Started
53
+ Please refer to [INSTRUCTIONS.md](INSTRUCTIONS.md) for setting up, training and evaluation details.
54
+
55
+
56
+ ## Acknowledgement
57
+ This repo was inspired by several open source projects. We are grateful for these excellent projects and list them as follows:
58
+ - [timm](https://github.com/rwightman/pytorch-image-models)
59
+ - [Swin Transformer](https://github.com/microsoft/Swin-Transformer)
60
+ - [ConvNeXt](https://github.com/facebookresearch/ConvNeXt)
61
+ - [beit](https://github.com/microsoft/unilm/tree/master/beit)
62
+
63
+ ## License
64
+ RevCol is released under the [Apache 2.0 license](LICENSE).
65
+
66
+ ## Contact Us
67
+ If you have any questions about this repo or the original paper, please contact Yuxuan at [email protected].
68
+
69
+
70
+ ## Citation
71
+ ```
72
+ @inproceedings{cai2022reversible,
73
+ title={Reversible Column Networks},
74
+ author={Cai, Yuxuan and Zhou, Yizhuang and Han, Qi and Sun, Jianjian and Kong, Xiangwen and Li, Jun and Zhang, Xiangyu},
75
+ booktitle={International Conference on Learning Representations},
76
+ year={2023},
77
+ url={https://openreview.net/forum?id=Oc2vlWU0jFY}
78
+ }
79
+ ```
training/config.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under TheApache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+ import os
8
+ import yaml
9
+ from yacs.config import CfgNode as CN
10
+
11
+ _C = CN()
12
+
13
+ # Base config files
14
+ _C.BASE = ['']
15
+
16
+ # -----------------------------------------------------------------------------
17
+ # Data settings
18
+ # -----------------------------------------------------------------------------
19
+ _C.DATA = CN()
20
+ # Batch size for a single GPU, could be overwritten by command line argument
21
+ _C.DATA.BATCH_SIZE = 128
22
+ # Path to dataset, could be overwritten by command line argument
23
+ _C.DATA.DATA_PATH = 'path/to/imagenet'
24
+ # Dataset name
25
+ _C.DATA.DATASET = 'imagenet'
26
+ # Input image size
27
+ _C.DATA.IMG_SIZE = 224
28
+ # Interpolation to resize image (random, bilinear, bicubic)
29
+ _C.DATA.INTERPOLATION = 'bicubic'
30
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
31
+ _C.DATA.PIN_MEMORY = True
32
+ # Number of data loading threads
33
+ _C.DATA.NUM_WORKERS = 8
34
+ # Path to evaluation dataset for ImageNet 22k
35
+ _C.DATA.EVAL_DATA_PATH = 'path/to/eval/data'
36
+
37
+ # -----------------------------------------------------------------------------
38
+ # Model settings
39
+ # -----------------------------------------------------------------------------
40
+ _C.MODEL = CN()
41
+ # Model type
42
+ _C.MODEL.TYPE = ''
43
+ # Model name
44
+ _C.MODEL.NAME = ''
45
+ # Checkpoint to resume, could be overwritten by command line argument
46
+ _C.MODEL.RESUME = ''
47
+ # Checkpoint to finetune, could be overwritten by command line argument
48
+ _C.MODEL.FINETUNE = ''
49
+ # Number of classes, overwritten in data preparation
50
+ _C.MODEL.NUM_CLASSES = 1000
51
+ # Label Smoothing
52
+ _C.MODEL.LABEL_SMOOTHING = 0.0
53
+
54
+ # -----------------------------------------------------------------------------
55
+ # Specific Model settings
56
+ # -----------------------------------------------------------------------------
57
+
58
+ _C.REVCOL = CN()
59
+
60
+ _C.REVCOL.INTER_SUPV = True
61
+
62
+ _C.REVCOL.SAVEMM = True
63
+
64
+ _C.REVCOL.FCOE = 4.0
65
+
66
+ _C.REVCOL.CCOE = 0.8
67
+
68
+ _C.REVCOL.KERNEL_SIZE = 3
69
+
70
+ _C.REVCOL.DROP_PATH = 0.1
71
+
72
+ _C.REVCOL.HEAD_INIT_SCALE = None
73
+
74
+ # -----------------------------------------------------------------------------
75
+ # Training settings
76
+ # -----------------------------------------------------------------------------
77
+ _C.TRAIN = CN()
78
+ _C.TRAIN.START_EPOCH = 0
79
+ _C.TRAIN.EPOCHS = 300
80
+ _C.TRAIN.WARMUP_EPOCHS = 5
81
+ _C.TRAIN.WEIGHT_DECAY = 4e-5
82
+ _C.TRAIN.BASE_LR = 0.4
83
+
84
+ _C.TRAIN.WARMUP_LR = 0.05
85
+ _C.TRAIN.MIN_LR = 1e-5
86
+ # Clip gradient norm
87
+ _C.TRAIN.CLIP_GRAD = 10.0
88
+ # Auto resume from latest checkpoint
89
+ _C.TRAIN.AUTO_RESUME = True
90
+ # Check point
91
+ _C.TRAIN.USE_CHECKPOINT = False
92
+
93
+ _C.TRAIN.AMP = True
94
+
95
+ # LR scheduler
96
+ _C.TRAIN.LR_SCHEDULER = CN()
97
+ # LR scheduler
98
+ _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
99
+ # Epoch interval to decay LR, used in StepLRScheduler
100
+ _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
101
+ # LR decay rate, used in StepLRScheduler
102
+ _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
103
+
104
+ # Optimizer
105
+ _C.TRAIN.OPTIMIZER = CN()
106
+ _C.TRAIN.OPTIMIZER.NAME = 'sgd'
107
+ # Optimizer Epsilon fow adamw
108
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
109
+ # Optimizer Betas fow adamw
110
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
111
+ # SGD momentum
112
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
113
+ # Layer Decay
114
+ _C.TRAIN.OPTIMIZER.LAYER_DECAY = 1.0
115
+
116
+ # -----------------------------------------------------------------------------
117
+ # Augmentation settings
118
+ # -----------------------------------------------------------------------------
119
+ _C.AUG = CN()
120
+ # Color jitter factor
121
+ _C.AUG.COLOR_JITTER = 0.4
122
+ # Use AutoAugment policy. "v0" or "original"
123
+ _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
124
+ # Random erase prob
125
+ _C.AUG.REPROB = 0.25
126
+ # Random erase mode
127
+ _C.AUG.REMODE = 'pixel'
128
+ # Random erase count
129
+ _C.AUG.RECOUNT = 1
130
+ # Mixup alpha, mixup enabled if > 0
131
+ _C.AUG.MIXUP = 0.8
132
+ # Cutmix alpha, cutmix enabled if > 0
133
+ _C.AUG.CUTMIX = 1.0
134
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
135
+ _C.AUG.CUTMIX_MINMAX = None
136
+ # Probability of performing mixup or cutmix when either/both is enabled
137
+ _C.AUG.MIXUP_PROB = 1.0
138
+ # Probability of switching to cutmix when both mixup and cutmix enabled
139
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
140
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
141
+ _C.AUG.MIXUP_MODE = 'batch'
142
+
143
+ # -----------------------------------------------------------------------------
144
+ # Testing settings
145
+ # -----------------------------------------------------------------------------
146
+ _C.TEST = CN()
147
+ # Whether to use center crop when testing
148
+ _C.TEST.CROP = True
149
+
150
+ # -----------------------------------------------------------------------------
151
+ # Misc
152
+ # -----------------------------------------------------------------------------
153
+ # Path to output folder, overwritten by command line argument
154
+ _C.OUTPUT = 'outputs/'
155
+ # Tag of experiment, overwritten by command line argument
156
+ _C.TAG = 'default'
157
+ # Frequency to save checkpoint
158
+ _C.SAVE_FREQ = 1
159
+ # Frequency to logging info
160
+ _C.PRINT_FREQ = 100
161
+ # Fixed random seed
162
+ _C.SEED = 0
163
+ # Perform evaluation only, overwritten by command line argument
164
+ _C.EVAL_MODE = False
165
+ # Test throughput only, overwritten by command line argument
166
+ _C.THROUGHPUT_MODE = False
167
+ # local rank for DistributedDataParallel, given by command line argument
168
+ _C.LOCAL_RANK = 0
169
+
170
+
171
+
172
+ # EMA
173
+ _C.MODEL_EMA = False
174
+ _C.MODEL_EMA_DECAY = 0.9999
175
+
176
+
177
+
178
+
179
+ # Machine
180
+ _C.MACHINE = CN()
181
+ _C.MACHINE.MACHINE_WORLD_SIZE = None
182
+ _C.MACHINE.MACHINE_RANK = None
183
+
184
+ def _update_config_from_file(config, cfg_file):
185
+ config.defrost()
186
+ with open(cfg_file, 'r') as f:
187
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
188
+
189
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
190
+ if cfg:
191
+ _update_config_from_file(
192
+ config, os.path.join(os.path.dirname(cfg_file), cfg)
193
+ )
194
+ print('=> merge config from {}'.format(cfg_file))
195
+ config.merge_from_file(cfg_file)
196
+ config.freeze()
197
+
198
+
199
+ def update_config(config, args):
200
+ _update_config_from_file(config, args.cfg)
201
+
202
+ config.defrost()
203
+ if args.opts:
204
+ config.merge_from_list(args.opts)
205
+
206
+ # merge from specific arguments
207
+ if args.batch_size:
208
+ config.DATA.BATCH_SIZE = args.batch_size
209
+ if args.data_path:
210
+ config.DATA.DATA_PATH = args.data_path
211
+ if args.resume:
212
+ config.MODEL.RESUME = args.resume
213
+ if args.finetune:
214
+ config.MODEL.FINETUNE = args.finetune
215
+ if args.use_checkpoint:
216
+ config.TRAIN.USE_CHECKPOINT = True
217
+ if args.output:
218
+ config.OUTPUT = args.output
219
+ if args.tag:
220
+ config.TAG = args.tag
221
+ if args.eval:
222
+ config.EVAL_MODE = True
223
+ if args.model_ema:
224
+ config.MODEL_EMA = True
225
+
226
+ config.dist_url = args.dist_url
227
+ # set local rank for distributed training
228
+ config.LOCAL_RANK = args.local_rank
229
+
230
+ # output folder
231
+ config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
232
+
233
+ config.freeze()
234
+
235
+
236
+ def get_config(args):
237
+ """Get a yacs CfgNode object with default values."""
238
+ # Return a clone so that the defaults will not be altered
239
+ # This is for the "local variable" use pattern
240
+ config = _C.clone()
241
+ update_config(config, args)
242
+
243
+ return config
training/configs/revcol_base_1k.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 100
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: True
4
+ DATA:
5
+ IMG_SIZE: 224
6
+ NUM_WORKERS: 6
7
+ MODEL:
8
+ TYPE: revcol_base
9
+ NAME: revcol_base
10
+ LABEL_SMOOTHING: 0.1
11
+ REVCOL:
12
+ INTER_SUPV: True
13
+ SAVEMM: True
14
+ FCOE: 3.0
15
+ CCOE: 0.7
16
+ DROP_PATH: 0.4
17
+ TRAIN:
18
+ EPOCHS: 300
19
+ BASE_LR: 1e-3
20
+ WARMUP_EPOCHS: 20
21
+ WEIGHT_DECAY: 0.05
22
+ WARMUP_LR: 1e-5
23
+ MIN_LR: 1e-8
24
+ OPTIMIZER:
25
+ NAME: 'adamw'
26
+ CLIP_GRAD: 5.0
27
+ AUG:
28
+ COLOR_JITTER: 0.0
29
+ # Use AutoAugment policy. "v0" or "original"
30
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
31
+ # Random erase prob
32
+ REPROB: 0.25
33
+ # Random erase mode
34
+ REMODE: 'pixel'
35
+ # Random erase count
36
+ RECOUNT: 1
37
+ # Mixup alpha, mixup enabled if > 0
38
+ MIXUP: 0.8
39
+ # Cutmix alpha, cutmix enabled if > 0
40
+ CUTMIX: 1.0
41
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
42
+ CUTMIX_MINMAX: None
43
+ # Probability of performing mixup or cutmix when either/both is enabled
44
+ MIXUP_PROB: 1.0
45
+ # Probability of switching to cutmix when both mixup and cutmix enabled
46
+ MIXUP_SWITCH_PROB: 0.5
47
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
48
+ MIXUP_MODE: 'batch'
training/configs/revcol_base_1k_224_finetune.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 100
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: False
4
+ DATA:
5
+ IMG_SIZE: 224
6
+ DATASET: imagenet
7
+ MODEL:
8
+ TYPE: revcol_base
9
+ NAME: revcol_base_1k_Finetune_224
10
+ LABEL_SMOOTHING: 0.1
11
+ NUM_CLASSES: 1000
12
+ REVCOL:
13
+ INTER_SUPV: False
14
+ SAVEMM: True
15
+ FCOE: 3.0
16
+ CCOE: 0.7
17
+ DROP_PATH: 0.2
18
+ HEAD_INIT_SCALE: 0.001
19
+ TRAIN:
20
+ EPOCHS: 30
21
+ BASE_LR: 1e-4
22
+ WARMUP_EPOCHS: 0
23
+ WEIGHT_DECAY: 1e-8
24
+ WARMUP_LR: 4e-6
25
+ MIN_LR: 2e-7
26
+ OPTIMIZER:
27
+ NAME: 'adamw'
28
+ LAYER_DECAY: 0.9
29
+ AUG:
30
+ COLOR_JITTER: 0.0
31
+ # Use AutoAugment policy. "v0" or "original"
32
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
33
+ # Random erase prob
34
+ REPROB: 0.25
35
+ # Random erase mode
36
+ REMODE: 'pixel'
37
+ # Random erase count
38
+ RECOUNT: 1
39
+ # Mixup alpha, mixup enabled if > 0
40
+ MIXUP: 0.0
41
+ # Cutmix alpha, cutmix enabled if > 0
42
+ CUTMIX: 0.0
43
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
44
+ CUTMIX_MINMAX: None
45
+ # Probability of performing mixup or cutmix when either/both is enabled
46
+ MIXUP_PROB: 0.0
47
+ # Probability of switching to cutmix when both mixup and cutmix enabled
48
+ MIXUP_SWITCH_PROB: 0.0
49
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
50
+ MIXUP_MODE: 'batch'
training/configs/revcol_base_1k_384_finetune.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 100
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: False
4
+ DATA:
5
+ IMG_SIZE: 384
6
+ DATASET: imagenet
7
+ MODEL:
8
+ TYPE: revcol_base
9
+ NAME: revcol_base_1k_Finetune_384
10
+ LABEL_SMOOTHING: 0.1
11
+ NUM_CLASSES: 1000
12
+ REVCOL:
13
+ INTER_SUPV: False
14
+ SAVEMM: True
15
+ FCOE: 3.0
16
+ CCOE: 0.7
17
+ DROP_PATH: 0.2
18
+ HEAD_INIT_SCALE: 0.001
19
+ TRAIN:
20
+ EPOCHS: 30
21
+ BASE_LR: 1e-4
22
+ WARMUP_EPOCHS: 0
23
+ WEIGHT_DECAY: 1e-8
24
+ WARMUP_LR: 4e-6
25
+ MIN_LR: 2e-7
26
+ OPTIMIZER:
27
+ NAME: 'adamw'
28
+ LAYER_DECAY: 0.9
29
+ AUG:
30
+ COLOR_JITTER: 0.0
31
+ # Use AutoAugment policy. "v0" or "original"
32
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
33
+ # Random erase prob
34
+ REPROB: 0.25
35
+ # Random erase mode
36
+ REMODE: 'pixel'
37
+ # Random erase count
38
+ RECOUNT: 1
39
+ # Mixup alpha, mixup enabled if > 0
40
+ MIXUP: 0.0
41
+ # Cutmix alpha, cutmix enabled if > 0
42
+ CUTMIX: 0.0
43
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
44
+ CUTMIX_MINMAX: None
45
+ # Probability of performing mixup or cutmix when either/both is enabled
46
+ MIXUP_PROB: 0.0
47
+ # Probability of switching to cutmix when both mixup and cutmix enabled
48
+ MIXUP_SWITCH_PROB: 0.0
49
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
50
+ MIXUP_MODE: 'batch'
training/configs/revcol_base_22k_pretrain.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 100
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: False
4
+ DATA:
5
+ IMG_SIZE: 224
6
+ DATASET: imagenet22k
7
+ NUM_WORKERS: 6
8
+ MODEL:
9
+ TYPE: revcol_base
10
+ NAME: revcol_base_22k_Pretrain
11
+ LABEL_SMOOTHING: 0.1
12
+ NUM_CLASSES: 21841
13
+ REVCOL:
14
+ INTER_SUPV: True
15
+ SAVEMM: True
16
+ FCOE: 3.0
17
+ CCOE: 0.7
18
+ DROP_PATH: 0.3
19
+ TRAIN:
20
+ EPOCHS: 90
21
+ BASE_LR: 1.25e-4
22
+ WARMUP_EPOCHS: 5
23
+ WEIGHT_DECAY: 0.1
24
+ WARMUP_LR: 1e-5
25
+ MIN_LR: 1e-7
26
+ OPTIMIZER:
27
+ NAME: 'adamw'
28
+ # BN:
29
+ # USE_PRECISE_STATS: True
30
+ AUG:
31
+ COLOR_JITTER: 0.4
32
+ # Use AutoAugment policy. "v0" or "original"
33
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
34
+ # Random erase prob
35
+ REPROB: 0.25
36
+ # Random erase mode
37
+ REMODE: 'pixel'
38
+ # Random erase count
39
+ RECOUNT: 1
40
+ # Mixup alpha, mixup enabled if > 0
41
+ MIXUP: 0.8
42
+ # Cutmix alpha, cutmix enabled if > 0
43
+ CUTMIX: 1.0
44
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
45
+ CUTMIX_MINMAX: None
46
+ # Probability of performing mixup or cutmix when either/both is enabled
47
+ MIXUP_PROB: 1.0
48
+ # Probability of switching to cutmix when both mixup and cutmix enabled
49
+ MIXUP_SWITCH_PROB: 0.5
50
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
51
+ MIXUP_MODE: 'batch'
training/configs/revcol_large_1k_224_finetune.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 20
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: False
4
+ DATA:
5
+ IMG_SIZE: 224
6
+ DATASET: imagenet
7
+ MODEL:
8
+ TYPE: revcol_large
9
+ NAME: revcol_large_1k_Finetune_224
10
+ LABEL_SMOOTHING: 0.1
11
+ NUM_CLASSES: 1000
12
+ REVCOL:
13
+ INTER_SUPV: False
14
+ SAVEMM: True
15
+ FCOE: 3.0
16
+ CCOE: 0.7
17
+ DROP_PATH: 0.3
18
+ HEAD_INIT_SCALE: 0.001
19
+ TRAIN:
20
+ EPOCHS: 30
21
+ BASE_LR: 5e-5
22
+ WARMUP_EPOCHS: 0
23
+ WEIGHT_DECAY: 1e-8
24
+ WARMUP_LR: 4e-6
25
+ MIN_LR: 2e-7
26
+ OPTIMIZER:
27
+ NAME: 'adamw'
28
+ LAYER_DECAY: 0.8
29
+
30
+ AUG:
31
+ COLOR_JITTER: 0.0
32
+ # Use AutoAugment policy. "v0" or "original"
33
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
34
+ # Random erase prob
35
+ REPROB: 0.25
36
+ # Random erase mode
37
+ REMODE: 'pixel'
38
+ # Random erase count
39
+ RECOUNT: 1
40
+ # Mixup alpha, mixup enabled if > 0
41
+ MIXUP: 0.0
42
+ # Cutmix alpha, cutmix enabled if > 0
43
+ CUTMIX: 0.0
44
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
45
+ CUTMIX_MINMAX: None
46
+ # Probability of performing mixup or cutmix when either/both is enabled
47
+ MIXUP_PROB: 0.0
48
+ # Probability of switching to cutmix when both mixup and cutmix enabled
49
+ MIXUP_SWITCH_PROB: 0.0
50
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
51
+ MIXUP_MODE: 'batch'
training/configs/revcol_large_1k_384_finetune.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 20
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: False
4
+ DATA:
5
+ IMG_SIZE: 384
6
+ DATASET: imagenet
7
+ MODEL:
8
+ TYPE: revcol_large
9
+ NAME: revcol_large_1k_Finetune_384
10
+ LABEL_SMOOTHING: 0.1
11
+ NUM_CLASSES: 1000
12
+ REVCOL:
13
+ INTER_SUPV: False
14
+ SAVEMM: True
15
+ FCOE: 3.0
16
+ CCOE: 0.7
17
+ DROP_PATH: 0.3
18
+ HEAD_INIT_SCALE: 0.001
19
+ TRAIN:
20
+ EPOCHS: 30
21
+ BASE_LR: 5e-5
22
+ WARMUP_EPOCHS: 0
23
+ WEIGHT_DECAY: 1e-8
24
+ WARMUP_LR: 4e-6
25
+ MIN_LR: 2e-7
26
+ OPTIMIZER:
27
+ NAME: 'adamw'
28
+ LAYER_DECAY: 0.8
29
+
30
+ AUG:
31
+ COLOR_JITTER: 0.0
32
+ # Use AutoAugment policy. "v0" or "original"
33
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
34
+ # Random erase prob
35
+ REPROB: 0.25
36
+ # Random erase mode
37
+ REMODE: 'pixel'
38
+ # Random erase count
39
+ RECOUNT: 1
40
+ # Mixup alpha, mixup enabled if > 0
41
+ MIXUP: 0.0
42
+ # Cutmix alpha, cutmix enabled if > 0
43
+ CUTMIX: 0.0
44
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
45
+ CUTMIX_MINMAX: None
46
+ # Probability of performing mixup or cutmix when either/both is enabled
47
+ MIXUP_PROB: 0.0
48
+ # Probability of switching to cutmix when both mixup and cutmix enabled
49
+ MIXUP_SWITCH_PROB: 0.0
50
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
51
+ MIXUP_MODE: 'batch'
training/configs/revcol_large_22k_pretrain.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 100
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: False
4
+ DATA:
5
+ IMG_SIZE: 224
6
+ DATASET: imagenet22k
7
+ NUM_WORKERS: 6
8
+ MODEL:
9
+ TYPE: revcol_large
10
+ NAME: revcol_large_22k_Pretrain
11
+ LABEL_SMOOTHING: 0.1
12
+ NUM_CLASSES: 21841
13
+ REVCOL:
14
+ INTER_SUPV: True
15
+ SAVEMM: True
16
+ FCOE: 3.0
17
+ CCOE: 0.7
18
+ DROP_PATH: 0.3
19
+ TRAIN:
20
+ EPOCHS: 90
21
+ BASE_LR: 1.25e-4
22
+ WARMUP_EPOCHS: 5
23
+ WEIGHT_DECAY: 0.1
24
+ WARMUP_LR: 1e-5
25
+ MIN_LR: 1e-7
26
+ OPTIMIZER:
27
+ NAME: 'adamw'
28
+
29
+ AUG:
30
+ COLOR_JITTER: 0.4
31
+ # Use AutoAugment policy. "v0" or "original"
32
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
33
+ # Random erase prob
34
+ REPROB: 0.25
35
+ # Random erase mode
36
+ REMODE: 'pixel'
37
+ # Random erase count
38
+ RECOUNT: 1
39
+ # Mixup alpha, mixup enabled if > 0
40
+ MIXUP: 0.8
41
+ # Cutmix alpha, cutmix enabled if > 0
42
+ CUTMIX: 1.0
43
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
44
+ CUTMIX_MINMAX: None
45
+ # Probability of performing mixup or cutmix when either/both is enabled
46
+ MIXUP_PROB: 1.0
47
+ # Probability of switching to cutmix when both mixup and cutmix enabled
48
+ MIXUP_SWITCH_PROB: 0.5
49
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
50
+ MIXUP_MODE: 'batch'
training/configs/revcol_small_1k.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 100
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: True
4
+ DATA:
5
+ IMG_SIZE: 224
6
+ NUM_WORKERS: 6
7
+ MODEL:
8
+ TYPE: revcol_small
9
+ NAME: revcol_small
10
+ LABEL_SMOOTHING: 0.1
11
+ REVCOL:
12
+ INTER_SUPV: True
13
+ SAVEMM: True
14
+ FCOE: 3.0
15
+ CCOE: 0.7
16
+ DROP_PATH: 0.3
17
+ TRAIN:
18
+ EPOCHS: 300
19
+ BASE_LR: 1e-3
20
+ WARMUP_EPOCHS: 20
21
+ WEIGHT_DECAY: 0.05
22
+ WARMUP_LR: 1e-5
23
+ MIN_LR: 1e-6
24
+ OPTIMIZER:
25
+ NAME: 'adamw'
26
+ CLIP_GRAD: 0.0
27
+ AUG:
28
+ COLOR_JITTER: 0.4
29
+ # Use AutoAugment policy. "v0" or "original"
30
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
31
+ # Random erase prob
32
+ REPROB: 0.25
33
+ # Random erase mode
34
+ REMODE: 'pixel'
35
+ # Random erase count
36
+ RECOUNT: 1
37
+ # Mixup alpha, mixup enabled if > 0
38
+ MIXUP: 0.8
39
+ # Cutmix alpha, cutmix enabled if > 0
40
+ CUTMIX: 1.0
41
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
42
+ CUTMIX_MINMAX: None
43
+ # Probability of performing mixup or cutmix when either/both is enabled
44
+ MIXUP_PROB: 1.0
45
+ # Probability of switching to cutmix when both mixup and cutmix enabled
46
+ MIXUP_SWITCH_PROB: 0.5
47
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
48
+ MIXUP_MODE: 'batch'
training/configs/revcol_tiny_1k.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 100
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: True
4
+ DATA:
5
+ IMG_SIZE: 224
6
+ NUM_WORKERS: 6
7
+ MODEL:
8
+ TYPE: revcol_tiny
9
+ NAME: revcol_tiny
10
+ LABEL_SMOOTHING: 0.1
11
+ REVCOL:
12
+ INTER_SUPV: True
13
+ SAVEMM: True
14
+ FCOE: 3.0
15
+ CCOE: 0.7
16
+ DROP_PATH: 0.1
17
+ TRAIN:
18
+ EPOCHS: 300
19
+ BASE_LR: 1e-3
20
+ WARMUP_EPOCHS: 20
21
+ WEIGHT_DECAY: 0.05
22
+ WARMUP_LR: 1e-5
23
+ MIN_LR: 1e-6
24
+ OPTIMIZER:
25
+ NAME: 'adamw'
26
+ CLIP_GRAD: 0.0
27
+ AUG:
28
+ COLOR_JITTER: 0.4
29
+ # Use AutoAugment policy. "v0" or "original"
30
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
31
+ # Random erase prob
32
+ REPROB: 0.25
33
+ # Random erase mode
34
+ REMODE: 'pixel'
35
+ # Random erase count
36
+ RECOUNT: 1
37
+ # Mixup alpha, mixup enabled if > 0
38
+ MIXUP: 0.8
39
+ # Cutmix alpha, cutmix enabled if > 0
40
+ CUTMIX: 1.0
41
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
42
+ CUTMIX_MINMAX: None
43
+ # Probability of performing mixup or cutmix when either/both is enabled
44
+ MIXUP_PROB: 1.0
45
+ # Probability of switching to cutmix when both mixup and cutmix enabled
46
+ MIXUP_SWITCH_PROB: 0.5
47
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
48
+ MIXUP_MODE: 'batch'
training/configs/revcol_xlarge_1k_384_finetune.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 30
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: True
4
+ DATA:
5
+ IMG_SIZE: 384
6
+ DATASET: imagenet
7
+ PIPE_NAME: 'dpflow.silvia.imagenet.train.rand-re-jitt.384'
8
+ NUM_WORKERS: 8
9
+ MODEL:
10
+ TYPE: revcol_xlarge
11
+ NAME: revcol_xlarge_1k_Finetune
12
+ LABEL_SMOOTHING: 0.1
13
+ NUM_CLASSES: 1000
14
+ REVCOL:
15
+ INTER_SUPV: False
16
+ SAVEMM: True
17
+ FCOE: 3.0
18
+ CCOE: 0.7
19
+ DROP_PATH: 0.4
20
+ HEAD_INIT_SCALE: 0.001
21
+ TRAIN:
22
+ EPOCHS: 30
23
+ BASE_LR: 2e-5
24
+ WARMUP_EPOCHS: 0
25
+ WEIGHT_DECAY: 1e-8
26
+ WARMUP_LR: 4e-6
27
+ MIN_LR: 2e-7
28
+ OPTIMIZER:
29
+ NAME: 'adamw'
30
+ LAYER_DECAY: 0.8
31
+
32
+ AUG:
33
+ COLOR_JITTER: 0.0
34
+ # Use AutoAugment policy. "v0" or "original"
35
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
36
+ # Random erase prob
37
+ REPROB: 0.25
38
+ # Random erase mode
39
+ REMODE: 'pixel'
40
+ # Random erase count
41
+ RECOUNT: 1
42
+ # Mixup alpha, mixup enabled if > 0
43
+ MIXUP: 0.0
44
+ # Cutmix alpha, cutmix enabled if > 0
45
+ CUTMIX: 0.0
46
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
47
+ CUTMIX_MINMAX: None
48
+ # Probability of performing mixup or cutmix when either/both is enabled
49
+ MIXUP_PROB: 0.0
50
+ # Probability of switching to cutmix when both mixup and cutmix enabled
51
+ MIXUP_SWITCH_PROB: 0.0
52
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
53
+ MIXUP_MODE: 'batch'
training/configs/revcol_xlarge_22k_pretrain.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PRINT_FREQ: 100
2
+ SAVE_FREQ: 1
3
+ MODEL_EMA: False
4
+ DATA:
5
+ IMG_SIZE: 224
6
+ DATASET: imagenet22k
7
+ NUM_WORKERS: 8
8
+ MODEL:
9
+ TYPE: revcol_xlarge
10
+ NAME: revcol_xlarge_22k_Pretrain
11
+ LABEL_SMOOTHING: 0.1
12
+ NUM_CLASSES: 21841
13
+ REVCOL:
14
+ INTER_SUPV: True
15
+ SAVEMM: True
16
+ FCOE: 3.0
17
+ CCOE: 0.7
18
+ DROP_PATH: 0.4
19
+ TRAIN:
20
+ EPOCHS: 90
21
+ BASE_LR: 1.25e-4
22
+ WARMUP_EPOCHS: 5
23
+ WEIGHT_DECAY: 0.1
24
+ WARMUP_LR: 1e-5
25
+ MIN_LR: 1e-7
26
+ OPTIMIZER:
27
+ NAME: 'adamw'
28
+
29
+ AUG:
30
+ COLOR_JITTER: 0.0
31
+ # Use AutoAugment policy. "v0" or "original"
32
+ AUTO_AUGMENT: 'rand-m9-mstd0.5-inc1'
33
+ # Random erase prob
34
+ REPROB: 0.25
35
+ # Random erase mode
36
+ REMODE: 'pixel'
37
+ # Random erase count
38
+ RECOUNT: 1
39
+ # Mixup alpha, mixup enabled if > 0
40
+ MIXUP: 0.8
41
+ # Cutmix alpha, cutmix enabled if > 0
42
+ CUTMIX: 1.0
43
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
44
+ CUTMIX_MINMAX: None
45
+ # Probability of performing mixup or cutmix when either/both is enabled
46
+ MIXUP_PROB: 1.0
47
+ # Probability of switching to cutmix when both mixup and cutmix enabled
48
+ MIXUP_SWITCH_PROB: 0.5
49
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
50
+ MIXUP_MODE: 'batch'
training/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build_data import build_loader
training/data/build_data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import queue
9
+ from typing import Dict, Sequence
10
+ import warnings
11
+ import os
12
+ import torch
13
+ import numpy as np
14
+ import torch.distributed as dist
15
+ from torchvision import datasets, transforms
16
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
17
+ from timm.data import Mixup
18
+ from timm.data import create_transform
19
+
20
+
21
+ from .samplers import SubsetRandomSampler
22
+
23
+ def build_loader(config):
24
+
25
+ config.defrost()
26
+ dataset_train, _ = build_dataset(is_train=True, config=config)
27
+ config.freeze()
28
+ print(f"global rank {dist.get_rank()} successfully build train dataset")
29
+
30
+
31
+ sampler_train = torch.utils.data.DistributedSampler(
32
+ dataset_train, shuffle=True
33
+ )
34
+
35
+ data_loader_train = torch.utils.data.DataLoader(
36
+ dataset_train, sampler=sampler_train,
37
+ batch_size=config.DATA.BATCH_SIZE,
38
+ num_workers=config.DATA.NUM_WORKERS,
39
+ pin_memory=config.DATA.PIN_MEMORY,
40
+ drop_last=True,
41
+ persistent_workers=True
42
+ )
43
+
44
+ #-----------------------------------Val Dataset-----------------------------------
45
+
46
+ dataset_val, _ = build_dataset(is_train=False, config=config)
47
+ print(f"global rank {dist.get_rank()} successfully build val dataset")
48
+
49
+ indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
50
+ sampler_val = SubsetRandomSampler(indices)
51
+
52
+ data_loader_val = torch.utils.data.DataLoader(
53
+ dataset_val, sampler=sampler_val,
54
+ batch_size=config.DATA.BATCH_SIZE,
55
+ shuffle=False,
56
+ num_workers=config.DATA.NUM_WORKERS,
57
+ pin_memory=config.DATA.PIN_MEMORY,
58
+ drop_last=False,
59
+ persistent_workers=True
60
+ )
61
+
62
+ # setup mixup / cutmix
63
+ mixup_fn = None
64
+ mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
65
+ if mixup_active:
66
+ mixup_fn = Mixup(
67
+ mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
68
+ prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
69
+ label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
70
+
71
+ return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
72
+
73
+
74
+ def build_dataset(is_train, config):
75
+ transform = build_transform(is_train, config)
76
+ if config.DATA.DATASET == 'imagenet':
77
+ prefix = 'train' if is_train else 'val'
78
+ root = os.path.join(config.DATA.DATA_PATH, prefix)
79
+ dataset = datasets.ImageFolder(root, transform=transform)
80
+ nb_classes = 1000
81
+ elif config.DATA.DATASET == 'imagenet22K':
82
+ if is_train:
83
+ root = config.DATA.DATA_PATH
84
+ else:
85
+ root = config.DATA.EVAL_DATA_PATH
86
+ dataset = datasets.ImageFolder(root, transform=transform)
87
+ nb_classes = 21841
88
+ else:
89
+ raise NotImplementedError("We only support ImageNet Now.")
90
+
91
+ return dataset, nb_classes
92
+
93
+
94
+ def build_transform(is_train, config):
95
+ resize_im = config.DATA.IMG_SIZE > 32
96
+ if is_train:
97
+ # this should always dispatch to transforms_imagenet_train
98
+ transform = create_transform(
99
+ input_size=config.DATA.IMG_SIZE,
100
+ is_training=True,
101
+ color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
102
+ auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
103
+ re_prob=config.AUG.REPROB,
104
+ re_mode=config.AUG.REMODE,
105
+ re_count=config.AUG.RECOUNT,
106
+ interpolation=config.DATA.INTERPOLATION,
107
+ )
108
+ if not resize_im:
109
+ # replace RandomResizedCropAndInterpolation with
110
+ # RandomCrop
111
+ transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
112
+ return transform
113
+
114
+ t = []
115
+ if resize_im:
116
+ if config.DATA.IMG_SIZE > 224:
117
+ t.append(
118
+ transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
119
+ interpolation=transforms.InterpolationMode.BICUBIC),
120
+ )
121
+ print(f"Warping {config.DATA.IMG_SIZE} size input images...")
122
+ elif config.TEST.CROP:
123
+ size = int((256 / 224) * config.DATA.IMG_SIZE)
124
+ t.append(
125
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
126
+ # to maintain same ratio w.r.t. 224 images
127
+ )
128
+ t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
129
+ else:
130
+ t.append(
131
+ transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
132
+ interpolation=transforms.InterpolationMode.BICUBIC)
133
+ )
134
+
135
+ t.append(transforms.ToTensor())
136
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
137
+ return transforms.Compose(t)
training/data/samplers.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+
11
+ class SubsetRandomSampler(torch.utils.data.Sampler):
12
+ r"""Samples elements randomly from a given list of indices, without replacement.
13
+
14
+ Arguments:
15
+ indices (sequence): a sequence of indices
16
+ """
17
+
18
+ def __init__(self, indices):
19
+ self.epoch = 0
20
+ self.indices = indices
21
+
22
+ def __iter__(self):
23
+ return (self.indices[i] for i in torch.randperm(len(self.indices)))
24
+
25
+ def __len__(self):
26
+ return len(self.indices)
27
+
28
+ def set_epoch(self, epoch):
29
+ self.epoch = epoch
training/figures/title.png ADDED
training/logger.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import sys
10
+ import logging
11
+ import functools
12
+ from termcolor import colored
13
+
14
+
15
+ @functools.lru_cache()
16
+ def create_logger(output_dir, dist_rank=0, name=''):
17
+ # create logger
18
+ logger = logging.getLogger(name)
19
+ logger.setLevel(logging.DEBUG)
20
+ logger.propagate = False
21
+
22
+ # create formatter
23
+ fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
24
+ color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
25
+ colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
26
+
27
+ # create console handlers for master process
28
+ if dist_rank == 0:
29
+ console_handler = logging.StreamHandler(sys.stdout)
30
+ console_handler.setLevel(logging.DEBUG)
31
+ console_handler.setFormatter(
32
+ logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
33
+ logger.addHandler(console_handler)
34
+
35
+ # create file handlers
36
+ file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
37
+ file_handler.setLevel(logging.DEBUG)
38
+ file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
39
+ logger.addHandler(file_handler)
40
+
41
+ return logger
training/loss.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ from dis import dis
9
+ import torch
10
+ from torch import nn
11
+ import torch.distributed as dist
12
+ from torch.functional import Tensor
13
+ import torch.nn.functional as F
14
+
15
+
16
+
17
+ def compound_loss(coe, output_feature, image:Tensor, output_label, targets, criterion_bce, criterion_ce, epoch):
18
+ f_coe, c_coe = coe
19
+ image.clamp_(0.01, 0.99)
20
+ multi_loss = []
21
+ for i, feature in enumerate(output_feature):
22
+ ratio_f = 1 - i / len(output_feature)
23
+ ratio_c = (i+1) / (len(output_label))
24
+
25
+ ihx = criterion_bce(feature, image) * ratio_f * f_coe
26
+ ihy = criterion_ce(output_label[i], targets) * ratio_c * c_coe
27
+ # if dist.get_rank() == 0:
28
+ # print(f'ihx: {ihx}, ihy: {ihy}')
29
+ multi_loss.append(ihx + ihy)
30
+ # feature_loss.append(torch.dist(output_feature[i], teacher_feature) * feature_coe)
31
+ multi_loss.append(criterion_ce(output_label[-1], targets))
32
+ # print(feature_loss)
33
+ loss = torch.sum(torch.stack(multi_loss), dim=0)
34
+ # +torch.mean(torch.stack(feature_loss), dim=0)
35
+ return loss, multi_loss
training/lr_scheduler.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from timm.scheduler.cosine_lr import CosineLRScheduler
10
+ from timm.scheduler.step_lr import StepLRScheduler
11
+ from timm.scheduler.scheduler import Scheduler
12
+
13
+
14
+ def build_scheduler(config, optimizer=None):
15
+
16
+ lr_scheduler = None
17
+ if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
18
+ lr_scheduler = CosLRScheduler()
19
+ elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
20
+ lr_scheduler = StepLRScheduler()
21
+ else:
22
+ raise NotImplementedError(f"Unkown lr scheduler: {config.TRAIN.LR_SCHEDULER.NAME}")
23
+
24
+ return lr_scheduler
25
+
26
+ import math
27
+
28
+
29
+ class CosLRScheduler():
30
+ def __init__(self) -> None:
31
+ pass
32
+
33
+ def step_update(self, optimizer, epoch, config):
34
+ """Decay the learning rate with half-cycle cosine after warmup"""
35
+ if epoch < config.TRAIN.WARMUP_EPOCHS:
36
+ lr = (config.TRAIN.BASE_LR-config.TRAIN.WARMUP_LR) * epoch / config.TRAIN.WARMUP_EPOCHS + config.TRAIN.WARMUP_LR
37
+ else:
38
+ lr = config.TRAIN.MIN_LR + (config.TRAIN.BASE_LR - config.TRAIN.MIN_LR) * 0.5 * \
39
+ (1. + math.cos(math.pi * (epoch - config.TRAIN.WARMUP_EPOCHS ) / (config.TRAIN.EPOCHS - config.TRAIN.WARMUP_EPOCHS )))
40
+ for param_group in optimizer.param_groups:
41
+ if "lr_scale" in param_group:
42
+ param_group["lr"] = lr * param_group["lr_scale"]
43
+ else:
44
+ param_group["lr"] = lr
45
+ return lr
46
+
47
+ class LinearLRScheduler(Scheduler):
48
+ def __init__(self,
49
+ optimizer: torch.optim.Optimizer,
50
+ t_initial: int,
51
+ lr_min_rate: float,
52
+ warmup_t=0,
53
+ warmup_lr_init=0.,
54
+ t_in_epochs=True,
55
+ noise_range_t=None,
56
+ noise_pct=0.67,
57
+ noise_std=1.0,
58
+ noise_seed=42,
59
+ initialize=True,
60
+ ) -> None:
61
+ super().__init__(
62
+ optimizer, param_group_field="lr",
63
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
64
+ initialize=initialize)
65
+
66
+ self.t_initial = t_initial
67
+ self.lr_min_rate = lr_min_rate
68
+ self.warmup_t = warmup_t
69
+ self.warmup_lr_init = warmup_lr_init
70
+ self.t_in_epochs = t_in_epochs
71
+ if self.warmup_t:
72
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
73
+ super().update_groups(self.warmup_lr_init)
74
+ else:
75
+ self.warmup_steps = [1 for _ in self.base_values]
76
+
77
+ def _get_lr(self, t):
78
+ if t < self.warmup_t:
79
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
80
+ else:
81
+ t = t - self.warmup_t
82
+ total_t = self.t_initial - self.warmup_t
83
+ lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
84
+ return lrs
85
+
86
+ def get_epoch_values(self, epoch: int):
87
+ if self.t_in_epochs:
88
+ return self._get_lr(epoch)
89
+ else:
90
+ return None
91
+
92
+ def get_update_values(self, num_updates: int):
93
+ if not self.t_in_epochs:
94
+ return self._get_lr(num_updates)
95
+ else:
96
+ return None
training/main.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import os
10
+ import subprocess
11
+ import sys
12
+ import time
13
+ import argparse
14
+ import datetime
15
+ import numpy as np
16
+
17
+ import torch
18
+ import torch.backends.cudnn as cudnn
19
+ import torch.distributed as dist
20
+ import torch.multiprocessing as mp
21
+ import torchvision.transforms.functional as visionF
22
+ import torch.cuda.amp as amp
23
+ from typing import Optional
24
+
25
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
26
+ from timm.utils import accuracy, AverageMeter
27
+ from timm.utils import ModelEma as ModelEma
28
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
29
+
30
+ from config import get_config
31
+ from models import *
32
+ from loss import *
33
+ from data import build_loader
34
+ from lr_scheduler import build_scheduler
35
+ from optimizer import build_optimizer
36
+ from logger import create_logger
37
+ from utils import denormalize, load_checkpoint, load_checkpoint_finetune, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
38
+ from torch.utils.tensorboard import SummaryWriter
39
+
40
+ scaler = amp.GradScaler()
41
+ logger = None
42
+
43
+ def parse_option():
44
+ parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
45
+ parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
46
+ parser.add_argument(
47
+ "--opts",
48
+ help="Modify config options by adding 'KEY VALUE' pairs. ",
49
+ default=None,
50
+ nargs='+',
51
+ )
52
+
53
+ # easy config modification
54
+ parser.add_argument('--batch-size', type=int, default=128, help="batch size for single GPU")
55
+ parser.add_argument('--data-path', type=str, default='data', help='path to dataset')
56
+ parser.add_argument('--resume', help='resume from checkpoint')
57
+ parser.add_argument('--finetune', help='finetune from checkpoint')
58
+
59
+ parser.add_argument('--use-checkpoint', action='store_true',
60
+ help="whether to use gradient checkpointing to save memory")
61
+
62
+ parser.add_argument('--output', default='outputs/', type=str, metavar='PATH',
63
+ help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
64
+ parser.add_argument('--tag', help='tag of experiment')
65
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
66
+
67
+ # ema
68
+ parser.add_argument('--model-ema', action='store_true')
69
+
70
+ # distributed training
71
+ parser.add_argument("--local_rank", type=int, required=False, help='local rank for DistributedDataParallel')
72
+
73
+ parser.add_argument('--dist-url', default='env://', type=str,
74
+ help='url used to set up distributed training')
75
+
76
+ args, unparsed = parser.parse_known_args()
77
+ # print(args)
78
+ config = get_config(args)
79
+
80
+ return args, config
81
+
82
+
83
+ def main(config):
84
+
85
+ config.defrost()
86
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
87
+ rank = int(os.environ["RANK"])
88
+ world_size = int(os.environ['WORLD_SIZE'])
89
+ print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
90
+
91
+ else:
92
+ rank = -1
93
+ world_size = -1
94
+ return
95
+
96
+ # linear scale the learning rate according to total batch size, base bs 1024
97
+ linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * world_size / 1024.0
98
+ linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * world_size / 1024.0
99
+ linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * world_size / 1024.0
100
+
101
+ config.TRAIN.BASE_LR = linear_scaled_lr
102
+ config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
103
+ config.TRAIN.MIN_LR = linear_scaled_min_lr
104
+ config.freeze()
105
+
106
+ dist.init_process_group(
107
+ backend='nccl', init_method=config.dist_url,
108
+ world_size=world_size, rank=rank,
109
+ )
110
+ seed = config.SEED + dist.get_rank()
111
+ torch.manual_seed(seed)
112
+ np.random.seed(seed)
113
+ torch.cuda.set_device(rank)
114
+ global logger
115
+
116
+ logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
117
+ logger.info(config.dump())
118
+ writer = None
119
+ if dist.get_rank() == 0:
120
+ writer = SummaryWriter(config.OUTPUT)
121
+
122
+ dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
123
+
124
+ logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
125
+ model = build_model(config)
126
+
127
+ model.cuda()
128
+ logger.info(str(model)[:10000])
129
+
130
+ model_ema = None
131
+ if config.MODEL_EMA:
132
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but
133
+ # before SyncBN and DDP wrapper
134
+ logger.info(f"Using EMA...")
135
+ model_ema = ModelEma(
136
+ model,
137
+ decay=config.MODEL_EMA_DECAY,
138
+ )
139
+
140
+ optimizer = build_optimizer(config, model)
141
+ if config.TRAIN.AMP:
142
+ logger.info(f"-------------------------------Using Pytorch AMP...--------------------------------")
143
+
144
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False, find_unused_parameters=False)
145
+ # model._set_static_graph()
146
+ model_without_ddp = model.module
147
+
148
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
149
+ logger.info(f"number of params: {n_parameters}")
150
+
151
+ lr_scheduler = build_scheduler(config)
152
+
153
+ if config.AUG.MIXUP > 0.:
154
+ # smoothing is handled with mixup label transform
155
+ criterion = SoftTargetCrossEntropy()
156
+ elif config.MODEL.LABEL_SMOOTHING > 0.:
157
+ criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
158
+ else:
159
+ criterion = torch.nn.CrossEntropyLoss()
160
+ criterion_bce = torch.nn.BCEWithLogitsLoss()
161
+ max_accuracy = 0.0
162
+
163
+ if config.TRAIN.AUTO_RESUME:
164
+ resume_file = auto_resume_helper(config.OUTPUT, logger)
165
+ if resume_file:
166
+ if config.MODEL.RESUME:
167
+ logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
168
+ config.defrost()
169
+ config.MODEL.RESUME = resume_file
170
+ config.freeze()
171
+ logger.info(f'auto resuming from {resume_file}')
172
+ else:
173
+ logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
174
+
175
+ if config.MODEL.RESUME:
176
+ max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, logger, model_ema)
177
+ logger.info(f"Start validation")
178
+ acc1, acc5, loss = validate(config, data_loader_val, model, writer, epoch=config.TRAIN.START_EPOCH)
179
+ logger.info(f"Accuracy of the network on the 50000 test images: {acc1:.1f}, {acc5:.1f}%")
180
+
181
+ if config.EVAL_MODE:
182
+ return
183
+
184
+ if config.MODEL.FINETUNE:
185
+ load_checkpoint_finetune(config, model_without_ddp, logger)
186
+
187
+
188
+ logger.info("Start training")
189
+ start_time = time.time()
190
+ for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
191
+ data_loader_train.sampler.set_epoch(epoch)
192
+
193
+ train_one_epoch(config, model, criterion, criterion_bce, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, writer, model_ema)
194
+
195
+ acc1, acc5, _ = validate(config, data_loader_val, model, writer, epoch)
196
+ logger.info(f"Accuracy of the network on the 5000 test images: {acc1:.2f}, {acc5:.2f}%")
197
+
198
+
199
+ if config.MODEL_EMA:
200
+ acc1_ema, acc5_ema, _ = validate_ema(config, data_loader_val, model_ema.ema, writer, epoch)
201
+ logger.info(f"Accuracy of the EMA network on the 5000 test images: {acc1_ema:.1f}, {acc5_ema:.1f}%")
202
+ # acc1 = max(acc1, acc1_ema)
203
+
204
+ if dist.get_rank() == 0 and epoch % config.SAVE_FREQ == 0:
205
+ save_checkpoint(config, epoch, model_without_ddp, acc1, max_accuracy, optimizer, logger, model_ema)
206
+
207
+ max_accuracy = max(max_accuracy, acc1)
208
+ logger.info(f'Max accuracy: {max_accuracy:.2f}%')
209
+
210
+ total_time = time.time() - start_time
211
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
212
+ logger.info('Training time {}'.format(total_time_str))
213
+
214
+
215
+ def train_one_epoch(config, model, criterion_ce, criterion_bce, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, writer, model_ema: Optional[ModelEma] = None):
216
+ global logger
217
+ model.train()
218
+ optimizer.zero_grad()
219
+
220
+ num_steps = len(data_loader)
221
+ batch_time = AverageMeter()
222
+ loss_meter = AverageMeter()
223
+ norm_meter = AverageMeter()
224
+ data_time = AverageMeter()
225
+
226
+ start = time.time()
227
+ end = time.time()
228
+
229
+
230
+ for idx, (samples, targets) in enumerate(data_loader):
231
+
232
+ samples = samples.cuda(non_blocking=True)
233
+ targets = targets.cuda(non_blocking=True)
234
+
235
+ data_time.update(time.time()-end)
236
+ lr_scheduler.step_update(optimizer, idx / num_steps + epoch, config)
237
+ if mixup_fn is not None:
238
+ samples, targets = mixup_fn(samples, targets)
239
+ with amp.autocast(enabled=config.TRAIN.AMP):
240
+ output_label, output_feature = model(samples)
241
+ if len(output_label) == 1:
242
+ loss = criterion_ce(output_label[0], targets)
243
+ multi_loss = []
244
+ else:
245
+ loss, multi_loss = compound_loss((config.REVCOL.FCOE, config.REVCOL.CCOE), output_feature, denormalize(samples, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), output_label, targets, criterion_bce, criterion_ce, epoch)
246
+
247
+
248
+ if not math.isfinite(loss.item()):
249
+ print("Loss is {} in iteration {}, multiloss {}, !".format(loss.item(), idx, multi_loss))
250
+
251
+ scaler.scale(loss).backward()
252
+ if config.TRAIN.CLIP_GRAD:
253
+ scaler.unscale_(optimizer)
254
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
255
+ else:
256
+ scaler.unscale_(optimizer)
257
+ grad_norm = get_grad_norm(model.parameters())
258
+ scaler.step(optimizer)
259
+ scaler.update()
260
+
261
+ optimizer.zero_grad()
262
+
263
+ if model_ema is not None:
264
+ model_ema.update(model)
265
+
266
+ loss_meter.update(loss.item(), targets.size(0))
267
+ norm_meter.update(grad_norm)
268
+ batch_time.update(time.time() - end)
269
+ end = time.time()
270
+
271
+ if dist.get_rank() == 0 and idx%10 == 0:
272
+ writer.add_scalar('Train/train_loss',loss_meter.val, epoch * num_steps + idx )
273
+ writer.add_scalar('Train/grad_norm',norm_meter.val, epoch * num_steps + idx )
274
+ for i, subloss in enumerate(multi_loss):
275
+ writer.add_scalar(f'Train/sub_loss{i}', subloss, epoch * num_steps + idx)
276
+
277
+ if idx % config.PRINT_FREQ == 0:
278
+ lr = optimizer.param_groups[-1]['lr']
279
+ memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
280
+ etas = batch_time.avg * (num_steps - idx)
281
+ logger.info(
282
+ f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
283
+ f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
284
+ f'datatime {data_time.val:.4f} ({data_time.avg:.4f})\t'
285
+ f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
286
+ f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
287
+ f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
288
+ f'mem {memory_used:.0f}MB')
289
+
290
+ epoch_time = time.time() - start
291
+ logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
292
+
293
+
294
+
295
+
296
+ @torch.no_grad()
297
+ def validate(config, data_loader, model, writer, epoch):
298
+ global logger
299
+ criterion = torch.nn.CrossEntropyLoss()
300
+ model.eval()
301
+
302
+ batch_time = AverageMeter()
303
+ loss_meter = AverageMeter()
304
+ acc1_meter_list = []
305
+ acc5_meter_list = []
306
+ for i in range(4):
307
+ acc1_meter_list.append(AverageMeter())
308
+ acc5_meter_list.append(AverageMeter())
309
+
310
+ end = time.time()
311
+ for idx, (images, target) in enumerate(data_loader):
312
+
313
+ images = images.cuda(non_blocking=True)
314
+ target = target.cuda(non_blocking=True)
315
+
316
+ # compute output
317
+ outputs,_ = model(images)
318
+
319
+
320
+ if len(acc1_meter_list) != len(outputs):
321
+ acc1_meter_list = acc1_meter_list[:len(outputs)]
322
+ acc5_meter_list = acc5_meter_list[:len(outputs)]
323
+
324
+ output_last = outputs[-1]
325
+ loss = criterion(output_last, target)
326
+ loss = reduce_tensor(loss)
327
+ loss_meter.update(loss.item(), target.size(0))
328
+
329
+ for i, subnet_out in enumerate(outputs):
330
+ acc1, acc5 = accuracy(subnet_out, target, topk=(1, 5))
331
+
332
+
333
+ acc1 = reduce_tensor(acc1)
334
+ acc5 = reduce_tensor(acc5)
335
+
336
+ acc1_meter_list[i].update(acc1.item(), target.size(0))
337
+ acc5_meter_list[i].update(acc5.item(), target.size(0))
338
+
339
+ # measure elapsed time
340
+ batch_time.update(time.time() - end)
341
+ end = time.time()
342
+
343
+
344
+ if idx % config.PRINT_FREQ == 0:
345
+ memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
346
+ logger.info(
347
+ f'Test: [{idx}/{len(data_loader)}]\t'
348
+ f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
349
+ f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
350
+ f'Acc@1 {acc1_meter_list[-1].val:.3f} ({acc1_meter_list[-1].avg:.3f})\t'
351
+ f'Acc@5 {acc5_meter_list[-1].val:.3f} ({acc5_meter_list[-1].avg:.3f})\t'
352
+ f'Mem {memory_used:.0f}MB')
353
+
354
+ logger.info(f' * Acc@1 {acc1_meter_list[-1].avg:.3f} Acc@5 {acc5_meter_list[-1].avg:.3f}')
355
+ if dist.get_rank() == 0:
356
+ for i in range(len(acc1_meter_list)):
357
+ writer.add_scalar(f'Val_top1/acc_{i}', acc1_meter_list[i].avg, epoch)
358
+ writer.add_scalar(f'Val_top5/acc_{i}', acc5_meter_list[i].avg, epoch)
359
+ return acc1_meter_list[-1].avg, acc5_meter_list[-1].avg, loss_meter.avg
360
+
361
+ @torch.no_grad()
362
+ def validate_ema(config, data_loader, model, writer, epoch):
363
+ global logger
364
+ criterion = torch.nn.CrossEntropyLoss()
365
+ model.eval()
366
+
367
+ batch_time = AverageMeter()
368
+ loss_meter = AverageMeter()
369
+ acc1_meter = AverageMeter()
370
+ acc5_meter = AverageMeter()
371
+
372
+ end = time.time()
373
+ for idx, (images, target) in enumerate(data_loader):
374
+ images = images.cuda(non_blocking=True)
375
+ target = target.cuda(non_blocking=True)
376
+
377
+ outputs,_ = model(images)
378
+
379
+ output_last = outputs[-1]
380
+ loss = criterion(output_last, target)
381
+ loss = reduce_tensor(loss)
382
+ loss_meter.update(loss.item(), target.size(0))
383
+
384
+
385
+ acc1, acc5 = accuracy(output_last, target, topk=(1, 5))
386
+
387
+
388
+ acc1 = reduce_tensor(acc1)
389
+ acc5 = reduce_tensor(acc5)
390
+
391
+ acc1_meter.update(acc1.item(), target.size(0))
392
+ acc5_meter.update(acc5.item(), target.size(0))
393
+
394
+ # measure elapsed time
395
+ batch_time.update(time.time() - end)
396
+ end = time.time()
397
+
398
+ if idx % config.PRINT_FREQ == 0:
399
+ memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
400
+ logger.info(
401
+ f'Test: [{idx}/{len(data_loader)}]\t'
402
+ f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
403
+ f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
404
+ f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
405
+ f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
406
+ f'Mem {memory_used:.0f}MB')
407
+
408
+ logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
409
+
410
+ return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
411
+
412
+
413
+ if __name__ == '__main__':
414
+ _, config = parse_option()
415
+
416
+ cudnn.benchmark = True
417
+
418
+ os.makedirs(config.OUTPUT, exist_ok=True)
419
+
420
+ ngpus_per_node = torch.cuda.device_count()
421
+
422
+ main(None, config, ngpus_per_node)
training/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build import build_model
training/models/build.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from models.revcol import *
10
+
11
+
12
+
13
+ def build_model(config):
14
+ model_type = config.MODEL.TYPE
15
+
16
+ ##-------------------------------------- revcol tiny ----------------------------------------------------------------------------------------------------------------------#
17
+
18
+ if model_type == "revcol_tiny":
19
+ model = revcol_tiny(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES, kernel_size = config.REVCOL.KERNEL_SIZE)
20
+
21
+ ##-------------------------------------- revcol small ----------------------------------------------------------------------------------------------------------------------#
22
+
23
+ elif model_type == "revcol_small":
24
+ model = revcol_small(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES, kernel_size = config.REVCOL.KERNEL_SIZE)
25
+
26
+ ##-------------------------------------- revcol base ----------------------------------------------------------------------------------------------------------------------#
27
+
28
+ elif model_type == "revcol_base":
29
+ model = revcol_base(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , kernel_size = config.REVCOL.KERNEL_SIZE)
30
+
31
+ ##-------------------------------------- revcol large ----------------------------------------------------------------------------------------------------------------------#
32
+
33
+ elif model_type == "revcol_large":
34
+ model = revcol_large(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , head_init_scale=config.REVCOL.HEAD_INIT_SCALE, kernel_size = config.REVCOL.KERNEL_SIZE)
35
+
36
+ ##-------------------------------------- revcol xlarge ----------------------------------------------------------------------------------------------------------------------#
37
+
38
+ elif model_type == "revcol_xlarge":
39
+ model = revcol_xlarge(save_memory=config.REVCOL.SAVEMM, inter_supv=config.REVCOL.INTER_SUPV, drop_path = config.REVCOL.DROP_PATH, num_classes=config.MODEL.NUM_CLASSES , head_init_scale=config.REVCOL.HEAD_INIT_SCALE, kernel_size = config.REVCOL.KERNEL_SIZE)
40
+
41
+ else:
42
+ raise NotImplementedError(f"Unkown model: {model_type}")
43
+
44
+ return model
45
+
46
+
47
+
48
+
training/models/modules.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import imp
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import DropPath
13
+
14
+
15
+ class UpSampleConvnext(nn.Module):
16
+ def __init__(self, ratio, inchannel, outchannel):
17
+ super().__init__()
18
+ self.ratio = ratio
19
+ self.channel_reschedule = nn.Sequential(
20
+ # LayerNorm(inchannel, eps=1e-6, data_format="channels_last"),
21
+ nn.Linear(inchannel, outchannel),
22
+ LayerNorm(outchannel, eps=1e-6, data_format="channels_last"))
23
+ self.upsample = nn.Upsample(scale_factor=2**ratio, mode='nearest')
24
+ def forward(self, x):
25
+ x = x.permute(0, 2, 3, 1)
26
+ x = self.channel_reschedule(x)
27
+ x = x = x.permute(0, 3, 1, 2)
28
+
29
+ return self.upsample(x)
30
+
31
+ class LayerNorm(nn.Module):
32
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
33
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
34
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
35
+ with shape (batch_size, channels, height, width).
36
+ """
37
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_first", elementwise_affine = True):
38
+ super().__init__()
39
+ self.elementwise_affine = elementwise_affine
40
+ if elementwise_affine:
41
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
42
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
43
+ self.eps = eps
44
+ self.data_format = data_format
45
+ if self.data_format not in ["channels_last", "channels_first"]:
46
+ raise NotImplementedError
47
+ self.normalized_shape = (normalized_shape, )
48
+
49
+ def forward(self, x):
50
+ if self.data_format == "channels_last":
51
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
52
+ elif self.data_format == "channels_first":
53
+ u = x.mean(1, keepdim=True)
54
+ s = (x - u).pow(2).mean(1, keepdim=True)
55
+ x = (x - u) / torch.sqrt(s + self.eps)
56
+ if self.elementwise_affine:
57
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
58
+ return x
59
+
60
+
61
+ class ConvNextBlock(nn.Module):
62
+ r""" ConvNeXt Block. There are two equivalent implementations:
63
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
64
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
65
+ We use (2) as we find it slightly faster in PyTorch
66
+
67
+ Args:
68
+ dim (int): Number of input channels.
69
+ drop_path (float): Stochastic depth rate. Default: 0.0
70
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
71
+ """
72
+ def __init__(self, in_channel, hidden_dim, out_channel, kernel_size=3, layer_scale_init_value=1e-6, drop_path= 0.0):
73
+ super().__init__()
74
+ self.dwconv = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=in_channel) # depthwise conv
75
+ self.norm = nn.LayerNorm(in_channel, eps=1e-6)
76
+ self.pwconv1 = nn.Linear(in_channel, hidden_dim) # pointwise/1x1 convs, implemented with linear layers
77
+ self.act = nn.GELU()
78
+ self.pwconv2 = nn.Linear(hidden_dim, out_channel)
79
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channel)),
80
+ requires_grad=True) if layer_scale_init_value > 0 else None
81
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
82
+
83
+ def forward(self, x):
84
+ input = x
85
+ x = self.dwconv(x)
86
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
87
+ x = self.norm(x)
88
+ x = self.pwconv1(x)
89
+ x = self.act(x)
90
+ # print(f"x min: {x.min()}, x max: {x.max()}, input min: {input.min()}, input max: {input.max()}, x mean: {x.mean()}, x var: {x.var()}, ratio: {torch.sum(x>8)/x.numel()}")
91
+ x = self.pwconv2(x)
92
+ if self.gamma is not None:
93
+ x = self.gamma * x
94
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
95
+
96
+ x = input + self.drop_path(x)
97
+ return x
98
+
99
+ class Decoder(nn.Module):
100
+ def __init__(self, depth=[2,2,2,2], dim=[112, 72, 40, 24], block_type = None, kernel_size = 3) -> None:
101
+ super().__init__()
102
+ self.depth = depth
103
+ self.dim = dim
104
+ self.block_type = block_type
105
+ self._build_decode_layer(dim, depth, kernel_size)
106
+ self.projback = nn.Sequential(
107
+ nn.Conv2d(
108
+ in_channels=dim[-1],
109
+ out_channels=4 ** 2 * 3, kernel_size=1),
110
+ nn.PixelShuffle(4),
111
+ )
112
+
113
+ def _build_decode_layer(self, dim, depth, kernel_size):
114
+ normal_layers = nn.ModuleList()
115
+ upsample_layers = nn.ModuleList()
116
+ proj_layers = nn.ModuleList()
117
+
118
+ norm_layer = LayerNorm
119
+
120
+ for i in range(1, len(dim)):
121
+ module = [self.block_type(dim[i], dim[i], dim[i], kernel_size) for _ in range(depth[i])]
122
+ normal_layers.append(nn.Sequential(*module))
123
+ upsample_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
124
+ proj_layers.append(nn.Sequential(
125
+ nn.Conv2d(dim[i-1], dim[i], 1, 1),
126
+ norm_layer(dim[i]),
127
+ nn.GELU()
128
+ ))
129
+ self.normal_layers = normal_layers
130
+ self.upsample_layers = upsample_layers
131
+ self.proj_layers = proj_layers
132
+
133
+ def _forward_stage(self, stage, x):
134
+ x = self.proj_layers[stage](x)
135
+ x = self.upsample_layers[stage](x)
136
+ return self.normal_layers[stage](x)
137
+
138
+ def forward(self, c3):
139
+ x = self._forward_stage(0, c3) #14
140
+ x = self._forward_stage(1, x) #28
141
+ x = self._forward_stage(2, x) #56
142
+ x = self.projback(x)
143
+ return x
144
+
145
+ class SimDecoder(nn.Module):
146
+ def __init__(self, in_channel, encoder_stride) -> None:
147
+ super().__init__()
148
+ self.projback = nn.Sequential(
149
+ LayerNorm(in_channel),
150
+ nn.Conv2d(
151
+ in_channels=in_channel,
152
+ out_channels=encoder_stride ** 2 * 3, kernel_size=1),
153
+ nn.PixelShuffle(encoder_stride),
154
+ )
155
+
156
+ def forward(self, c3):
157
+ return self.projback(c3)
training/models/revcol.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from models.modules import ConvNextBlock, Decoder, LayerNorm, SimDecoder, UpSampleConvnext
11
+ import torch.distributed as dist
12
+ from models.revcol_function import ReverseFunction
13
+ from timm.models.layers import trunc_normal_
14
+
15
+ class Fusion(nn.Module):
16
+ def __init__(self, level, channels, first_col) -> None:
17
+ super().__init__()
18
+
19
+ self.level = level
20
+ self.first_col = first_col
21
+ self.down = nn.Sequential(
22
+ nn.Conv2d(channels[level-1], channels[level], kernel_size=2, stride=2),
23
+ LayerNorm(channels[level], eps=1e-6, data_format="channels_first"),
24
+ ) if level in [1, 2, 3] else nn.Identity()
25
+ if not first_col:
26
+ self.up = UpSampleConvnext(1, channels[level+1], channels[level]) if level in [0, 1, 2] else nn.Identity()
27
+
28
+ def forward(self, *args):
29
+
30
+ c_down, c_up = args
31
+
32
+ if self.first_col:
33
+ x = self.down(c_down)
34
+ return x
35
+
36
+ if self.level == 3:
37
+ x = self.down(c_down)
38
+ else:
39
+ x = self.up(c_up) + self.down(c_down)
40
+ return x
41
+
42
+ class Level(nn.Module):
43
+ def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None:
44
+ super().__init__()
45
+ countlayer = sum(layers[:level])
46
+ expansion = 4
47
+ self.fusion = Fusion(level, channels, first_col)
48
+ modules = [ConvNextBlock(channels[level], expansion*channels[level], channels[level], kernel_size = kernel_size, layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer+i]) for i in range(layers[level])]
49
+ self.blocks = nn.Sequential(*modules)
50
+ def forward(self, *args):
51
+ x = self.fusion(*args)
52
+ x = self.blocks(x)
53
+ return x
54
+
55
+ class SubNet(nn.Module):
56
+ def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None:
57
+ super().__init__()
58
+ shortcut_scale_init_value = 0.5
59
+ self.save_memory = save_memory
60
+ self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)),
61
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
62
+ self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)),
63
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
64
+ self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)),
65
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
66
+ self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)),
67
+ requires_grad=True) if shortcut_scale_init_value > 0 else None
68
+
69
+ self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates)
70
+
71
+ self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates)
72
+
73
+ self.level2 = Level(2, channels, layers, kernel_size,first_col, dp_rates)
74
+
75
+ self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates)
76
+
77
+ def _forward_nonreverse(self, *args):
78
+ x, c0, c1, c2, c3= args
79
+
80
+ c0 = (self.alpha0)*c0 + self.level0(x, c1)
81
+ c1 = (self.alpha1)*c1 + self.level1(c0, c2)
82
+ c2 = (self.alpha2)*c2 + self.level2(c1, c3)
83
+ c3 = (self.alpha3)*c3 + self.level3(c2, None)
84
+ return c0, c1, c2, c3
85
+
86
+ def _forward_reverse(self, *args):
87
+
88
+ local_funs = [self.level0, self.level1, self.level2, self.level3]
89
+ alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3]
90
+ _, c0, c1, c2, c3 = ReverseFunction.apply(
91
+ local_funs, alpha, *args)
92
+
93
+ return c0, c1, c2, c3
94
+
95
+ def forward(self, *args):
96
+
97
+ self._clamp_abs(self.alpha0.data, 1e-3)
98
+ self._clamp_abs(self.alpha1.data, 1e-3)
99
+ self._clamp_abs(self.alpha2.data, 1e-3)
100
+ self._clamp_abs(self.alpha3.data, 1e-3)
101
+
102
+ if self.save_memory:
103
+ return self._forward_reverse(*args)
104
+ else:
105
+ return self._forward_nonreverse(*args)
106
+
107
+ def _clamp_abs(self, data, value):
108
+ with torch.no_grad():
109
+ sign=data.sign()
110
+ data.abs_().clamp_(value)
111
+ data*=sign
112
+
113
+
114
+ class Classifier(nn.Module):
115
+ def __init__(self, in_channels, num_classes):
116
+ super().__init__()
117
+
118
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
119
+ self.classifier = nn.Sequential(
120
+ nn.LayerNorm(in_channels, eps=1e-6), # final norm layer
121
+ nn.Linear(in_channels, num_classes),
122
+ )
123
+
124
+ def forward(self, x):
125
+ x = self.avgpool(x)
126
+ x = x.view(x.size(0), -1)
127
+ x = self.classifier(x)
128
+ return x
129
+
130
+ class FullNet(nn.Module):
131
+ def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size = 3, num_classes=1000, drop_path = 0.0, save_memory=True, inter_supv=True, head_init_scale=None) -> None:
132
+ super().__init__()
133
+ self.num_subnet = num_subnet
134
+ self.inter_supv = inter_supv
135
+ self.channels = channels
136
+ self.layers = layers
137
+
138
+ self.stem = nn.Sequential(
139
+ nn.Conv2d(3, channels[0], kernel_size=4, stride=4),
140
+ LayerNorm(channels[0], eps=1e-6, data_format="channels_first")
141
+ )
142
+
143
+ dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))]
144
+ for i in range(num_subnet):
145
+ first_col = True if i == 0 else False
146
+ self.add_module(f'subnet{str(i)}', SubNet(
147
+ channels,layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory))
148
+
149
+ if not inter_supv:
150
+ self.cls = Classifier(in_channels=channels[-1], num_classes=num_classes)
151
+ else:
152
+ self.cls_blocks = nn.ModuleList([Classifier(in_channels=channels[-1], num_classes=num_classes) for _ in range(4) ])
153
+ if num_classes<=1000:
154
+ channels.reverse()
155
+ self.decoder_blocks = nn.ModuleList([Decoder(depth=[1,1,1,1], dim=channels, block_type=ConvNextBlock, kernel_size = 3) for _ in range(3) ])
156
+ else:
157
+ self.decoder_blocks = nn.ModuleList([SimDecoder(in_channel=channels[-1], encoder_stride=32) for _ in range(3) ])
158
+
159
+ self.apply(self._init_weights)
160
+
161
+ if head_init_scale:
162
+ print(f'Head_init_scale: {head_init_scale}')
163
+ self.cls.classifier._modules['1'].weight.data.mul_(head_init_scale)
164
+ self.cls.classifier._modules['1'].bias.data.mul_(head_init_scale)
165
+
166
+
167
+ def forward(self, x):
168
+
169
+ if self.inter_supv:
170
+ return self._forward_intermediate_supervision(x)
171
+ else:
172
+ c0, c1, c2, c3 = 0, 0, 0, 0
173
+ x = self.stem(x)
174
+ for i in range(self.num_subnet):
175
+ c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
176
+ return [self.cls(c3)], None
177
+
178
+ def _forward_intermediate_supervision(self, x):
179
+ x_cls_out = []
180
+ x_img_out = []
181
+ c0, c1, c2, c3 = 0, 0, 0, 0
182
+ interval = self.num_subnet//4
183
+
184
+ x = self.stem(x)
185
+ for i in range(self.num_subnet):
186
+ c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3)
187
+ if (i+1) % interval == 0:
188
+ x_cls_out.append(self.cls_blocks[i//interval](c3))
189
+ if i != self.num_subnet-1:
190
+ x_img_out.append(self.decoder_blocks[i//interval](c3))
191
+
192
+ return x_cls_out, x_img_out
193
+
194
+
195
+ def _init_weights(self, module):
196
+ if isinstance(module, nn.Conv2d):
197
+ trunc_normal_(module.weight, std=.02)
198
+ nn.init.constant_(module.bias, 0)
199
+ elif isinstance(module, nn.Linear):
200
+ trunc_normal_(module.weight, std=.02)
201
+ nn.init.constant_(module.bias, 0)
202
+
203
+ ##-------------------------------------- Tiny -----------------------------------------
204
+
205
+ def revcol_tiny(save_memory, inter_supv=True, drop_path=0.1, num_classes=1000, kernel_size = 3):
206
+ channels = [64, 128, 256, 512]
207
+ layers = [2, 2, 4, 2]
208
+ num_subnet = 4
209
+ return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size)
210
+
211
+ ##-------------------------------------- Small -----------------------------------------
212
+
213
+ def revcol_small(save_memory, inter_supv=True, drop_path=0.3, num_classes=1000, kernel_size = 3):
214
+ channels = [64, 128, 256, 512]
215
+ layers = [2, 2, 4, 2]
216
+ num_subnet = 8
217
+ return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size)
218
+
219
+ ##-------------------------------------- Base -----------------------------------------
220
+
221
+ def revcol_base(save_memory, inter_supv=True, drop_path=0.4, num_classes=1000, kernel_size = 3, head_init_scale=None):
222
+ channels = [72, 144, 288, 576]
223
+ layers = [1, 1, 3, 2]
224
+ num_subnet = 16
225
+ return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)
226
+
227
+
228
+ ##-------------------------------------- Large -----------------------------------------
229
+
230
+ def revcol_large(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None):
231
+ channels = [128, 256, 512, 1024]
232
+ layers = [1, 2, 6, 2]
233
+ num_subnet = 8
234
+ return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)
235
+
236
+ ##--------------------------------------Extra-Large -----------------------------------------
237
+ def revcol_xlarge(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None):
238
+ channels = [224, 448, 896, 1792]
239
+ layers = [1, 2, 6, 2]
240
+ num_subnet = 8
241
+ return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size)
242
+
training/models/revcol_function.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from typing import Any, Iterable, List, Tuple, Callable
10
+ import torch.distributed as dist
11
+
12
+ def get_gpu_states(fwd_gpu_devices) -> Tuple[List[int], List[torch.Tensor]]:
13
+ # This will not error out if "arg" is a CPU tensor or a non-tensor type because
14
+ # the conditionals short-circuit.
15
+ fwd_gpu_states = []
16
+ for device in fwd_gpu_devices:
17
+ with torch.cuda.device(device):
18
+ fwd_gpu_states.append(torch.cuda.get_rng_state())
19
+
20
+ return fwd_gpu_states
21
+
22
+ def get_gpu_device(*args):
23
+
24
+ fwd_gpu_devices = list(set(arg.get_device() for arg in args
25
+ if isinstance(arg, torch.Tensor) and arg.is_cuda))
26
+ return fwd_gpu_devices
27
+
28
+ def set_device_states(fwd_cpu_state, devices, states) -> None:
29
+ torch.set_rng_state(fwd_cpu_state)
30
+ for device, state in zip(devices, states):
31
+ with torch.cuda.device(device):
32
+ torch.cuda.set_rng_state(state)
33
+
34
+ def detach_and_grad(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
35
+ if isinstance(inputs, tuple):
36
+ out = []
37
+ for inp in inputs:
38
+ if not isinstance(inp, torch.Tensor):
39
+ out.append(inp)
40
+ continue
41
+
42
+ x = inp.detach()
43
+ x.requires_grad = True
44
+ out.append(x)
45
+ return tuple(out)
46
+ else:
47
+ raise RuntimeError(
48
+ "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
49
+
50
+ def get_cpu_and_gpu_states(gpu_devices):
51
+ return torch.get_rng_state(), get_gpu_states(gpu_devices)
52
+
53
+ class ReverseFunction(torch.autograd.Function):
54
+ @staticmethod
55
+ def forward(ctx, run_functions, alpha, *args):
56
+ l0, l1, l2, l3 = run_functions
57
+ alpha0, alpha1, alpha2, alpha3 = alpha
58
+ ctx.run_functions = run_functions
59
+ ctx.alpha = alpha
60
+ ctx.preserve_rng_state = True
61
+
62
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
63
+ "dtype": torch.get_autocast_gpu_dtype(),
64
+ "cache_enabled": torch.is_autocast_cache_enabled()}
65
+ ctx.cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
66
+ "dtype": torch.get_autocast_cpu_dtype(),
67
+ "cache_enabled": torch.is_autocast_cache_enabled()}
68
+
69
+ assert len(args) == 5
70
+ [x, c0, c1, c2, c3] = args
71
+ if type(c0) == int:
72
+ ctx.first_col = True
73
+ else:
74
+ ctx.first_col = False
75
+ with torch.no_grad():
76
+ gpu_devices = get_gpu_device(*args)
77
+ ctx.gpu_devices = gpu_devices
78
+ ctx.cpu_states_0, ctx.gpu_states_0 = get_cpu_and_gpu_states(gpu_devices)
79
+ c0 = l0(x, c1) + c0*alpha0
80
+ ctx.cpu_states_1, ctx.gpu_states_1 = get_cpu_and_gpu_states(gpu_devices)
81
+ c1 = l1(c0, c2) + c1*alpha1
82
+ ctx.cpu_states_2, ctx.gpu_states_2 = get_cpu_and_gpu_states(gpu_devices)
83
+ c2 = l2(c1, c3) + c2*alpha2
84
+ ctx.cpu_states_3, ctx.gpu_states_3 = get_cpu_and_gpu_states(gpu_devices)
85
+ c3 = l3(c2, None) + c3*alpha3
86
+ ctx.save_for_backward(x, c0, c1, c2, c3)
87
+ return x, c0, c1 ,c2, c3
88
+
89
+ @staticmethod
90
+ def backward(ctx, *grad_outputs):
91
+ x, c0, c1, c2, c3 = ctx.saved_tensors
92
+ l0, l1, l2, l3 = ctx.run_functions
93
+ alpha0, alpha1, alpha2, alpha3 = ctx.alpha
94
+ gx_right, g0_right, g1_right, g2_right, g3_right = grad_outputs
95
+ (x, c0, c1, c2, c3) = detach_and_grad((x, c0, c1, c2, c3))
96
+
97
+ with torch.enable_grad(), \
98
+ torch.random.fork_rng(devices=ctx.gpu_devices, enabled=ctx.preserve_rng_state), \
99
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
100
+ torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
101
+
102
+ g3_up = g3_right
103
+ g3_left = g3_up*alpha3 ##shortcut
104
+ set_device_states(ctx.cpu_states_3, ctx.gpu_devices, ctx.gpu_states_3)
105
+ oup3 = l3(c2, None)
106
+ torch.autograd.backward(oup3, g3_up, retain_graph=True)
107
+ with torch.no_grad():
108
+ c3_left = (1/alpha3)*(c3 - oup3) ## feature reverse
109
+ g2_up = g2_right+ c2.grad
110
+ g2_left = g2_up*alpha2 ##shortcut
111
+
112
+ (c3_left,) = detach_and_grad((c3_left,))
113
+ set_device_states(ctx.cpu_states_2, ctx.gpu_devices, ctx.gpu_states_2)
114
+ oup2 = l2(c1, c3_left)
115
+ torch.autograd.backward(oup2, g2_up, retain_graph=True)
116
+ c3_left.requires_grad = False
117
+ cout3 = c3_left*alpha3 ##alpha3 update
118
+ torch.autograd.backward(cout3, g3_up)
119
+
120
+ with torch.no_grad():
121
+ c2_left = (1/alpha2)*(c2 - oup2) ## feature reverse
122
+ g3_left = g3_left + c3_left.grad if c3_left.grad is not None else g3_left
123
+ g1_up = g1_right+c1.grad
124
+ g1_left = g1_up*alpha1 ##shortcut
125
+
126
+ (c2_left,) = detach_and_grad((c2_left,))
127
+ set_device_states(ctx.cpu_states_1, ctx.gpu_devices, ctx.gpu_states_1)
128
+ oup1 = l1(c0, c2_left)
129
+ torch.autograd.backward(oup1, g1_up, retain_graph=True)
130
+ c2_left.requires_grad = False
131
+ cout2 = c2_left*alpha2 ##alpha2 update
132
+ torch.autograd.backward(cout2, g2_up)
133
+
134
+ with torch.no_grad():
135
+ c1_left = (1/alpha1)*(c1 - oup1) ## feature reverse
136
+ g0_up = g0_right + c0.grad
137
+ g0_left = g0_up*alpha0 ##shortcut
138
+ g2_left = g2_left + c2_left.grad if c2_left.grad is not None else g2_left ## Fusion
139
+
140
+ (c1_left,) = detach_and_grad((c1_left,))
141
+ set_device_states(ctx.cpu_states_0, ctx.gpu_devices, ctx.gpu_states_0)
142
+ oup0 = l0(x, c1_left)
143
+ torch.autograd.backward(oup0, g0_up, retain_graph=True)
144
+ c1_left.requires_grad = False
145
+ cout1 = c1_left*alpha1 ##alpha1 update
146
+ torch.autograd.backward(cout1, g1_up)
147
+
148
+ with torch.no_grad():
149
+ c0_left = (1/alpha0)*(c0 - oup0) ## feature reverse
150
+ gx_up = x.grad ## Fusion
151
+ g1_left = g1_left + c1_left.grad if c1_left.grad is not None else g1_left ## Fusion
152
+ c0_left.requires_grad = False
153
+ cout0 = c0_left*alpha0 ##alpha0 update
154
+ torch.autograd.backward(cout0, g0_up)
155
+
156
+ if ctx.first_col:
157
+ return None, None, gx_up, None, None, None, None
158
+ else:
159
+ return None, None, gx_up, g0_left, g1_left, g2_left, g3_left
training/optimizer.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+ from torch import optim as optim
10
+
11
+ def build_optimizer(config, model):
12
+ """
13
+ Build optimizer, set weight decay of normalization to 0 by default.
14
+ """
15
+ skip = {}
16
+ skip_keywords = {}
17
+ if hasattr(model, 'no_weight_decay'):
18
+ skip = model.no_weight_decay()
19
+ if hasattr(model, 'no_weight_decay_keywords'):
20
+ skip_keywords = model.no_weight_decay_keywords()
21
+
22
+ elif config.MODEL.TYPE.startswith("revcol"):
23
+ parameters = param_groups_lrd(model, weight_decay=config.TRAIN.WEIGHT_DECAY, no_weight_decay_list=[], layer_decay=config.TRAIN.OPTIMIZER.LAYER_DECAY)
24
+ else:
25
+ parameters = set_weight_decay(model, skip, skip_keywords)
26
+
27
+
28
+ opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
29
+ optimizer = None
30
+ if opt_lower == 'sgd':
31
+ optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
32
+ lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
33
+ elif opt_lower == 'adamw':
34
+ optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
35
+ lr=config.TRAIN.BASE_LR)
36
+
37
+ return optimizer
38
+
39
+
40
+ def set_weight_decay(model, skip_list=(), skip_keywords=()):
41
+ has_decay = []
42
+ no_decay = []
43
+
44
+ for name, param in model.named_parameters():
45
+ if not param.requires_grad or name in ["linear_eval.weight", "linear_eval.bias"]:
46
+ continue # frozen weights
47
+ if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
48
+ check_keywords_in_name(name, skip_keywords):
49
+ no_decay.append(param)
50
+ # print(f"{name} has no weight decay")
51
+ else:
52
+ has_decay.append(param)
53
+ return [{'params': has_decay},
54
+ {'params': no_decay, 'weight_decay': 0.}]
55
+
56
+
57
+ def check_keywords_in_name(name, keywords=()):
58
+ isin = False
59
+ for keyword in keywords:
60
+ if keyword in name:
61
+ isin = True
62
+ return isin
63
+
64
+ def cal_model_depth(columns, layers):
65
+ depth = sum(layers)
66
+ dp = np.zeros((depth, columns))
67
+ dp[:,0]=np.linspace(0, depth-1, depth)
68
+ dp[0,:]=np.linspace(0, columns-1, columns)
69
+ for i in range(1, depth):
70
+ for j in range(1, columns):
71
+ dp[i][j] = min(dp[i][j-1], dp[i-1][j])+1
72
+ dp = dp.astype(int)
73
+ return dp
74
+
75
+
76
+ def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
77
+ """
78
+ Parameter groups for layer-wise lr decay
79
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
80
+ """
81
+ param_group_names = {}
82
+ param_groups = {}
83
+ dp = cal_model_depth(model.num_subnet, model.layers)+1
84
+ num_layers = dp[-1][-1] + 1
85
+
86
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
87
+
88
+ for n, p in model.named_parameters():
89
+ if not p.requires_grad:
90
+ continue
91
+
92
+ # no decay: all 1D parameters and model specific ones
93
+ if p.ndim == 1 or n in no_weight_decay_list:# or re.match('(.*).alpha.$', n):
94
+ g_decay = "no_decay"
95
+ this_decay = 0.
96
+ else:
97
+ g_decay = "decay"
98
+ this_decay = weight_decay
99
+
100
+ layer_id = get_layer_id(n, dp, model.layers)
101
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
102
+
103
+ if group_name not in param_group_names:
104
+ this_scale = layer_scales[layer_id]
105
+
106
+ param_group_names[group_name] = {
107
+ "lr_scale": this_scale,
108
+ "weight_decay": this_decay,
109
+ "params": [],
110
+ }
111
+ param_groups[group_name] = {
112
+ "lr_scale": this_scale,
113
+ "weight_decay": this_decay,
114
+ "params": [],
115
+ }
116
+
117
+ param_group_names[group_name]["params"].append(n)
118
+ param_groups[group_name]["params"].append(p)
119
+
120
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
121
+
122
+ return list(param_groups.values())
123
+
124
+ def get_layer_id(n, dp, layers):
125
+ if n.startswith("subnet"):
126
+ name_part = n.split('.')
127
+ subnet = int(name_part[0][6:])
128
+ if name_part[1].startswith("alpha"):
129
+ id = dp[0][subnet]
130
+ else:
131
+ level = int(name_part[1][-1])
132
+ if name_part[2].startswith("blocks"):
133
+ sub = int(name_part[3])
134
+ if sub>layers[level]-1:
135
+ sub = layers[level]-1
136
+ block = sum(layers[:level])+sub
137
+
138
+ if name_part[2].startswith("fusion"):
139
+ block = sum(layers[:level])
140
+ id = dp[block][subnet]
141
+ elif n.startswith("stem"):
142
+ id = 0
143
+ else:
144
+ id = dp[-1][-1]+1
145
+ return id
training/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fvcore==0.1.5.post20211023
2
+ numpy==1.20.3
3
+ opencv_python==4.4.0.46
4
+ termcolor==1.1.0
5
+ timm==0.5.4
6
+ yacs==0.1.8
7
+ tensorboard
training/utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Reversible Column Networks
3
+ # Copyright (c) 2022 Megvii Inc.
4
+ # Licensed under The Apache License 2.0 [see LICENSE for details]
5
+ # Written by Yuxuan Cai
6
+ # --------------------------------------------------------
7
+
8
+ import io
9
+ import os
10
+ import re
11
+ from typing import List
12
+ from timm.utils.model_ema import ModelEma
13
+ import torch
14
+ import torch.distributed as dist
15
+ from timm.utils import get_state_dict
16
+ import subprocess
17
+
18
+
19
+
20
+
21
+ def load_checkpoint(config, model, optimizer, logger, model_ema=None):
22
+ logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
23
+ if config.MODEL.RESUME.startswith('https'):
24
+ checkpoint = torch.hub.load_state_dict_from_url(
25
+ config.MODEL.RESUME, map_location='cpu', check_hash=True)
26
+ else:
27
+ checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
28
+ logger.info("Already loaded checkpoint to memory..")
29
+ msg = model.load_state_dict(checkpoint['model'], strict=False)
30
+ logger.info(msg)
31
+ max_accuracy = 0.0
32
+ if config.MODEL_EMA:
33
+ if 'state_dict_ema' in checkpoint.keys():
34
+ model_ema.ema.load_state_dict(checkpoint['state_dict_ema'], strict=False)
35
+
36
+ logger.info("Loaded state_dict_ema")
37
+ else:
38
+ model_ema.ema.load_state_dict(checkpoint['model'], strict=False)
39
+ logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
40
+
41
+ if not config.EVAL_MODE and 'optimizer' in checkpoint and 'epoch' in checkpoint:
42
+ optimizer.load_state_dict(checkpoint['optimizer'])
43
+ config.defrost()
44
+ config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
45
+ config.freeze()
46
+
47
+ logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
48
+ if 'max_accuracy' in checkpoint:
49
+ max_accuracy = checkpoint['max_accuracy']
50
+ # del checkpoint
51
+ # torch.cuda.empty_cache()
52
+ return max_accuracy
53
+
54
+ def load_checkpoint_finetune(config, model, logger, model_ema=None):
55
+ logger.info(f"==============> Finetune {config.MODEL.FINETUNE}....................")
56
+ checkpoint = torch.load(config.MODEL.FINETUNE, map_location='cpu')['model']
57
+ converted_weights = {}
58
+ keys = list(checkpoint.keys())
59
+ for key in keys:
60
+ if re.match(r'cls.*', key):
61
+ # if re.match(r'cls.classifier.1.*', key):
62
+ print(f'key: {key} is used for pretrain, discarded.')
63
+ continue
64
+ else:
65
+ converted_weights[key] = checkpoint[key]
66
+ msg = model.load_state_dict(converted_weights, strict=False)
67
+ logger.info(msg)
68
+ if model_ema is not None:
69
+ ema_msg = model_ema.ema.load_state_dict(converted_weights, strict=False)
70
+ logger.info(f"==============> Loaded Pretraind statedict into EMA....................")
71
+ logger.info(ema_msg)
72
+ del checkpoint
73
+ torch.cuda.empty_cache()
74
+
75
+
76
+ def save_checkpoint(config, epoch, model, epoch_accuracy, max_accuracy, optimizer, logger, model_ema=None):
77
+ if model_ema is not None:
78
+ logger.info("Model EMA is not None...")
79
+ save_state = {'model': model.state_dict(),
80
+ 'optimizer': optimizer.state_dict(),
81
+ 'max_accuracy': max(max_accuracy, epoch_accuracy),
82
+ 'epoch': epoch,
83
+ 'state_dict_ema': get_state_dict(model_ema),
84
+ 'input': input,
85
+ 'config': config}
86
+ else:
87
+ save_state = {'model': model.state_dict(),
88
+ 'optimizer': optimizer.state_dict(),
89
+ 'max_accuracy': max(max_accuracy, epoch_accuracy),
90
+ 'epoch': epoch,
91
+ 'state_dict_ema': None,
92
+ 'input': input,
93
+ 'config': config}
94
+
95
+ save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
96
+ best_path = os.path.join(config.OUTPUT, f'best.pth')
97
+
98
+ logger.info(f"{save_path} saving......")
99
+ torch.save(save_state, save_path)
100
+ if epoch_accuracy>max_accuracy:
101
+ torch.save(save_state, best_path)
102
+ logger.info(f"{save_path} saved !!!")
103
+
104
+
105
+ def get_grad_norm(parameters, norm_type=2):
106
+ if isinstance(parameters, torch.Tensor):
107
+ parameters = [parameters]
108
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
109
+ norm_type = float(norm_type)
110
+ total_norm = 0
111
+ for p in parameters:
112
+ param_norm = p.grad.data.norm(norm_type)
113
+ total_norm += param_norm.item() ** norm_type
114
+ total_norm = total_norm ** (1. / norm_type)
115
+ return total_norm
116
+
117
+
118
+ def auto_resume_helper(output_dir,logger):
119
+ checkpoints = os.listdir(output_dir)
120
+ checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and ckpt.startswith('ckpt_')]
121
+ logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}")
122
+ if len(checkpoints) > 0:
123
+ latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
124
+ logger.info(f"The latest checkpoint founded: {latest_checkpoint}")
125
+ resume_file = latest_checkpoint
126
+ else:
127
+ resume_file = None
128
+ return resume_file
129
+
130
+
131
+ def reduce_tensor(tensor):
132
+ rt = tensor.clone()
133
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
134
+ rt /= dist.get_world_size()
135
+ return rt
136
+
137
+ def denormalize(tensor: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
138
+ """Denormalize a float tensor image with mean and standard deviation.
139
+ This transform does not support PIL Image.
140
+
141
+ .. note::
142
+ This transform acts out of place by default, i.e., it does not mutates the input tensor.
143
+
144
+ See :class:`~torchvision.transforms.Normalize` for more details.
145
+
146
+ Args:
147
+ tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
148
+ mean (sequence): Sequence of means for each channel.
149
+ std (sequence): Sequence of standard deviations for each channel.
150
+ inplace(bool,optional): Bool to make this operation inplace.
151
+
152
+ Returns:
153
+ Tensor: Denormalized Tensor image.
154
+ """
155
+ if not isinstance(tensor, torch.Tensor):
156
+ raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
157
+
158
+ if not tensor.is_floating_point():
159
+ raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype))
160
+
161
+ if tensor.ndim < 3:
162
+ raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
163
+ '{}.'.format(tensor.size()))
164
+
165
+ if not inplace:
166
+ tensor = tensor.clone()
167
+
168
+ dtype = tensor.dtype
169
+ mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
170
+ std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
171
+ if (std == 0).any():
172
+ raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
173
+ if mean.ndim == 1:
174
+ mean = mean.view(-1, 1, 1)
175
+ if std.ndim == 1:
176
+ std = std.view(-1, 1, 1)
177
+ tensor.mul_(std).add_(mean).clip_(0.0, 1.0)
178
+ return tensor
179
+