Upload 13 files
Browse files- Dockerfile +62 -0
- LICENSE +202 -0
- README.md +180 -8
- main.py +814 -0
- pretraining.sh +231 -0
- src/RandAugment.py +506 -0
- src/dataset.py +367 -0
- src/loss.py +244 -0
- src/model.py +607 -0
- src/multicropdataset.py +445 -0
- src/optimizer.py +210 -0
- src/vision_transformer.py +491 -0
- utils.py +583 -0
Dockerfile
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.0-base-ubuntu20.04
|
2 |
+
|
3 |
+
ENV DEBIAN_FRONTEND noninteractive
|
4 |
+
|
5 |
+
ENV CUDNN_VERSION=8.0.5.39-1+cuda11.1
|
6 |
+
ENV NCCL_VERSION=2.7.8-1+cuda11.1
|
7 |
+
|
8 |
+
ARG python=3.8
|
9 |
+
ENV PYTHON_VERSION=${python}
|
10 |
+
|
11 |
+
# Set default shell to /bin/bash
|
12 |
+
SHELL ["/bin/bash", "-cu"]
|
13 |
+
|
14 |
+
RUN apt-get update && apt-get install -y --allow-downgrades \
|
15 |
+
--allow-change-held-packages --no-install-recommends \
|
16 |
+
build-essential \
|
17 |
+
cmake \
|
18 |
+
git \
|
19 |
+
curl \
|
20 |
+
vim \
|
21 |
+
wget \
|
22 |
+
ca-certificates \
|
23 |
+
libcudnn8=${CUDNN_VERSION} \
|
24 |
+
libnccl2=${NCCL_VERSION} \
|
25 |
+
libnccl-dev=${NCCL_VERSION} \
|
26 |
+
libjpeg-dev \
|
27 |
+
libpng-dev \
|
28 |
+
python${PYTHON_VERSION} \
|
29 |
+
python${PYTHON_VERSION}-dev \
|
30 |
+
python${PYTHON_VERSION}-distutils \
|
31 |
+
librdmacm1 \
|
32 |
+
libibverbs1 \
|
33 |
+
ibverbs-providers
|
34 |
+
|
35 |
+
RUN ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
36 |
+
|
37 |
+
RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
|
38 |
+
python get-pip.py && \
|
39 |
+
rm get-pip.py
|
40 |
+
|
41 |
+
RUN /usr/bin/python -m pip install --upgrade pip
|
42 |
+
|
43 |
+
# Install pytorch
|
44 |
+
RUN pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 \
|
45 |
+
-f https://download.pytorch.org/whl/torch_stable.html
|
46 |
+
|
47 |
+
RUN pip install tensorboard==2.5.0
|
48 |
+
RUN pip install tensorboard-data-server==0.6.1
|
49 |
+
RUN pip install tensorboard-plugin-wit==1.8.0
|
50 |
+
RUN pip install tensorboardX==1.8
|
51 |
+
|
52 |
+
RUN pip install timm==0.4.5
|
53 |
+
RUN pip install opencv-contrib-python-headless==4.5.2.54
|
54 |
+
RUN pip install tqdm==4.61.2
|
55 |
+
RUN pip install PyYAML==5.4.1
|
56 |
+
RUN pip install Pillow==8.3.1
|
57 |
+
RUN pip install einops==0.3.0
|
58 |
+
RUN pip install scipy==1.7.1
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright 2022 Garena Online Private Limited
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,8 +1,180 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mugs: A Multi-Granular Self-Supervised Learning Framework
|
2 |
+
|
3 |
+
This is a PyTorch implementation of **Mugs** proposed by our paper "**Mugs: A Multi-Granular Self-Supervised Learning Framework**". [![arXiv](https://img.shields.io/badge/arXiv-2203.14415-b31b1b.svg?style=flat)](http://arxiv.org/abs/2203.14415)
|
4 |
+
|
5 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/mugs-a-multi-granular-self-supervised/self-supervised-image-classification-on)](https://paperswithcode.com/sota/self-supervised-image-classification-on?p=mugs-a-multi-granular-self-supervised)
|
6 |
+
|
7 |
+
<div align="center">
|
8 |
+
<img width="100%" alt="Overall framework of Mugs. " src="./exp_illustration/framework.png">
|
9 |
+
</div>
|
10 |
+
|
11 |
+
**<p align="center">Fig 1. Overall framework of Mugs.** In (a), for each image, two random crops of one image
|
12 |
+
are fed into backbones of student and teacher. Three granular supervisions: 1) instance discrimination supervision, 2) local-group discrimination
|
13 |
+
supervision, and 3) group discrimination supervision, are adopted to learn multi-granular representation. In (b), local-group modules in
|
14 |
+
student/teacher averages all patch tokens, and finds top-k neighbors from memory buffer to aggregate them with the average for obtaining a local-group feature.</p>
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
# Pretrained models on ImageNet-1K
|
21 |
+
|
22 |
+
You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks.
|
23 |
+
**<p align="center">Table 1. KNN and linear probing performance with their corresponding hyper-parameters, logs and model weights.</p>**
|
24 |
+
<table>
|
25 |
+
<tr>
|
26 |
+
<th>arch</th>
|
27 |
+
<th>params</th>
|
28 |
+
<th>pretraining epochs</th>
|
29 |
+
<th>k-nn</th>
|
30 |
+
<th>linear</th>
|
31 |
+
<th colspan="6">download</th>
|
32 |
+
</tr>
|
33 |
+
<tr>
|
34 |
+
<td>ViT-S/16</td>
|
35 |
+
<td>21M</td>
|
36 |
+
<td>100</td>
|
37 |
+
<td>72.3%</td>
|
38 |
+
<td>76.4%</td>
|
39 |
+
<td><a href="https://drive.google.com/file/d/1V2TyArzr7qY93UFglPBHRfYVyAMEfsHR/view?usp=sharing">backbone only</a></td>
|
40 |
+
<td><a href="https://drive.google.com/file/d/1AePcCeUEhK0nb9syQKufqqnhpUEr9Rji/view?usp=sharing">full ckpt</a></td>
|
41 |
+
<td><a href="https://drive.google.com/file/d/17phHQx88f4_xSqkPtIYvUoUE2U-1tovg/view?usp=sharing">args</a></td>
|
42 |
+
<td><a href="https://drive.google.com/file/d/1UBMTB-C3BnNKT5939fhSstHc9H30Vizd/view?usp=sharing">logs</a></td>
|
43 |
+
<td><a href="https://drive.google.com/file/d/1MkXctkgqEXjWWRs4Cz5CyTTx_IHDOP4G/view?usp=sharing">eval logs</a></td>
|
44 |
+
</tr>
|
45 |
+
<tr>
|
46 |
+
<td>ViT-S/16</td>
|
47 |
+
<td>21M</td>
|
48 |
+
<td>300</td>
|
49 |
+
<td>74.8%</td>
|
50 |
+
<td>78.2%</td>
|
51 |
+
<td><a href="https://drive.google.com/file/d/1ZAPQ0HiDZO5Uk7jVqF46H6VbGxunZkuf/view?usp=sharing">backbone only</a></td>
|
52 |
+
<td><a href="https://drive.google.com/file/d/1EO-_kYlAt23qgFYZF2u-KLks5js9LvrZ/view?usp=sharing">full ckpt</a></td>
|
53 |
+
<td><a href="https://drive.google.com/file/d/1b6zLZ3r_mZbk17SvhJIZF2VCoYVbJUnU/view?usp=sharing">args</a></td>
|
54 |
+
<td><a href="https://drive.google.com/file/d/1L7VzH1rztoraBCBNVWL-Y8k7Y8PFU773/view?usp=sharing">logs</a></td>
|
55 |
+
<td><a href="https://drive.google.com/file/d/1KgnX8ReXIVsu65_-p7NWPH8S0HEDPMUU/view?usp=sharing">eval logs</a></td>
|
56 |
+
</tr>
|
57 |
+
<tr>
|
58 |
+
<td>ViT-S/16</td>
|
59 |
+
<td>21M</td>
|
60 |
+
<td>800</td>
|
61 |
+
<td>75.6%</td>
|
62 |
+
<td>78.9%</td>
|
63 |
+
<td><a href="https://drive.google.com/file/d/1KMdhxxWc2JXAiFqVxX584V4RvlJgckGq/view?usp=sharing">backbone only</a></td>
|
64 |
+
<td><a href="https://drive.google.com/file/d/1FBaOt0Rjxm6yyJadttOyN6hSh8ueZ0dh/view?usp=sharing">full ckpt</a></td>
|
65 |
+
<td><a href="https://drive.google.com/file/d/19Ma-eSIgdwLoBg6wBXeFiW46zCI2EHvH/view?usp=sharing">args</a></td>
|
66 |
+
<td><a href="https://drive.google.com/file/d/1wX4AUO5NBVZUb8jN1iGBRkS17sszb4_O/view?usp=sharing">logs</a></td>
|
67 |
+
<td><a href="https://drive.google.com/file/d/12tiO4glWZNB044TYiPPCfbnUX_9AbqVc/view?usp=sharing">eval logs</a></td>
|
68 |
+
</tr>
|
69 |
+
<tr>
|
70 |
+
<td>ViT-B/16</td>
|
71 |
+
<td>85M</td>
|
72 |
+
<td>400</td>
|
73 |
+
<td>78.0%</td>
|
74 |
+
<td>80.6%</td>
|
75 |
+
<td><a href="https://drive.google.com/file/d/13NUziwToBXBmS7n7V_1Z5N6EG_7bcncW/view?usp=sharing">backbone only</a></td>
|
76 |
+
<td><a href="https://drive.google.com/file/d/1M41TVVFyVRDTK5kbgLCEImrxw0AVtebb/view?usp=sharing">full ckpt</a></td>
|
77 |
+
<td><a href="https://drive.google.com/file/d/1-5fB5ZCVQAfxTXZ6ro56AVkhb3whpaJc/view?usp=sharing">args</a></td>
|
78 |
+
<td><a href="https://drive.google.com/file/d/11RlCx6eViRnFD6gBlr_lOOxOhu-L6l6D/view?usp=sharing">logs</a></td>
|
79 |
+
<td><a href="https://drive.google.com/file/d/1gOR250QFLZfe40pLNPcOqaLPAnKLuE_C/view?usp=sharing">eval logs</a></td>
|
80 |
+
</tr>
|
81 |
+
<tr>
|
82 |
+
<td>ViT-L/16</td>
|
83 |
+
<td>307M</td>
|
84 |
+
<td>250</td>
|
85 |
+
<td>80.3%</td>
|
86 |
+
<td>82.1%</td>
|
87 |
+
<td><a href="https://drive.google.com/file/d/1K76a-YnFYcmDXUZ_UlYVYFrWOt2a6733/view?usp=sharing">backbone only</a></td>
|
88 |
+
<td><a href="https://drive.google.com/file/d/1Q5Ukvucx44YawyOhMEAY13Ppb8OOWOAB/view?usp=sharing">full ckpt</a></td>
|
89 |
+
<td><a href="https://drive.google.com/file/d/1p8XhaA2_Zbejm__UT8iNKG8r5tzS9c6c/view?usp=sharing">args</a></td>
|
90 |
+
<td><a href="https://drive.google.com/file/d/1JLVcUNfkyBI0BcMm7OpNU_3KTxIABK0Z/view?usp=sharing">logs</a></td>
|
91 |
+
<td><a href="https://drive.google.com/file/d/1rqWenRFN0czat_55GY9GNOu7gS6fww3g/view?usp=sharing">eval logs</a></td>
|
92 |
+
</tr>
|
93 |
+
</table>
|
94 |
+
|
95 |
+
<div align="center">
|
96 |
+
<img width="100%" alt="Comparison of linear probing accuracy on ImageNet-1K." src="./exp_illustration/comparison.png">
|
97 |
+
</div>
|
98 |
+
|
99 |
+
**<p align="center">Fig 2. Comparison of linear probing accuracy on ImageNet-1K.**</p>
|
100 |
+
|
101 |
+
## Pretraining Settings
|
102 |
+
|
103 |
+
### Environment
|
104 |
+
For reproducing, please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset.
|
105 |
+
This codebase has been developed with python version 3.8, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. For the full
|
106 |
+
environment, please refer to our `Dockerfile` file.
|
107 |
+
|
108 |
+
|
109 |
+
### ViT pretraining :beer:
|
110 |
+
To pretraining each model, please find the exact hyper-parameter settings at the `args` column of [Table 1](https://github.com/sail-sg/mugs). For training log and linear probing log, please refer to the
|
111 |
+
`log` and `eval logs` column of [Table 1](https://github.com/sail-sg/mugs).
|
112 |
+
|
113 |
+
#### ViT-Small pretraining:
|
114 |
+
To run ViT-small for 100 epochs, we use two nodes of total 8 A100 GPUs (total 512 minibatch size) by using following command:
|
115 |
+
```
|
116 |
+
python -m torch.distributed.launch --nproc_per_node=8 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small
|
117 |
+
--group_teacher_temp 0.04 --group_warmup_teacher_temp_epochs 0 --weight_decay_end 0.2 --norm_last_layer false --epochs 100
|
118 |
+
```
|
119 |
+
To run ViT-small for 300 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:
|
120 |
+
```
|
121 |
+
python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small
|
122 |
+
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 300
|
123 |
+
```
|
124 |
+
To run ViT-small for 800 epochs, we use two nodes of total 16 A100 GPUs (total 1024 minibatch size) by using following command:
|
125 |
+
```
|
126 |
+
python -m torch.distributed.launch --nproc_per_node=16 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_small
|
127 |
+
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 30 --weight_decay_end 0.1 --norm_last_layer false --epochs 800
|
128 |
+
```
|
129 |
+
|
130 |
+
#### ViT-Base pretraining:
|
131 |
+
To run ViT-base for 400 epochs, we use two nodes of total 24 A100 GPUs (total 1024 minibatch size) by using following command:
|
132 |
+
```
|
133 |
+
python -m torch.distributed.launch --nproc_per_node=24 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_base
|
134 |
+
--group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --min_lr 2e-06 --weight_decay_end 0.1 --freeze_last_layer 3 --norm_last_layer
|
135 |
+
false --epochs 400
|
136 |
+
```
|
137 |
+
|
138 |
+
#### ViT-Large pretraining:
|
139 |
+
To run ViT-large for 250 epochs, we use two nodes of total 40 A100 GPUs (total 640 minibatch size) by using following command:
|
140 |
+
```
|
141 |
+
python -m torch.distributed.launch --nproc_per_node=40 main.py --data_path DATASET_ROOT --output_dir OUTPUT_ROOT --arch vit_large
|
142 |
+
--lr 0.0015 --min_lr 1.5e-4 --group_teacher_temp 0.07 --group_warmup_teacher_temp_epochs 50 --weight_decay 0.025
|
143 |
+
--weight_decay_end 0.08 --norm_last_layer true --drop_path_rate 0.3 --freeze_last_layer 3 --epochs 250
|
144 |
+
```
|
145 |
+
|
146 |
+
## Evaluation
|
147 |
+
We are cleaning up the evalutation code and will release them when they are ready.
|
148 |
+
|
149 |
+
## Self-attention visualization
|
150 |
+
Here we provide the self-attention map of the [CLS] token on the heads of the last layer
|
151 |
+
<div align="center">
|
152 |
+
<img width="100%" alt="Self-attention from a ViT-Base/16 trained with Mugs" src="./exp_illustration/attention_vis.png">
|
153 |
+
</div>
|
154 |
+
|
155 |
+
**<p align="center">Fig 3. Self-attention from a ViT-Base/16 trained with Mugs.**</p>
|
156 |
+
|
157 |
+
|
158 |
+
## T-SNE visualization
|
159 |
+
Here we provide the T-SNE visualization of the learned feature by ViT-B/16.
|
160 |
+
We show the fish classes in ImageNet-1K, i.e., the first six classes,
|
161 |
+
including tench, goldfish, white shark, tiger shark, hammerhead, electric
|
162 |
+
ray. See more examples in Appendix.
|
163 |
+
<div align="center">
|
164 |
+
<img width="100%" alt="T-SNE visualization of the learned feature by ViT-B/16." src="./exp_illustration/TSNE.png">
|
165 |
+
</div>
|
166 |
+
|
167 |
+
**<p align="center">Fig 4. T-SNE visualization of the learned feature by ViT-B/16.**</p>
|
168 |
+
|
169 |
+
## License
|
170 |
+
This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
|
171 |
+
|
172 |
+
## Citation
|
173 |
+
If you find this repository useful, please consider giving a star :star: and citation :beer::
|
174 |
+
```
|
175 |
+
@inproceedings{mugs2022SSL,
|
176 |
+
title={Mugs: A Multi-Granular Self-Supervised Learning Framework},
|
177 |
+
author={Pan Zhou and Yichen Zhou and Chenyang Si and Weihao Yu and Teck Khim Ng and Shuicheng Yan},
|
178 |
+
booktitle={arXiv preprint arXiv:2203.14415},
|
179 |
+
year={2022}
|
180 |
+
}
|
main.py
ADDED
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Mugs training code
|
16 |
+
"""
|
17 |
+
import argparse
|
18 |
+
import datetime
|
19 |
+
import json
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
import time
|
24 |
+
from collections import OrderedDict
|
25 |
+
from pathlib import Path
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.backends.cudnn as cudnn
|
29 |
+
import torch.nn as nn
|
30 |
+
from torchvision import models as torchvision_models
|
31 |
+
|
32 |
+
import utils
|
33 |
+
from src.loss import get_multi_granular_loss
|
34 |
+
from src.model import get_model
|
35 |
+
from src.multicropdataset import data_prefetcher, get_dataset
|
36 |
+
from src.optimizer import cancel_gradients_last_layer, get_optimizer, clip_gradients
|
37 |
+
|
38 |
+
torchvision_archs = sorted(
|
39 |
+
name
|
40 |
+
for name in torchvision_models.__dict__
|
41 |
+
if name.islower()
|
42 |
+
and not name.startswith("__")
|
43 |
+
and callable(torchvision_models.__dict__[name])
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def get_args_parser():
|
48 |
+
parser = argparse.ArgumentParser("Mugs", add_help=False)
|
49 |
+
|
50 |
+
##======== Model parameters ============
|
51 |
+
parser.add_argument(
|
52 |
+
"--arch",
|
53 |
+
type=str,
|
54 |
+
default="vit_small",
|
55 |
+
choices=["vit_small", "vit_base", "vit_large"],
|
56 |
+
help="""Name of architecture to train.""",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--patch_size",
|
60 |
+
type=int,
|
61 |
+
default=16,
|
62 |
+
help="""Size in pixels
|
63 |
+
of input square patches - default 16 (for 16x16 patches). Using smaller
|
64 |
+
values leads to better performance but requires more memory. Applies only
|
65 |
+
for ViTs (vit_small and vit_base). If <16, we recommend disabling
|
66 |
+
mixed precision training (--use_fp16 false) to avoid unstabilities.""",
|
67 |
+
)
|
68 |
+
|
69 |
+
##======== Training/Optimization parameters ============
|
70 |
+
parser.add_argument(
|
71 |
+
"--momentum_teacher",
|
72 |
+
type=float,
|
73 |
+
default=0.996,
|
74 |
+
help="""Base EMA
|
75 |
+
parameter for teacher update. The value is increased to 1 during training with
|
76 |
+
cosine schedule. We recommend setting a higher value with small batches: for
|
77 |
+
example use 0.9995 with batch size of 256.""",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--use_fp16",
|
81 |
+
type=utils.bool_flag,
|
82 |
+
default=False,
|
83 |
+
help="""Whether or not
|
84 |
+
to use half precision for training. Improves training time and memory requirements,
|
85 |
+
but can provoke instability and slight decay of performance. We recommend disabling
|
86 |
+
mixed precision if the loss is unstable, if reducing the patch size or if training
|
87 |
+
with bigger ViTs.""",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--weight_decay",
|
91 |
+
type=float,
|
92 |
+
default=0.04,
|
93 |
+
help="""Initial value of the
|
94 |
+
weight decay. With ViT, a smaller value at the beginning of training works well.""",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--weight_decay_end",
|
98 |
+
type=float,
|
99 |
+
default=0.2,
|
100 |
+
help="""Final value of the
|
101 |
+
weight decay. We use a cosine schedule for WD and using a larger decay by
|
102 |
+
the end of training improves performance for ViTs.""",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--clip_grad",
|
106 |
+
type=float,
|
107 |
+
default=3.0,
|
108 |
+
help="""Maximal parameter
|
109 |
+
gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
|
110 |
+
help optimization for larger ViT architectures. 0 for disabling.""",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--batch_size_per_gpu",
|
114 |
+
type=int,
|
115 |
+
default=64,
|
116 |
+
help="Per-GPU batch-size : number of distinct images loaded on one GPU.",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--epochs", type=int, default=100, help="Number of epochs of training."
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--warmup_epochs",
|
123 |
+
default=10,
|
124 |
+
type=int,
|
125 |
+
help="""Number of epochs for the linear learning-rate warm up.=""",
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--freeze_last_layer",
|
129 |
+
type=int,
|
130 |
+
default=1,
|
131 |
+
help="""Number of epochs during
|
132 |
+
which we keep the output layer fixed for the group supervision loss. Typically doing so during
|
133 |
+
the first epoch helps training. Try increasing this value if the loss does not decrease.""",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"--lr",
|
137 |
+
type=float,
|
138 |
+
default=0.0008,
|
139 |
+
help="""Learning rate at the end of
|
140 |
+
linear warmup (highest LR used during training). The learning rate is linearly scaled
|
141 |
+
with the batch size, and specified here for a reference batch size of 256.""",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--patch_embed_lr_mult",
|
145 |
+
type=float,
|
146 |
+
default=0.2,
|
147 |
+
help="""For patch
|
148 |
+
embedding layer, its learning rate is lr * patch_embed_lr_mult (<1.0) in most case, which
|
149 |
+
stables training and also slightly improve the performance.""",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--min_lr",
|
153 |
+
type=float,
|
154 |
+
default=1e-6,
|
155 |
+
help="""Target LR at the
|
156 |
+
end of optimization. We use a cosine LR schedule with linear warmup.""",
|
157 |
+
)
|
158 |
+
parser.add_argument(
|
159 |
+
"--optimizer",
|
160 |
+
type=str,
|
161 |
+
default="adamw",
|
162 |
+
choices=["adamw", "sgd", "lars"],
|
163 |
+
help="""Type of optimizer. We recommend using adamw
|
164 |
+
with ViTs.""",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--drop_path_rate", type=float, default=0.1, help="""stochastic depth rate"""
|
168 |
+
)
|
169 |
+
|
170 |
+
##======== Multi-granular supervisions (instance/local-group/group supervisions) ==========
|
171 |
+
parser.add_argument(
|
172 |
+
"--loss_weights",
|
173 |
+
type=float,
|
174 |
+
nargs="+",
|
175 |
+
default=[1.0, 1.0, 1.0],
|
176 |
+
help="""three loss weights for instance, local-group, group supervision losses in turn""",
|
177 |
+
)
|
178 |
+
|
179 |
+
parser.add_argument(
|
180 |
+
"--use_bn_in_head",
|
181 |
+
type=utils.bool_flag,
|
182 |
+
default=False,
|
183 |
+
help="Whether to use batch normalizations in the three projection heads (Default: False)",
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--norm_before_pred",
|
187 |
+
type=utils.bool_flag,
|
188 |
+
default=True,
|
189 |
+
help="""Whether to use batch normalizations after projection heads (namely before
|
190 |
+
prediction heads) in instance and local-group supervisions. (Default: False)""",
|
191 |
+
)
|
192 |
+
|
193 |
+
# parameters for instance discrimination supervision
|
194 |
+
parser.add_argument(
|
195 |
+
"--instance_out_dim",
|
196 |
+
type=int,
|
197 |
+
default=256,
|
198 |
+
help="""output dimention in the projection and prediction heads.""",
|
199 |
+
)
|
200 |
+
parser.add_argument(
|
201 |
+
"--instance_queue_size",
|
202 |
+
type=int,
|
203 |
+
default=65536,
|
204 |
+
help="""the queue size of the memory to store the negative keys.""",
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--instance_temp",
|
208 |
+
type=float,
|
209 |
+
default=0.2,
|
210 |
+
help="""the temperature parameters for the infoNCE loss in instance supervision.""",
|
211 |
+
)
|
212 |
+
|
213 |
+
# parameters for local-group discrimination supervision
|
214 |
+
parser.add_argument(
|
215 |
+
"--local_group_out_dim",
|
216 |
+
type=int,
|
217 |
+
default=256,
|
218 |
+
help="""output dimention in the projection and prediction heads.""",
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--local_group_knn_top_n",
|
222 |
+
type=int,
|
223 |
+
default=8,
|
224 |
+
help="how many neighbors we use to aggregate for a local-group",
|
225 |
+
)
|
226 |
+
parser.add_argument(
|
227 |
+
"--local_group_queue_size",
|
228 |
+
type=int,
|
229 |
+
default=65536,
|
230 |
+
help="""the queue sizes of the memory to store the negative keys for infoNCE loss and
|
231 |
+
another memory size to store the weak augmentated samples for local-group aggregation.""",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--local_group_temp",
|
235 |
+
type=float,
|
236 |
+
default=0.2,
|
237 |
+
help="""the temperature parameters for the infoNCE loss in instance supervision.""",
|
238 |
+
)
|
239 |
+
|
240 |
+
## parameters for group discrimination supervision
|
241 |
+
parser.add_argument(
|
242 |
+
"--group_out_dim",
|
243 |
+
type=int,
|
244 |
+
default=65536,
|
245 |
+
help="""output dimention in the prediction heads.""",
|
246 |
+
)
|
247 |
+
parser.add_argument(
|
248 |
+
"--group_bottleneck_dim",
|
249 |
+
type=float,
|
250 |
+
default=256,
|
251 |
+
help="""head bottleneck dimention in the prediction heads.""",
|
252 |
+
)
|
253 |
+
parser.add_argument(
|
254 |
+
"--norm_last_layer",
|
255 |
+
type=utils.bool_flag,
|
256 |
+
default=True,
|
257 |
+
help="""Whether or not to weight normalize the last layer of the group supervision head.
|
258 |
+
Not normalizing leads to better performance but can make the training unstable. We
|
259 |
+
typically set this paramater to False with vit_small and True with vit_base and vit_large.""",
|
260 |
+
)
|
261 |
+
|
262 |
+
parser.add_argument(
|
263 |
+
"--group_student_temp",
|
264 |
+
type=float,
|
265 |
+
default=0.1,
|
266 |
+
help="""the temperature parameters for the clustering loss in student output.""",
|
267 |
+
)
|
268 |
+
parser.add_argument(
|
269 |
+
"--group_warmup_teacher_temp",
|
270 |
+
default=0.04,
|
271 |
+
type=float,
|
272 |
+
help="""Initial value for the teacher temperature: 0.04 works well in most cases.
|
273 |
+
Try decreasing it if the training loss does not decrease.""",
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--group_teacher_temp",
|
277 |
+
default=0.04,
|
278 |
+
type=float,
|
279 |
+
help="""Final value
|
280 |
+
(after linear warmup) of the teacher temperature. For most experiments, anything above
|
281 |
+
0.07 is unstable. We recommend starting with the default value of 0.04 and increase
|
282 |
+
this slightly if needed.""",
|
283 |
+
)
|
284 |
+
parser.add_argument(
|
285 |
+
"--group_warmup_teacher_temp_epochs",
|
286 |
+
default=0,
|
287 |
+
type=int,
|
288 |
+
help="""Number of warmup epochs for the teacher temperature (Default: 30).""",
|
289 |
+
)
|
290 |
+
|
291 |
+
##======== augmentation parameters ============
|
292 |
+
# Multi-crop parameters
|
293 |
+
parser.add_argument(
|
294 |
+
"--global_crops_scale",
|
295 |
+
type=float,
|
296 |
+
nargs="+",
|
297 |
+
default=(0.25, 1.0),
|
298 |
+
help="""Scale range of the cropped image before resizing, relatively to the origin image.
|
299 |
+
Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
|
300 |
+
recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""",
|
301 |
+
)
|
302 |
+
parser.add_argument(
|
303 |
+
"--local_crops_number",
|
304 |
+
type=int,
|
305 |
+
default=10,
|
306 |
+
help="""Number of small
|
307 |
+
local views to generate. Set this parameter to 0 to disable multi-crop training.
|
308 |
+
When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """,
|
309 |
+
)
|
310 |
+
parser.add_argument(
|
311 |
+
"--local_crops_scale",
|
312 |
+
type=float,
|
313 |
+
nargs="+",
|
314 |
+
default=(0.05, 0.25),
|
315 |
+
help="""Scale range of the cropped image before resizing, relatively to the origin image.
|
316 |
+
Used for small local view cropping of multi-crop.""",
|
317 |
+
)
|
318 |
+
# strong augmentation parameters
|
319 |
+
parser.add_argument(
|
320 |
+
"--timm_auto_augment_par",
|
321 |
+
type=str,
|
322 |
+
default="rand-m9-mstd0.5-inc1",
|
323 |
+
help="""the parameters for the AutoAugment used in DeiT.""",
|
324 |
+
)
|
325 |
+
parser.add_argument(
|
326 |
+
"--color_aug",
|
327 |
+
type=utils.bool_flag,
|
328 |
+
default=False,
|
329 |
+
help="""after AutoAugment, whether we further perform color augmentation. (Default: False).""",
|
330 |
+
)
|
331 |
+
parser.add_argument(
|
332 |
+
"--size_crops",
|
333 |
+
type=int,
|
334 |
+
default=[96],
|
335 |
+
nargs="+",
|
336 |
+
help="""the small crop size. Note we use multi-crop strategy, namely two 224-sized crops +
|
337 |
+
ten 96-sized crops. (Default: 96)""",
|
338 |
+
)
|
339 |
+
parser.add_argument(
|
340 |
+
"--strong_ratio",
|
341 |
+
type=float,
|
342 |
+
default=0.45,
|
343 |
+
help="""the ratio of image augmentation for the AutoAugment used in DeiT.""",
|
344 |
+
)
|
345 |
+
parser.add_argument(
|
346 |
+
"--re_prob",
|
347 |
+
type=float,
|
348 |
+
default=0.25,
|
349 |
+
help="""the re-prob parameter of image augmentation for the AutoAugment used in DeiT.""",
|
350 |
+
)
|
351 |
+
parser.add_argument(
|
352 |
+
"--vanilla_weak_augmentation",
|
353 |
+
type=utils.bool_flag,
|
354 |
+
default=False,
|
355 |
+
help="""Whether we use the same augmentation in DINO, namely only using weak augmentation.""",
|
356 |
+
)
|
357 |
+
parser.add_argument(
|
358 |
+
"--prob",
|
359 |
+
type=float,
|
360 |
+
default=0.5,
|
361 |
+
help="""When we use strong augmentation and weak augmentation, the ratio of images to
|
362 |
+
be cropped with strong augmentation.""",
|
363 |
+
)
|
364 |
+
|
365 |
+
##======== Misc ============
|
366 |
+
parser.add_argument(
|
367 |
+
"--data_path",
|
368 |
+
default="/dataset/imageNet100_sicy/train/",
|
369 |
+
type=str,
|
370 |
+
help="""Please specify path to the ImageNet training data.""",
|
371 |
+
)
|
372 |
+
parser.add_argument(
|
373 |
+
"--output_dir",
|
374 |
+
default="./exp/",
|
375 |
+
type=str,
|
376 |
+
help="""Path to save logs and checkpoints.""",
|
377 |
+
)
|
378 |
+
parser.add_argument(
|
379 |
+
"--saveckp_freq",
|
380 |
+
default=50,
|
381 |
+
type=int,
|
382 |
+
help="""Save checkpoint every x epochs.""",
|
383 |
+
)
|
384 |
+
parser.add_argument("--seed", default=0, type=int, help="""Random seed.""")
|
385 |
+
parser.add_argument(
|
386 |
+
"--num_workers",
|
387 |
+
default=12,
|
388 |
+
type=int,
|
389 |
+
help="""Number of data loading workers per GPU.""",
|
390 |
+
)
|
391 |
+
parser.add_argument(
|
392 |
+
"--dist_url",
|
393 |
+
default="env://",
|
394 |
+
type=str,
|
395 |
+
help="""url used to set up
|
396 |
+
distributed training; see https://pytorch.org/docs/stable/distributed.html""",
|
397 |
+
)
|
398 |
+
parser.add_argument(
|
399 |
+
"--local_rank",
|
400 |
+
default=0,
|
401 |
+
type=int,
|
402 |
+
help="""local rank for distrbuted training.""",
|
403 |
+
)
|
404 |
+
parser.add_argument(
|
405 |
+
"--rank", default=0, type=int, help="""rank for distrbuted training."""
|
406 |
+
)
|
407 |
+
parser.add_argument(
|
408 |
+
"--world_size",
|
409 |
+
default=1,
|
410 |
+
type=int,
|
411 |
+
help="""world size for distrbuted training.""",
|
412 |
+
)
|
413 |
+
|
414 |
+
parser.add_argument(
|
415 |
+
"--use_prefetcher",
|
416 |
+
type=utils.bool_flag,
|
417 |
+
default=True,
|
418 |
+
help="""whether we use prefetcher which can accerelate the training speed.""",
|
419 |
+
)
|
420 |
+
parser.add_argument(
|
421 |
+
"--debug",
|
422 |
+
type=utils.bool_flag,
|
423 |
+
default=False,
|
424 |
+
help="""whether we debug. if yes, we only load small fraction of training data to reduce data reading time.""",
|
425 |
+
)
|
426 |
+
parser.add_argument(
|
427 |
+
"--ddpjob",
|
428 |
+
default=False,
|
429 |
+
type=utils.bool_flag,
|
430 |
+
help="""whether we use ddp job. We suggest to use it for distributed training. For single GPUs
|
431 |
+
or Node, you can close it.""",
|
432 |
+
)
|
433 |
+
|
434 |
+
return parser
|
435 |
+
|
436 |
+
|
437 |
+
def train_mugs(args):
|
438 |
+
"""
|
439 |
+
main training code for Mugs, including building dataloader, models, losses, optimizers, etc
|
440 |
+
"""
|
441 |
+
##======== prepare logger for more detailed logs ============
|
442 |
+
logger = utils.get_logger(args.output_dir + "/train.log")
|
443 |
+
logger.info(args)
|
444 |
+
if args.output_dir and utils.is_main_process():
|
445 |
+
with (Path(args.output_dir) / "log.txt").open("a") as f:
|
446 |
+
f.write(str(args) + "\n")
|
447 |
+
|
448 |
+
##======== initilize distribution ============
|
449 |
+
if args.ddpjob is True:
|
450 |
+
utils.init_distributed_ddpjob(args)
|
451 |
+
else:
|
452 |
+
utils.init_distributed_mode(args)
|
453 |
+
|
454 |
+
##======== fix seed for reproduce ============
|
455 |
+
utils.fix_random_seeds(args.seed)
|
456 |
+
print("git:\n {}\n".format(utils.get_sha()))
|
457 |
+
print(
|
458 |
+
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))
|
459 |
+
)
|
460 |
+
cudnn.benchmark = True
|
461 |
+
cudnn.deterministic = True
|
462 |
+
|
463 |
+
##======== get the training dataset/loader ============
|
464 |
+
data_loader = get_dataset(args)
|
465 |
+
logger.info(f"Data loaded: there are {len(data_loader.dataset)} images.")
|
466 |
+
|
467 |
+
##====== build student and teacher networks (vit_small, vit_base, vit_large) =========
|
468 |
+
student, teacher, student_mem, teacher_mem = get_model(args)
|
469 |
+
|
470 |
+
# move networks to gpu
|
471 |
+
student, teacher = student.cuda(), teacher.cuda()
|
472 |
+
student_mem, teacher_mem = student_mem.cuda(), teacher_mem.cuda()
|
473 |
+
|
474 |
+
# synchronize batch norms (if any)
|
475 |
+
if utils.has_batchnorms(student):
|
476 |
+
student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
|
477 |
+
teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
|
478 |
+
# we need DDP wrapper to have synchro batch norms working...
|
479 |
+
teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
|
480 |
+
teacher_without_ddp = teacher.module
|
481 |
+
else:
|
482 |
+
# teacher_without_ddp and teacher are the same thing
|
483 |
+
teacher_without_ddp = teacher
|
484 |
+
student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
|
485 |
+
# teacher and student start with the same weights
|
486 |
+
teacher_without_ddp.load_state_dict(student.module.state_dict(), strict=False)
|
487 |
+
|
488 |
+
# there is no backpropagation through the teacher, so no need for gradients
|
489 |
+
for p in teacher.parameters():
|
490 |
+
p.requires_grad = False
|
491 |
+
print(f"Student and Teacher are built: they are both {args.arch} network.")
|
492 |
+
|
493 |
+
##======== get multi granular losses and their loss weights ============
|
494 |
+
all_losses, all_weights = get_multi_granular_loss(args)
|
495 |
+
|
496 |
+
##======== preparing optimizer ============
|
497 |
+
optimizer, fp16_scaler, lr_schedule, wd_schedule, momentum_schedule = get_optimizer(
|
498 |
+
student, len(data_loader), args
|
499 |
+
)
|
500 |
+
|
501 |
+
##======== optionally resume training ============
|
502 |
+
to_restore = {"epoch": 0}
|
503 |
+
utils.restart_from_checkpoint(
|
504 |
+
os.path.join(args.output_dir, "checkpoint.pth"),
|
505 |
+
run_variables=to_restore,
|
506 |
+
student=student,
|
507 |
+
teacher=teacher,
|
508 |
+
optimizer=optimizer,
|
509 |
+
fp16_scaler=fp16_scaler,
|
510 |
+
student_mem=student_mem,
|
511 |
+
teacher_mem=teacher_mem,
|
512 |
+
**all_losses,
|
513 |
+
)
|
514 |
+
start_epoch = to_restore["epoch"]
|
515 |
+
|
516 |
+
##======== Starting Mugs training ============
|
517 |
+
logger.info("Starting Mugs training !")
|
518 |
+
start_time = time.time()
|
519 |
+
for epoch in range(start_epoch, args.epochs):
|
520 |
+
t1 = time.time()
|
521 |
+
data_loader.sampler.set_epoch(epoch)
|
522 |
+
|
523 |
+
##======== training one epoch of Mugs ============
|
524 |
+
train_stats = train_one_epoch(
|
525 |
+
student,
|
526 |
+
teacher,
|
527 |
+
teacher_without_ddp,
|
528 |
+
all_losses,
|
529 |
+
all_weights,
|
530 |
+
data_loader,
|
531 |
+
optimizer,
|
532 |
+
lr_schedule,
|
533 |
+
wd_schedule,
|
534 |
+
momentum_schedule,
|
535 |
+
epoch,
|
536 |
+
fp16_scaler,
|
537 |
+
student_mem,
|
538 |
+
teacher_mem,
|
539 |
+
logger,
|
540 |
+
args,
|
541 |
+
)
|
542 |
+
|
543 |
+
##======== save model checkpoint ============
|
544 |
+
save_dict = {
|
545 |
+
"student": student.state_dict(),
|
546 |
+
"teacher": teacher.state_dict(),
|
547 |
+
"student_mem": student_mem.state_dict()
|
548 |
+
if student_mem is not None
|
549 |
+
else None,
|
550 |
+
"teacher_mem": teacher_mem.state_dict()
|
551 |
+
if teacher_mem is not None
|
552 |
+
else None,
|
553 |
+
"optimizer": optimizer.state_dict(),
|
554 |
+
"epoch": epoch + 1,
|
555 |
+
"args": args,
|
556 |
+
}
|
557 |
+
granular_loss_dicts = {}
|
558 |
+
for name, loss in all_losses.items():
|
559 |
+
granular_loss_dicts[name] = loss.state_dict()
|
560 |
+
save_dict.update(granular_loss_dicts)
|
561 |
+
|
562 |
+
if fp16_scaler is not None:
|
563 |
+
save_dict["fp16_scaler"] = fp16_scaler.state_dict()
|
564 |
+
|
565 |
+
utils.save_on_master(save_dict, os.path.join(args.output_dir, "checkpoint.pth"))
|
566 |
+
if args.saveckp_freq and epoch % args.saveckp_freq == 0:
|
567 |
+
utils.save_on_master(
|
568 |
+
save_dict, os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth")
|
569 |
+
)
|
570 |
+
|
571 |
+
##======== writing logs ============
|
572 |
+
log_stats = {**{f"{k}": v for k, v in train_stats.items()}, "epoch": epoch}
|
573 |
+
if utils.is_main_process():
|
574 |
+
with (Path(args.output_dir) / "log.txt").open("a") as f:
|
575 |
+
f.write(json.dumps(log_stats) + "\n")
|
576 |
+
|
577 |
+
t2 = time.time()
|
578 |
+
log_results = ""
|
579 |
+
for k, v in train_stats.items():
|
580 |
+
log_results += "%s: %.6f, " % (k, v)
|
581 |
+
logger.info(
|
582 |
+
"%d-epoch: %s remaining time %.2f hours"
|
583 |
+
% (epoch, log_results, (t2 - t1) * (args.epochs - epoch) / 3600.0)
|
584 |
+
)
|
585 |
+
|
586 |
+
total_time = time.time() - start_time
|
587 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
588 |
+
logger.info("Training time {}".format(total_time_str))
|
589 |
+
|
590 |
+
|
591 |
+
def train_one_epoch(
|
592 |
+
student,
|
593 |
+
teacher,
|
594 |
+
teacher_without_ddp,
|
595 |
+
all_losses,
|
596 |
+
all_weights,
|
597 |
+
data_loader,
|
598 |
+
optimizer,
|
599 |
+
lr_schedule,
|
600 |
+
wd_schedule,
|
601 |
+
momentum_schedule,
|
602 |
+
epoch,
|
603 |
+
fp16_scaler,
|
604 |
+
student_mem,
|
605 |
+
teacher_mem,
|
606 |
+
logger,
|
607 |
+
args,
|
608 |
+
):
|
609 |
+
"""
|
610 |
+
main training code for each epoch
|
611 |
+
"""
|
612 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
613 |
+
prefetcher = data_prefetcher(data_loader, fp16=(fp16_scaler is not None))
|
614 |
+
images, weak_aug_flags = prefetcher.next()
|
615 |
+
epoch_it = 0
|
616 |
+
while images is not None:
|
617 |
+
# Step 1. update weight decay and learning rate according to their schedule
|
618 |
+
it = len(data_loader) * epoch + epoch_it # global training iteration
|
619 |
+
for _, param_group in enumerate(optimizer.param_groups):
|
620 |
+
lr_mult = 1.0
|
621 |
+
if "patch_embed" in param_group["name"]:
|
622 |
+
lr_mult = args.patch_embed_lr_mult
|
623 |
+
param_group["lr"] = lr_schedule[it] * lr_mult
|
624 |
+
if param_group.get("apply_wd", True): # only the first group is regularized
|
625 |
+
param_group["weight_decay"] = wd_schedule[it]
|
626 |
+
|
627 |
+
granular_losses = OrderedDict()
|
628 |
+
total_loss = 0
|
629 |
+
with torch.cuda.amp.autocast(fp16_scaler is not None):
|
630 |
+
## Step 2. forward images into teacher and student to obtain the
|
631 |
+
# features/superivisons for the three granular superivison losses
|
632 |
+
(
|
633 |
+
teacher_instance_target,
|
634 |
+
teacher_local_group_target,
|
635 |
+
teacher_group_target,
|
636 |
+
teacher_memory_tokens,
|
637 |
+
) = teacher(
|
638 |
+
images[:2],
|
639 |
+
return_target=True,
|
640 |
+
local_group_memory_inputs={"mem": teacher_mem},
|
641 |
+
)
|
642 |
+
|
643 |
+
(
|
644 |
+
student_instance_target,
|
645 |
+
student_local_group_target,
|
646 |
+
student_group_target,
|
647 |
+
student_memory_tokens,
|
648 |
+
) = student(
|
649 |
+
images[2:],
|
650 |
+
return_target=False,
|
651 |
+
local_group_memory_inputs={"mem": student_mem},
|
652 |
+
)
|
653 |
+
|
654 |
+
## Step 3. compute the three granular supervision losses, including instance,
|
655 |
+
# local-group, group supervision losses
|
656 |
+
weigts_sum, total_loss, granular_losses = 0.0, 0.0, OrderedDict()
|
657 |
+
# instance loss
|
658 |
+
loss_cls, loss_weight = (
|
659 |
+
all_losses["instance-sup."],
|
660 |
+
all_weights["instance-sup."],
|
661 |
+
)
|
662 |
+
if loss_weight > 0:
|
663 |
+
instance_loss = loss_cls(
|
664 |
+
student_instance_target, teacher_instance_target, epoch
|
665 |
+
)
|
666 |
+
weigts_sum, total_loss = (
|
667 |
+
weigts_sum + loss_weight,
|
668 |
+
total_loss + instance_loss,
|
669 |
+
)
|
670 |
+
granular_losses["instance-sup."] = instance_loss.item()
|
671 |
+
|
672 |
+
# local group loss
|
673 |
+
loss_cls, loss_weight = (
|
674 |
+
all_losses["local-group-sup."],
|
675 |
+
all_weights["local-group-sup."],
|
676 |
+
)
|
677 |
+
if loss_weight > 0:
|
678 |
+
local_group_loss = loss_cls(
|
679 |
+
student_local_group_target, teacher_local_group_target, epoch
|
680 |
+
)
|
681 |
+
weigts_sum, total_loss = (
|
682 |
+
weigts_sum + loss_weight,
|
683 |
+
total_loss + local_group_loss,
|
684 |
+
)
|
685 |
+
granular_losses["local-group-sup."] = local_group_loss.item()
|
686 |
+
|
687 |
+
# group loss
|
688 |
+
loss_cls, loss_weight = all_losses["group-sup."], all_weights["group-sup."]
|
689 |
+
if loss_weight > 0:
|
690 |
+
group_loss = loss_cls(student_group_target, teacher_group_target, epoch)
|
691 |
+
weigts_sum, total_loss = (
|
692 |
+
weigts_sum + loss_weight,
|
693 |
+
total_loss + group_loss,
|
694 |
+
)
|
695 |
+
granular_losses["group-sup."] = group_loss.item()
|
696 |
+
|
697 |
+
# average loss
|
698 |
+
total_loss /= weigts_sum
|
699 |
+
|
700 |
+
## ## Step 4. update the memory buffer for local-group supervision losses.
|
701 |
+
# for student, we only update memory by the image of size 224 and weak augmentations
|
702 |
+
student_features = (student_memory_tokens.chunk(2))[0]
|
703 |
+
len_weak = student_mem._dequeue_and_enqueue(
|
704 |
+
student_features,
|
705 |
+
weak_aug_flags,
|
706 |
+
)
|
707 |
+
|
708 |
+
teacher_weak = (teacher_memory_tokens.chunk(2))[0]
|
709 |
+
_ = teacher_mem._dequeue_and_enqueue(teacher_weak, None)
|
710 |
+
|
711 |
+
if not math.isfinite(total_loss.item()):
|
712 |
+
print("Loss is {}, stopping training".format(total_loss.item()), force=True)
|
713 |
+
sys.exit(1)
|
714 |
+
|
715 |
+
## Step 5. student and teacher update
|
716 |
+
# student update
|
717 |
+
optimizer.zero_grad()
|
718 |
+
if fp16_scaler is None:
|
719 |
+
total_loss.backward()
|
720 |
+
if args.clip_grad:
|
721 |
+
clip_grad = args.clip_grad
|
722 |
+
if epoch > 100 and args.arch == "vit_large":
|
723 |
+
clip_grad = args.clip_grad / 10.0
|
724 |
+
_ = clip_gradients(student, clip_grad)
|
725 |
+
cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
|
726 |
+
optimizer.step()
|
727 |
+
else:
|
728 |
+
fp16_scaler.scale(total_loss).backward()
|
729 |
+
if args.clip_grad:
|
730 |
+
clip_grad = args.clip_grad
|
731 |
+
if epoch > 100 and args.arch == "vit_large":
|
732 |
+
clip_grad = args.clip_grad /10.0
|
733 |
+
fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
734 |
+
_ = clip_gradients(student, clip_grad)
|
735 |
+
cancel_gradients_last_layer(epoch, student, args.freeze_last_layer)
|
736 |
+
fp16_scaler.step(optimizer)
|
737 |
+
fp16_scaler.update()
|
738 |
+
|
739 |
+
# EMA update for the teacher
|
740 |
+
with torch.no_grad():
|
741 |
+
m = momentum_schedule[it] # momentum parameter
|
742 |
+
for param_q, param_k in zip(
|
743 |
+
student.module.backbone.parameters(),
|
744 |
+
teacher_without_ddp.backbone.parameters(),
|
745 |
+
):
|
746 |
+
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
|
747 |
+
|
748 |
+
if teacher_without_ddp.instance_head is not None:
|
749 |
+
for param_q, param_k in zip(
|
750 |
+
student.module.instance_head.parameters(),
|
751 |
+
teacher_without_ddp.instance_head.parameters(),
|
752 |
+
):
|
753 |
+
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
|
754 |
+
|
755 |
+
if teacher_without_ddp.local_group_head is not None:
|
756 |
+
for param_q, param_k in zip(
|
757 |
+
student.module.local_group_head.parameters(),
|
758 |
+
teacher_without_ddp.local_group_head.parameters(),
|
759 |
+
):
|
760 |
+
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
|
761 |
+
|
762 |
+
if teacher_without_ddp.group_head is not None:
|
763 |
+
for param_q, param_k in zip(
|
764 |
+
student.module.group_head.parameters(),
|
765 |
+
teacher_without_ddp.group_head.parameters(),
|
766 |
+
):
|
767 |
+
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
|
768 |
+
|
769 |
+
## Step 6. load images
|
770 |
+
images, weak_aug_flags = prefetcher.next()
|
771 |
+
epoch_it += 1
|
772 |
+
|
773 |
+
## Step 7. logging
|
774 |
+
torch.cuda.synchronize()
|
775 |
+
metric_logger.update(loss=total_loss.item())
|
776 |
+
for loss_name, loss_value in granular_losses.items():
|
777 |
+
metric_logger.update(**{loss_name: loss_value})
|
778 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
779 |
+
metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
|
780 |
+
|
781 |
+
if epoch_it % 500 == 0 and args.rank == 0: # and epoch_it < 10:
|
782 |
+
log_results = ""
|
783 |
+
for _, loss_name in enumerate(all_losses):
|
784 |
+
if all_weights[loss_name] > 0:
|
785 |
+
log_results += "%s: %.6f," % (
|
786 |
+
loss_name,
|
787 |
+
metric_logger.meters[loss_name].global_avg,
|
788 |
+
)
|
789 |
+
logger.info(
|
790 |
+
"%d-epoch (%d/%d): total loss %.6f, %s, lr %.4e, wd %.4e, weak aug. ratio %.1f"
|
791 |
+
% (
|
792 |
+
epoch,
|
793 |
+
it,
|
794 |
+
len(data_loader),
|
795 |
+
metric_logger.meters["loss"].global_avg,
|
796 |
+
log_results,
|
797 |
+
optimizer.param_groups[0]["lr"],
|
798 |
+
optimizer.param_groups[0]["weight_decay"],
|
799 |
+
len_weak / len(weak_aug_flags) / args.world_size,
|
800 |
+
)
|
801 |
+
)
|
802 |
+
|
803 |
+
# gather the stats from all processes
|
804 |
+
metric_logger.synchronize_between_processes()
|
805 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
806 |
+
|
807 |
+
|
808 |
+
if __name__ == "__main__":
|
809 |
+
parser = argparse.ArgumentParser("Mugs", parents=[get_args_parser()])
|
810 |
+
args = parser.parse_args()
|
811 |
+
if not os.path.exists(args.output_dir):
|
812 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
813 |
+
|
814 |
+
train_mugs(args)
|
pretraining.sh
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
DATASET_ROOT=/dataset/imageNet100_sicy/train/ #/raid/common/imagenet-raw/
|
3 |
+
|
4 |
+
## train ViT-small for 100 epochs
|
5 |
+
OUTPUT_ROOT=./exps/vit_small_100ep
|
6 |
+
NPROC_PER_NODE=8 # GPU numbers
|
7 |
+
BATCH_SIZE_PER_GPU=64
|
8 |
+
DEBUG=false # debug = true, then we only load subset of the whole training dataset
|
9 |
+
python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
|
10 |
+
--data_path $DATASET_ROOT \
|
11 |
+
--output_dir $OUTPUT_ROOT \
|
12 |
+
--arch vit_small \
|
13 |
+
--instance_queue_size 65536 \
|
14 |
+
--local_group_queue_size 65536 \
|
15 |
+
--use_bn_in_head false \
|
16 |
+
--instance_out_dim 256 \
|
17 |
+
--instance_temp 0.2 \
|
18 |
+
--local_group_out_dim 256 \
|
19 |
+
--local_group_temp 0.2 \
|
20 |
+
--local_group_knn_top_n 8 \
|
21 |
+
--group_out_dim 65536 \
|
22 |
+
--group_student_temp 0.1 \
|
23 |
+
--group_warmup_teacher_temp 0.04 \
|
24 |
+
--group_teacher_temp 0.04 \
|
25 |
+
--group_warmup_teacher_temp_epochs 0 \
|
26 |
+
--norm_last_layer false \
|
27 |
+
--norm_before_pred true \
|
28 |
+
--batch_size_per_gpu $BATCH_SIZE_PER_GPU \
|
29 |
+
--epochs 100 \
|
30 |
+
--warmup_epochs 10 \
|
31 |
+
--clip_grad 3.0 \
|
32 |
+
--lr 0.0008 \
|
33 |
+
--min_lr 1e-06 \
|
34 |
+
--patch_embed_lr_mult 0.2 \
|
35 |
+
--drop_path_rate 0.1 \
|
36 |
+
--weight_decay 0.04 \
|
37 |
+
--weight_decay_end 0.2 \
|
38 |
+
--freeze_last_layer 1 \
|
39 |
+
--momentum_teacher 0.996 \
|
40 |
+
--use_fp16 false \
|
41 |
+
--local_crops_number 10 \
|
42 |
+
--size_crops 96 \
|
43 |
+
--global_crops_scale 0.25 1 \
|
44 |
+
--local_crops_scale 0.05 0.25 \
|
45 |
+
--timm_auto_augment_par rand-m9-mstd0.5-inc1 \
|
46 |
+
--prob 0.5 \
|
47 |
+
--use_prefetcher true \
|
48 |
+
--debug $DEBUG
|
49 |
+
|
50 |
+
## train ViT-small for 300 epochs
|
51 |
+
OUTPUT_ROOT=./exps/vit_small_300ep
|
52 |
+
NPROC_PER_NODE=16 # GPU numbers
|
53 |
+
BATCH_SIZE_PER_GPU=64
|
54 |
+
DEBUG=false # debug = true, then we only load subset of the whole training dataset
|
55 |
+
python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
|
56 |
+
--data_path $DATASET_ROOT \
|
57 |
+
--output_dir $OUTPUT_ROOT \
|
58 |
+
--arch vit_small \
|
59 |
+
--instance_queue_size 65536 \
|
60 |
+
--local_group_queue_size 65536 \
|
61 |
+
--use_bn_in_head false \
|
62 |
+
--instance_out_dim 256 \
|
63 |
+
--instance_temp 0.2 \
|
64 |
+
--local_group_out_dim 256 \
|
65 |
+
--local_group_temp 0.2 \
|
66 |
+
--local_group_knn_top_n 8 \
|
67 |
+
--group_out_dim 65536 \
|
68 |
+
--group_student_temp 0.1 \
|
69 |
+
--group_warmup_teacher_temp 0.04 \
|
70 |
+
--group_teacher_temp 0.07 \
|
71 |
+
--group_warmup_teacher_temp_epochs 30 \
|
72 |
+
--norm_last_layer false \
|
73 |
+
--norm_before_pred true \
|
74 |
+
--batch_size_per_gpu $BATCH_SIZE_PER_GPU \
|
75 |
+
--epochs 300 \
|
76 |
+
--warmup_epochs 10 \
|
77 |
+
--clip_grad 3.0 \
|
78 |
+
--lr 0.0008 \
|
79 |
+
--min_lr 1e-06 \
|
80 |
+
--patch_embed_lr_mult 0.2 \
|
81 |
+
--drop_path_rate 0.1 \
|
82 |
+
--weight_decay 0.04 \
|
83 |
+
--weight_decay_end 0.1 \
|
84 |
+
--freeze_last_layer 1 \
|
85 |
+
--momentum_teacher 0.996 \
|
86 |
+
--use_fp16 false \
|
87 |
+
--local_crops_number 10 \
|
88 |
+
--size_crops 96 \
|
89 |
+
--global_crops_scale 0.25 1 \
|
90 |
+
--local_crops_scale 0.05 0.25 \
|
91 |
+
--timm_auto_augment_par rand-m9-mstd0.5-inc1 \
|
92 |
+
--prob 0.5 \
|
93 |
+
--use_prefetcher true \
|
94 |
+
--debug $DEBUG
|
95 |
+
|
96 |
+
## train ViT-small for 800 epochs
|
97 |
+
NPROC_PER_NODE=16 # GPU numbers
|
98 |
+
BATCH_SIZE_PER_GPU=64
|
99 |
+
DEBUG=false # debug = true, then we only load subset of the whole training dataset
|
100 |
+
python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
|
101 |
+
--data_path $DATASET_ROOT \
|
102 |
+
--output_dir $OUTPUT_ROOT \
|
103 |
+
--arch vit_small \
|
104 |
+
--instance_queue_size 65536 \
|
105 |
+
--local_group_queue_size 65536 \
|
106 |
+
--use_bn_in_head false \
|
107 |
+
--instance_out_dim 256 \
|
108 |
+
--instance_temp 0.2 \
|
109 |
+
--local_group_out_dim 256 \
|
110 |
+
--local_group_temp 0.2 \
|
111 |
+
--local_group_knn_top_n 8 \
|
112 |
+
--group_out_dim 65536 \
|
113 |
+
--group_student_temp 0.1 \
|
114 |
+
--group_warmup_teacher_temp 0.04 \
|
115 |
+
--group_teacher_temp 0.07 \
|
116 |
+
--group_warmup_teacher_temp_epochs 30 \
|
117 |
+
--norm_last_layer false \
|
118 |
+
--norm_before_pred true \
|
119 |
+
--batch_size_per_gpu $BATCH_SIZE_PER_GPU \
|
120 |
+
--epochs 800 \
|
121 |
+
--warmup_epochs 10 \
|
122 |
+
--clip_grad 3.0 \
|
123 |
+
--lr 0.0008 \
|
124 |
+
--min_lr 1e-06 \
|
125 |
+
--patch_embed_lr_mult 0.2 \
|
126 |
+
--drop_path_rate 0.1 \
|
127 |
+
--weight_decay 0.04 \
|
128 |
+
--weight_decay_end 0.1 \
|
129 |
+
--freeze_last_layer 1 \
|
130 |
+
--momentum_teacher 0.996 \
|
131 |
+
--use_fp16 false \
|
132 |
+
--local_crops_number 10 \
|
133 |
+
--size_crops 96 \
|
134 |
+
--global_crops_scale 0.25 1 \
|
135 |
+
--local_crops_scale 0.05 0.25 \
|
136 |
+
--timm_auto_augment_par rand-m9-mstd0.5-inc1 \
|
137 |
+
--prob 0.5 \
|
138 |
+
--use_prefetcher true \
|
139 |
+
--debug $DEBUG
|
140 |
+
|
141 |
+
## train ViT-base for 400 epochs
|
142 |
+
OUTPUT_ROOT=./exps/vit_base_400ep
|
143 |
+
NPROC_PER_NODE=24 # GPU numbers
|
144 |
+
BATCH_SIZE_PER_GPU=42
|
145 |
+
DEBUG=false # debug = true, then we only load subset of the whole training dataset
|
146 |
+
python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
|
147 |
+
--data_path $DATASET_ROOT \
|
148 |
+
--output_dir $OUTPUT_ROOT \
|
149 |
+
--arch vit_base \
|
150 |
+
--instance_queue_size 65536 \
|
151 |
+
--local_group_queue_size 65536 \
|
152 |
+
--use_bn_in_head false \
|
153 |
+
--instance_out_dim 256 \
|
154 |
+
--instance_temp 0.2 \
|
155 |
+
--local_group_out_dim 256 \
|
156 |
+
--local_group_temp 0.2 \
|
157 |
+
--local_group_knn_top_n 8 \
|
158 |
+
--group_out_dim 65536 \
|
159 |
+
--group_student_temp 0.1 \
|
160 |
+
--group_warmup_teacher_temp 0.04 \
|
161 |
+
--group_teacher_temp 0.07 \
|
162 |
+
--group_warmup_teacher_temp_epochs 50 \
|
163 |
+
--norm_last_layer false \
|
164 |
+
--norm_before_pred true \
|
165 |
+
--batch_size_per_gpu $BATCH_SIZE_PER_GPU \
|
166 |
+
--epochs 400 \
|
167 |
+
--warmup_epochs 10 \
|
168 |
+
--clip_grad 3.0 \
|
169 |
+
--lr 0.0008 \
|
170 |
+
--min_lr 2e-06 \
|
171 |
+
--patch_embed_lr_mult 0.2 \
|
172 |
+
--drop_path_rate 0.1 \
|
173 |
+
--weight_decay 0.04 \
|
174 |
+
--weight_decay_end 0.1 \
|
175 |
+
--freeze_last_layer 3 \
|
176 |
+
--momentum_teacher 0.996 \
|
177 |
+
--use_fp16 false \
|
178 |
+
--local_crops_number 10 \
|
179 |
+
--size_crops 96 \
|
180 |
+
--global_crops_scale 0.25 1 \
|
181 |
+
--local_crops_scale 0.05 0.25 \
|
182 |
+
--timm_auto_augment_par rand-m9-mstd0.5-inc1 \
|
183 |
+
--prob 0.5 \
|
184 |
+
--use_prefetcher true \
|
185 |
+
--debug $DEBUG
|
186 |
+
|
187 |
+
## train ViT-large for 250 epochs
|
188 |
+
OUTPUT_ROOT=./exps/vit_large_250ep
|
189 |
+
NPROC_PER_NODE=40 # GPU numbers
|
190 |
+
BATCH_SIZE_PER_GPU=16
|
191 |
+
DEBUG=false # debug = true, then we only load subset of the whole training dataset
|
192 |
+
python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE main.py \
|
193 |
+
--data_path $DATASET_ROOT \
|
194 |
+
--output_dir $OUTPUT_ROOT \
|
195 |
+
--arch vit_large \
|
196 |
+
--instance_queue_size 65536 \
|
197 |
+
--local_group_queue_size 65536 \
|
198 |
+
--use_bn_in_head false \
|
199 |
+
--instance_out_dim 256 \
|
200 |
+
--instance_temp 0.2 \
|
201 |
+
--local_group_out_dim 256 \
|
202 |
+
--local_group_temp 0.2 \
|
203 |
+
--local_group_knn_top_n 8 \
|
204 |
+
--group_out_dim 65536 \
|
205 |
+
--group_student_temp 0.1 \
|
206 |
+
--group_warmup_teacher_temp 0.04 \
|
207 |
+
--group_teacher_temp 0.07 \
|
208 |
+
--group_warmup_teacher_temp_epochs 50 \
|
209 |
+
--norm_last_layer true \
|
210 |
+
--norm_before_pred true \
|
211 |
+
--batch_size_per_gpu $BATCH_SIZE_PER_GPU \
|
212 |
+
--epochs 250 \
|
213 |
+
--warmup_epochs 10 \
|
214 |
+
--clip_grad 3.0 \
|
215 |
+
--lr 0.0015 \
|
216 |
+
--min_lr 1.5e-4 \
|
217 |
+
--patch_embed_lr_mult 0.2 \
|
218 |
+
--drop_path_rate 0.3 \
|
219 |
+
--weight_decay 0.025 \
|
220 |
+
--weight_decay_end 0.08 \
|
221 |
+
--freeze_last_layer 3 \
|
222 |
+
--momentum_teacher 0.996 \
|
223 |
+
--use_fp16 false \
|
224 |
+
--local_crops_number 10 \
|
225 |
+
--size_crops 96 \
|
226 |
+
--global_crops_scale 0.25 1 \
|
227 |
+
--local_crops_scale 0.05 0.25 \
|
228 |
+
--timm_auto_augment_par rand-m9-mstd0.5-inc1 \
|
229 |
+
--prob 0.5 \
|
230 |
+
--use_prefetcher true \
|
231 |
+
--debug $DEBUG
|
src/RandAugment.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
implment AutoAugment, RandAugment
|
16 |
+
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py and modified for token labeling
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
import random
|
20 |
+
import re
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import PIL
|
24 |
+
from PIL import Image, ImageEnhance, ImageOps
|
25 |
+
|
26 |
+
_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
|
27 |
+
|
28 |
+
_FILL = (128, 128, 128)
|
29 |
+
|
30 |
+
_MAX_LEVEL = 10.0
|
31 |
+
|
32 |
+
_HPARAMS_DEFAULT = dict(
|
33 |
+
translate_const=250,
|
34 |
+
img_mean=_FILL,
|
35 |
+
)
|
36 |
+
|
37 |
+
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
38 |
+
|
39 |
+
|
40 |
+
def _interpolation(kwargs):
|
41 |
+
interpolation = kwargs.pop("resample", Image.BILINEAR)
|
42 |
+
if isinstance(interpolation, (list, tuple)):
|
43 |
+
return random.choice(interpolation)
|
44 |
+
else:
|
45 |
+
return interpolation
|
46 |
+
|
47 |
+
|
48 |
+
def _check_args_tf(kwargs):
|
49 |
+
if "fillcolor" in kwargs and _PIL_VER < (5, 0):
|
50 |
+
kwargs.pop("fillcolor")
|
51 |
+
kwargs["resample"] = _interpolation(kwargs)
|
52 |
+
|
53 |
+
|
54 |
+
def shear_x(img, factor, **kwargs):
|
55 |
+
_check_args_tf(kwargs)
|
56 |
+
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
57 |
+
|
58 |
+
|
59 |
+
def shear_y(img, factor, **kwargs):
|
60 |
+
_check_args_tf(kwargs)
|
61 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
62 |
+
|
63 |
+
|
64 |
+
def translate_x_rel(img, pct, **kwargs):
|
65 |
+
pixels = pct * img.size[0]
|
66 |
+
_check_args_tf(kwargs)
|
67 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
68 |
+
|
69 |
+
|
70 |
+
def translate_y_rel(img, pct, **kwargs):
|
71 |
+
pixels = pct * img.size[1]
|
72 |
+
_check_args_tf(kwargs)
|
73 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
74 |
+
|
75 |
+
|
76 |
+
def translate_x_abs(img, pixels, **kwargs):
|
77 |
+
_check_args_tf(kwargs)
|
78 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
def translate_y_abs(img, pixels, **kwargs):
|
82 |
+
_check_args_tf(kwargs)
|
83 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
84 |
+
|
85 |
+
|
86 |
+
def rotate(img, degrees, **kwargs):
|
87 |
+
_check_args_tf(kwargs)
|
88 |
+
if _PIL_VER >= (5, 2):
|
89 |
+
return img.rotate(degrees, **kwargs)
|
90 |
+
elif _PIL_VER >= (5, 0):
|
91 |
+
w, h = img.size
|
92 |
+
post_trans = (0, 0)
|
93 |
+
rotn_center = (w / 2.0, h / 2.0)
|
94 |
+
angle = -math.radians(degrees)
|
95 |
+
matrix = [
|
96 |
+
round(math.cos(angle), 15),
|
97 |
+
round(math.sin(angle), 15),
|
98 |
+
0.0,
|
99 |
+
round(-math.sin(angle), 15),
|
100 |
+
round(math.cos(angle), 15),
|
101 |
+
0.0,
|
102 |
+
]
|
103 |
+
|
104 |
+
def transform(x, y, matrix):
|
105 |
+
(a, b, c, d, e, f) = matrix
|
106 |
+
return a * x + b * y + c, d * x + e * y + f
|
107 |
+
|
108 |
+
matrix[2], matrix[5] = transform(
|
109 |
+
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
|
110 |
+
)
|
111 |
+
matrix[2] += rotn_center[0]
|
112 |
+
matrix[5] += rotn_center[1]
|
113 |
+
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
114 |
+
else:
|
115 |
+
return img.rotate(degrees, resample=kwargs["resample"])
|
116 |
+
|
117 |
+
|
118 |
+
def auto_contrast(img, **__):
|
119 |
+
return ImageOps.autocontrast(img)
|
120 |
+
|
121 |
+
|
122 |
+
def invert(img, **__):
|
123 |
+
return ImageOps.invert(img)
|
124 |
+
|
125 |
+
|
126 |
+
def equalize(img, **__):
|
127 |
+
return ImageOps.equalize(img)
|
128 |
+
|
129 |
+
|
130 |
+
def solarize(img, thresh, **__):
|
131 |
+
return ImageOps.solarize(img, thresh)
|
132 |
+
|
133 |
+
|
134 |
+
def solarize_add(img, add, thresh=128, **__):
|
135 |
+
lut = []
|
136 |
+
for i in range(256):
|
137 |
+
if i < thresh:
|
138 |
+
lut.append(min(255, i + add))
|
139 |
+
else:
|
140 |
+
lut.append(i)
|
141 |
+
if img.mode in ("L", "RGB"):
|
142 |
+
if img.mode == "RGB" and len(lut) == 256:
|
143 |
+
lut = lut + lut + lut
|
144 |
+
return img.point(lut)
|
145 |
+
else:
|
146 |
+
return img
|
147 |
+
|
148 |
+
|
149 |
+
def posterize(img, bits_to_keep, **__):
|
150 |
+
if bits_to_keep >= 8:
|
151 |
+
return img
|
152 |
+
return ImageOps.posterize(img, bits_to_keep)
|
153 |
+
|
154 |
+
|
155 |
+
def contrast(img, factor, **__):
|
156 |
+
return ImageEnhance.Contrast(img).enhance(factor)
|
157 |
+
|
158 |
+
|
159 |
+
def color(img, factor, **__):
|
160 |
+
return ImageEnhance.Color(img).enhance(factor)
|
161 |
+
|
162 |
+
|
163 |
+
def brightness(img, factor, **__):
|
164 |
+
return ImageEnhance.Brightness(img).enhance(factor)
|
165 |
+
|
166 |
+
|
167 |
+
def sharpness(img, factor, **__):
|
168 |
+
return ImageEnhance.Sharpness(img).enhance(factor)
|
169 |
+
|
170 |
+
|
171 |
+
def _randomly_negate(v):
|
172 |
+
"""With 50% prob, negate the value"""
|
173 |
+
return -v if random.random() > 0.5 else v
|
174 |
+
|
175 |
+
|
176 |
+
def _rotate_level_to_arg(level, _hparams):
|
177 |
+
# range [-30, 30]
|
178 |
+
level = (level / _MAX_LEVEL) * 30.0
|
179 |
+
level = _randomly_negate(level)
|
180 |
+
return (level,)
|
181 |
+
|
182 |
+
|
183 |
+
def _enhance_level_to_arg(level, _hparams):
|
184 |
+
# range [0.1, 1.9]
|
185 |
+
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
186 |
+
|
187 |
+
|
188 |
+
def _enhance_increasing_level_to_arg(level, _hparams):
|
189 |
+
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
|
190 |
+
# range [0.1, 1.9]
|
191 |
+
level = (level / _MAX_LEVEL) * 0.9
|
192 |
+
level = 1.0 + _randomly_negate(level)
|
193 |
+
return (level,)
|
194 |
+
|
195 |
+
|
196 |
+
def _shear_level_to_arg(level, _hparams):
|
197 |
+
# range [-0.3, 0.3]
|
198 |
+
level = (level / _MAX_LEVEL) * 0.3
|
199 |
+
level = _randomly_negate(level)
|
200 |
+
return (level,)
|
201 |
+
|
202 |
+
|
203 |
+
def _translate_abs_level_to_arg(level, hparams):
|
204 |
+
translate_const = hparams["translate_const"]
|
205 |
+
level = (level / _MAX_LEVEL) * float(translate_const)
|
206 |
+
level = _randomly_negate(level)
|
207 |
+
return (level,)
|
208 |
+
|
209 |
+
|
210 |
+
def _translate_rel_level_to_arg(level, hparams):
|
211 |
+
# default range [-0.45, 0.45]
|
212 |
+
translate_pct = hparams.get("translate_pct", 0.45)
|
213 |
+
level = (level / _MAX_LEVEL) * translate_pct
|
214 |
+
level = _randomly_negate(level)
|
215 |
+
return (level,)
|
216 |
+
|
217 |
+
|
218 |
+
def _posterize_level_to_arg(level, _hparams):
|
219 |
+
# As per Tensorflow TPU EfficientNet impl
|
220 |
+
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
221 |
+
# intensity/severity of augmentation decreases with level
|
222 |
+
return (int((level / _MAX_LEVEL) * 4),)
|
223 |
+
|
224 |
+
|
225 |
+
def _posterize_increasing_level_to_arg(level, hparams):
|
226 |
+
# As per Tensorflow models research and UDA impl
|
227 |
+
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
228 |
+
# intensity/severity of augmentation increases with level
|
229 |
+
return (4 - _posterize_level_to_arg(level, hparams)[0],)
|
230 |
+
|
231 |
+
|
232 |
+
def _posterize_original_level_to_arg(level, _hparams):
|
233 |
+
# As per original AutoAugment paper description
|
234 |
+
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
235 |
+
# intensity/severity of augmentation decreases with level
|
236 |
+
return (int((level / _MAX_LEVEL) * 4) + 4,)
|
237 |
+
|
238 |
+
|
239 |
+
def _solarize_level_to_arg(level, _hparams):
|
240 |
+
# range [0, 256]
|
241 |
+
# intensity/severity of augmentation decreases with level
|
242 |
+
return (int((level / _MAX_LEVEL) * 256),)
|
243 |
+
|
244 |
+
|
245 |
+
def _solarize_increasing_level_to_arg(level, _hparams):
|
246 |
+
# range [0, 256]
|
247 |
+
# intensity/severity of augmentation increases with level
|
248 |
+
return (256 - _solarize_level_to_arg(level, _hparams)[0],)
|
249 |
+
|
250 |
+
|
251 |
+
def _solarize_add_level_to_arg(level, _hparams):
|
252 |
+
# range [0, 110]
|
253 |
+
return (int((level / _MAX_LEVEL) * 110),)
|
254 |
+
|
255 |
+
|
256 |
+
LEVEL_TO_ARG = {
|
257 |
+
"AutoContrast": None,
|
258 |
+
"Equalize": None,
|
259 |
+
"Invert": None,
|
260 |
+
"Rotate": _rotate_level_to_arg,
|
261 |
+
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
262 |
+
"Posterize": _posterize_level_to_arg,
|
263 |
+
"PosterizeIncreasing": _posterize_increasing_level_to_arg,
|
264 |
+
"PosterizeOriginal": _posterize_original_level_to_arg,
|
265 |
+
"Solarize": _solarize_level_to_arg,
|
266 |
+
"SolarizeIncreasing": _solarize_increasing_level_to_arg,
|
267 |
+
"SolarizeAdd": _solarize_add_level_to_arg,
|
268 |
+
"Color": _enhance_level_to_arg,
|
269 |
+
"ColorIncreasing": _enhance_increasing_level_to_arg,
|
270 |
+
"Contrast": _enhance_level_to_arg,
|
271 |
+
"ContrastIncreasing": _enhance_increasing_level_to_arg,
|
272 |
+
"Brightness": _enhance_level_to_arg,
|
273 |
+
"BrightnessIncreasing": _enhance_increasing_level_to_arg,
|
274 |
+
"Sharpness": _enhance_level_to_arg,
|
275 |
+
"SharpnessIncreasing": _enhance_increasing_level_to_arg,
|
276 |
+
"ShearX": _shear_level_to_arg,
|
277 |
+
"ShearY": _shear_level_to_arg,
|
278 |
+
"TranslateX": _translate_abs_level_to_arg,
|
279 |
+
"TranslateY": _translate_abs_level_to_arg,
|
280 |
+
"TranslateXRel": _translate_rel_level_to_arg,
|
281 |
+
"TranslateYRel": _translate_rel_level_to_arg,
|
282 |
+
}
|
283 |
+
|
284 |
+
|
285 |
+
NAME_TO_OP = {
|
286 |
+
"AutoContrast": auto_contrast,
|
287 |
+
"Equalize": equalize,
|
288 |
+
"Invert": invert,
|
289 |
+
"Rotate": rotate,
|
290 |
+
"Posterize": posterize,
|
291 |
+
"PosterizeIncreasing": posterize,
|
292 |
+
"PosterizeOriginal": posterize,
|
293 |
+
"Solarize": solarize,
|
294 |
+
"SolarizeIncreasing": solarize,
|
295 |
+
"SolarizeAdd": solarize_add,
|
296 |
+
"Color": color,
|
297 |
+
"ColorIncreasing": color,
|
298 |
+
"Contrast": contrast,
|
299 |
+
"ContrastIncreasing": contrast,
|
300 |
+
"Brightness": brightness,
|
301 |
+
"BrightnessIncreasing": brightness,
|
302 |
+
"Sharpness": sharpness,
|
303 |
+
"SharpnessIncreasing": sharpness,
|
304 |
+
"ShearX": shear_x,
|
305 |
+
"ShearY": shear_y,
|
306 |
+
"TranslateX": translate_x_abs,
|
307 |
+
"TranslateY": translate_y_abs,
|
308 |
+
"TranslateXRel": translate_x_rel,
|
309 |
+
"TranslateYRel": translate_y_rel,
|
310 |
+
}
|
311 |
+
|
312 |
+
_RAND_TRANSFORMS = [
|
313 |
+
"AutoContrast",
|
314 |
+
"Equalize",
|
315 |
+
"Invert",
|
316 |
+
"Rotate",
|
317 |
+
"Posterize",
|
318 |
+
"Solarize",
|
319 |
+
"SolarizeAdd",
|
320 |
+
"Color",
|
321 |
+
"Contrast",
|
322 |
+
"Brightness",
|
323 |
+
"Sharpness",
|
324 |
+
"ShearX",
|
325 |
+
"ShearY",
|
326 |
+
"TranslateXRel",
|
327 |
+
"TranslateYRel",
|
328 |
+
#'Cutout'
|
329 |
+
]
|
330 |
+
|
331 |
+
|
332 |
+
_RAND_INCREASING_TRANSFORMS = [
|
333 |
+
"AutoContrast",
|
334 |
+
"Equalize",
|
335 |
+
"Invert",
|
336 |
+
"Rotate",
|
337 |
+
"PosterizeIncreasing",
|
338 |
+
"SolarizeIncreasing",
|
339 |
+
"SolarizeAdd",
|
340 |
+
"ColorIncreasing",
|
341 |
+
"ContrastIncreasing",
|
342 |
+
"BrightnessIncreasing",
|
343 |
+
"SharpnessIncreasing",
|
344 |
+
"ShearX",
|
345 |
+
"ShearY",
|
346 |
+
"TranslateXRel",
|
347 |
+
"TranslateYRel",
|
348 |
+
#'Cutout'
|
349 |
+
]
|
350 |
+
|
351 |
+
|
352 |
+
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
353 |
+
# They may not result in increased performance, but could likely be tuned to so.
|
354 |
+
_RAND_CHOICE_WEIGHTS_0 = {
|
355 |
+
"Rotate": 0.3,
|
356 |
+
"ShearX": 0.2,
|
357 |
+
"ShearY": 0.2,
|
358 |
+
"TranslateXRel": 0.1,
|
359 |
+
"TranslateYRel": 0.1,
|
360 |
+
"Color": 0.025,
|
361 |
+
"Sharpness": 0.025,
|
362 |
+
"AutoContrast": 0.025,
|
363 |
+
"Solarize": 0.005,
|
364 |
+
"SolarizeAdd": 0.005,
|
365 |
+
"Contrast": 0.005,
|
366 |
+
"Brightness": 0.005,
|
367 |
+
"Equalize": 0.005,
|
368 |
+
"Posterize": 0,
|
369 |
+
"Invert": 0,
|
370 |
+
}
|
371 |
+
|
372 |
+
|
373 |
+
def _select_rand_weights(weight_idx=0, transforms=None):
|
374 |
+
transforms = transforms or _RAND_TRANSFORMS
|
375 |
+
assert weight_idx == 0 # only one set of weights currently
|
376 |
+
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
377 |
+
probs = [rand_weights[k] for k in transforms]
|
378 |
+
probs /= np.sum(probs)
|
379 |
+
return probs
|
380 |
+
|
381 |
+
|
382 |
+
class AugmentOp:
|
383 |
+
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
384 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
385 |
+
self.name = name
|
386 |
+
self.aug_fn = NAME_TO_OP[name]
|
387 |
+
self.level_fn = LEVEL_TO_ARG[name]
|
388 |
+
self.prob = prob
|
389 |
+
self.magnitude = magnitude
|
390 |
+
self.hparams = hparams.copy()
|
391 |
+
self.kwargs = dict(
|
392 |
+
fillcolor=hparams["img_mean"] if "img_mean" in hparams else _FILL,
|
393 |
+
resample=hparams["interpolation"]
|
394 |
+
if "interpolation" in hparams
|
395 |
+
else _RANDOM_INTERPOLATION,
|
396 |
+
)
|
397 |
+
|
398 |
+
# If magnitude_std is > 0, we introduce some randomness
|
399 |
+
# in the usually fixed policy and sample magnitude from a normal distribution
|
400 |
+
# with mean `magnitude` and std-dev of `magnitude_std`.
|
401 |
+
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
402 |
+
self.magnitude_std = self.hparams.get("magnitude_std", 0)
|
403 |
+
|
404 |
+
def __call__(self, img):
|
405 |
+
if self.prob < 1.0 and random.random() > self.prob:
|
406 |
+
return img
|
407 |
+
magnitude = self.magnitude
|
408 |
+
if self.magnitude_std and self.magnitude_std > 0:
|
409 |
+
magnitude = random.gauss(magnitude, self.magnitude_std)
|
410 |
+
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
411 |
+
level_args = (
|
412 |
+
self.level_fn(magnitude, self.hparams)
|
413 |
+
if self.level_fn is not None
|
414 |
+
else tuple()
|
415 |
+
)
|
416 |
+
imgs = self.aug_fn(img, *level_args, **self.kwargs)
|
417 |
+
|
418 |
+
return imgs
|
419 |
+
|
420 |
+
|
421 |
+
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
422 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
423 |
+
transforms = transforms or _RAND_TRANSFORMS
|
424 |
+
return [
|
425 |
+
AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
|
426 |
+
for name in transforms
|
427 |
+
]
|
428 |
+
|
429 |
+
|
430 |
+
class RandAugment:
|
431 |
+
"""
|
432 |
+
Apply RandAug on image
|
433 |
+
"""
|
434 |
+
|
435 |
+
def __init__(self, ops, num_layers=2, choice_weights=None):
|
436 |
+
self.ops = ops
|
437 |
+
self.num_layers = num_layers
|
438 |
+
self.choice_weights = choice_weights
|
439 |
+
|
440 |
+
def __call__(self, img):
|
441 |
+
# no replacement when using weighted choice
|
442 |
+
ops = np.random.choice(
|
443 |
+
self.ops, self.num_layers, replace=False, p=self.choice_weights
|
444 |
+
)
|
445 |
+
for op in ops:
|
446 |
+
img = op(img)
|
447 |
+
|
448 |
+
return img
|
449 |
+
|
450 |
+
|
451 |
+
def rand_augment_transform(config_str, hparams):
|
452 |
+
"""
|
453 |
+
Create a RandAugment transform
|
454 |
+
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
455 |
+
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
456 |
+
sections, not order sepecific determine
|
457 |
+
'm' - integer magnitude of rand augment
|
458 |
+
'n' - integer num layers (number of transform ops selected per image)
|
459 |
+
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
460 |
+
'mstd' - float std deviation of magnitude noise applied
|
461 |
+
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
462 |
+
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
463 |
+
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
464 |
+
|
465 |
+
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
466 |
+
|
467 |
+
:return: A PyTorch compatible Transform
|
468 |
+
"""
|
469 |
+
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
470 |
+
num_layers = 2 # default to 2 ops per image
|
471 |
+
weight_idx = None # default to no probability weights for op choice
|
472 |
+
transforms = _RAND_TRANSFORMS
|
473 |
+
config = config_str.split("-")
|
474 |
+
assert config[0] == "rand"
|
475 |
+
config = config[1:]
|
476 |
+
for c in config:
|
477 |
+
cs = re.split(r"(\d.*)", c)
|
478 |
+
if len(cs) < 2:
|
479 |
+
continue
|
480 |
+
key, val = cs[:2]
|
481 |
+
if key == "mstd":
|
482 |
+
# noise param injected via hparams for now
|
483 |
+
hparams.setdefault("magnitude_std", float(val))
|
484 |
+
elif key == "inc":
|
485 |
+
if bool(val): # this path
|
486 |
+
transforms = _RAND_INCREASING_TRANSFORMS
|
487 |
+
elif key == "m":
|
488 |
+
magnitude = int(val)
|
489 |
+
elif key == "n":
|
490 |
+
num_layers = int(val)
|
491 |
+
elif key == "w":
|
492 |
+
weight_idx = int(val)
|
493 |
+
else:
|
494 |
+
assert False, "Unknown RandAugment config section"
|
495 |
+
# magnitude 9
|
496 |
+
# hparams {'translate_const': 100, 'img_mean': (124, 116, 104), 'magnitude_std': 0.5}
|
497 |
+
# transforms ['AutoContrast', 'Equalize', 'Invert', 'Rotate', 'PosterizeIncreasing', \
|
498 |
+
# 'SolarizeIncreasing', 'SolarizeAdd', 'ColorIncreasing', 'ContrastIncreasing', \
|
499 |
+
# 'BrightnessIncreasing', 'SharpnessIncreasing', 'ShearX', 'ShearY', 'TranslateXRel', 'TranslateYRel']
|
500 |
+
ra_ops = rand_augment_ops(
|
501 |
+
magnitude=magnitude, hparams=hparams, transforms=transforms
|
502 |
+
)
|
503 |
+
choice_weights = (
|
504 |
+
None if weight_idx is None else _select_rand_weights(weight_idx)
|
505 |
+
) ## None
|
506 |
+
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
src/dataset.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
ImageFold function.
|
16 |
+
|
17 |
+
Mostly copy-paste from torchvision references
|
18 |
+
"""
|
19 |
+
import os
|
20 |
+
import os.path
|
21 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
22 |
+
|
23 |
+
from PIL import Image
|
24 |
+
from torchvision.datasets.vision import VisionDataset
|
25 |
+
|
26 |
+
|
27 |
+
def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
|
28 |
+
"""Checks if a file is an allowed extension.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
filename (string): path to a file
|
32 |
+
extensions (tuple of strings): extensions to consider (lowercase)
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
bool: True if the filename ends with one of given extensions
|
36 |
+
"""
|
37 |
+
return filename.lower().endswith(extensions)
|
38 |
+
|
39 |
+
|
40 |
+
def is_image_file(filename: str) -> bool:
|
41 |
+
"""Checks if a file is an allowed image extension.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
filename (string): path to a file
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
bool: True if the filename ends with a known image extension
|
48 |
+
"""
|
49 |
+
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
|
50 |
+
|
51 |
+
|
52 |
+
def find_classes(directory: str, class_num: int) -> Tuple[List[str], Dict[str, int]]:
|
53 |
+
"""Finds the class folders in a dataset.
|
54 |
+
|
55 |
+
See :class:`DatasetFolder` for details.
|
56 |
+
"""
|
57 |
+
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
|
58 |
+
if not classes:
|
59 |
+
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
|
60 |
+
classes = classes[:class_num]
|
61 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
62 |
+
return classes, class_to_idx
|
63 |
+
|
64 |
+
|
65 |
+
def make_dataset(
|
66 |
+
directory: str,
|
67 |
+
class_to_idx: Optional[Dict[str, int]] = None,
|
68 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
69 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
70 |
+
class_num=10,
|
71 |
+
) -> List[Tuple[str, int]]:
|
72 |
+
"""Generates a list of samples of a form (path_to_sample, class).
|
73 |
+
|
74 |
+
See :class:`DatasetFolder` for details.
|
75 |
+
|
76 |
+
Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
|
77 |
+
by default.
|
78 |
+
"""
|
79 |
+
directory = os.path.expanduser(directory)
|
80 |
+
|
81 |
+
if class_to_idx is None:
|
82 |
+
_, class_to_idx = find_classes(directory, class_num)
|
83 |
+
elif not class_to_idx:
|
84 |
+
raise ValueError(
|
85 |
+
"'class_to_index' must have at least one entry to collect any samples."
|
86 |
+
)
|
87 |
+
|
88 |
+
both_none = extensions is None and is_valid_file is None
|
89 |
+
both_something = extensions is not None and is_valid_file is not None
|
90 |
+
if both_none or both_something:
|
91 |
+
raise ValueError(
|
92 |
+
"Both extensions and is_valid_file cannot be None or not None at the same time"
|
93 |
+
)
|
94 |
+
|
95 |
+
if extensions is not None:
|
96 |
+
|
97 |
+
def is_valid_file(x: str) -> bool:
|
98 |
+
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
|
99 |
+
|
100 |
+
is_valid_file = cast(Callable[[str], bool], is_valid_file)
|
101 |
+
|
102 |
+
instances = []
|
103 |
+
available_classes = set()
|
104 |
+
for target_class in sorted(class_to_idx.keys()):
|
105 |
+
class_index = class_to_idx[target_class]
|
106 |
+
target_dir = os.path.join(directory, target_class)
|
107 |
+
if not os.path.isdir(target_dir):
|
108 |
+
continue
|
109 |
+
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
|
110 |
+
for fname in sorted(fnames):
|
111 |
+
path = os.path.join(root, fname)
|
112 |
+
if is_valid_file(path):
|
113 |
+
item = path, class_index
|
114 |
+
instances.append(item)
|
115 |
+
|
116 |
+
if target_class not in available_classes:
|
117 |
+
available_classes.add(target_class)
|
118 |
+
|
119 |
+
empty_classes = set(class_to_idx.keys()) - available_classes
|
120 |
+
if empty_classes:
|
121 |
+
msg = (
|
122 |
+
f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
|
123 |
+
)
|
124 |
+
if extensions is not None:
|
125 |
+
msg += f"Supported extensions are: {', '.join(extensions)}"
|
126 |
+
raise FileNotFoundError(msg)
|
127 |
+
|
128 |
+
return instances
|
129 |
+
|
130 |
+
|
131 |
+
class DatasetFolder(VisionDataset):
|
132 |
+
"""A generic data loader.
|
133 |
+
|
134 |
+
This default directory structure can be customized by overriding the
|
135 |
+
:meth:`find_classes` method.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
root (string): Root directory path.
|
139 |
+
loader (callable): A function to load a sample given its path.
|
140 |
+
extensions (tuple[string]): A list of allowed extensions.
|
141 |
+
both extensions and is_valid_file should not be passed.
|
142 |
+
transform (callable, optional): A function/transform that takes in
|
143 |
+
a sample and returns a transformed version.
|
144 |
+
E.g, ``transforms.RandomCrop`` for images.
|
145 |
+
target_transform (callable, optional): A function/transform that takes
|
146 |
+
in the target and transforms it.
|
147 |
+
is_valid_file (callable, optional): A function that takes path of a file
|
148 |
+
and check if the file is a valid file (used to check of corrupt files)
|
149 |
+
both extensions and is_valid_file should not be passed.
|
150 |
+
class_num: how many classes will be loaded
|
151 |
+
Attributes:
|
152 |
+
classes (list): List of the class names sorted alphabetically.
|
153 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
154 |
+
samples (list): List of (sample path, class_index) tuples
|
155 |
+
targets (list): The class_index value for each image in the dataset
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
root: str,
|
161 |
+
loader: Callable[[str], Any],
|
162 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
163 |
+
transform: Optional[Callable] = None,
|
164 |
+
target_transform: Optional[Callable] = None,
|
165 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
166 |
+
class_num=10,
|
167 |
+
) -> None:
|
168 |
+
super(DatasetFolder, self).__init__(
|
169 |
+
root, transform=transform, target_transform=target_transform
|
170 |
+
)
|
171 |
+
classes, class_to_idx = self.find_classes(self.root, class_num=class_num)
|
172 |
+
samples = self.make_dataset(
|
173 |
+
self.root, class_to_idx, extensions, is_valid_file, class_num=class_num
|
174 |
+
)
|
175 |
+
|
176 |
+
self.loader = loader
|
177 |
+
self.extensions = extensions
|
178 |
+
|
179 |
+
self.classes = classes
|
180 |
+
self.class_to_idx = class_to_idx
|
181 |
+
self.samples = samples
|
182 |
+
self.targets = [s[1] for s in samples]
|
183 |
+
|
184 |
+
@staticmethod
|
185 |
+
def make_dataset(
|
186 |
+
directory: str,
|
187 |
+
class_to_idx: Dict[str, int],
|
188 |
+
extensions: Optional[Tuple[str, ...]] = None,
|
189 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
190 |
+
class_num=10,
|
191 |
+
) -> List[Tuple[str, int]]:
|
192 |
+
"""Generates a list of samples of a form (path_to_sample, class).
|
193 |
+
|
194 |
+
This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
directory (str): root dataset directory, corresponding to ``self.root``.
|
198 |
+
class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
|
199 |
+
extensions (optional): A list of allowed extensions.
|
200 |
+
Either extensions or is_valid_file should be passed. Defaults to None.
|
201 |
+
is_valid_file (optional): A function that takes path of a file
|
202 |
+
and checks if the file is a valid file
|
203 |
+
(used to check of corrupt files) both extensions and
|
204 |
+
is_valid_file should not be passed. Defaults to None.
|
205 |
+
class_num: how many classes will be loaded
|
206 |
+
Raises:
|
207 |
+
ValueError: In case ``class_to_idx`` is empty.
|
208 |
+
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
|
209 |
+
FileNotFoundError: In case no valid file was found for any class.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
|
213 |
+
"""
|
214 |
+
if class_to_idx is None:
|
215 |
+
# prevent potential bug since make_dataset() would use the class_to_idx logic of the
|
216 |
+
# find_classes() function, instead of using that of the find_classes() method, which
|
217 |
+
# is potentially overridden and thus could have a different logic.
|
218 |
+
raise ValueError("The class_to_idx parameter cannot be None.")
|
219 |
+
return make_dataset(
|
220 |
+
directory,
|
221 |
+
class_to_idx,
|
222 |
+
extensions=extensions,
|
223 |
+
is_valid_file=is_valid_file,
|
224 |
+
class_num=class_num,
|
225 |
+
)
|
226 |
+
|
227 |
+
def find_classes(
|
228 |
+
self, directory: str, class_num: int
|
229 |
+
) -> Tuple[List[str], Dict[str, int]]:
|
230 |
+
"""Find the class folders in a dataset structured as follows::
|
231 |
+
|
232 |
+
directory/
|
233 |
+
├── class_x
|
234 |
+
│ ├── xxx.ext
|
235 |
+
│ ├── xxy.ext
|
236 |
+
│ └── ...
|
237 |
+
│ └── xxz.ext
|
238 |
+
└── class_y
|
239 |
+
├── 123.ext
|
240 |
+
├── nsdf3.ext
|
241 |
+
└── ...
|
242 |
+
└── asd932_.ext
|
243 |
+
|
244 |
+
This method can be overridden to only consider
|
245 |
+
a subset of classes, or to adapt to a different dataset directory structure.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
directory(str): Root directory path, corresponding to ``self.root``
|
249 |
+
|
250 |
+
Raises:
|
251 |
+
FileNotFoundError: If ``dir`` has no class folders.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
(Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
|
255 |
+
"""
|
256 |
+
return find_classes(directory, class_num=class_num)
|
257 |
+
|
258 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
259 |
+
"""
|
260 |
+
Args:
|
261 |
+
index (int): Index
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
tuple: (sample, target) where target is class_index of the target class.
|
265 |
+
"""
|
266 |
+
path, target = self.samples[index]
|
267 |
+
sample = self.loader(path)
|
268 |
+
if self.transform is not None:
|
269 |
+
sample = self.transform(sample)
|
270 |
+
# if self.target_transform is not None:
|
271 |
+
# target = self.target_transform(target)
|
272 |
+
|
273 |
+
return sample # , target
|
274 |
+
|
275 |
+
def __len__(self) -> int:
|
276 |
+
return len(self.samples)
|
277 |
+
|
278 |
+
|
279 |
+
IMG_EXTENSIONS = (
|
280 |
+
".jpg",
|
281 |
+
".jpeg",
|
282 |
+
".png",
|
283 |
+
".ppm",
|
284 |
+
".bmp",
|
285 |
+
".pgm",
|
286 |
+
".tif",
|
287 |
+
".tiff",
|
288 |
+
".webp",
|
289 |
+
)
|
290 |
+
|
291 |
+
|
292 |
+
def pil_loader(path: str) -> Image.Image:
|
293 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
294 |
+
with open(path, "rb") as f:
|
295 |
+
img = Image.open(f)
|
296 |
+
return img.convert("RGB")
|
297 |
+
|
298 |
+
|
299 |
+
# TODO: specify the return type
|
300 |
+
def accimage_loader(path: str) -> Any:
|
301 |
+
import accimage
|
302 |
+
|
303 |
+
try:
|
304 |
+
return accimage.Image(path)
|
305 |
+
except IOError:
|
306 |
+
# Potentially a decoding problem, fall back to PIL.Image
|
307 |
+
return pil_loader(path)
|
308 |
+
|
309 |
+
|
310 |
+
def default_loader(path: str) -> Any:
|
311 |
+
from torchvision import get_image_backend
|
312 |
+
|
313 |
+
if get_image_backend() == "accimage":
|
314 |
+
return accimage_loader(path)
|
315 |
+
else:
|
316 |
+
return pil_loader(path)
|
317 |
+
|
318 |
+
|
319 |
+
class ImageFolder(DatasetFolder):
|
320 |
+
"""A generic data loader where the images are arranged in this way by default: ::
|
321 |
+
|
322 |
+
root/dog/xxx.png
|
323 |
+
root/dog/xxy.png
|
324 |
+
root/dog/[...]/xxz.png
|
325 |
+
|
326 |
+
root/cat/123.png
|
327 |
+
root/cat/nsdf3.png
|
328 |
+
root/cat/[...]/asd932_.png
|
329 |
+
|
330 |
+
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
|
331 |
+
the same methods can be overridden to customize the dataset.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
root (string): Root directory path.
|
335 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
336 |
+
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
337 |
+
target_transform (callable, optional): A function/transform that takes in the
|
338 |
+
target and transforms it.
|
339 |
+
loader (callable, optional): A function to load an image given its path.
|
340 |
+
is_valid_file (callable, optional): A function that takes path of an Image file
|
341 |
+
and check if the file is a valid file (used to check of corrupt files)
|
342 |
+
class_num: how many classes will be loaded
|
343 |
+
Attributes:
|
344 |
+
classes (list): List of the class names sorted alphabetically.
|
345 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
346 |
+
imgs (list): List of (image path, class_index) tuples
|
347 |
+
"""
|
348 |
+
|
349 |
+
def __init__(
|
350 |
+
self,
|
351 |
+
root: str,
|
352 |
+
transform: Optional[Callable] = None,
|
353 |
+
target_transform: Optional[Callable] = None,
|
354 |
+
loader: Callable[[str], Any] = default_loader,
|
355 |
+
is_valid_file: Optional[Callable[[str], bool]] = None,
|
356 |
+
class_num=10,
|
357 |
+
):
|
358 |
+
super(ImageFolder, self).__init__(
|
359 |
+
root,
|
360 |
+
loader,
|
361 |
+
IMG_EXTENSIONS if is_valid_file is None else None,
|
362 |
+
transform=transform,
|
363 |
+
target_transform=target_transform,
|
364 |
+
is_valid_file=is_valid_file,
|
365 |
+
class_num=class_num,
|
366 |
+
)
|
367 |
+
self.imgs = self.samples
|
src/loss.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
functions for building multi-granular losses.
|
16 |
+
"""
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.distributed as dist
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from utils import concat_all_gather
|
24 |
+
|
25 |
+
|
26 |
+
class InfoNCELoss(nn.Module):
|
27 |
+
"""
|
28 |
+
vanilla infoNCEloss.
|
29 |
+
--ncrops: how many crops are used in student networks
|
30 |
+
--dim: feature dimension in queue determinted by output dimention of student network
|
31 |
+
--queue_size: queue size
|
32 |
+
--temperature: temperature parameter for infoNCEloss
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, ncrops, dim=256, queue_size=65536, temperature=0.2):
|
36 |
+
super().__init__()
|
37 |
+
self.queue_size = queue_size
|
38 |
+
self.temperature = temperature
|
39 |
+
|
40 |
+
self.register_buffer("queue", torch.randn(dim, queue_size))
|
41 |
+
self.queue = nn.functional.normalize(self.queue, dim=0)
|
42 |
+
|
43 |
+
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
44 |
+
self.CrossEntropyLoss = nn.CrossEntropyLoss()
|
45 |
+
self.ncrops = ncrops
|
46 |
+
|
47 |
+
@torch.no_grad()
|
48 |
+
def _dequeue_and_enqueue(self, keys):
|
49 |
+
"""
|
50 |
+
queue update
|
51 |
+
"""
|
52 |
+
keys = concat_all_gather(keys)
|
53 |
+
batch_size = keys.shape[0]
|
54 |
+
ptr = int(self.queue_ptr)
|
55 |
+
# replace the keys at ptr (dequeue and enqueue)
|
56 |
+
if ptr + batch_size <= self.queue_size:
|
57 |
+
self.queue[:, ptr : ptr + batch_size] = keys.T
|
58 |
+
ptr = (ptr + batch_size) % self.queue_size
|
59 |
+
else:
|
60 |
+
keys_t = keys.T
|
61 |
+
queue_remaining_size = self.queue_size - ptr
|
62 |
+
self.queue[:, ptr:] = keys_t[:, :queue_remaining_size]
|
63 |
+
self.queue[:, : batch_size - queue_remaining_size] = keys_t[
|
64 |
+
:, queue_remaining_size:
|
65 |
+
]
|
66 |
+
|
67 |
+
ptr = batch_size - queue_remaining_size # move pointer
|
68 |
+
|
69 |
+
self.queue_ptr[0] = ptr
|
70 |
+
|
71 |
+
# student_output, teacher_output
|
72 |
+
def forward(self, student_output, teacher_output, epoch):
|
73 |
+
"""
|
74 |
+
Cross-entropy between softmax outputs of the teacher and student networks.
|
75 |
+
"""
|
76 |
+
preds = student_output.chunk(self.ncrops)
|
77 |
+
targets = teacher_output.detach().chunk(2)
|
78 |
+
small_crop_loss, large_crop_loss = 0, 0
|
79 |
+
small_loss_terms, large_loss_terms = 0, 0
|
80 |
+
queue_feat = self.queue.clone().detach()
|
81 |
+
|
82 |
+
for t_idx, targ in enumerate(targets):
|
83 |
+
for p_idx, pred in enumerate(preds):
|
84 |
+
if t_idx == p_idx:
|
85 |
+
continue
|
86 |
+
# positive logits: Nx1
|
87 |
+
l_pos = torch.einsum("nc,nc->n", [pred, targ]).unsqueeze(-1)
|
88 |
+
# negative logits: NxK
|
89 |
+
l_neg = torch.einsum("nc,ck->nk", [pred, queue_feat])
|
90 |
+
# logits: Nx(1+K)
|
91 |
+
logits = torch.cat([l_pos, l_neg], dim=1)
|
92 |
+
# apply temperature
|
93 |
+
logits /= self.temperature
|
94 |
+
# labels: positive key indicators
|
95 |
+
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(
|
96 |
+
logits.device
|
97 |
+
)
|
98 |
+
loss = self.CrossEntropyLoss(logits, labels)
|
99 |
+
if p_idx < 2: ## large crop loss, namely loss on 224-sized images
|
100 |
+
large_crop_loss += loss
|
101 |
+
large_loss_terms += 1
|
102 |
+
else: ## small crop loss, namely loss on 96-sized images
|
103 |
+
small_crop_loss += loss
|
104 |
+
small_loss_terms += 1
|
105 |
+
# dequeue and enqueue
|
106 |
+
self._dequeue_and_enqueue(targ)
|
107 |
+
|
108 |
+
large_crop_loss /= large_loss_terms
|
109 |
+
small_crop_loss /= small_loss_terms
|
110 |
+
loss = 0.5 * (large_crop_loss + small_crop_loss)
|
111 |
+
return loss
|
112 |
+
|
113 |
+
|
114 |
+
class ClusteringLoss(nn.Module):
|
115 |
+
"""
|
116 |
+
Clustering loss which is very simialr to the one in DINO
|
117 |
+
--out_dim: center dimension determinted by output dimention of student network
|
118 |
+
--ncrops: how many crops are used in student networks
|
119 |
+
--warmup_teacher_temp: Initial value for the teacher temperature
|
120 |
+
--teacher_temp: Final value (after linear warmup) of the teacher temperature
|
121 |
+
--warmup_teacher_temp_epochs: Number of warmup epochs for the teacher temperature
|
122 |
+
--nepochs: total training epoch
|
123 |
+
--student_temp: temperature parameter in student output
|
124 |
+
--center_momentum: EMA parameter for center update
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
out_dim,
|
130 |
+
ncrops,
|
131 |
+
warmup_teacher_temp,
|
132 |
+
teacher_temp,
|
133 |
+
warmup_teacher_temp_epochs,
|
134 |
+
nepochs,
|
135 |
+
student_temp=0.1,
|
136 |
+
center_momentum=0.9,
|
137 |
+
):
|
138 |
+
super().__init__()
|
139 |
+
self.student_temp = student_temp
|
140 |
+
self.center_momentum = center_momentum
|
141 |
+
self.ncrops = ncrops
|
142 |
+
self.register_buffer("center", torch.zeros(1, out_dim))
|
143 |
+
# we apply a warm up for the teacher temperature because
|
144 |
+
# a too high temperature makes the training instable at the beginning
|
145 |
+
self.teacher_temp_schedule = np.concatenate(
|
146 |
+
(
|
147 |
+
np.linspace(
|
148 |
+
warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs
|
149 |
+
),
|
150 |
+
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp,
|
151 |
+
)
|
152 |
+
)
|
153 |
+
|
154 |
+
def forward(self, student_output, teacher_output, epoch):
|
155 |
+
"""
|
156 |
+
Cross-entropy between softmax outputs of the teacher and student networks.
|
157 |
+
"""
|
158 |
+
student_out = student_output / self.student_temp
|
159 |
+
student_out = student_out.chunk(self.ncrops)
|
160 |
+
|
161 |
+
# teacher centering and sharpening
|
162 |
+
temp = self.teacher_temp_schedule[epoch]
|
163 |
+
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
|
164 |
+
teacher_out = teacher_out.detach().chunk(2)
|
165 |
+
|
166 |
+
loss_large_crop, loss_small_crop = 0.0, 0.0
|
167 |
+
loss_terms_large_crop, loss_terms_small_crop = 0, 0
|
168 |
+
for iq, q in enumerate(teacher_out):
|
169 |
+
for v in range(len(student_out)):
|
170 |
+
if v == iq:
|
171 |
+
# we skip cases where student and teacher operate on the same view
|
172 |
+
continue
|
173 |
+
loss = torch.sum(
|
174 |
+
-q * F.log_softmax(student_out[v], dim=-1), dim=-1
|
175 |
+
).mean()
|
176 |
+
if v < 2:
|
177 |
+
loss_large_crop += loss
|
178 |
+
loss_terms_large_crop += 1
|
179 |
+
else:
|
180 |
+
loss_small_crop += loss
|
181 |
+
loss_terms_small_crop += 1
|
182 |
+
|
183 |
+
self.update_center(teacher_output)
|
184 |
+
loss_large_crop /= loss_terms_large_crop
|
185 |
+
loss_small_crop /= loss_terms_small_crop
|
186 |
+
total_loss = 0.5 * (loss_large_crop + loss_small_crop)
|
187 |
+
return total_loss
|
188 |
+
|
189 |
+
@torch.no_grad()
|
190 |
+
def update_center(self, teacher_output):
|
191 |
+
"""
|
192 |
+
Update center used for teacher output.
|
193 |
+
"""
|
194 |
+
batch_center = torch.mean(teacher_output, dim=0, keepdim=False)
|
195 |
+
dist.all_reduce(batch_center)
|
196 |
+
batch_center = batch_center / dist.get_world_size()
|
197 |
+
|
198 |
+
# ema update
|
199 |
+
self.center = self.center * self.center_momentum + batch_center * (
|
200 |
+
1 - self.center_momentum
|
201 |
+
)
|
202 |
+
|
203 |
+
|
204 |
+
def get_multi_granular_loss(args):
|
205 |
+
"""
|
206 |
+
build the multi-granular loss
|
207 |
+
"""
|
208 |
+
all_losses, all_weights = {}, {}
|
209 |
+
|
210 |
+
## build the instance discrimination loss
|
211 |
+
instance_supervision_loss = InfoNCELoss(
|
212 |
+
args.local_crops_number + 2,
|
213 |
+
dim=args.instance_out_dim,
|
214 |
+
queue_size=args.instance_queue_size,
|
215 |
+
temperature=args.instance_temp,
|
216 |
+
).cuda()
|
217 |
+
all_losses["instance-sup."] = instance_supervision_loss
|
218 |
+
all_weights["instance-sup."] = args.loss_weights[0]
|
219 |
+
|
220 |
+
## build the local group discrimination loss
|
221 |
+
local_group_supervision = InfoNCELoss(
|
222 |
+
args.local_crops_number + 2,
|
223 |
+
dim=args.local_group_out_dim,
|
224 |
+
queue_size=args.local_group_queue_size,
|
225 |
+
temperature=args.local_group_temp,
|
226 |
+
).cuda()
|
227 |
+
all_losses["local-group-sup."] = local_group_supervision
|
228 |
+
all_weights["local-group-sup."] = args.loss_weights[1]
|
229 |
+
|
230 |
+
## build the group discrimination loss
|
231 |
+
group_loss = ClusteringLoss(
|
232 |
+
args.group_out_dim,
|
233 |
+
args.local_crops_number
|
234 |
+
+ 2, # total number of crops = 2 global crops + local_crops_number
|
235 |
+
args.group_warmup_teacher_temp,
|
236 |
+
args.group_teacher_temp,
|
237 |
+
args.group_warmup_teacher_temp_epochs,
|
238 |
+
args.epochs,
|
239 |
+
student_temp=args.group_student_temp,
|
240 |
+
center_momentum=0.9,
|
241 |
+
).cuda()
|
242 |
+
all_losses["group-sup."] = group_loss
|
243 |
+
all_weights["group-sup."] = args.loss_weights[2]
|
244 |
+
return all_losses, all_weights
|
src/model.py
ADDED
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
models and functions for building student and teacher networks for multi-granular losses.
|
16 |
+
"""
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
import src.vision_transformer as vits
|
21 |
+
from src.vision_transformer import trunc_normal_
|
22 |
+
|
23 |
+
|
24 |
+
class Instance_Superivsion_Head(nn.Module):
|
25 |
+
"""
|
26 |
+
a class to implement Instance Superivsion Head
|
27 |
+
--in_dim: input dimension of projection head
|
28 |
+
--hidden_dim: hidden dimension of projection head
|
29 |
+
--out_dim: ouput dimension of projection and prediction heads
|
30 |
+
--pred_hidden_dim: hidden dimension of prediction head
|
31 |
+
--nlayers: layer number of projection head. prediction head has nlayers-1 layer
|
32 |
+
--proj_bn: whether we use batch normalization in projection head
|
33 |
+
--pred_bn: whether we use batch normalization in prediction head
|
34 |
+
--norm_before_pred: whether we use normalization before prediction head
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
in_dim,
|
40 |
+
hidden_dim=2048,
|
41 |
+
out_dim=256,
|
42 |
+
pred_hidden_dim=4096,
|
43 |
+
nlayers=3,
|
44 |
+
proj_bn=False,
|
45 |
+
pred_bn=False,
|
46 |
+
norm_before_pred=True,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
nlayers = max(nlayers, 1)
|
50 |
+
self.norm_before_pred = norm_before_pred
|
51 |
+
|
52 |
+
self.projector = self._build_mlp(
|
53 |
+
nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn
|
54 |
+
)
|
55 |
+
|
56 |
+
self.apply(self._init_weights)
|
57 |
+
|
58 |
+
self.predictor = None
|
59 |
+
if pred_hidden_dim > 0: # teacher no, student yes
|
60 |
+
self.predictor = self._build_mlp(
|
61 |
+
nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn
|
62 |
+
)
|
63 |
+
|
64 |
+
def _init_weights(self, m):
|
65 |
+
"""
|
66 |
+
initilize the parameters in network
|
67 |
+
"""
|
68 |
+
if isinstance(m, nn.Linear):
|
69 |
+
trunc_normal_(m.weight, std=0.02)
|
70 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
71 |
+
nn.init.constant_(m.bias, 0)
|
72 |
+
|
73 |
+
def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False):
|
74 |
+
"""
|
75 |
+
build a mlp
|
76 |
+
"""
|
77 |
+
mlp = []
|
78 |
+
for layer in range(num_layers):
|
79 |
+
dim1 = input_dim if layer == 0 else hidden_dim
|
80 |
+
dim2 = output_dim if layer == num_layers - 1 else hidden_dim
|
81 |
+
|
82 |
+
mlp.append(nn.Linear(dim1, dim2, bias=False))
|
83 |
+
|
84 |
+
if layer < num_layers - 1:
|
85 |
+
if use_bn:
|
86 |
+
mlp.append(nn.BatchNorm1d(dim2))
|
87 |
+
mlp.append(nn.GELU())
|
88 |
+
|
89 |
+
return nn.Sequential(*mlp)
|
90 |
+
|
91 |
+
def forward(self, x, return_target=False):
|
92 |
+
"""
|
93 |
+
forward the input through projection head for teacher and
|
94 |
+
projection/prediction heads for student
|
95 |
+
"""
|
96 |
+
feat = self.projector(x)
|
97 |
+
|
98 |
+
if return_target:
|
99 |
+
feat = nn.functional.normalize(feat, dim=-1, p=2)
|
100 |
+
return feat
|
101 |
+
## return prediction
|
102 |
+
if self.norm_before_pred:
|
103 |
+
feat = nn.functional.normalize(feat, dim=-1, p=2)
|
104 |
+
pred = self.predictor(feat)
|
105 |
+
pred = nn.functional.normalize(pred, dim=-1, p=2)
|
106 |
+
return pred
|
107 |
+
|
108 |
+
|
109 |
+
class Local_Group_Superivsion_Head(nn.Module):
|
110 |
+
"""
|
111 |
+
a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head
|
112 |
+
--in_dim: input dimension of projection head
|
113 |
+
--hidden_dim: hidden dimension of projection head
|
114 |
+
--out_dim: ouput dimension of projection and prediction heads
|
115 |
+
--pred_hidden_dim: hidden dimension of prediction head
|
116 |
+
--nlayers: layer number of projection head. prediction head has nlayers-1 layer
|
117 |
+
--proj_bn: whether we use batch normalization in projection head
|
118 |
+
--pred_bn: whether we use batch normalization in prediction head
|
119 |
+
--norm_before_pred: whether we use normalization before prediction head
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
in_dim,
|
125 |
+
hidden_dim=2048,
|
126 |
+
out_dim=256,
|
127 |
+
pred_hidden_dim=4096,
|
128 |
+
nlayers=3,
|
129 |
+
proj_bn=False,
|
130 |
+
pred_bn=False,
|
131 |
+
norm_before_pred=True,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
nlayers = max(nlayers, 1)
|
135 |
+
self.norm_before_pred = norm_before_pred
|
136 |
+
|
137 |
+
self.projector = self._build_mlp(
|
138 |
+
nlayers, in_dim, hidden_dim, out_dim, use_bn=proj_bn
|
139 |
+
)
|
140 |
+
|
141 |
+
self.apply(self._init_weights)
|
142 |
+
|
143 |
+
self.predictor = None
|
144 |
+
if pred_hidden_dim > 0: # teacher no, student yes
|
145 |
+
self.predictor = self._build_mlp(
|
146 |
+
nlayers - 1, out_dim, pred_hidden_dim, out_dim, use_bn=pred_bn
|
147 |
+
)
|
148 |
+
|
149 |
+
def _init_weights(self, m):
|
150 |
+
"""
|
151 |
+
initilize the parameters in network
|
152 |
+
"""
|
153 |
+
if isinstance(m, nn.Linear):
|
154 |
+
trunc_normal_(m.weight, std=0.02)
|
155 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
156 |
+
nn.init.constant_(m.bias, 0)
|
157 |
+
|
158 |
+
def _build_mlp(self, num_layers, input_dim, hidden_dim, output_dim, use_bn=False):
|
159 |
+
"""
|
160 |
+
build a mlp
|
161 |
+
"""
|
162 |
+
mlp = []
|
163 |
+
for layer in range(num_layers):
|
164 |
+
dim1 = input_dim if layer == 0 else hidden_dim
|
165 |
+
dim2 = output_dim if layer == num_layers - 1 else hidden_dim
|
166 |
+
|
167 |
+
mlp.append(nn.Linear(dim1, dim2, bias=False))
|
168 |
+
|
169 |
+
if layer < num_layers - 1:
|
170 |
+
if use_bn:
|
171 |
+
mlp.append(nn.BatchNorm1d(dim2))
|
172 |
+
mlp.append(nn.GELU())
|
173 |
+
|
174 |
+
return nn.Sequential(*mlp)
|
175 |
+
|
176 |
+
def forward(self, x, return_target=False):
|
177 |
+
"""
|
178 |
+
forward the input through projection head for teacher and
|
179 |
+
projection/prediction heads for student
|
180 |
+
"""
|
181 |
+
feat = self.projector(x)
|
182 |
+
|
183 |
+
if return_target:
|
184 |
+
feat = nn.functional.normalize(feat, dim=-1, p=2)
|
185 |
+
return feat
|
186 |
+
## return prediction
|
187 |
+
if self.norm_before_pred:
|
188 |
+
feat = nn.functional.normalize(feat, dim=-1, p=2)
|
189 |
+
pred = self.predictor(feat)
|
190 |
+
pred = nn.functional.normalize(pred, dim=-1, p=2)
|
191 |
+
return pred
|
192 |
+
|
193 |
+
|
194 |
+
class Group_Superivsion_Head(nn.Module):
|
195 |
+
"""
|
196 |
+
a class to implement Local Group Superivsion Head which is the same as Instance Superivsion Head
|
197 |
+
--in_dim: input dimension of projection head
|
198 |
+
--hidden_dim: hidden dimension of projection head
|
199 |
+
--out_dim: ouput dimension of projection and prediction heads
|
200 |
+
--pred_hidden_dim: hidden dimension of prediction head
|
201 |
+
--nlayers: layer number of projection head. prediction head has nlayers-1 layer
|
202 |
+
--proj_bn: whether we use batch normalization in projection head
|
203 |
+
--pred_bn: whether we use batch normalization in prediction head
|
204 |
+
--norm_before_pred: whether we use normalization before prediction head
|
205 |
+
"""
|
206 |
+
|
207 |
+
def __init__(
|
208 |
+
self,
|
209 |
+
in_dim,
|
210 |
+
out_dim,
|
211 |
+
hidden_dim=2048,
|
212 |
+
bottleneck_dim=256,
|
213 |
+
nlayers=3,
|
214 |
+
use_bn=False,
|
215 |
+
norm_last_layer=True,
|
216 |
+
):
|
217 |
+
super().__init__()
|
218 |
+
nlayers = max(nlayers, 1)
|
219 |
+
|
220 |
+
self.projector = self._build_mlp(
|
221 |
+
nlayers, in_dim, hidden_dim, bottleneck_dim, use_bn=use_bn
|
222 |
+
)
|
223 |
+
self.apply(self._init_weights)
|
224 |
+
|
225 |
+
self.last_layer = nn.utils.weight_norm(
|
226 |
+
nn.Linear(bottleneck_dim, out_dim, bias=False)
|
227 |
+
)
|
228 |
+
self.last_layer.weight_g.data.fill_(1)
|
229 |
+
if norm_last_layer:
|
230 |
+
self.last_layer.weight_g.requires_grad = False
|
231 |
+
|
232 |
+
def _build_mlp(self, num_layers, in_dim, hidden_dim, output_dim, use_bn=False):
|
233 |
+
"""
|
234 |
+
build a mlp
|
235 |
+
"""
|
236 |
+
if num_layers == 1:
|
237 |
+
mlp = nn.Linear(in_dim, output_dim)
|
238 |
+
else:
|
239 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
240 |
+
if use_bn:
|
241 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
242 |
+
layers.append(nn.GELU())
|
243 |
+
for _ in range(num_layers - 2):
|
244 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
245 |
+
if use_bn:
|
246 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
247 |
+
layers.append(nn.GELU())
|
248 |
+
layers.append(nn.Linear(hidden_dim, output_dim))
|
249 |
+
mlp = nn.Sequential(*layers)
|
250 |
+
return mlp
|
251 |
+
|
252 |
+
def _init_weights(self, m):
|
253 |
+
"""
|
254 |
+
initilize the parameters in network
|
255 |
+
"""
|
256 |
+
if isinstance(m, nn.Linear):
|
257 |
+
trunc_normal_(m.weight, std=0.02)
|
258 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
259 |
+
nn.init.constant_(m.bias, 0)
|
260 |
+
|
261 |
+
def forward(self, x):
|
262 |
+
"""
|
263 |
+
forward the input through the projection and last prediction layer
|
264 |
+
"""
|
265 |
+
feat = self.projector(x)
|
266 |
+
feat = nn.functional.normalize(feat, dim=-1, p=2)
|
267 |
+
feat = self.last_layer(feat)
|
268 |
+
return feat
|
269 |
+
|
270 |
+
|
271 |
+
class Block_mem(nn.Module):
|
272 |
+
"""
|
273 |
+
a class to implement a memory block for local group supervision
|
274 |
+
--dim: feature vector dimenstion in the memory
|
275 |
+
--K: memory size
|
276 |
+
--top_n: number for neighbors in local group supervision
|
277 |
+
"""
|
278 |
+
|
279 |
+
def __init__(self, dim, K=2048, top_n=10):
|
280 |
+
super().__init__()
|
281 |
+
self.dim = dim
|
282 |
+
self.K = K
|
283 |
+
self.top_n = top_n
|
284 |
+
# create the queue
|
285 |
+
self.register_buffer("queue_q", torch.randn(K, dim))
|
286 |
+
self.register_buffer("queue_k", torch.randn(K, dim))
|
287 |
+
self.register_buffer("queue_v", torch.randn(K, dim))
|
288 |
+
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
289 |
+
|
290 |
+
@torch.no_grad()
|
291 |
+
def _dequeue_and_enqueue(self, query, weak_aug_flags):
|
292 |
+
"""
|
293 |
+
update memory queue
|
294 |
+
"""
|
295 |
+
# import pdb
|
296 |
+
# pdb.set_trace()
|
297 |
+
len_weak = 0
|
298 |
+
query = concat_all_gather(query)
|
299 |
+
if weak_aug_flags is not None:
|
300 |
+
weak_aug_flags = weak_aug_flags.cuda()
|
301 |
+
weak_aug_flags = concat_all_gather(weak_aug_flags)
|
302 |
+
idx_weak = torch.nonzero(weak_aug_flags)
|
303 |
+
len_weak = len(idx_weak)
|
304 |
+
if len_weak > 0:
|
305 |
+
idx_weak = idx_weak.squeeze(-1)
|
306 |
+
query = query[idx_weak]
|
307 |
+
else:
|
308 |
+
return len_weak
|
309 |
+
|
310 |
+
all_size = query.shape[0]
|
311 |
+
ptr = int(self.queue_ptr)
|
312 |
+
remaining_size = ptr + all_size - self.K
|
313 |
+
if remaining_size <= 0:
|
314 |
+
self.queue_q[ptr : ptr + all_size, :] = query
|
315 |
+
self.queue_k[ptr : ptr + all_size, :] = query
|
316 |
+
self.queue_v[ptr : ptr + all_size, :] = query
|
317 |
+
ptr = ptr + all_size
|
318 |
+
self.queue_ptr[0] = (ptr + all_size) % self.K
|
319 |
+
else:
|
320 |
+
self.queue_q[ptr : self.K, :] = query[0 : self.K - ptr, :]
|
321 |
+
self.queue_k[ptr : self.K, :] = query[0 : self.K - ptr, :]
|
322 |
+
self.queue_v[ptr : self.K, :] = query[0 : self.K - ptr, :]
|
323 |
+
|
324 |
+
self.queue_q[0:remaining_size, :] = query[self.K - ptr :, :]
|
325 |
+
self.queue_k[0:remaining_size, :] = query[self.K - ptr :, :]
|
326 |
+
self.queue_v[0:remaining_size, :] = query[self.K - ptr :, :]
|
327 |
+
self.queue_ptr[0] = remaining_size
|
328 |
+
return len_weak
|
329 |
+
|
330 |
+
@torch.no_grad()
|
331 |
+
def _get_similarity_index(self, x):
|
332 |
+
"""
|
333 |
+
compute the index of the top-n neighbors (key-value pair) in memory
|
334 |
+
"""
|
335 |
+
x = nn.functional.normalize(x, dim=-1)
|
336 |
+
queue_q = nn.functional.normalize(self.queue_q, dim=-1)
|
337 |
+
|
338 |
+
cosine = x @ queue_q.T
|
339 |
+
_, index = torch.topk(cosine, self.top_n, dim=-1)
|
340 |
+
return index
|
341 |
+
|
342 |
+
@torch.no_grad()
|
343 |
+
def _get_similarity_samples(self, query, index=None):
|
344 |
+
"""
|
345 |
+
compute top-n neighbors (key-value pair) in memory
|
346 |
+
"""
|
347 |
+
if index is None:
|
348 |
+
index = self._get_similarity_index(query)
|
349 |
+
get_k = self.queue_k[index.view(-1)]
|
350 |
+
get_v = self.queue_v[index.view(-1)]
|
351 |
+
B, tn = index.shape
|
352 |
+
get_k = get_k.view(B, tn, self.dim)
|
353 |
+
get_v = get_v.view(B, tn, self.dim)
|
354 |
+
return get_k, get_v
|
355 |
+
|
356 |
+
def forward(self, query):
|
357 |
+
"""
|
358 |
+
forward to find the top-n neighbors (key-value pair) in memory
|
359 |
+
"""
|
360 |
+
get_k, get_v = self._get_similarity_samples(query)
|
361 |
+
return get_k, get_v
|
362 |
+
|
363 |
+
|
364 |
+
class vit_mem(nn.Module):
|
365 |
+
"""
|
366 |
+
a class to implement a memory for local group supervision
|
367 |
+
--dim: feature vector dimenstion in the memory
|
368 |
+
--K: memory size
|
369 |
+
--top_n: number for neighbors in local group supervision
|
370 |
+
"""
|
371 |
+
|
372 |
+
def __init__(self, dim, K=2048, top_n=10):
|
373 |
+
super().__init__()
|
374 |
+
self.block = Block_mem(dim, K, top_n)
|
375 |
+
|
376 |
+
def _dequeue_and_enqueue(self, query, weak_aug_flags):
|
377 |
+
"""
|
378 |
+
update memory queue
|
379 |
+
"""
|
380 |
+
query = query.float()
|
381 |
+
weak_num = self.block._dequeue_and_enqueue(query, weak_aug_flags)
|
382 |
+
return weak_num
|
383 |
+
|
384 |
+
def forward(self, query):
|
385 |
+
"""
|
386 |
+
forward to find the top-n neighbors (key-value pair) in memory
|
387 |
+
"""
|
388 |
+
query = query.float()
|
389 |
+
get_k, get_v = self.block(query)
|
390 |
+
return get_k, get_v
|
391 |
+
|
392 |
+
|
393 |
+
class Mugs_Wrapper(nn.Module):
|
394 |
+
"""
|
395 |
+
a class to implement a student or teacher wrapper for mugs
|
396 |
+
--backbone: the backnone of student/teacher, e.g. ViT-small
|
397 |
+
--instance_head: head, including projection/prediction heads, for instance supervision
|
398 |
+
--local_group_head: head, including projection/prediction heads, for local group supervision
|
399 |
+
--group_head: projection head for group supervision
|
400 |
+
"""
|
401 |
+
|
402 |
+
def __init__(self, backbone, instance_head, local_group_head, group_head):
|
403 |
+
super(Mugs_Wrapper, self).__init__()
|
404 |
+
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
|
405 |
+
self.backbone = backbone
|
406 |
+
self.instance_head = instance_head
|
407 |
+
self.local_group_head = local_group_head
|
408 |
+
self.group_head = group_head
|
409 |
+
|
410 |
+
def forward(self, x, return_target=False, local_group_memory_inputs=None):
|
411 |
+
"""
|
412 |
+
forward input to get instance/local-group/group targets or predictions
|
413 |
+
"""
|
414 |
+
# convert to list
|
415 |
+
if not isinstance(x, list):
|
416 |
+
x = [x]
|
417 |
+
idx_crops = torch.cumsum(
|
418 |
+
torch.unique_consecutive(
|
419 |
+
torch.tensor([inp.shape[-1] for inp in x]),
|
420 |
+
return_counts=True,
|
421 |
+
)[1],
|
422 |
+
0,
|
423 |
+
)
|
424 |
+
|
425 |
+
start_idx = 0
|
426 |
+
class_tokens = torch.empty(0).to(x[0].device)
|
427 |
+
mean_patch_tokens = torch.empty(0).to(x[0].device)
|
428 |
+
memory_class_tokens = torch.empty(0).to(x[0].device)
|
429 |
+
for _, end_idx in enumerate(idx_crops):
|
430 |
+
input = torch.cat(x[start_idx:end_idx])
|
431 |
+
token_feat, memory_class_token_feat = self.backbone(
|
432 |
+
input,
|
433 |
+
return_all=True,
|
434 |
+
local_group_memory_inputs=local_group_memory_inputs,
|
435 |
+
) # [[16, 197, 384], [16, 384]] teacher
|
436 |
+
# [[16, 197, 384], [16, 384]] student [[48, 37, 384], [48, 384]]
|
437 |
+
|
438 |
+
class_token_feat = token_feat[
|
439 |
+
:, 0
|
440 |
+
] # class tokens in ViT, [16, 384] teacher [16, 384] student [48, 384]
|
441 |
+
class_tokens = torch.cat((class_tokens, class_token_feat))
|
442 |
+
|
443 |
+
start_idx = end_idx
|
444 |
+
|
445 |
+
if self.local_group_head is not None:
|
446 |
+
memory_class_tokens = torch.cat(
|
447 |
+
(memory_class_tokens, memory_class_token_feat)
|
448 |
+
)
|
449 |
+
if input.shape[-1] == 224:
|
450 |
+
mean_patch_tokens = torch.cat(
|
451 |
+
(mean_patch_tokens, token_feat[:, 1:].mean(dim=1))
|
452 |
+
)
|
453 |
+
|
454 |
+
## target [16, 256] for teacher, [64, 256] for student,
|
455 |
+
instance_feat = (
|
456 |
+
self.instance_head(class_tokens, return_target)
|
457 |
+
if self.instance_head is not None
|
458 |
+
else None
|
459 |
+
)
|
460 |
+
|
461 |
+
## target [16, 256] for teacher, [64, 256] for student
|
462 |
+
local_group_feat = (
|
463 |
+
self.local_group_head(memory_class_tokens, return_target)
|
464 |
+
if self.local_group_head is not None
|
465 |
+
else None
|
466 |
+
)
|
467 |
+
|
468 |
+
# target [16, 65536] for teacher, [64, 65536] for student
|
469 |
+
group_feat = (
|
470 |
+
self.group_head(class_tokens) if self.group_head is not None else None
|
471 |
+
)
|
472 |
+
return instance_feat, local_group_feat, group_feat, mean_patch_tokens.detach()
|
473 |
+
|
474 |
+
|
475 |
+
def get_model(args):
|
476 |
+
"""
|
477 |
+
build a student or teacher for mugs, includeing backbone, instance/local-group/group heads,
|
478 |
+
and memory buffer
|
479 |
+
"""
|
480 |
+
## backbone
|
481 |
+
if args.arch in vits.__dict__.keys():
|
482 |
+
student = vits.__dict__[args.arch](
|
483 |
+
patch_size=args.patch_size,
|
484 |
+
num_relation_blocks=1,
|
485 |
+
drop_path_rate=args.drop_path_rate, # stochastic depth
|
486 |
+
)
|
487 |
+
teacher = vits.__dict__[args.arch](
|
488 |
+
patch_size=args.patch_size, num_relation_blocks=1
|
489 |
+
)
|
490 |
+
embed_dim = student.embed_dim
|
491 |
+
else:
|
492 |
+
assert f"Unknow architecture: {args.arch}"
|
493 |
+
|
494 |
+
## memory buffer for local-group loss
|
495 |
+
student_mem = vit_mem(
|
496 |
+
embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n
|
497 |
+
)
|
498 |
+
teacher_mem = vit_mem(
|
499 |
+
embed_dim, K=args.local_group_queue_size, top_n=args.local_group_knn_top_n
|
500 |
+
)
|
501 |
+
|
502 |
+
## multi-crop wrapper handles forward with inputs of different resolutions
|
503 |
+
student_instance_head, student_local_group_head, student_group_head = (
|
504 |
+
None,
|
505 |
+
None,
|
506 |
+
None,
|
507 |
+
)
|
508 |
+
teacher_instance_head, teacher_local_group_head, teacher_group_head = (
|
509 |
+
None,
|
510 |
+
None,
|
511 |
+
None,
|
512 |
+
)
|
513 |
+
|
514 |
+
# instance head
|
515 |
+
if args.loss_weights[0] > 0:
|
516 |
+
student_instance_head = Instance_Superivsion_Head(
|
517 |
+
in_dim=embed_dim,
|
518 |
+
hidden_dim=2048,
|
519 |
+
out_dim=args.instance_out_dim,
|
520 |
+
pred_hidden_dim=4096,
|
521 |
+
nlayers=3,
|
522 |
+
proj_bn=args.use_bn_in_head,
|
523 |
+
pred_bn=False,
|
524 |
+
norm_before_pred=args.norm_before_pred,
|
525 |
+
)
|
526 |
+
teacher_instance_head = Instance_Superivsion_Head(
|
527 |
+
in_dim=embed_dim,
|
528 |
+
hidden_dim=2048,
|
529 |
+
out_dim=args.instance_out_dim,
|
530 |
+
pred_hidden_dim=0,
|
531 |
+
nlayers=3,
|
532 |
+
proj_bn=args.use_bn_in_head,
|
533 |
+
pred_bn=False,
|
534 |
+
norm_before_pred=args.norm_before_pred,
|
535 |
+
)
|
536 |
+
|
537 |
+
# local group head
|
538 |
+
if args.loss_weights[1] > 0:
|
539 |
+
student_local_group_head = Local_Group_Superivsion_Head(
|
540 |
+
in_dim=embed_dim,
|
541 |
+
hidden_dim=2048,
|
542 |
+
out_dim=args.local_group_out_dim,
|
543 |
+
pred_hidden_dim=4096,
|
544 |
+
nlayers=3,
|
545 |
+
proj_bn=args.use_bn_in_head,
|
546 |
+
pred_bn=False,
|
547 |
+
norm_before_pred=args.norm_before_pred,
|
548 |
+
)
|
549 |
+
teacher_local_group_head = Local_Group_Superivsion_Head(
|
550 |
+
in_dim=embed_dim,
|
551 |
+
hidden_dim=2048,
|
552 |
+
out_dim=args.local_group_out_dim,
|
553 |
+
pred_hidden_dim=0,
|
554 |
+
nlayers=3,
|
555 |
+
proj_bn=args.use_bn_in_head,
|
556 |
+
pred_bn=False,
|
557 |
+
norm_before_pred=args.norm_before_pred,
|
558 |
+
)
|
559 |
+
|
560 |
+
# group head
|
561 |
+
if args.loss_weights[2] > 0:
|
562 |
+
student_group_head = Group_Superivsion_Head(
|
563 |
+
in_dim=embed_dim,
|
564 |
+
out_dim=args.group_out_dim,
|
565 |
+
hidden_dim=2048,
|
566 |
+
bottleneck_dim=args.group_bottleneck_dim,
|
567 |
+
nlayers=3,
|
568 |
+
use_bn=args.use_bn_in_head,
|
569 |
+
norm_last_layer=args.norm_last_layer,
|
570 |
+
)
|
571 |
+
teacher_group_head = Group_Superivsion_Head(
|
572 |
+
in_dim=embed_dim,
|
573 |
+
out_dim=args.group_out_dim,
|
574 |
+
hidden_dim=2048,
|
575 |
+
bottleneck_dim=args.group_bottleneck_dim,
|
576 |
+
nlayers=3,
|
577 |
+
use_bn=args.use_bn_in_head,
|
578 |
+
norm_last_layer=args.norm_last_layer,
|
579 |
+
)
|
580 |
+
|
581 |
+
# multi-crop wrapper
|
582 |
+
student = Mugs_Wrapper(
|
583 |
+
student, student_instance_head, student_local_group_head, student_group_head
|
584 |
+
)
|
585 |
+
|
586 |
+
teacher = Mugs_Wrapper(
|
587 |
+
teacher, teacher_instance_head, teacher_local_group_head, teacher_group_head
|
588 |
+
)
|
589 |
+
|
590 |
+
return student, teacher, student_mem, teacher_mem
|
591 |
+
|
592 |
+
|
593 |
+
# utils
|
594 |
+
@torch.no_grad()
|
595 |
+
def concat_all_gather(tensor):
|
596 |
+
"""
|
597 |
+
Performs all_gather operation on the provided tensors.
|
598 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
599 |
+
"""
|
600 |
+
|
601 |
+
tensors_gather = [
|
602 |
+
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
603 |
+
]
|
604 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
605 |
+
|
606 |
+
output = torch.cat(tensors_gather, dim=0)
|
607 |
+
return output
|
src/multicropdataset.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
multi-crop dataset to implement multi-crop augmentation and also dataset
|
16 |
+
"""
|
17 |
+
import copy
|
18 |
+
import random
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
from PIL import Image, ImageFilter, ImageOps
|
23 |
+
from src.dataset import ImageFolder
|
24 |
+
from src.RandAugment import rand_augment_transform
|
25 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
26 |
+
from timm.data.random_erasing import RandomErasing
|
27 |
+
from timm.data.transforms import _pil_interp
|
28 |
+
|
29 |
+
|
30 |
+
class GaussianBlur(object):
|
31 |
+
"""
|
32 |
+
Apply Gaussian Blur to the PIL image.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
|
36 |
+
self.prob = p
|
37 |
+
self.radius_min = radius_min
|
38 |
+
self.radius_max = radius_max
|
39 |
+
|
40 |
+
def __call__(self, img):
|
41 |
+
do_it = random.random() <= self.prob
|
42 |
+
if not do_it:
|
43 |
+
return img
|
44 |
+
|
45 |
+
return img.filter(
|
46 |
+
ImageFilter.GaussianBlur(
|
47 |
+
radius=random.uniform(self.radius_min, self.radius_max)
|
48 |
+
)
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class Solarization(object):
|
53 |
+
"""
|
54 |
+
Apply Solarization to the PIL image.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, p):
|
58 |
+
self.p = p
|
59 |
+
|
60 |
+
def __call__(self, img):
|
61 |
+
if random.random() < self.p:
|
62 |
+
return ImageOps.solarize(img)
|
63 |
+
else:
|
64 |
+
return img
|
65 |
+
|
66 |
+
|
67 |
+
def strong_transforms(
|
68 |
+
img_size=224,
|
69 |
+
scale=(0.08, 1.0),
|
70 |
+
ratio=(0.75, 1.3333333333333333),
|
71 |
+
hflip=0.5,
|
72 |
+
vflip=0.0,
|
73 |
+
color_jitter=0.4,
|
74 |
+
auto_augment="rand-m9-mstd0.5-inc1",
|
75 |
+
interpolation="random",
|
76 |
+
use_prefetcher=True,
|
77 |
+
mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
|
78 |
+
std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
|
79 |
+
re_prob=0.25,
|
80 |
+
re_mode="pixel",
|
81 |
+
re_count=1,
|
82 |
+
re_num_splits=0,
|
83 |
+
color_aug=False,
|
84 |
+
strong_ratio=0.45,
|
85 |
+
):
|
86 |
+
"""
|
87 |
+
for use in a mixing dataset that passes
|
88 |
+
* all data through the first (primary) transform, called the 'clean' data
|
89 |
+
* a portion of the data through the secondary transform
|
90 |
+
* normalizes and converts the branches above with the third, final transform
|
91 |
+
"""
|
92 |
+
|
93 |
+
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
94 |
+
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
|
95 |
+
|
96 |
+
primary_tfl = []
|
97 |
+
if hflip > 0.0:
|
98 |
+
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
|
99 |
+
if vflip > 0.0:
|
100 |
+
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
|
101 |
+
|
102 |
+
secondary_tfl = []
|
103 |
+
if auto_augment:
|
104 |
+
assert isinstance(auto_augment, str)
|
105 |
+
if isinstance(img_size, tuple):
|
106 |
+
img_size_min = min(img_size)
|
107 |
+
else:
|
108 |
+
img_size_min = img_size
|
109 |
+
aa_params = dict(
|
110 |
+
translate_const=int(img_size_min * strong_ratio),
|
111 |
+
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
112 |
+
)
|
113 |
+
if interpolation and interpolation != "random":
|
114 |
+
aa_params["interpolation"] = _pil_interp(interpolation)
|
115 |
+
if auto_augment.startswith("rand"):
|
116 |
+
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
117 |
+
if color_jitter is not None and color_aug:
|
118 |
+
# color jitter is enabled when not using AA
|
119 |
+
flip_and_color_jitter = [
|
120 |
+
transforms.RandomApply(
|
121 |
+
[
|
122 |
+
transforms.ColorJitter(
|
123 |
+
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
|
124 |
+
)
|
125 |
+
],
|
126 |
+
p=0.8,
|
127 |
+
),
|
128 |
+
transforms.RandomGrayscale(p=0.2),
|
129 |
+
]
|
130 |
+
secondary_tfl += flip_and_color_jitter
|
131 |
+
|
132 |
+
if interpolation == "random":
|
133 |
+
interpolation = (Image.BILINEAR, Image.BICUBIC)
|
134 |
+
else:
|
135 |
+
interpolation = _pil_interp(interpolation)
|
136 |
+
final_tfl = [
|
137 |
+
transforms.RandomResizedCrop(
|
138 |
+
size=img_size, scale=scale, ratio=ratio, interpolation=Image.BICUBIC
|
139 |
+
)
|
140 |
+
]
|
141 |
+
if use_prefetcher:
|
142 |
+
# prefetcher and collate will handle tensor conversion and norm
|
143 |
+
final_tfl += [transforms.ToTensor()]
|
144 |
+
else:
|
145 |
+
final_tfl += [
|
146 |
+
transforms.ToTensor(),
|
147 |
+
transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
|
148 |
+
]
|
149 |
+
if re_prob > 0.0:
|
150 |
+
final_tfl.append(
|
151 |
+
RandomErasing(
|
152 |
+
re_prob,
|
153 |
+
mode=re_mode,
|
154 |
+
max_count=re_count,
|
155 |
+
num_splits=re_num_splits,
|
156 |
+
device="cpu",
|
157 |
+
)
|
158 |
+
)
|
159 |
+
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
|
160 |
+
|
161 |
+
|
162 |
+
class DataAugmentation(object):
|
163 |
+
"""
|
164 |
+
implement multi-crop data augmentation.
|
165 |
+
--global_crops_scale: scale range of the 224-sized cropped image before resizing
|
166 |
+
--local_crops_scale: scale range of the 96-sized cropped image before resizing
|
167 |
+
--local_crops_number: Number of small local views to generate
|
168 |
+
--prob: when we use strong augmentation and weak augmentation, the ratio of images to
|
169 |
+
be cropped with strong augmentation
|
170 |
+
--vanilla_weak_augmentation: whether we use the same augmentation in DINO, namely
|
171 |
+
only using weak augmentation
|
172 |
+
--color_aug: after AutoAugment, whether we further perform color augmentation
|
173 |
+
--local_crop_size: the small crop size
|
174 |
+
--timm_auto_augment_par: the parameters for the AutoAugment used in DeiT
|
175 |
+
--strong_ratio: the ratio of image augmentation for the AutoAugment used in DeiT
|
176 |
+
--re_prob: the re-prob parameter of image augmentation for the AutoAugment used in DeiT
|
177 |
+
--use_prefetcher: whether we use prefetcher which can accerelate the training speed
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
global_crops_scale,
|
183 |
+
local_crops_scale,
|
184 |
+
local_crops_number,
|
185 |
+
prob=0.5,
|
186 |
+
vanilla_weak_augmentation=False,
|
187 |
+
color_aug=False,
|
188 |
+
local_crop_size=[96],
|
189 |
+
timm_auto_augment_par="rand-m9-mstd0.5-inc1",
|
190 |
+
strong_ratio=0.45,
|
191 |
+
re_prob=0.25,
|
192 |
+
use_prefetcher=False,
|
193 |
+
):
|
194 |
+
|
195 |
+
## propability to perform strong augmentation
|
196 |
+
self.prob = prob
|
197 |
+
## whether we use the commonly used augmentations, e.g. DINO or MoCo-V3
|
198 |
+
self.vanilla_weak_augmentation = vanilla_weak_augmentation
|
199 |
+
|
200 |
+
flip_and_color_jitter = transforms.Compose(
|
201 |
+
[
|
202 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
203 |
+
transforms.RandomApply(
|
204 |
+
[
|
205 |
+
transforms.ColorJitter(
|
206 |
+
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
|
207 |
+
)
|
208 |
+
],
|
209 |
+
p=0.8,
|
210 |
+
),
|
211 |
+
transforms.RandomGrayscale(p=0.2),
|
212 |
+
]
|
213 |
+
)
|
214 |
+
|
215 |
+
if use_prefetcher:
|
216 |
+
normalize = transforms.Compose(
|
217 |
+
[
|
218 |
+
transforms.ToTensor(),
|
219 |
+
]
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
normalize = transforms.Compose(
|
223 |
+
[
|
224 |
+
transforms.ToTensor(),
|
225 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
226 |
+
]
|
227 |
+
)
|
228 |
+
|
229 |
+
##====== build augmentation of global crops, i.e. 224-sized image crops =========
|
230 |
+
# first global crop, always weak augmentation
|
231 |
+
self.global_transfo1 = transforms.Compose(
|
232 |
+
[
|
233 |
+
transforms.RandomResizedCrop(
|
234 |
+
224, scale=global_crops_scale, interpolation=Image.BICUBIC
|
235 |
+
),
|
236 |
+
flip_and_color_jitter,
|
237 |
+
GaussianBlur(1.0),
|
238 |
+
normalize,
|
239 |
+
]
|
240 |
+
)
|
241 |
+
|
242 |
+
# second global crop, always weak augmentation
|
243 |
+
self.global_transfo2 = transforms.Compose(
|
244 |
+
[
|
245 |
+
transforms.RandomResizedCrop(
|
246 |
+
224, scale=global_crops_scale, interpolation=Image.BICUBIC
|
247 |
+
),
|
248 |
+
flip_and_color_jitter,
|
249 |
+
GaussianBlur(0.1),
|
250 |
+
Solarization(0.2),
|
251 |
+
normalize,
|
252 |
+
]
|
253 |
+
)
|
254 |
+
|
255 |
+
# strong augmentation, maybe used if we need to perform strong augmentation
|
256 |
+
self.global_transfo3 = strong_transforms(
|
257 |
+
img_size=224,
|
258 |
+
scale=global_crops_scale,
|
259 |
+
ratio=(0.75, 1.3333333333333333),
|
260 |
+
hflip=0.5,
|
261 |
+
vflip=0.0,
|
262 |
+
color_jitter=0.4,
|
263 |
+
auto_augment=timm_auto_augment_par, # 'rand-m9-mstd0.5-inc1'
|
264 |
+
interpolation="random",
|
265 |
+
use_prefetcher=use_prefetcher, # True
|
266 |
+
mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
|
267 |
+
std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
|
268 |
+
re_prob=re_prob, # 0.25
|
269 |
+
re_mode="pixel",
|
270 |
+
re_count=1,
|
271 |
+
re_num_splits=0,
|
272 |
+
color_aug=color_aug,
|
273 |
+
strong_ratio=strong_ratio,
|
274 |
+
)
|
275 |
+
|
276 |
+
##====== build augmentation of local crops, i.e. 96-sized image crops =========
|
277 |
+
self.local_crops_number = (
|
278 |
+
local_crops_number # transformation for the local small crops
|
279 |
+
)
|
280 |
+
assert local_crop_size[0] == 96
|
281 |
+
# weak augmentation, maybe used if we need to perform weak augmentation
|
282 |
+
self.local_transfo = transforms.Compose(
|
283 |
+
[
|
284 |
+
transforms.RandomResizedCrop(
|
285 |
+
local_crop_size[0],
|
286 |
+
scale=local_crops_scale,
|
287 |
+
interpolation=Image.BICUBIC,
|
288 |
+
),
|
289 |
+
flip_and_color_jitter,
|
290 |
+
GaussianBlur(p=0.5),
|
291 |
+
normalize,
|
292 |
+
]
|
293 |
+
)
|
294 |
+
# strong augmentation, maybe used if we need to perform strong augmentation
|
295 |
+
self.local_transfo2 = strong_transforms(
|
296 |
+
img_size=local_crop_size[0], # (224, 224)
|
297 |
+
scale=local_crops_scale, # (0.08, 1.0)
|
298 |
+
ratio=(0.75, 1.3333333333333333), # (0.75, 1.3333333333333333)
|
299 |
+
hflip=0.5, # 0.5
|
300 |
+
vflip=0.0, # 0.0
|
301 |
+
color_jitter=0.4, # 0.4
|
302 |
+
auto_augment=timm_auto_augment_par, # 'rand-m9-mstd0.5-inc1'
|
303 |
+
interpolation="random", # 'random'
|
304 |
+
use_prefetcher=use_prefetcher, # True
|
305 |
+
mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
|
306 |
+
std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
|
307 |
+
re_prob=re_prob, # 0.25
|
308 |
+
re_mode="pixel", # 'pixel'
|
309 |
+
re_count=1, # 1
|
310 |
+
re_num_splits=0, # 0
|
311 |
+
color_aug=color_aug,
|
312 |
+
strong_ratio=strong_ratio,
|
313 |
+
)
|
314 |
+
|
315 |
+
def __call__(self, image):
|
316 |
+
"""
|
317 |
+
implement multi-crop data augmentation. Generate two 224-sized +
|
318 |
+
"local_crops_number" 96-sized images
|
319 |
+
"""
|
320 |
+
crops = []
|
321 |
+
##====== images to be fed into teacher, two 224-sized =========
|
322 |
+
img1 = self.global_transfo1(image)
|
323 |
+
img2 = self.global_transfo2(image)
|
324 |
+
crops.append(img1)
|
325 |
+
crops.append(img2)
|
326 |
+
|
327 |
+
##====== images to be fed into student, two 224-sized + "local_crops_number" 96-sized =========
|
328 |
+
# first to generate two 224-sized
|
329 |
+
# this weak_flag indicates whether the current image is weakly augmented.
|
330 |
+
# For local group supervision, we only use weakly augmented images of size 224 to
|
331 |
+
# update the memory for local-group aggregation.
|
332 |
+
weak_flag = False
|
333 |
+
|
334 |
+
if self.vanilla_weak_augmentation is True:
|
335 |
+
## directly copy the images of weak augmentation
|
336 |
+
crops.append(copy.deepcopy(img1))
|
337 |
+
crops.append(copy.deepcopy(img2))
|
338 |
+
weak_flag = True
|
339 |
+
elif self.prob < 1.0 and random.random() > self.prob:
|
340 |
+
## whether perform strong augmentation
|
341 |
+
crops.append(self.global_transfo3(image))
|
342 |
+
crops.append(self.global_transfo3(image))
|
343 |
+
else:
|
344 |
+
## perform weak augmentation
|
345 |
+
crops.append(self.global_transfo1(image))
|
346 |
+
crops.append(self.global_transfo2(image))
|
347 |
+
weak_flag = True
|
348 |
+
|
349 |
+
# then to generate "local_crops_number" 96-sized
|
350 |
+
for _ in range(self.local_crops_number):
|
351 |
+
if self.prob < 1.0 and random.random() > self.prob:
|
352 |
+
## whether perform strong augmentation
|
353 |
+
crops.append(self.local_transfo2(image))
|
354 |
+
else:
|
355 |
+
## perform weak augmentation
|
356 |
+
crops.append(self.local_transfo(image))
|
357 |
+
|
358 |
+
return crops, weak_flag
|
359 |
+
|
360 |
+
|
361 |
+
def get_dataset(args):
|
362 |
+
"""
|
363 |
+
build a multi-crop data augmentation and a dataset/dataloader
|
364 |
+
"""
|
365 |
+
## preparing augmentations, including weak and strong augmentations
|
366 |
+
transform = DataAugmentation(
|
367 |
+
global_crops_scale=args.global_crops_scale,
|
368 |
+
local_crops_scale=args.local_crops_scale,
|
369 |
+
local_crops_number=args.local_crops_number,
|
370 |
+
vanilla_weak_augmentation=args.vanilla_weak_augmentation,
|
371 |
+
prob=args.prob,
|
372 |
+
color_aug=args.color_aug,
|
373 |
+
local_crop_size=args.size_crops,
|
374 |
+
timm_auto_augment_par=args.timm_auto_augment_par,
|
375 |
+
strong_ratio=args.strong_ratio,
|
376 |
+
re_prob=args.re_prob,
|
377 |
+
use_prefetcher=args.use_prefetcher,
|
378 |
+
)
|
379 |
+
|
380 |
+
## For debug mode, we only load the first two classes to reduce data reading time.
|
381 |
+
## otherwise, we load all training data for pretraining.
|
382 |
+
class_num = 2 if args.debug else 1000
|
383 |
+
dataset = ImageFolder(args.data_path, transform=transform, class_num=class_num)
|
384 |
+
|
385 |
+
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
|
386 |
+
data_loader = torch.utils.data.DataLoader(
|
387 |
+
dataset,
|
388 |
+
sampler=sampler,
|
389 |
+
batch_size=args.batch_size_per_gpu,
|
390 |
+
num_workers=args.num_workers,
|
391 |
+
pin_memory=True,
|
392 |
+
drop_last=True,
|
393 |
+
)
|
394 |
+
return data_loader
|
395 |
+
|
396 |
+
|
397 |
+
class data_prefetcher:
|
398 |
+
"""
|
399 |
+
implement data prefetcher. we perform some augmentation on GPUs intead of CPUs
|
400 |
+
--loader: a data loader
|
401 |
+
--fp16: whether we use fp16, if yes, we need to tranform the data to be fp16
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(self, loader, fp16=True):
|
405 |
+
self.loader = iter(loader)
|
406 |
+
self.fp16 = fp16
|
407 |
+
self.stream = torch.cuda.Stream()
|
408 |
+
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1, 3, 1, 1)
|
409 |
+
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1, 3, 1, 1)
|
410 |
+
if fp16:
|
411 |
+
self.mean = self.mean.half()
|
412 |
+
self.std = self.std.half()
|
413 |
+
|
414 |
+
self.preload()
|
415 |
+
|
416 |
+
def preload(self):
|
417 |
+
"""
|
418 |
+
preload the next minibatch of data
|
419 |
+
"""
|
420 |
+
try:
|
421 |
+
self.multi_crops, self.weak_flag = next(self.loader)
|
422 |
+
except StopIteration:
|
423 |
+
self.multi_crops, self.weak_flag = None, None
|
424 |
+
return
|
425 |
+
|
426 |
+
with torch.cuda.stream(self.stream):
|
427 |
+
for i in range(len(self.multi_crops)):
|
428 |
+
self.multi_crops[i] = self.multi_crops[i].cuda(non_blocking=True)
|
429 |
+
if self.fp16:
|
430 |
+
self.multi_crops[i] = (
|
431 |
+
self.multi_crops[i].half().sub_(self.mean).div_(self.std)
|
432 |
+
)
|
433 |
+
else:
|
434 |
+
self.multi_crops[i] = (
|
435 |
+
self.multi_crops[i].float().sub_(self.mean).div_(self.std)
|
436 |
+
)
|
437 |
+
|
438 |
+
def next(self):
|
439 |
+
"""
|
440 |
+
load the next minibatch of data
|
441 |
+
"""
|
442 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
443 |
+
multi_crops, weak_flags = self.multi_crops, self.weak_flag
|
444 |
+
self.preload()
|
445 |
+
return multi_crops, weak_flags
|
src/optimizer.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
implment some functions for optimizers
|
16 |
+
"""
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
|
20 |
+
import utils
|
21 |
+
|
22 |
+
|
23 |
+
def clip_gradients(model, clip):
|
24 |
+
"""
|
25 |
+
clip gradient if gradient norm > clip
|
26 |
+
"""
|
27 |
+
norms = []
|
28 |
+
for name, p in model.named_parameters():
|
29 |
+
if p.grad is not None:
|
30 |
+
param_norm = p.grad.data.norm(2)
|
31 |
+
norms.append(param_norm.item())
|
32 |
+
clip_coef = clip / (param_norm + 1e-6)
|
33 |
+
if clip_coef < 1:
|
34 |
+
p.grad.data.mul_(clip_coef)
|
35 |
+
return norms
|
36 |
+
|
37 |
+
|
38 |
+
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
39 |
+
"""
|
40 |
+
cancle gradient if epoch > freeze_last_layer
|
41 |
+
"""
|
42 |
+
if epoch >= freeze_last_layer:
|
43 |
+
return
|
44 |
+
for n, p in model.named_parameters():
|
45 |
+
if "last_layer" in n:
|
46 |
+
p.grad = None
|
47 |
+
|
48 |
+
|
49 |
+
def cosine_scheduler(
|
50 |
+
base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
start_warmup_value to base_value in the first warmup_epochs epochs;
|
54 |
+
then cosine scheduling base_value to final_value in the remaining epochs-warmup_epochs
|
55 |
+
"""
|
56 |
+
warmup_schedule = np.array([])
|
57 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
58 |
+
if warmup_epochs > 0:
|
59 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
60 |
+
|
61 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
62 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (
|
63 |
+
1 + np.cos(np.pi * iters / len(iters))
|
64 |
+
)
|
65 |
+
|
66 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
67 |
+
assert len(schedule) == epochs * niter_per_ep
|
68 |
+
return schedule
|
69 |
+
|
70 |
+
|
71 |
+
def get_params_groups(model):
|
72 |
+
"""
|
73 |
+
divide the parameters into several groups, see below
|
74 |
+
"""
|
75 |
+
regularized = []
|
76 |
+
not_regularized = []
|
77 |
+
patch_embed = []
|
78 |
+
patch_embed_not_regularized = []
|
79 |
+
for name, param in model.named_parameters():
|
80 |
+
if not param.requires_grad:
|
81 |
+
continue
|
82 |
+
# we do not regularize biases nor Norm parameters
|
83 |
+
if name.endswith(".bias") or len(param.shape) == 1:
|
84 |
+
if "patch_embed" in name:
|
85 |
+
patch_embed_not_regularized.append(param)
|
86 |
+
else:
|
87 |
+
not_regularized.append(param)
|
88 |
+
elif "patch_embed" in name:
|
89 |
+
patch_embed.append(param)
|
90 |
+
else:
|
91 |
+
regularized.append(param)
|
92 |
+
return [
|
93 |
+
{"name": "normal_params", "params": regularized},
|
94 |
+
{"name": "patch_embed", "params": patch_embed},
|
95 |
+
{
|
96 |
+
"name": "no_wd",
|
97 |
+
"params": not_regularized,
|
98 |
+
"apply_wd": False,
|
99 |
+
"weight_decay": 0.0,
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"name": "patch_embed_no_wd",
|
103 |
+
"params": patch_embed_not_regularized,
|
104 |
+
"apply_wd": False,
|
105 |
+
"weight_decay": 0.0,
|
106 |
+
},
|
107 |
+
]
|
108 |
+
|
109 |
+
|
110 |
+
class LARS(torch.optim.Optimizer):
|
111 |
+
"""
|
112 |
+
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
params,
|
118 |
+
lr=0,
|
119 |
+
weight_decay=0,
|
120 |
+
momentum=0.9,
|
121 |
+
eta=0.001,
|
122 |
+
weight_decay_filter=None,
|
123 |
+
lars_adaptation_filter=None,
|
124 |
+
):
|
125 |
+
defaults = dict(
|
126 |
+
lr=lr,
|
127 |
+
weight_decay=weight_decay,
|
128 |
+
momentum=momentum,
|
129 |
+
eta=eta,
|
130 |
+
weight_decay_filter=weight_decay_filter,
|
131 |
+
lars_adaptation_filter=lars_adaptation_filter,
|
132 |
+
)
|
133 |
+
super().__init__(params, defaults)
|
134 |
+
|
135 |
+
@torch.no_grad()
|
136 |
+
def step(self):
|
137 |
+
for g in self.param_groups:
|
138 |
+
for p in g["params"]:
|
139 |
+
dp = p.grad
|
140 |
+
|
141 |
+
if dp is None:
|
142 |
+
continue
|
143 |
+
|
144 |
+
if p.ndim != 1:
|
145 |
+
dp = dp.add(p, alpha=g["weight_decay"])
|
146 |
+
|
147 |
+
if p.ndim != 1:
|
148 |
+
param_norm = torch.norm(p)
|
149 |
+
update_norm = torch.norm(dp)
|
150 |
+
one = torch.ones_like(param_norm)
|
151 |
+
q = torch.where(
|
152 |
+
param_norm > 0.0,
|
153 |
+
torch.where(
|
154 |
+
update_norm > 0, (g["eta"] * param_norm / update_norm), one
|
155 |
+
),
|
156 |
+
one,
|
157 |
+
)
|
158 |
+
dp = dp.mul(q)
|
159 |
+
|
160 |
+
param_state = self.state[p]
|
161 |
+
if "mu" not in param_state:
|
162 |
+
param_state["mu"] = torch.zeros_like(p)
|
163 |
+
mu = param_state["mu"]
|
164 |
+
mu.mul_(g["momentum"]).add_(dp)
|
165 |
+
|
166 |
+
p.add_(mu, alpha=-g["lr"])
|
167 |
+
|
168 |
+
|
169 |
+
def get_optimizer(student, len_dataloader, args):
|
170 |
+
"""
|
171 |
+
build an optimizer for training
|
172 |
+
"""
|
173 |
+
# ============ preparing optimizer ... ============
|
174 |
+
params_groups = get_params_groups(student)
|
175 |
+
if args.optimizer == "adamw":
|
176 |
+
optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
|
177 |
+
elif args.optimizer == "sgd":
|
178 |
+
optimizer = torch.optim.SGD(
|
179 |
+
params_groups, lr=0, momentum=0.9
|
180 |
+
) # lr is set by scheduler
|
181 |
+
elif args.optimizer == "lars":
|
182 |
+
optimizer = LARS(params_groups) # to use with convnet and large batches
|
183 |
+
# for mixed precision training
|
184 |
+
fp16_scaler = None
|
185 |
+
if args.use_fp16:
|
186 |
+
fp16_scaler = torch.cuda.amp.GradScaler()
|
187 |
+
|
188 |
+
# ============ init schedulers ... ============
|
189 |
+
lr_schedule = cosine_scheduler(
|
190 |
+
args.lr
|
191 |
+
* (args.batch_size_per_gpu * utils.get_world_size())
|
192 |
+
/ 256.0, # linear scaling rule
|
193 |
+
args.min_lr,
|
194 |
+
args.epochs,
|
195 |
+
len_dataloader,
|
196 |
+
warmup_epochs=args.warmup_epochs,
|
197 |
+
)
|
198 |
+
wd_schedule = cosine_scheduler(
|
199 |
+
args.weight_decay,
|
200 |
+
args.weight_decay_end,
|
201 |
+
args.epochs,
|
202 |
+
len_dataloader, # len(data_loader),
|
203 |
+
)
|
204 |
+
# momentum parameter is increased to 1. during training with a cosine schedule
|
205 |
+
momentum_schedule = cosine_scheduler(
|
206 |
+
args.momentum_teacher, 1, args.epochs, len_dataloader
|
207 |
+
)
|
208 |
+
print("Loss, optimizer and schedulers ready.")
|
209 |
+
|
210 |
+
return optimizer, fp16_scaler, lr_schedule, wd_schedule, momentum_schedule
|
src/vision_transformer.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
ViT backbones, including ViT-small, ViT-base, ViT-large
|
16 |
+
Mostly copy-paste from timm library.
|
17 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
18 |
+
"""
|
19 |
+
import math
|
20 |
+
from functools import partial
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
|
25 |
+
|
26 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
27 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
28 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
29 |
+
def norm_cdf(x):
|
30 |
+
# Computes standard normal cumulative distribution function
|
31 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
32 |
+
|
33 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
34 |
+
warnings.warn(
|
35 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
36 |
+
"The distribution of values may be incorrect.",
|
37 |
+
stacklevel=2,
|
38 |
+
)
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
# Values are generated by using a truncated uniform distribution and
|
42 |
+
# then using the inverse CDF for the normal distribution.
|
43 |
+
# Get upper and lower cdf values
|
44 |
+
lower = norm_cdf((a - mean) / std)
|
45 |
+
upper = norm_cdf((b - mean) / std)
|
46 |
+
|
47 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
48 |
+
# [2l-1, 2u-1].
|
49 |
+
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
50 |
+
|
51 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
52 |
+
# standard normal
|
53 |
+
tensor.erfinv_()
|
54 |
+
|
55 |
+
# Transform to proper mean, std
|
56 |
+
tensor.mul_(std * math.sqrt(2.0))
|
57 |
+
tensor.add_(mean)
|
58 |
+
|
59 |
+
# Clamp to ensure it's in the proper range
|
60 |
+
tensor.clamp_(min=a, max=b)
|
61 |
+
return tensor
|
62 |
+
|
63 |
+
|
64 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
65 |
+
# type: (torch.tensor, float, float, float, float) -> torch.tensor
|
66 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
67 |
+
|
68 |
+
|
69 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
70 |
+
"""
|
71 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
72 |
+
"""
|
73 |
+
if drop_prob == 0.0 or not training:
|
74 |
+
return x
|
75 |
+
keep_prob = 1 - drop_prob
|
76 |
+
shape = (x.shape[0],) + (1,) * (
|
77 |
+
x.ndim - 1
|
78 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
79 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
80 |
+
random_tensor.floor_() # binarize
|
81 |
+
output = x.div(keep_prob) * random_tensor
|
82 |
+
return output
|
83 |
+
|
84 |
+
|
85 |
+
class DropPath(nn.Module):
|
86 |
+
"""
|
87 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, drop_prob=None):
|
91 |
+
super(DropPath, self).__init__()
|
92 |
+
self.drop_prob = drop_prob
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
return drop_path(x, self.drop_prob, self.training)
|
96 |
+
|
97 |
+
|
98 |
+
class Mlp(nn.Module):
|
99 |
+
"""
|
100 |
+
MLP module in ViT
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
in_features,
|
106 |
+
hidden_features=None,
|
107 |
+
out_features=None,
|
108 |
+
act_layer=nn.GELU,
|
109 |
+
drop=0.0,
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
out_features = out_features or in_features
|
113 |
+
hidden_features = hidden_features or in_features
|
114 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
115 |
+
self.act = act_layer()
|
116 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
117 |
+
self.drop = nn.Dropout(drop)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.fc1(x)
|
121 |
+
x = self.act(x)
|
122 |
+
x = self.drop(x)
|
123 |
+
x = self.fc2(x)
|
124 |
+
x = self.drop(x)
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
class Attention(nn.Module):
|
129 |
+
"""
|
130 |
+
Attention module in ViT
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
dim,
|
136 |
+
num_heads=8,
|
137 |
+
qkv_bias=False,
|
138 |
+
qk_scale=None,
|
139 |
+
attn_drop=0.0,
|
140 |
+
proj_drop=0.0,
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
self.num_heads = num_heads
|
144 |
+
head_dim = dim // num_heads
|
145 |
+
self.scale = qk_scale or head_dim ** -0.5
|
146 |
+
|
147 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
148 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
149 |
+
self.proj = nn.Linear(dim, dim)
|
150 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
B, N, C = x.shape
|
154 |
+
|
155 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
156 |
+
reshaped_qkv = qkv.permute(2, 0, 3, 1, 4)
|
157 |
+
q, k, v = reshaped_qkv[0], reshaped_qkv[1], reshaped_qkv[2]
|
158 |
+
|
159 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
160 |
+
attn = attn.softmax(dim=-1)
|
161 |
+
attn = self.attn_drop(attn)
|
162 |
+
|
163 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
164 |
+
x = self.proj(x)
|
165 |
+
x = self.proj_drop(x)
|
166 |
+
return x, attn
|
167 |
+
|
168 |
+
|
169 |
+
class Block(nn.Module):
|
170 |
+
"""
|
171 |
+
ViT block, including Attention, MLP, etc.
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
dim,
|
177 |
+
num_heads,
|
178 |
+
mlp_ratio=4.0,
|
179 |
+
qkv_bias=False,
|
180 |
+
qk_scale=None,
|
181 |
+
drop=0.0,
|
182 |
+
attn_drop=0.0,
|
183 |
+
drop_path=0.0,
|
184 |
+
act_layer=nn.GELU,
|
185 |
+
norm_layer=nn.LayerNorm,
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
self.norm1 = norm_layer(dim)
|
189 |
+
self.attn = Attention(
|
190 |
+
dim,
|
191 |
+
num_heads=num_heads,
|
192 |
+
qkv_bias=qkv_bias,
|
193 |
+
qk_scale=qk_scale,
|
194 |
+
attn_drop=attn_drop,
|
195 |
+
proj_drop=drop,
|
196 |
+
)
|
197 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
198 |
+
self.norm2 = norm_layer(dim)
|
199 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
200 |
+
self.mlp = Mlp(
|
201 |
+
in_features=dim,
|
202 |
+
hidden_features=mlp_hidden_dim,
|
203 |
+
act_layer=act_layer,
|
204 |
+
drop=drop,
|
205 |
+
)
|
206 |
+
|
207 |
+
def forward(self, x, return_attention=False):
|
208 |
+
y, attn = self.attn(self.norm1(x))
|
209 |
+
if return_attention:
|
210 |
+
return attn
|
211 |
+
x = x + self.drop_path(y)
|
212 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
213 |
+
return x
|
214 |
+
|
215 |
+
|
216 |
+
class PatchEmbed(nn.Module):
|
217 |
+
"""Image to Patch Embedding"""
|
218 |
+
|
219 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
220 |
+
super().__init__()
|
221 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
222 |
+
self.img_size = img_size
|
223 |
+
self.patch_size = patch_size
|
224 |
+
self.num_patches = num_patches
|
225 |
+
|
226 |
+
self.proj = nn.Conv2d(
|
227 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
228 |
+
)
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
232 |
+
return x
|
233 |
+
|
234 |
+
|
235 |
+
class VisionTransformer(nn.Module):
|
236 |
+
"""Vision Transformer"""
|
237 |
+
|
238 |
+
def __init__(
|
239 |
+
self,
|
240 |
+
img_size=[224, 224],
|
241 |
+
patch_size=16,
|
242 |
+
in_chans=3,
|
243 |
+
num_classes=0,
|
244 |
+
embed_dim=768,
|
245 |
+
depth=12,
|
246 |
+
num_heads=12,
|
247 |
+
mlp_ratio=4.0,
|
248 |
+
qkv_bias=False,
|
249 |
+
qk_scale=None,
|
250 |
+
drop_rate=0.0,
|
251 |
+
attn_drop_rate=0.0,
|
252 |
+
drop_path_rate=0.0,
|
253 |
+
norm_layer=nn.LayerNorm,
|
254 |
+
num_relation_blocks=0,
|
255 |
+
**kwargs
|
256 |
+
):
|
257 |
+
super().__init__()
|
258 |
+
self.num_features = self.embed_dim = embed_dim
|
259 |
+
self.patch_size = patch_size
|
260 |
+
self.num_classes = num_classes
|
261 |
+
self.depth = depth
|
262 |
+
|
263 |
+
self.patch_embed = PatchEmbed(
|
264 |
+
img_size=img_size[0],
|
265 |
+
patch_size=patch_size,
|
266 |
+
in_chans=in_chans,
|
267 |
+
embed_dim=embed_dim,
|
268 |
+
)
|
269 |
+
|
270 |
+
num_patches = self.patch_embed.num_patches
|
271 |
+
|
272 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
273 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
274 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
275 |
+
|
276 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
277 |
+
|
278 |
+
dpr = [
|
279 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
280 |
+
] # stochastic depth decay rule
|
281 |
+
self.blocks = nn.ModuleList(
|
282 |
+
[
|
283 |
+
Block(
|
284 |
+
dim=embed_dim,
|
285 |
+
num_heads=num_heads,
|
286 |
+
mlp_ratio=mlp_ratio,
|
287 |
+
qkv_bias=qkv_bias,
|
288 |
+
qk_scale=qk_scale,
|
289 |
+
drop=drop_rate,
|
290 |
+
attn_drop=attn_drop_rate,
|
291 |
+
drop_path=dpr[i],
|
292 |
+
norm_layer=norm_layer,
|
293 |
+
)
|
294 |
+
for i in range(depth)
|
295 |
+
]
|
296 |
+
)
|
297 |
+
self.norm = norm_layer(embed_dim)
|
298 |
+
|
299 |
+
# Classifier head
|
300 |
+
self.head = (
|
301 |
+
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
302 |
+
)
|
303 |
+
|
304 |
+
self.num_relation_blocks = num_relation_blocks
|
305 |
+
if num_relation_blocks > 0:
|
306 |
+
self.relation_blocks = nn.ModuleList(
|
307 |
+
[
|
308 |
+
Block(
|
309 |
+
dim=embed_dim,
|
310 |
+
num_heads=num_heads,
|
311 |
+
mlp_ratio=mlp_ratio,
|
312 |
+
qkv_bias=qkv_bias,
|
313 |
+
qk_scale=qk_scale,
|
314 |
+
drop=drop_rate,
|
315 |
+
attn_drop=attn_drop_rate,
|
316 |
+
drop_path=dpr[i],
|
317 |
+
norm_layer=norm_layer,
|
318 |
+
)
|
319 |
+
for i in range(int(num_relation_blocks))
|
320 |
+
]
|
321 |
+
)
|
322 |
+
|
323 |
+
trunc_normal_(self.cls_token, std=0.02)
|
324 |
+
self.apply(self._init_weights)
|
325 |
+
|
326 |
+
def add_pos_emb_for_cls_token(self):
|
327 |
+
pe_cls_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
|
328 |
+
self.pos_embed = nn.Parameter(torch.cat([pe_cls_token, self.pos_embed], dim=1))
|
329 |
+
self.pos_embed.requires_grad = False
|
330 |
+
|
331 |
+
def _init_weights(self, m):
|
332 |
+
if isinstance(m, nn.Linear):
|
333 |
+
trunc_normal_(m.weight, std=0.02)
|
334 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
335 |
+
nn.init.constant_(m.bias, 0)
|
336 |
+
elif isinstance(m, nn.LayerNorm):
|
337 |
+
nn.init.constant_(m.bias, 0)
|
338 |
+
nn.init.constant_(m.weight, 1.0)
|
339 |
+
|
340 |
+
def interpolate_pos_encoding(self, x, w, h):
|
341 |
+
npatch = x.shape[1] - 1
|
342 |
+
N = self.pos_embed.shape[1] - 1
|
343 |
+
if npatch == N and w == h:
|
344 |
+
return self.pos_embed
|
345 |
+
class_pos_embed = self.pos_embed[:, 0]
|
346 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
347 |
+
dim = x.shape[-1]
|
348 |
+
w0 = w // self.patch_embed.patch_size
|
349 |
+
h0 = h // self.patch_embed.patch_size
|
350 |
+
# we add a small number to avoid floating point error in the interpolation
|
351 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
352 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
353 |
+
patch_pos_embed = nn.functional.interpolate(
|
354 |
+
patch_pos_embed.reshape(
|
355 |
+
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
356 |
+
).permute(0, 3, 1, 2),
|
357 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
358 |
+
mode="bicubic",
|
359 |
+
)
|
360 |
+
assert (
|
361 |
+
int(w0) == patch_pos_embed.shape[-2]
|
362 |
+
and int(h0) == patch_pos_embed.shape[-1]
|
363 |
+
)
|
364 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
365 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
366 |
+
|
367 |
+
def prepare_tokens(self, x):
|
368 |
+
B, nc, w, h = x.shape
|
369 |
+
x = self.patch_embed(x) # patch linear embedding
|
370 |
+
|
371 |
+
# add the [CLS] token to the embed patch tokens
|
372 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
373 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
374 |
+
|
375 |
+
# add positional encoding to each token
|
376 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
377 |
+
return self.pos_drop(x)
|
378 |
+
|
379 |
+
def forward(self, x, return_all=False, local_group_memory_inputs=None, **kwargs):
|
380 |
+
x = self.prepare_tokens(x)
|
381 |
+
for blk in self.blocks:
|
382 |
+
x = blk(x)
|
383 |
+
|
384 |
+
if self.num_relation_blocks > 0:
|
385 |
+
mem = local_group_memory_inputs.get("mem")
|
386 |
+
if mem is not None:
|
387 |
+
m, _ = mem(x.mean(1))
|
388 |
+
rx = torch.cat((x.mean(1).unsqueeze(1), m), dim=1)
|
389 |
+
else:
|
390 |
+
rx = x
|
391 |
+
for i, blk in enumerate(self.relation_blocks):
|
392 |
+
rx = blk(rx)
|
393 |
+
relation_out = self.norm(rx[:, 0])
|
394 |
+
|
395 |
+
x = self.norm(x)
|
396 |
+
if self.num_classes > 0:
|
397 |
+
return self.head(x[:, 0])
|
398 |
+
|
399 |
+
if return_all:
|
400 |
+
return x, relation_out
|
401 |
+
else:
|
402 |
+
return x[:, 0], relation_out
|
403 |
+
|
404 |
+
def forward_knn(self, x):
|
405 |
+
x = self.prepare_tokens(x)
|
406 |
+
for blk in self.blocks:
|
407 |
+
x = blk(x)
|
408 |
+
x = self.norm(x)
|
409 |
+
return x[:, 0]
|
410 |
+
|
411 |
+
def get_last_selfattention(self, x):
|
412 |
+
x = self.prepare_tokens(x)
|
413 |
+
for i, blk in enumerate(self.blocks):
|
414 |
+
if i < len(self.blocks) - 1:
|
415 |
+
x = blk(x)
|
416 |
+
else:
|
417 |
+
# return attention of the last block
|
418 |
+
return blk(x, return_attention=True)
|
419 |
+
|
420 |
+
def get_intermediate_layers(self, x, n=1):
|
421 |
+
x = self.prepare_tokens(x)
|
422 |
+
# we return the output tokens from the `n` last blocks
|
423 |
+
output = []
|
424 |
+
for i, blk in enumerate(self.blocks):
|
425 |
+
x = blk(x)
|
426 |
+
if len(self.blocks) - i <= n:
|
427 |
+
output.append(self.norm(x))
|
428 |
+
return output
|
429 |
+
|
430 |
+
def get_num_layers(self):
|
431 |
+
return self.depth
|
432 |
+
|
433 |
+
@torch.jit.ignore
|
434 |
+
def no_weight_decay(self):
|
435 |
+
return {"pos_embed", "cls_token"}
|
436 |
+
|
437 |
+
|
438 |
+
def vit_tiny(patch_size=16, **kwargs):
|
439 |
+
model = VisionTransformer(
|
440 |
+
patch_size=patch_size,
|
441 |
+
embed_dim=192,
|
442 |
+
depth=12,
|
443 |
+
num_heads=3,
|
444 |
+
mlp_ratio=4,
|
445 |
+
qkv_bias=True,
|
446 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
447 |
+
**kwargs
|
448 |
+
)
|
449 |
+
return model
|
450 |
+
|
451 |
+
|
452 |
+
def vit_small(patch_size=16, **kwargs):
|
453 |
+
model = VisionTransformer(
|
454 |
+
patch_size=patch_size,
|
455 |
+
embed_dim=384,
|
456 |
+
depth=12,
|
457 |
+
num_heads=6,
|
458 |
+
mlp_ratio=4,
|
459 |
+
qkv_bias=True,
|
460 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
461 |
+
**kwargs
|
462 |
+
)
|
463 |
+
return model
|
464 |
+
|
465 |
+
|
466 |
+
def vit_base(patch_size=16, **kwargs):
|
467 |
+
model = VisionTransformer(
|
468 |
+
patch_size=patch_size,
|
469 |
+
embed_dim=768,
|
470 |
+
depth=12,
|
471 |
+
num_heads=12,
|
472 |
+
mlp_ratio=4,
|
473 |
+
qkv_bias=True,
|
474 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
475 |
+
**kwargs
|
476 |
+
)
|
477 |
+
return model
|
478 |
+
|
479 |
+
|
480 |
+
def vit_large(patch_size=16, **kwargs):
|
481 |
+
model = VisionTransformer(
|
482 |
+
patch_size=patch_size,
|
483 |
+
embed_dim=1024,
|
484 |
+
depth=24,
|
485 |
+
num_heads=16,
|
486 |
+
mlp_ratio=4,
|
487 |
+
qkv_bias=True,
|
488 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
489 |
+
**kwargs
|
490 |
+
)
|
491 |
+
return model
|
utils.py
ADDED
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Garena Online Private Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Misc functions.
|
16 |
+
|
17 |
+
Mostly copy-paste from torchvision references or other public repos like DETR and DINO:
|
18 |
+
https://github.com/facebookresearch/detr/blob/master/util/misc.py
|
19 |
+
https://github.com/facebookresearch/dino/blob/main/utils.py
|
20 |
+
"""
|
21 |
+
import datetime
|
22 |
+
import logging
|
23 |
+
import os
|
24 |
+
import subprocess
|
25 |
+
import sys
|
26 |
+
import time
|
27 |
+
from collections import defaultdict, deque
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
import torch.distributed as dist
|
32 |
+
from torch import nn
|
33 |
+
|
34 |
+
|
35 |
+
def get_logger(file_path_name):
|
36 |
+
"""
|
37 |
+
build a logger which both write on the desk and also on the terminal
|
38 |
+
"""
|
39 |
+
logger = logging.getLogger()
|
40 |
+
logger.setLevel("INFO")
|
41 |
+
BASIC_FORMAT = "%(levelname)s:%(message)s"
|
42 |
+
DATE_FORMAT = ""
|
43 |
+
formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
|
44 |
+
chlr = logging.StreamHandler()
|
45 |
+
chlr.setFormatter(formatter)
|
46 |
+
chlr.setLevel("INFO")
|
47 |
+
fhlr = logging.FileHandler(file_path_name)
|
48 |
+
fhlr.setFormatter(formatter)
|
49 |
+
logger.addHandler(chlr)
|
50 |
+
logger.addHandler(fhlr)
|
51 |
+
|
52 |
+
return logger
|
53 |
+
|
54 |
+
|
55 |
+
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
|
56 |
+
"""
|
57 |
+
Re-start from checkpoint
|
58 |
+
"""
|
59 |
+
if not os.path.isfile(ckp_path):
|
60 |
+
return
|
61 |
+
print("Found checkpoint at {}".format(ckp_path))
|
62 |
+
|
63 |
+
# open checkpoint file
|
64 |
+
checkpoint = torch.load(ckp_path, map_location="cpu")
|
65 |
+
# key is what to look for in the checkpoint file
|
66 |
+
# value is the object to load
|
67 |
+
# example: {'state_dict': model}
|
68 |
+
for key, value in kwargs.items():
|
69 |
+
if key in checkpoint and value is not None:
|
70 |
+
try:
|
71 |
+
msg = value.load_state_dict(checkpoint[key], strict=False)
|
72 |
+
print(
|
73 |
+
"=> loaded '{}' from checkpoint '{}' with msg {}".format(
|
74 |
+
key, ckp_path, msg
|
75 |
+
)
|
76 |
+
)
|
77 |
+
except TypeError:
|
78 |
+
try:
|
79 |
+
msg = value.load_state_dict(checkpoint[key])
|
80 |
+
print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
|
81 |
+
except ValueError:
|
82 |
+
print(
|
83 |
+
"=> failed to load '{}' from checkpoint: '{}'".format(
|
84 |
+
key, ckp_path
|
85 |
+
)
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
|
89 |
+
|
90 |
+
# reload variable important for the run
|
91 |
+
if run_variables is not None:
|
92 |
+
for var_name in run_variables:
|
93 |
+
if var_name in checkpoint:
|
94 |
+
run_variables[var_name] = checkpoint[var_name]
|
95 |
+
|
96 |
+
|
97 |
+
def bool_flag(s):
|
98 |
+
"""
|
99 |
+
Parse boolean arguments from the command line.
|
100 |
+
"""
|
101 |
+
FALSY_STRINGS = {"off", "false", "0"}
|
102 |
+
TRUTHY_STRINGS = {"on", "true", "1"}
|
103 |
+
if s.lower() in FALSY_STRINGS:
|
104 |
+
return False
|
105 |
+
elif s.lower() in TRUTHY_STRINGS:
|
106 |
+
return True
|
107 |
+
else:
|
108 |
+
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
|
109 |
+
|
110 |
+
|
111 |
+
def fix_random_seeds(seed=31):
|
112 |
+
"""
|
113 |
+
Fix random seeds.
|
114 |
+
"""
|
115 |
+
torch.manual_seed(seed)
|
116 |
+
torch.cuda.manual_seed_all(seed)
|
117 |
+
np.random.seed(seed)
|
118 |
+
|
119 |
+
|
120 |
+
def has_batchnorms(model):
|
121 |
+
"""
|
122 |
+
judge whether a model has batch normalization
|
123 |
+
"""
|
124 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
125 |
+
for name, module in model.named_modules():
|
126 |
+
if isinstance(module, bn_types):
|
127 |
+
return True
|
128 |
+
return False
|
129 |
+
|
130 |
+
|
131 |
+
class SmoothedValue(object):
|
132 |
+
"""Track a series of values and provide access to smoothed values over a
|
133 |
+
window or the global series average.
|
134 |
+
"""
|
135 |
+
|
136 |
+
def __init__(self, window_size=20, fmt=None):
|
137 |
+
if fmt is None:
|
138 |
+
fmt = "{median:.6f} ({global_avg:.6f})"
|
139 |
+
self.deque = deque(maxlen=window_size)
|
140 |
+
self.total = 0.0
|
141 |
+
self.count = 0
|
142 |
+
self.fmt = fmt
|
143 |
+
|
144 |
+
def update(self, value, n=1):
|
145 |
+
self.deque.append(value)
|
146 |
+
self.count += n
|
147 |
+
self.total += value * n
|
148 |
+
|
149 |
+
def synchronize_between_processes(self):
|
150 |
+
"""
|
151 |
+
Warning: does not synchronize the deque!
|
152 |
+
"""
|
153 |
+
if not is_dist_avail_and_initialized():
|
154 |
+
return
|
155 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
156 |
+
dist.barrier()
|
157 |
+
dist.all_reduce(t)
|
158 |
+
t = t.tolist()
|
159 |
+
self.count = int(t[0])
|
160 |
+
self.total = t[1]
|
161 |
+
|
162 |
+
@property
|
163 |
+
def median(self):
|
164 |
+
d = torch.tensor(list(self.deque))
|
165 |
+
return d.median().item()
|
166 |
+
|
167 |
+
@property
|
168 |
+
def avg(self):
|
169 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
170 |
+
return d.mean().item()
|
171 |
+
|
172 |
+
@property
|
173 |
+
def global_avg(self):
|
174 |
+
return self.total / self.count
|
175 |
+
|
176 |
+
@property
|
177 |
+
def max(self):
|
178 |
+
return max(self.deque)
|
179 |
+
|
180 |
+
@property
|
181 |
+
def value(self):
|
182 |
+
return self.deque[-1]
|
183 |
+
|
184 |
+
def __str__(self):
|
185 |
+
return self.fmt.format(
|
186 |
+
median=self.median,
|
187 |
+
avg=self.avg,
|
188 |
+
global_avg=self.global_avg,
|
189 |
+
max=self.max,
|
190 |
+
value=self.value,
|
191 |
+
)
|
192 |
+
|
193 |
+
|
194 |
+
class MetricLogger(object):
|
195 |
+
"""
|
196 |
+
build a Metric Logger
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(self, delimiter="\t"):
|
200 |
+
self.meters = defaultdict(SmoothedValue)
|
201 |
+
self.delimiter = delimiter
|
202 |
+
|
203 |
+
def update(self, **kwargs):
|
204 |
+
for k, v in kwargs.items():
|
205 |
+
if isinstance(v, torch.Tensor):
|
206 |
+
v = v.item()
|
207 |
+
assert isinstance(v, (float, int))
|
208 |
+
self.meters[k].update(v)
|
209 |
+
|
210 |
+
def __getattr__(self, attr):
|
211 |
+
if attr in self.meters:
|
212 |
+
return self.meters[attr]
|
213 |
+
if attr in self.__dict__:
|
214 |
+
return self.__dict__[attr]
|
215 |
+
raise AttributeError(
|
216 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
217 |
+
)
|
218 |
+
|
219 |
+
def __str__(self):
|
220 |
+
loss_str = []
|
221 |
+
for name, meter in self.meters.items():
|
222 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
223 |
+
return self.delimiter.join(loss_str)
|
224 |
+
|
225 |
+
def synchronize_between_processes(self):
|
226 |
+
for meter in self.meters.values():
|
227 |
+
meter.synchronize_between_processes()
|
228 |
+
|
229 |
+
def add_meter(self, name, meter):
|
230 |
+
self.meters[name] = meter
|
231 |
+
|
232 |
+
def log_every(self, iterable, print_freq, header=None):
|
233 |
+
i = 0
|
234 |
+
if not header:
|
235 |
+
header = ""
|
236 |
+
start_time = time.time()
|
237 |
+
end = time.time()
|
238 |
+
iter_time = SmoothedValue(fmt="{avg:.6f}")
|
239 |
+
data_time = SmoothedValue(fmt="{avg:.6f}")
|
240 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
241 |
+
if torch.cuda.is_available():
|
242 |
+
log_msg = self.delimiter.join(
|
243 |
+
[
|
244 |
+
header,
|
245 |
+
"[{0" + space_fmt + "}/{1}]",
|
246 |
+
"eta: {eta}",
|
247 |
+
"{meters}",
|
248 |
+
"time: {time}",
|
249 |
+
"data: {data}",
|
250 |
+
"max mem: {memory:.0f}",
|
251 |
+
]
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
log_msg = self.delimiter.join(
|
255 |
+
[
|
256 |
+
header,
|
257 |
+
"[{0" + space_fmt + "}/{1}]",
|
258 |
+
"eta: {eta}",
|
259 |
+
"{meters}",
|
260 |
+
"time: {time}",
|
261 |
+
"data: {data}",
|
262 |
+
]
|
263 |
+
)
|
264 |
+
MB = 1024.0 * 1024.0
|
265 |
+
for obj in iterable:
|
266 |
+
data_time.update(time.time() - end)
|
267 |
+
yield obj
|
268 |
+
iter_time.update(time.time() - end)
|
269 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
270 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
271 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
272 |
+
if torch.cuda.is_available():
|
273 |
+
print(
|
274 |
+
log_msg.format(
|
275 |
+
i,
|
276 |
+
len(iterable),
|
277 |
+
eta=eta_string,
|
278 |
+
meters=str(self),
|
279 |
+
time=str(iter_time),
|
280 |
+
data=str(data_time),
|
281 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
282 |
+
)
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
print(
|
286 |
+
log_msg.format(
|
287 |
+
i,
|
288 |
+
len(iterable),
|
289 |
+
eta=eta_string,
|
290 |
+
meters=str(self),
|
291 |
+
time=str(iter_time),
|
292 |
+
data=str(data_time),
|
293 |
+
)
|
294 |
+
)
|
295 |
+
i += 1
|
296 |
+
end = time.time()
|
297 |
+
total_time = time.time() - start_time
|
298 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
299 |
+
print(
|
300 |
+
"{} Total time: {} ({:.6f} s / it)".format(
|
301 |
+
header, total_time_str, total_time / len(iterable)
|
302 |
+
)
|
303 |
+
)
|
304 |
+
|
305 |
+
|
306 |
+
def get_sha():
|
307 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
308 |
+
|
309 |
+
def _run(command):
|
310 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
311 |
+
|
312 |
+
sha = "N/A"
|
313 |
+
diff = "clean"
|
314 |
+
branch = "N/A"
|
315 |
+
try:
|
316 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
317 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
318 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
319 |
+
diff = "has uncommited changes" if diff else "clean"
|
320 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
321 |
+
except Exception:
|
322 |
+
pass
|
323 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
324 |
+
return message
|
325 |
+
|
326 |
+
|
327 |
+
def is_dist_avail_and_initialized():
|
328 |
+
"""
|
329 |
+
judge whether distributed training is available and well-initialized
|
330 |
+
"""
|
331 |
+
if not dist.is_available():
|
332 |
+
return False
|
333 |
+
if not dist.is_initialized():
|
334 |
+
return False
|
335 |
+
return True
|
336 |
+
|
337 |
+
|
338 |
+
def get_world_size():
|
339 |
+
"""
|
340 |
+
get the world size
|
341 |
+
"""
|
342 |
+
if not is_dist_avail_and_initialized():
|
343 |
+
return 1
|
344 |
+
return dist.get_world_size()
|
345 |
+
|
346 |
+
|
347 |
+
def get_rank():
|
348 |
+
"""
|
349 |
+
get the rank
|
350 |
+
"""
|
351 |
+
if not is_dist_avail_and_initialized():
|
352 |
+
return 0
|
353 |
+
return dist.get_rank()
|
354 |
+
|
355 |
+
|
356 |
+
def is_main_process():
|
357 |
+
"""
|
358 |
+
judge whether the current node is the master node
|
359 |
+
"""
|
360 |
+
return get_rank() == 0
|
361 |
+
|
362 |
+
|
363 |
+
def save_on_master(*args, **kwargs):
|
364 |
+
"""
|
365 |
+
save checkpoint on the master node
|
366 |
+
"""
|
367 |
+
if is_main_process():
|
368 |
+
torch.save(*args, **kwargs)
|
369 |
+
|
370 |
+
|
371 |
+
def setup_for_distributed(is_master):
|
372 |
+
"""
|
373 |
+
This function disables printing when not in master process
|
374 |
+
"""
|
375 |
+
import builtins as __builtin__
|
376 |
+
|
377 |
+
builtin_print = __builtin__.print
|
378 |
+
|
379 |
+
def print(*args, **kwargs):
|
380 |
+
force = kwargs.pop("force", False)
|
381 |
+
if is_master or force:
|
382 |
+
builtin_print(*args, **kwargs)
|
383 |
+
|
384 |
+
__builtin__.print = print
|
385 |
+
|
386 |
+
|
387 |
+
def init_distributed_ddpjob(args=None):
|
388 |
+
"""
|
389 |
+
initialize the ddp job
|
390 |
+
"""
|
391 |
+
if dist.is_available() and dist.is_initialized():
|
392 |
+
return dist.get_world_size(), dist.get_rank()
|
393 |
+
|
394 |
+
try:
|
395 |
+
os.environ["MASTER_PORT"] = "40101"
|
396 |
+
torch.distributed.init_process_group(backend="nccl")
|
397 |
+
except Exception:
|
398 |
+
world_size, rank = 1, 0
|
399 |
+
print("distributed training not available")
|
400 |
+
|
401 |
+
world_size = dist.get_world_size()
|
402 |
+
rank = dist.get_rank()
|
403 |
+
args.gpu = args.rank
|
404 |
+
args.world_size, args.rank = world_size, rank
|
405 |
+
return world_size, rank
|
406 |
+
|
407 |
+
|
408 |
+
def init_distributed_mode(args):
|
409 |
+
"""
|
410 |
+
initialize the normal job
|
411 |
+
"""
|
412 |
+
# launched with torch.distributed.launch
|
413 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
414 |
+
args.rank = int(os.environ["RANK"])
|
415 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
416 |
+
args.gpu = int(os.environ.get("LOCAL_RANK", 0))
|
417 |
+
print(
|
418 |
+
"args.rank",
|
419 |
+
args.rank,
|
420 |
+
"args.world_size",
|
421 |
+
args.world_size,
|
422 |
+
"args.gpu",
|
423 |
+
args.gpu,
|
424 |
+
)
|
425 |
+
print("get_rank()", get_rank())
|
426 |
+
# launched with submitit on a slurm cluster
|
427 |
+
elif "SLURM_PROCID" in os.environ:
|
428 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
429 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
430 |
+
# launched naively with `python main_dino.py`
|
431 |
+
# we manually add MASTER_ADDR and MASTER_PORT to env variables
|
432 |
+
elif torch.cuda.is_available():
|
433 |
+
print("Will run the code on one GPU.")
|
434 |
+
args.rank, args.gpu, args.world_size = 0, 0, 1
|
435 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
436 |
+
os.environ["MASTER_PORT"] = "2950"
|
437 |
+
else:
|
438 |
+
print("Does not support training without GPU.")
|
439 |
+
sys.exit(1)
|
440 |
+
|
441 |
+
os.environ["MASTER_PORT"] = "6542"
|
442 |
+
|
443 |
+
dist.init_process_group(
|
444 |
+
backend="nccl",
|
445 |
+
init_method=args.dist_url,
|
446 |
+
world_size=args.world_size,
|
447 |
+
rank=args.rank,
|
448 |
+
)
|
449 |
+
|
450 |
+
torch.cuda.set_device(args.gpu)
|
451 |
+
print(
|
452 |
+
"| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
|
453 |
+
)
|
454 |
+
dist.barrier()
|
455 |
+
setup_for_distributed(args.rank == 0)
|
456 |
+
|
457 |
+
|
458 |
+
def accuracy(output, target, topk=(1,)):
|
459 |
+
"""
|
460 |
+
Computes the accuracy over the k top predictions for the specified values of k
|
461 |
+
"""
|
462 |
+
maxk = max(topk)
|
463 |
+
batch_size = target.size(0)
|
464 |
+
_, pred = output.topk(maxk, 1, True, True)
|
465 |
+
pred = pred.t()
|
466 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
467 |
+
return [correct[:k].reshape(-1).float().sum(0) * 100.0 / batch_size for k in topk]
|
468 |
+
|
469 |
+
|
470 |
+
def multi_scale(samples, model):
|
471 |
+
"""
|
472 |
+
build a multi-scale features
|
473 |
+
"""
|
474 |
+
v = None
|
475 |
+
for s in [1, 1 / 2 ** (1 / 2), 1 / 2]: # we use 3 different scales
|
476 |
+
if s == 1:
|
477 |
+
inp = samples.clone()
|
478 |
+
else:
|
479 |
+
inp = nn.functional.interpolate(
|
480 |
+
samples, scale_factor=s, mode="bilinear", align_corners=False
|
481 |
+
)
|
482 |
+
feats = model.forward_knn(inp).clone()
|
483 |
+
if v is None:
|
484 |
+
v = feats
|
485 |
+
else:
|
486 |
+
v += feats
|
487 |
+
v /= 3
|
488 |
+
v /= v.norm()
|
489 |
+
return v
|
490 |
+
|
491 |
+
|
492 |
+
class AllGather(torch.autograd.Function):
|
493 |
+
"""
|
494 |
+
gather the variable on different nodes toghther
|
495 |
+
"""
|
496 |
+
|
497 |
+
@staticmethod
|
498 |
+
def forward(ctx, x):
|
499 |
+
if (
|
500 |
+
dist.is_available()
|
501 |
+
and dist.is_initialized()
|
502 |
+
and (dist.get_world_size() > 1)
|
503 |
+
):
|
504 |
+
outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
505 |
+
dist.all_gather(outputs, x)
|
506 |
+
return torch.cat(outputs, 0)
|
507 |
+
return x
|
508 |
+
|
509 |
+
@staticmethod
|
510 |
+
def backward(ctx, grads):
|
511 |
+
if (
|
512 |
+
dist.is_available()
|
513 |
+
and dist.is_initialized()
|
514 |
+
and (dist.get_world_size() > 1)
|
515 |
+
):
|
516 |
+
s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank()
|
517 |
+
e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1)
|
518 |
+
grads = grads.contiguous()
|
519 |
+
dist.all_reduce(grads)
|
520 |
+
return grads[s:e]
|
521 |
+
return grads
|
522 |
+
|
523 |
+
|
524 |
+
class AllReduce(torch.autograd.Function):
|
525 |
+
"""
|
526 |
+
reduce the variable on different nodes toghther
|
527 |
+
"""
|
528 |
+
|
529 |
+
@staticmethod
|
530 |
+
def forward(ctx, x):
|
531 |
+
if (
|
532 |
+
dist.is_available()
|
533 |
+
and dist.is_initialized()
|
534 |
+
and (dist.get_world_size() > 1)
|
535 |
+
):
|
536 |
+
x = x.contiguous() / dist.get_world_size()
|
537 |
+
dist.all_reduce(x)
|
538 |
+
return x
|
539 |
+
|
540 |
+
@staticmethod
|
541 |
+
def backward(ctx, grads):
|
542 |
+
return grads
|
543 |
+
|
544 |
+
|
545 |
+
def load_pretrained_weights(
|
546 |
+
model, pretrained_weights, checkpoint_key, model_name, patch_size
|
547 |
+
):
|
548 |
+
if os.path.isfile(pretrained_weights):
|
549 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
550 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
551 |
+
print(f"Take key {checkpoint_key} in provided checkpoint dict")
|
552 |
+
state_dict = state_dict[checkpoint_key]
|
553 |
+
# remove `module.` prefix
|
554 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
555 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
556 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
557 |
+
# remove `encoder.` prefix induced by MAE
|
558 |
+
state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items()}
|
559 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
560 |
+
print(
|
561 |
+
"Pretrained weights found at {} and loaded with msg: {}".format(
|
562 |
+
pretrained_weights, msg
|
563 |
+
)
|
564 |
+
)
|
565 |
+
else:
|
566 |
+
print(
|
567 |
+
"There is no reference weights available for this model => We use random weights."
|
568 |
+
)
|
569 |
+
|
570 |
+
|
571 |
+
@torch.no_grad()
|
572 |
+
def concat_all_gather(tensor):
|
573 |
+
"""
|
574 |
+
Performs all_gather operation on the provided tensors.
|
575 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
576 |
+
"""
|
577 |
+
tensors_gather = [
|
578 |
+
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
579 |
+
]
|
580 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
581 |
+
|
582 |
+
output = torch.cat(tensors_gather, dim=0)
|
583 |
+
return output
|