medical
AleksanderObuchowski commited on
Commit
5ceacbc
1 Parent(s): 9ee70d2

Add files using upload-large-folder tool

Browse files
Files changed (43) hide show
  1. .idea/.gitignore +8 -0
  2. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  3. .idea/misc.xml +7 -0
  4. .idea/modules.xml +8 -0
  5. .idea/vcs.xml +6 -0
  6. .idea/workspace.xml +204 -0
  7. 2024.09.27/config.yaml +211 -0
  8. 2024.09.27/language_model/clip_tokenizer_4.16.2/merges.txt +0 -0
  9. 2024.09.27/language_model/clip_tokenizer_4.16.2/special_tokens_map.json +27 -0
  10. 2024.09.27/language_model/clip_tokenizer_4.16.2/tokenizer_config.json +38 -0
  11. 2024.09.27/language_model/clip_tokenizer_4.16.2/vocab.json +0 -0
  12. MedImageInsight/Distributed/Utils.py +344 -0
  13. MedImageInsight/Distributed/__init__.py +6 -0
  14. MedImageInsight/ImageDataLoader/__init__.py +8 -0
  15. MedImageInsight/ImageDataLoader/blob_storage.py +244 -0
  16. MedImageInsight/ImageDataLoader/build.py +260 -0
  17. MedImageInsight/ImageDataLoader/constants.py +85 -0
  18. MedImageInsight/ImageDataLoader/languages/__init__.py +0 -0
  19. MedImageInsight/ImageDataLoader/languages/prompt_engineering.py +101 -0
  20. MedImageInsight/ImageDataLoader/transforms/__init__.py +1 -0
  21. MedImageInsight/ImageDataLoader/transforms/autoaugment.py +447 -0
  22. MedImageInsight/ImageDataLoader/transforms/build.py +261 -0
  23. MedImageInsight/ImageDataLoader/transforms/threeaugment.py +54 -0
  24. MedImageInsight/ImageDataLoader/tsv.py +351 -0
  25. MedImageInsight/ImageDataLoader/tsv_file.py +290 -0
  26. MedImageInsight/ImageDataLoader/zipdata.py +98 -0
  27. MedImageInsight/ImageEncoder/__init__.py +8 -0
  28. MedImageInsight/ImageEncoder/build.py +13 -0
  29. MedImageInsight/ImageEncoder/coswin.py +779 -0
  30. MedImageInsight/ImageEncoder/davit_v1.py +727 -0
  31. MedImageInsight/ImageEncoder/registry.py +18 -0
  32. MedImageInsight/LangEncoder/__init__.py +13 -0
  33. MedImageInsight/LangEncoder/build.py +108 -0
  34. MedImageInsight/LangEncoder/registry.py +18 -0
  35. MedImageInsight/LangEncoder/transformer.py +210 -0
  36. MedImageInsight/UniCLModel.py +293 -0
  37. MedImageInsight/Utils/Arguments.py +134 -0
  38. MedImageInsight/Utils/GeneraUtils.py +263 -0
  39. MedImageInsight/Utils/GlobalExceptHook.py +61 -0
  40. MedImageInsight/Utils/MPIAdapter.py +147 -0
  41. MedImageInsight/Utils/Utils.py +141 -0
  42. MedImageInsight/Utils/__init__.py +7 -0
  43. MedImageInsight/__init__.py +9 -0
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 virtualenv at ~/medatlas/.venv" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/medatlas.iml" filepath="$PROJECT_DIR$/.idea/medatlas.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="9ec92c76-0e74-4c49-9687-c62749296b88" name="Changes" comment="" />
8
+ <option name="SHOW_DIALOG" value="false" />
9
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
+ <option name="LAST_RESOLUTION" value="IGNORE" />
12
+ </component>
13
+ <component name="FileTemplateManagerImpl">
14
+ <option name="RECENT_TEMPLATES">
15
+ <list>
16
+ <option value="Python Script" />
17
+ </list>
18
+ </option>
19
+ </component>
20
+ <component name="FlaskConsoleOptions" custom-start-script="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS])&#10;from flask.cli import ScriptInfo, NoAppException&#10;for module in [&quot;main.py&quot;, &quot;wsgi.py&quot;, &quot;app.py&quot;]:&#10; try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print(&quot;\nFlask App: %s&quot; % app.import_name); break&#10; except NoAppException: pass">
21
+ <envs>
22
+ <env key="FLASK_APP" value="app" />
23
+ </envs>
24
+ <option name="myCustomStartScript" value="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS])&#10;from flask.cli import ScriptInfo, NoAppException&#10;for module in [&quot;main.py&quot;, &quot;wsgi.py&quot;, &quot;app.py&quot;]:&#10; try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print(&quot;\nFlask App: %s&quot; % app.import_name); break&#10; except NoAppException: pass" />
25
+ <option name="myEnvs">
26
+ <map>
27
+ <entry key="FLASK_APP" value="app" />
28
+ </map>
29
+ </option>
30
+ </component>
31
+ <component name="Git.Settings">
32
+ <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
33
+ </component>
34
+ <component name="HighlightingSettingsPerFile">
35
+ <setting file="file://$PROJECT_DIR$/.venv/lib/python3.8/site-packages/safetensors/torch.py" root0="SKIP_INSPECTION" />
36
+ </component>
37
+ <component name="ProjectColorInfo">{
38
+ &quot;associatedIndex&quot;: 3
39
+ }</component>
40
+ <component name="ProjectId" id="2nytZGYw1NHCwZYyKjVZHmbmsFp" />
41
+ <component name="ProjectLevelVcsManager">
42
+ <ConfirmationsSetting value="1" id="Add" />
43
+ </component>
44
+ <component name="ProjectViewState">
45
+ <option name="hideEmptyMiddlePackages" value="true" />
46
+ <option name="showLibraryContents" value="true" />
47
+ </component>
48
+ <component name="PropertiesComponent"><![CDATA[{
49
+ "keyToString": {
50
+ "Python.example.executor": "Debug",
51
+ "Python.explainability.executor": "Run",
52
+ "Python.main.executor": "Run",
53
+ "Python.medimageinsightmodel.executor": "Run",
54
+ "Python.push_to_hub.executor": "Run",
55
+ "RunOnceActivity.ShowReadmeOnStart": "true",
56
+ "RunOnceActivity.git.unshallow": "true",
57
+ "git-widget-placeholder": "master",
58
+ "last_opened_file_path": "/home/olek/medatlas/2024.09.27/vision_model",
59
+ "node.js.detected.package.eslint": "true",
60
+ "node.js.detected.package.tslint": "true",
61
+ "node.js.selected.package.eslint": "(autodetect)",
62
+ "node.js.selected.package.tslint": "(autodetect)",
63
+ "nodejs_package_manager_path": "npm",
64
+ "vue.rearranger.settings.migration": "true"
65
+ }
66
+ }]]></component>
67
+ <component name="RdControllerToolWindowsLayoutState" isNewUi="true">
68
+ <layout>
69
+ <window_info id="Bookmarks" side_tool="true" />
70
+ <window_info id="Merge Requests" />
71
+ <window_info id="Commit_Guest" show_stripe_button="false" />
72
+ <window_info id="Pull Requests" />
73
+ <window_info id="Learn" />
74
+ <window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.16933593" />
75
+ <window_info id="Commit" order="1" weight="0.25" />
76
+ <window_info id="Structure" order="2" side_tool="true" weight="0.25" />
77
+ <window_info anchor="bottom" id="Database Changes" />
78
+ <window_info anchor="bottom" id="TypeScript" />
79
+ <window_info anchor="bottom" id="TODO" />
80
+ <window_info anchor="bottom" id="File Transfer" />
81
+ <window_info anchor="bottom" id="Version Control" order="0" />
82
+ <window_info anchor="bottom" id="Problems" order="1" />
83
+ <window_info anchor="bottom" id="Problems View" order="2" />
84
+ <window_info active="true" anchor="bottom" id="Terminal" order="3" visible="true" weight="0.3795139" />
85
+ <window_info anchor="bottom" id="Services" order="4" />
86
+ <window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
87
+ <window_info anchor="bottom" id="Debug" order="6" weight="0.29618055" />
88
+ <window_info anchor="bottom" id="Python Console" order="7" weight="0.1" />
89
+ <window_info anchor="bottom" id="HfCacheToolWindow" order="8" weight="0.44131944" />
90
+ <window_info anchor="bottom" id="Run" order="9" weight="0.6490499" />
91
+ <window_info anchor="bottom" id="Find" order="10" weight="0.33020833" />
92
+ <window_info anchor="right" id="Endpoints" />
93
+ <window_info anchor="right" id="Coverage" side_tool="true" />
94
+ <window_info anchor="right" id="SciView" />
95
+ <window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
96
+ <window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
97
+ <window_info anchor="right" id="Database" order="2" weight="0.25" />
98
+ <window_info anchor="right" id="Gradle" order="3" weight="0.25" />
99
+ <window_info anchor="right" id="Maven" order="4" weight="0.25" />
100
+ <window_info anchor="right" id="CodeGPT" order="5" weight="0.30566406" />
101
+ <window_info anchor="right" id="Plots" order="6" weight="0.1" />
102
+ </layout>
103
+ </component>
104
+ <component name="RecentsManager">
105
+ <key name="CopyFile.RECENT_KEYS">
106
+ <recent name="$PROJECT_DIR$/2024.09.27/vision_model" />
107
+ <recent name="$PROJECT_DIR$" />
108
+ </key>
109
+ <key name="MoveFile.RECENT_KEYS">
110
+ <recent name="$PROJECT_DIR$" />
111
+ <recent name="$PROJECT_DIR$/MedImageInsight/ImageEncoder" />
112
+ <recent name="$PROJECT_DIR$/MedImageInsights" />
113
+ </key>
114
+ </component>
115
+ <component name="RunManager">
116
+ <configuration name="main" type="PythonConfigurationType" factoryName="Python" nameIsGenerated="true">
117
+ <module name="medatlas" />
118
+ <option name="ENV_FILES" value="" />
119
+ <option name="INTERPRETER_OPTIONS" value="" />
120
+ <option name="PARENT_ENVS" value="true" />
121
+ <envs>
122
+ <env name="PYTHONUNBUFFERED" value="1" />
123
+ </envs>
124
+ <option name="SDK_HOME" value="" />
125
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
126
+ <option name="IS_MODULE_SDK" value="true" />
127
+ <option name="ADD_CONTENT_ROOTS" value="true" />
128
+ <option name="ADD_SOURCE_ROOTS" value="true" />
129
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
130
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
131
+ <option name="PARAMETERS" value="" />
132
+ <option name="SHOW_COMMAND_LINE" value="false" />
133
+ <option name="EMULATE_TERMINAL" value="false" />
134
+ <option name="MODULE_MODE" value="false" />
135
+ <option name="REDIRECT_INPUT" value="false" />
136
+ <option name="INPUT_FILE" value="" />
137
+ <method v="2" />
138
+ </configuration>
139
+ <configuration name="medimageinsightmodel" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
140
+ <module name="medatlas" />
141
+ <option name="ENV_FILES" value="" />
142
+ <option name="INTERPRETER_OPTIONS" value="" />
143
+ <option name="PARENT_ENVS" value="true" />
144
+ <envs>
145
+ <env name="PYTHONUNBUFFERED" value="1" />
146
+ </envs>
147
+ <option name="SDK_HOME" value="" />
148
+ <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
149
+ <option name="IS_MODULE_SDK" value="true" />
150
+ <option name="ADD_CONTENT_ROOTS" value="true" />
151
+ <option name="ADD_SOURCE_ROOTS" value="true" />
152
+ <EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
153
+ <option name="SCRIPT_NAME" value="$PROJECT_DIR$/medimageinsightmodel.py" />
154
+ <option name="PARAMETERS" value="" />
155
+ <option name="SHOW_COMMAND_LINE" value="false" />
156
+ <option name="EMULATE_TERMINAL" value="false" />
157
+ <option name="MODULE_MODE" value="false" />
158
+ <option name="REDIRECT_INPUT" value="false" />
159
+ <option name="INPUT_FILE" value="" />
160
+ <method v="2" />
161
+ </configuration>
162
+ <recent_temporary>
163
+ <list>
164
+ <item itemvalue="Python.medimageinsightmodel" />
165
+ </list>
166
+ </recent_temporary>
167
+ </component>
168
+ <component name="SharedIndexes">
169
+ <attachedChunks>
170
+ <set>
171
+ <option value="bundled-js-predefined-d6986cc7102b-bed05e336f61-JavaScript-PY-243.21155.22" />
172
+ <option value="bundled-python-sdk-5ff8a29a62a8-ca77fbc60dd9-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-243.21155.22" />
173
+ </set>
174
+ </attachedChunks>
175
+ </component>
176
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
177
+ <component name="TaskManager">
178
+ <task active="true" id="Default" summary="Default task">
179
+ <changelist id="9ec92c76-0e74-4c49-9687-c62749296b88" name="Changes" comment="" />
180
+ <created>1729957197525</created>
181
+ <option name="number" value="Default" />
182
+ <option name="presentableId" value="Default" />
183
+ <updated>1729957197525</updated>
184
+ <workItem from="1729957199944" duration="8141000" />
185
+ <workItem from="1729970018757" duration="142000" />
186
+ <workItem from="1729970174785" duration="25000" />
187
+ <workItem from="1729970270429" duration="53000" />
188
+ <workItem from="1729970419018" duration="9867000" />
189
+ <workItem from="1730030408588" duration="2251000" />
190
+ <workItem from="1730037237796" duration="27583000" />
191
+ </task>
192
+ <servers />
193
+ </component>
194
+ <component name="TypeScriptGeneratedFilesManager">
195
+ <option name="version" value="3" />
196
+ </component>
197
+ <component name="com.intellij.coverage.CoverageDataManagerImpl">
198
+ <SUITE FILE_PATH="coverage/medatlas$explainability.coverage" NAME="explainability Coverage Results" MODIFIED="1730155021389" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
199
+ <SUITE FILE_PATH="coverage/medatlas$push_to_hub.coverage" NAME="push_to_hub Coverage Results" MODIFIED="1730031227719" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
200
+ <SUITE FILE_PATH="coverage/medatlas$example.coverage" NAME="example Coverage Results" MODIFIED="1730041646094" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
201
+ <SUITE FILE_PATH="coverage/medatlas$main.coverage" NAME="main Coverage Results" MODIFIED="1730153590829" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
202
+ <SUITE FILE_PATH="coverage/medatlas$medimageinsightmodel.coverage" NAME="medimageinsightmodel Coverage Results" MODIFIED="1730037368621" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
203
+ </component>
204
+ </project>
2024.09.27/config.yaml ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##################
2
+ # Trainer settings
3
+ ##################
4
+
5
+
6
+ TASK: UniCLTask
7
+
8
+ NAME: 'Example Eval Configuration'
9
+ SAVE_TIMER_LOG: true
10
+
11
+ # TUTORIAL STEP 1: CHOOSE SAVE DIR
12
+ SAVE_DIR: ''
13
+ LOG_EVERY: 10
14
+ LOGLEVEL_OVERRIDE: INFO
15
+ LOG_GPU_MEM: true
16
+ RESUME: False
17
+ RESET_DATA_LOADER: false
18
+
19
+ FP16: true
20
+ ZERO_STAGE: 0
21
+ DEEPSPEED: false
22
+ # ZERO_STAGE: 1
23
+ AMP: PYTORCH
24
+ # USE_APEX_DDP: false
25
+ # USE_APEX_AMP: false
26
+ # USE_HIT: false
27
+
28
+ FIND_UNUSED_PARAMETERS: false
29
+
30
+ SAVE_PER_OPTIM_STEPS: 500
31
+ EVAL_PER_OPTIM_STEPS: 250
32
+ EVAL_AT_START: False
33
+ # SAVE_PER_UPDATE_NUM: -1
34
+ # EVAL_PER_UPDATE_NUM: 0 # 0: do evaluation when saving checkpoint, -1: don't do evaluation
35
+
36
+ NO_AUTO_LR_SCALING: true
37
+ GRAD_CLIPPING: 1.0 #0.07
38
+
39
+ SET_SAMPLER_EPOCH: true
40
+
41
+ DONT_LOAD_MODEL: true
42
+
43
+ user_dir: "./MainzVision" # lower case due to it is used in mainz as such
44
+
45
+ ##################
46
+ # Task settings
47
+ ##################
48
+
49
+
50
+
51
+ VERBOSE: true
52
+ WORKERS: 6
53
+ PIN_MEMORY: true
54
+ IMAGE_ENCODER:
55
+ NAME: davit_v1
56
+ NUM_CLASSES: 0
57
+ #IMAGE_SIZE: [384, 384]
58
+ IMAGE_SIZE: [480, 480]
59
+ LOAD_PRETRAINED: true
60
+ PRETRAINED: ''
61
+ PRETRAINED_LAYERS: '*'
62
+ IMAGE_MEAN: [0.485, 0.456, 0.406]
63
+ IMAGE_STD: [0.229, 0.224, 0.225]
64
+ SPEC:
65
+ DROP_RATE: 0.1
66
+ DROP_PATH_RATE: 0.2
67
+ PATCH_SIZE: [7, 3, 3, 3]
68
+ PATCH_STRIDE: [4, 2, 2, 2]
69
+ PATCH_PADDING: [3, 1, 1, 1]
70
+ PATCH_PRENORM: [false, true, true, true]
71
+ DIM_EMBED: [256, 512, 1024, 2048]
72
+ NUM_HEADS: [8, 16, 32, 64]
73
+ NUM_GROUPS: [8, 16, 32, 64]
74
+ DEPTHS: [1, 1, 9, 1]
75
+ WINDOW_SIZE: 12
76
+ ENABLE_CHECKPOINT: true
77
+
78
+ LANG_ENCODER:
79
+ NAME: transformer
80
+ LOAD_PRETRAINED: false
81
+ PRETRAINED: ''
82
+ PRETRAINED_LAYERS: '*'
83
+ TOKENIZER: clip
84
+ CONTEXT_LENGTH: 77
85
+ WIDTH: 1024
86
+ HEADS: 16
87
+ LAYERS: 16
88
+ AUTOGRESSIVE: false
89
+
90
+ UNICL_MODEL:
91
+ DIM_PROJECTION: 1024
92
+ GATHER_TENSORS: true
93
+ LOAD_PRETRAINED: true
94
+
95
+ # TUTORIAL STEP 2: CHOOSE MODEL PATH
96
+ PRETRAINED: ''
97
+
98
+ PRETRAINED_LAYERS: '*'
99
+
100
+ AUG:
101
+ MIXUP_PROB: 0.0
102
+ MIXUP: 0.8
103
+ MIXCUT: 1.0
104
+ MIXCUT_MINMAX: []
105
+ MIXUP_SWITCH_PROB: 0.5
106
+ MIXUP_MODE: 'batch'
107
+ SCALE: [0.8, 1.0]
108
+ RATIO: [0.75, 1.3333333]
109
+ INTERPOLATION: 'bicubic'
110
+ TORCHVISION_AUG:
111
+ AUTO_AUGMENT: ta_wide
112
+ RE_PROB: 0.25
113
+ HFLIP: 0.0
114
+ VFLIP: 0.0
115
+
116
+ LOSS:
117
+ LOSS: UniCL
118
+ DATASET:
119
+ DATASET: 'image_text_pairs_v2'
120
+ TEXT_FORMAT: 'json'
121
+ ROOT: ''
122
+ TRAIN_SET: 'mimic_cxr_v2-chestxray14-chexpertv4-irma2009_v2-rsnaboneage-mura-bingmedicalfewshot'
123
+ DATA_FORMAT: 'tsv'
124
+ SAMPLER: 'default'
125
+ LOADER: 'default'
126
+ TOKEN_FILE: ''
127
+ #PROMPT_ENGINEERING: False
128
+ #SAMPLER: 'chunk'
129
+ #LOADER: 'azcopy'
130
+ #TOKEN_FILE: 'cliptrainingpairs.txt'
131
+ #TEST_SET: 'MarsAtrain'
132
+
133
+
134
+ # TUTORIAL STEP 3: CHOOSE ALL BELOW EVAL PATHS (THESE ARE ALL OPTIONAL EXTRA EVALS)
135
+ # Note how one eval is ZIP format and the other is TSV format.
136
+
137
+
138
+
139
+
140
+ EVALDATASET_LTCXR_S100_N100_TEXT_CLASSIFIER:
141
+ TEXT_FORMAT: json
142
+ FORMAT: 'zip'
143
+ SPLIT: 'NIH-CXR-LT'
144
+ ZIP_FILE: ''
145
+ ZIP_MAP_FILE: ''
146
+ LABEL_FILE: ''
147
+ IMAGE_TSV: ''
148
+ TEXT_TSV: ''
149
+ CWEIGHT_FILE: ''
150
+ ZS_MODE: 2
151
+ ZS_WEIGHT: 1.0
152
+ KNN: 100
153
+ # CLASSIFICATION_SETS: ['NIH-CXR-LT']
154
+ # NUM_CLASSES: [20]
155
+
156
+
157
+
158
+
159
+ # TUTORIAL STEP 4: SET THE DEFAULT ZEROSHOT EVAL (THIS IS THE MANDATORY EVAL)
160
+
161
+ ZEROSHOT_EVAL_DATASET:
162
+ FORMAT: 'zip'
163
+ SPLIT: 'NIH-CXR-LT'
164
+ ZIP_FILE: ''
165
+ ZIP_MAP_FILE: ''
166
+ LABEL_FILE: ''
167
+
168
+
169
+
170
+ EVALUATION_SPLITS: ['cls-zeroshot-eval']
171
+ TEST:
172
+ BATCH_SIZE_PER_GPU: 8
173
+ MODEL_FILE: ''
174
+ CENTER_CROP: false
175
+ TRAIN:
176
+ BATCH_SIZE_TOTAL: 1024
177
+ BATCH_SIZE_PER_GPU: 16
178
+
179
+ SHUFFLE: true
180
+
181
+ WEIGHT_SMOOTHING:
182
+ decay: 0.999
183
+ use_cpu: False
184
+ eval_smoothed_weight: True
185
+
186
+ START_LEARNING_RATE: 0.00001
187
+ # MAX_NUM_EPOCHS: 2
188
+ MAX_NUM_EPOCHS: 100
189
+ OPTIMIZER: AdamW # adam
190
+ OPTIMIZER_PARAMS:
191
+ weight_decay: 0.2 #0.1
192
+ CUSTOMIZED_PARAMS_CONF:
193
+ NO_WEIGHT_DECAY_MODULES: ['dw', 'norm']
194
+ WEIGHT_DECAY_PATTERNS:
195
+ "\\.bias$": 0.0
196
+ "logit_scale": 0.0
197
+ "positional_embedding": 0.0
198
+ "token_embedding": 0.0
199
+
200
+
201
+
202
+ LR_SCHEDULER: TimmScheduler
203
+ LR_SCHEDULER_PARAMS:
204
+ sched: cosine
205
+ warmup_steps: 5
206
+ warmup_lr: 0.000000001
207
+ min_lr: 0.000000001
208
+
209
+ # GRADIENT_ACCUMULATE_STEP will be updated by:
210
+ # BATCH_SIZE_TOTAL // (BATCH_SIZE_PER_GPU * world_size)
211
+ GRADIENT_ACCUMULATE_STEP: -1
2024.09.27/language_model/clip_tokenizer_4.16.2/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
2024.09.27/language_model/clip_tokenizer_4.16.2/special_tokens_map.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "single_word": false,
5
+ "lstrip": false,
6
+ "rstrip": false,
7
+ "normalized": true,
8
+ "special": false
9
+ },
10
+ "eos_token": {
11
+ "content": "<|endoftext|>",
12
+ "single_word": false,
13
+ "lstrip": false,
14
+ "rstrip": false,
15
+ "normalized": true,
16
+ "special": false
17
+ },
18
+ "unk_token": {
19
+ "content": "<|endoftext|>",
20
+ "single_word": false,
21
+ "lstrip": false,
22
+ "rstrip": false,
23
+ "normalized": true,
24
+ "special": false
25
+ },
26
+ "pad_token": "<|endoftext|>"
27
+ }
2024.09.27/language_model/clip_tokenizer_4.16.2/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "errors": "replace",
3
+ "unk_token": {
4
+ "content": "<|endoftext|>",
5
+ "single_word": false,
6
+ "lstrip": false,
7
+ "rstrip": false,
8
+ "normalized": true,
9
+ "special": false,
10
+ "__type": "AddedToken"
11
+ },
12
+ "bos_token": {
13
+ "content": "<|startoftext|>",
14
+ "single_word": false,
15
+ "lstrip": false,
16
+ "rstrip": false,
17
+ "normalized": true,
18
+ "special": false,
19
+ "__type": "AddedToken"
20
+ },
21
+ "eos_token": {
22
+ "content": "<|endoftext|>",
23
+ "single_word": false,
24
+ "lstrip": false,
25
+ "rstrip": false,
26
+ "normalized": true,
27
+ "special": false,
28
+ "__type": "AddedToken"
29
+ },
30
+ "pad_token": "<|endoftext|>",
31
+ "add_prefix_space": false,
32
+ "do_lower_case": true,
33
+ "name_or_path": "openai/clip-vit-base-patch32",
34
+ "model_max_length": 77,
35
+ "special_tokens_map_file": "/home/ncodella/.cache/huggingface/transformers/18a566598f286c9139f88160c99f84eec492a26bd22738fa9cb44d5b7e0a5c76.cce1206abbad28826f000510f22f354e53e66a97f7c23745a7dfe27609cc07f5",
36
+ "tokenizer_file": "/home/ncodella/.cache/huggingface/transformers/7811def0c53be25ba790cb67ac785669b508a8d1cf8c912b8ac046c5f08aee68.20428ea8b6821af2719b760af844a371643ff49f255c73285f6ea448e15597fe",
37
+ "tokenizer_class": "CLIPTokenizer"
38
+ }
2024.09.27/language_model/clip_tokenizer_4.16.2/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
MedImageInsight/Distributed/Utils.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pickle
4
+ import requests
5
+ import tenacity
6
+ import time
7
+ import shutil
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+ from PIL import Image
13
+ from torchvision.utils import make_grid
14
+
15
+
16
+ from fvcore.nn import FlopCountAnalysis
17
+ from fvcore.nn import flop_count_table
18
+ from fvcore.nn import flop_count_str
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ NORM_MODULES = [
23
+ torch.nn.BatchNorm1d,
24
+ torch.nn.BatchNorm2d,
25
+ torch.nn.BatchNorm3d,
26
+ torch.nn.SyncBatchNorm,
27
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
28
+ torch.nn.GroupNorm,
29
+ torch.nn.InstanceNorm1d,
30
+ torch.nn.InstanceNorm2d,
31
+ torch.nn.InstanceNorm3d,
32
+ torch.nn.LayerNorm,
33
+ torch.nn.LocalResponseNorm,
34
+ ]
35
+
36
+
37
+ def register_norm_module(cls):
38
+ NORM_MODULES.append(cls)
39
+
40
+ return cls
41
+
42
+
43
+ def is_main_process():
44
+ rank = 0
45
+ if 'OMPI_COMM_WORLD_SIZE' in os.environ:
46
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
47
+
48
+ return rank == 0
49
+
50
+
51
+ @torch.no_grad()
52
+ def analysis_model(model, dump_input, verbose=False):
53
+ model.eval()
54
+ flops = FlopCountAnalysis(model, dump_input)
55
+ total = flops.total()
56
+ model.train()
57
+ params_total = sum(p.numel() for p in model.parameters())
58
+ params_learned = sum(
59
+ p.numel() for p in model.parameters() if p.requires_grad
60
+ )
61
+ logger.info(f"flop count table:\n {flop_count_table(flops)}")
62
+ if verbose:
63
+ logger.info(f"flop count str:\n {flop_count_str(flops)}")
64
+ logger.info(f" Total flops: {total / 1000 / 1000:.3f}M,")
65
+ logger.info(f" Total params: {params_total / 1000 / 1000:.3f}M,")
66
+ logger.info(f" Learned params: {params_learned / 1000 / 1000:.3f}M")
67
+
68
+ return total, flop_count_table(flops), flop_count_str(flops)
69
+
70
+
71
+ def gather_tensors(tensor):
72
+ """
73
+ Performs all_gather operation on the provided tensors.
74
+ *** Warning ***: torch.distributed.all_gather has no gradient.
75
+ """
76
+ tensors_gather = [
77
+ torch.ones_like(tensor)
78
+ for _ in range(int(os.environ['WORLD_SIZE']))
79
+ ]
80
+
81
+ dist.all_gather(tensors_gather, tensor, async_op=False)
82
+ # need to do this to restore propagation of the gradients
83
+ tensors_gather[int(os.environ['RANK'])] = tensor
84
+ output = torch.cat(tensors_gather, dim=0)
85
+ return output
86
+
87
+
88
+ def is_valid_url(url):
89
+ try:
90
+ from urllib import parse
91
+ return parse.urlparse(str(url)).scheme != ''
92
+ except Exception:
93
+ return False
94
+
95
+
96
+ @tenacity.retry(stop=tenacity.stop_after_attempt(3))
97
+ def download_file(url, filepath):
98
+ logger.info(f'Downloading from {url} to {filepath.absolute()}.')
99
+ with requests.get(url, stream=True, allow_redirects=True, timeout=60) as r:
100
+ if r.status_code > 200:
101
+ raise RuntimeError(f'Failed in downloading from {url}, status code {r.status_code}.')
102
+
103
+ with open(filepath, 'wb') as f:
104
+ shutil.copyfileobj(r.raw, f, length=4194304)
105
+
106
+
107
+ class DistributionGridFactory:
108
+ """
109
+ DistributionGrid Factory for helping create, cache and share the DistributionGrid based on the usage.
110
+ The DistributionGrid con be shared cross modules only the when this 3 conditions:
111
+ 1. expert parallel group size
112
+ 2. expert parallel replica group size,
113
+ are the same.
114
+ """
115
+ distribution_grid_cache = {}
116
+
117
+ @classmethod
118
+ def get_distribution_grid(cls,
119
+ expert_parallel_group_size,
120
+ expert_parallel_replica_group_size,
121
+ ddp_type):
122
+ """
123
+ Get the DistributionGrid by the conditions.
124
+ Args:
125
+ expert_parallel_group_size: expert parallel group size
126
+ expert_parallel_replica_group_size: expert parallel replica group size
127
+ ddp_type: distributed data parallel type. "DDP" of the recipe, only allow ddp_type is "MAINZ", "OSS" or "ShardedDDP".
128
+
129
+ Returns: new created DistributionGrid or shared DistributionGrid.
130
+
131
+ Notes: Currently get_distribution_grid only support "DDP" is "MAINZ", "OSS" or "ShardedDDP".
132
+ """
133
+ # TODO: Support cases that "DDP" is "FSDP".
134
+ # For "FSDP", we use the DG of self.opt['fsdp_expert_grid'] which is initialize in DistributedTrainer directly.
135
+ ddp_type = ddp_type.upper()
136
+ assert ddp_type in ["MAINZ", "OSS", "SHARDEDDDP"], f'DistributionGrid Factory only support "DDP" is "MAINZ",' \
137
+ f' "OSS" or "ShardedDDP".' \
138
+ f' But currently "DDP" is {ddp_type}'
139
+
140
+ cached_distributed_grid = cls.distribution_grid_cache.get(
141
+ (expert_parallel_group_size, expert_parallel_replica_group_size), None)
142
+
143
+ if cached_distributed_grid is not None:
144
+ return cached_distributed_grid
145
+ else:
146
+ from ort_moe.grids import DistributionGrid
147
+ distributed_grid = DistributionGrid(expert_parallel_group_size=expert_parallel_group_size,
148
+ expert_parallel_replica_group_size=expert_parallel_replica_group_size)
149
+
150
+ cls.distribution_grid_cache[expert_parallel_group_size,
151
+ expert_parallel_replica_group_size] = distributed_grid
152
+ return distributed_grid
153
+
154
+
155
+ def get_world_size():
156
+ if not dist.is_available():
157
+ return 1
158
+ if not dist.is_initialized():
159
+ return 1
160
+ return dist.get_world_size()
161
+
162
+
163
+ def get_rank():
164
+ if not dist.is_available():
165
+ return 0
166
+ if not dist.is_initialized():
167
+ return 0
168
+ return dist.get_rank()
169
+
170
+
171
+ def synchronize():
172
+ """
173
+ Helper function to synchronize (barrier) among all processes when
174
+ using distributed training
175
+ """
176
+ if not dist.is_available():
177
+ return
178
+ if not dist.is_initialized():
179
+ return
180
+ world_size = dist.get_world_size()
181
+ rank = dist.get_rank()
182
+ if world_size == 1:
183
+ return
184
+
185
+ def _send_and_wait(r):
186
+ if rank == r:
187
+ tensor = torch.tensor(0, device="cuda")
188
+ else:
189
+ tensor = torch.tensor(1, device="cuda")
190
+ dist.broadcast(tensor, r)
191
+ while tensor.item() == 1:
192
+ time.sleep(1)
193
+
194
+ _send_and_wait(0)
195
+ # now sync on the main process
196
+ _send_and_wait(1)
197
+
198
+
199
+ def all_gather(data):
200
+ """
201
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
202
+ Args:
203
+ data: any picklable object
204
+ Returns:
205
+ list[data]: list of data gathered from each rank
206
+ """
207
+ world_size = get_world_size()
208
+ if world_size == 1:
209
+ return [data]
210
+
211
+ # serialized to a Tensor
212
+ buffer = pickle.dumps(data)
213
+ storage = torch.ByteStorage.from_buffer(buffer)
214
+ tensor = torch.ByteTensor(storage).to("cuda")
215
+
216
+ # obtain Tensor size of each rank
217
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
218
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
219
+ dist.all_gather(size_list, local_size)
220
+ size_list = [int(size.item()) for size in size_list]
221
+ max_size = max(size_list)
222
+
223
+ # receiving Tensor from all ranks
224
+ # we pad the tensor because torch all_gather does not support
225
+ # gathering tensors of different shapes
226
+ tensor_list = []
227
+ for _ in size_list:
228
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
229
+ if local_size != max_size:
230
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
231
+ tensor = torch.cat((tensor, padding), dim=0)
232
+ dist.all_gather(tensor_list, tensor)
233
+
234
+ data_list = []
235
+ for size, tensor in zip(size_list, tensor_list):
236
+ buffer = tensor.cpu().numpy().tobytes()[:size]
237
+ data_list.append(pickle.loads(buffer))
238
+
239
+ return data_list
240
+
241
+
242
+ def all_gather_cpu(data):
243
+ """
244
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
245
+ Args:
246
+ data: any picklable object
247
+ group: a torch process group. By default, will use a group which
248
+ contains all ranks on gloo backend.
249
+ Returns:
250
+ list[data]: list of data gathered from each rank
251
+ """
252
+
253
+ def _get_global_gloo_group():
254
+ """
255
+ Return a process group based on gloo backend, containing all the ranks
256
+ The result is cached.
257
+ """
258
+ if dist.get_backend() == "nccl":
259
+ return dist.new_group(backend="gloo")
260
+ else:
261
+ return dist.group.WORLD
262
+
263
+ if get_world_size() == 1:
264
+ return [data]
265
+ group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage.
266
+ world_size = dist.get_world_size(group)
267
+ if world_size == 1:
268
+ return [data]
269
+
270
+ output = [None for _ in range(world_size)]
271
+ dist.all_gather_object(output, data, group=group)
272
+ return output
273
+
274
+
275
+ def reduce_dict(input_dict, average=True):
276
+ """
277
+ Args:
278
+ input_dict (dict): all the values will be reduced
279
+ average (bool): whether to do average or sum
280
+ Reduce the values in the dictionary from all processes so that process with rank
281
+ 0 has the averaged results. Returns a dict with the same fields as
282
+ input_dict, after reduction.
283
+ """
284
+ world_size = get_world_size()
285
+ if world_size < 2:
286
+ return input_dict
287
+ with torch.no_grad():
288
+ names = []
289
+ values = []
290
+ # sort the keys so that they are consistent across processes
291
+ for k in sorted(input_dict.keys()):
292
+ names.append(k)
293
+ values.append(input_dict[k])
294
+ values = torch.stack(values, dim=0)
295
+ dist.reduce(values, dst=0)
296
+ if dist.get_rank() == 0 and average:
297
+ # only main process gets accumulated, so only divide by
298
+ # world_size in this case
299
+ values /= world_size
300
+ reduced_dict = {k: v for k, v in zip(names, values)}
301
+ return reduced_dict
302
+
303
+
304
+ def broadcast_data(data):
305
+ if not torch.distributed.is_initialized():
306
+ return data
307
+ rank = dist.get_rank()
308
+ if rank == 0:
309
+ data_tensor = torch.tensor(data + [0], device="cuda")
310
+ else:
311
+ data_tensor = torch.tensor(data + [1], device="cuda")
312
+ torch.distributed.broadcast(data_tensor, 0)
313
+ while data_tensor.cpu().numpy()[-1] == 1:
314
+ time.sleep(1)
315
+
316
+ return data_tensor.cpu().numpy().tolist()[:-1]
317
+
318
+
319
+ def reduce_sum(tensor):
320
+ if get_world_size() <= 1:
321
+ return tensor
322
+
323
+ tensor = tensor.clone()
324
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
325
+ return tensor
326
+
327
+
328
+ def save_result(result, filename):
329
+ output_folder = os.path.dirname(filename)
330
+ basename = os.path.splitext(os.path.basename(filename))[0]
331
+ os.makedirs(output_folder, exist_ok=True)
332
+
333
+ if isinstance(result, torch.Tensor) and result.ndim in [3,4]:
334
+ if result.ndim==3 and result.size(0) not in [1,3]:
335
+ result = make_grid(result.unsqueeze(1))
336
+ elif result.ndim==4:
337
+ result = make_grid(result)
338
+ else:
339
+ result = make_grid([result])
340
+
341
+ im = Image.fromarray(result.clamp_(0, 255).permute(1, 2, 0).to(torch.uint8).numpy())
342
+ im.save(os.path.join(output_folder, '{}.png'.format(basename)))
343
+ else:
344
+ torch.save(result, os.path.join(output_folder, '{}.pth'.format(basename)))
MedImageInsight/Distributed/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .Utils import analysis_model
2
+ from .Utils import is_main_process
3
+ from .Utils import gather_tensors
4
+ from .Utils import register_norm_module
5
+ from .Utils import NORM_MODULES
6
+ from .Utils import DistributionGridFactory
MedImageInsight/ImageDataLoader/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .build import build_dataloader
2
+ #from .build import build_multitask_dataloader
3
+ from .transforms import build_transforms
4
+ #from .imagenet.real_labels import RealLabelsImagenet
5
+ from .constants import IMAGENET_CLASSES
6
+ from .constants import IMAGENET_DEFAULT_TEMPLATES
7
+ from .zipdata import ZipData
8
+ #from .vision_dataset import VDImageTextDataset, MultiClassTorchDatasetWrapper
MedImageInsight/ImageDataLoader/blob_storage.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+ import logging
5
+ import subprocess
6
+ import os.path as op
7
+ from typing import List
8
+ from collections import OrderedDict
9
+
10
+ import torch.distributed as distributed
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ DEFAULT_AZCOPY_PATH = 'azcopy/azcopy'
15
+
16
+
17
+ def disk_usage(path: str) -> float:
18
+ stat = shutil.disk_usage(path)
19
+ return stat.used / stat.total
20
+
21
+
22
+ def is_download_successful(stdout: str) -> bool:
23
+ for line in stdout.split('\n'):
24
+ if line == "Number of Transfers Failed: 0":
25
+ return True
26
+ logger.info("Azcopy message:\n %s" % stdout)
27
+ return False
28
+
29
+
30
+ def ensure_directory(path):
31
+ """Check existence of the given directory path. If not, create a new directory.
32
+
33
+ Args:
34
+ path (str): path of a given directory.
35
+ """
36
+ if path == '' or path == '.':
37
+ return
38
+ if path is not None and len(path) > 0:
39
+ assert not op.isfile(path), '{} is a file'.format(path)
40
+ if not op.exists(path) and not op.islink(path):
41
+ os.makedirs(path, exist_ok=True)
42
+ # we should always check if it succeeds.
43
+ assert op.isdir(op.abspath(path)), path
44
+
45
+
46
+ class LRU(OrderedDict):
47
+ def __init__(self, maxsize=3):
48
+ self.maxsize = maxsize
49
+
50
+ def __getitem__(self, key):
51
+ value = super().__getitem__(key)
52
+ self.move_to_end(key)
53
+ return value
54
+
55
+ def __setitem__(self, key, value):
56
+ if key in self:
57
+ if self[key] is not None:
58
+ self[key].close()
59
+ self.move_to_end(key)
60
+
61
+ logger.debug('=> Cache {}'.format(key))
62
+ super().__setitem__(key, value)
63
+
64
+ if len(self) > self.maxsize:
65
+ oldest = next(iter(self))
66
+ if self[oldest] is not None:
67
+ self[oldest].close()
68
+ logger.debug('=> Purged {}'.format(oldest))
69
+ del self[oldest]
70
+
71
+
72
+ class BlobStorage(OrderedDict):
73
+ """ Pseudo Blob Storage manager
74
+
75
+ The registered blobs are maintained in a LRU cache.
76
+ Limit size, evicting the least recently looked-up key when full.
77
+ https://docs.python.org/3/library/collections.html#collections.OrderedDict
78
+
79
+ Input argument:
80
+ sas_token (str): path to SAS token.
81
+ """
82
+ def __init__(self,
83
+ is_train: bool,
84
+ sas_token_path: str = None,
85
+ azcopy_path: str = None,
86
+ *args, **kwds):
87
+ super().__init__(*args, **kwds)
88
+ self.maxsize = 2 if is_train else 10 # Set maxsize to large number such val data never get purged.
89
+ self.is_train = is_train
90
+
91
+ if sas_token_path:
92
+ self.sas_token = BlobStorage.read_sas_token(sas_token_path)
93
+ self.base_url = self.sas_token[:self.sas_token.index("?")]
94
+ self.query_string = self.sas_token[self.sas_token.index("?"):]
95
+ self.container = BlobStorage.extract_container(self.sas_token)
96
+ else:
97
+ self.sas_token = None
98
+ self.base_url = None
99
+ self.query_string = None
100
+ self.container = None
101
+
102
+ logger.debug(
103
+ f"=> [BlobStorage] Base url: {self.base_url}"
104
+ f"=> [BlobStorage] Query string: {self.query_string}"
105
+ f"=> [BlobStorage] Container name: {self.container}"
106
+ )
107
+
108
+ self.azcopy_path = azcopy_path if azcopy_path else DEFAULT_AZCOPY_PATH
109
+ self._cached_files = LRU(3)
110
+
111
+ def __getitem__(self, key):
112
+ value = super().__getitem__(key)
113
+ self.move_to_end(key)
114
+ return value
115
+
116
+ def __setitem__(self, key, value):
117
+ if key in self:
118
+ self.move_to_end(key)
119
+ super().__setitem__(key, value)
120
+ # NOTE: purge the least recently used data if the disk usage is high.
121
+ # ITP restarts GPU clusters when disk usage reaches 80%.
122
+ if len(self) > self.maxsize:
123
+ oldest = next(iter(self))
124
+ del self[oldest]
125
+
126
+ @staticmethod
127
+ def read_sas_token(path: str) -> str:
128
+ with open(path, 'r') as f:
129
+ token = f.readline().strip()
130
+ return token
131
+
132
+ @staticmethod
133
+ def extract_container(token: str) -> str:
134
+ """
135
+ Input argument:
136
+ token (str): the full URI of Shared Access Signature (SAS) in the following format.
137
+ https://[storage_account].blob.core.windows.net/[container_name][SAS_token]
138
+ """
139
+ return os.path.basename(token.split('?')[0])
140
+
141
+ def _convert_to_blob_url(self, local_path: str):
142
+ return self.base_url + local_path.split("azcopy")[1] + self.query_string
143
+
144
+ def _convert_to_blob_folder_url(self, local_path: str):
145
+ return self.base_url + local_path.split("azcopy")[1] + "/*" + self.query_string
146
+
147
+ def fetch_blob(self, local_path: str) -> None:
148
+ if op.exists(local_path):
149
+ logger.info('=> Try to open {}'.format(local_path))
150
+ fp = open(local_path, 'r')
151
+ self._cached_files[local_path] = fp
152
+ logger.debug("=> %s downloaded. Skip." % local_path)
153
+ return
154
+ blob_url = self._convert_to_blob_url(local_path)
155
+ rank = '0' if 'RANK' not in os.environ else os.environ['RANK']
156
+ cmd = [self.azcopy_path, "copy", blob_url, local_path + rank]
157
+ curr_usage = disk_usage('/')
158
+ logger.info(
159
+ "=> Downloading %s with azcopy ... (disk usage: %.2f%%)"
160
+ % (local_path, curr_usage * 100)
161
+ )
162
+ proc = subprocess.run(cmd, stdout=subprocess.PIPE)
163
+ while not is_download_successful(proc.stdout.decode()):
164
+ logger.info("=> Azcopy failed to download {}. Retrying ...".format(blob_url))
165
+ proc = subprocess.run(cmd, stdout=subprocess.PIPE)
166
+ if not op.exists(local_path):
167
+ os.rename(local_path + rank, local_path)
168
+ else:
169
+ os.remove(local_path + rank)
170
+ logger.info(
171
+ "=> Downloaded %s with azcopy ... (disk usage: %.2f%% => %.2f%%)" %
172
+ (local_path, curr_usage * 100, disk_usage('/') * 100)
173
+ )
174
+
175
+ def fetch_blob_folder(self, local_path: str, azcopy_args: list=[]) -> None:
176
+ blob_url = self._convert_to_blob_folder_url(local_path)
177
+ cmd = [self.azcopy_path, "copy", blob_url, local_path] + azcopy_args
178
+ curr_usage = disk_usage('/')
179
+ logger.info(
180
+ "=> Downloading %s with azcopy args %s ... (disk usage: %.2f%%)"
181
+ % (local_path, ' '.join(azcopy_args), curr_usage * 100)
182
+ )
183
+ proc = subprocess.run(cmd, stdout=subprocess.PIPE)
184
+ while not is_download_successful(proc.stdout.decode()):
185
+ logger.info("=> Azcopy failed to download {} with args {}. Retrying ...".format(blob_url, ' '.join(azcopy_args)))
186
+ proc = subprocess.run(cmd, stdout=subprocess.PIPE)
187
+ logger.info(
188
+ "=> Downloaded %s with azcopy args %s ... (disk usage: %.2f%% => %.2f%%)" %
189
+ (local_path, ' '.join(azcopy_args), curr_usage * 100, disk_usage('/') * 100)
190
+ )
191
+
192
+ def register_local_tsv_paths(self, local_paths: List[str]) -> List[str]:
193
+ if self.sas_token:
194
+ tsv_paths_new = []
195
+ lineidx_paths = set()
196
+ linelist_paths = set()
197
+ for path in local_paths:
198
+ tsv_path_az = path.replace(self.container, 'azcopy')
199
+ tsv_paths_new.append(tsv_path_az)
200
+ logger.debug("=> Registering {}".format(tsv_path_az))
201
+
202
+ if not self.is_train:
203
+ logger.info('=> Downloading {}...'.format(tsv_path_az))
204
+ self.fetch_blob(tsv_path_az)
205
+ logger.info('=> Downloaded {}'.format(tsv_path_az))
206
+
207
+ lineidx = op.splitext(path)[0] + '.lineidx'
208
+ lineidx_ = lineidx.replace(self.container, 'azcopy')
209
+ if self.is_train:
210
+ if not op.isfile(lineidx_) and op.dirname(lineidx_) not in lineidx_paths:
211
+ lineidx_paths.add(op.dirname(lineidx_))
212
+ else:
213
+ if not op.isfile(lineidx_):
214
+ ensure_directory(op.dirname(lineidx_))
215
+ self.fetch_blob(lineidx_)
216
+
217
+ linelist = op.splitext(path)[0] + '.linelist'
218
+ linelist_ = linelist.replace(self.container, 'azcopy')
219
+ # .linelist does not always exist. Check existence before fetch
220
+ if self.is_train:
221
+ if op.isfile(linelist) and not op.isfile(linelist_) and op.dirname(linelist_) not in linelist_paths:
222
+ linelist_paths.add(op.dirname(linelist_))
223
+ else:
224
+ if op.isfile(linelist) and not op.isfile(linelist_):
225
+ ensure_directory(op.dirname(linelist_))
226
+ self.fetch_blob(linelist_)
227
+
228
+ if self.is_train:
229
+ for path in lineidx_paths:
230
+ self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.lineidx'])
231
+
232
+ for path in linelist_paths:
233
+ self.fetch_blob_folder(path, azcopy_args=['--include-pattern', '*.linelist'])
234
+
235
+ return tsv_paths_new
236
+ else:
237
+ return local_paths
238
+
239
+ def open(self, local_path: str):
240
+ if self.sas_token and 'azcopy' in local_path:
241
+ while not op.exists(local_path):
242
+ time.sleep(1)
243
+ fid = open(local_path, 'r')
244
+ return fid
MedImageInsight/ImageDataLoader/build.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import logging
6
+ import os
7
+ import json
8
+ import pathlib
9
+ from os.path import basename
10
+
11
+ from timm.data import create_loader
12
+ import torch
13
+ import torch.utils.data
14
+ import torch.distributed as dist
15
+ import torchvision.datasets as datasets
16
+ from torchvision.io import read_image
17
+ import torch.distributed as dist
18
+ from pathlib import Path
19
+ from yacs.config import CfgNode as CN
20
+
21
+ from ..LangEncoder import build_tokenizer
22
+
23
+ from .tsv import TSVImageTextDatasetV2
24
+ from .tsv import TSVMeta
25
+ from .transforms import build_transforms
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def build_dataset(cfg, is_train):
31
+ if cfg['DATASET']['DATASET'] == 'image_text_pairs_v2':
32
+ dataset = _build_pairs_dataset_v2(cfg, is_train)
33
+ else:
34
+ raise ValueError(f'Unknown dataset: {cfg["DATASET"]["DATASET"]}')
35
+ return dataset
36
+
37
+
38
+ def _get_tsv_list(cfg, is_train):
39
+ tmp_list = []
40
+ if is_train and 'TRAIN_TSV_LIST' in cfg['DATASET']:
41
+ tmp_list = cfg['DATASET']['TRAIN_TSV_LIST']
42
+ elif 'TEST_TSV_LIST' in cfg['DATASET']:
43
+ tmp_list = cfg['DATASET']['TEST_TSV_LIST']
44
+
45
+ tsv_list = []
46
+ for l in tmp_list:
47
+ if l.endswith('.list'):
48
+ with open(l, 'r') as f:
49
+ tsv_list.extend([i.strip() for i in f])
50
+ else:
51
+ tsv_list.append(l)
52
+
53
+ logger.info(f'tsv list: {tsv_list}')
54
+
55
+ return tsv_list
56
+
57
+
58
+ def _get_token_file(cfg):
59
+ num_nodes = dist.get_world_size() // torch.cuda.device_count()
60
+ if isinstance(cfg['DATASET']['TOKEN_FILE'], list):
61
+ if num_nodes == 1:
62
+ logger.warning('=> Multi token files are provided, but only one node is used for training')
63
+ sas_token_file = cfg['DATASET']['TOKEN_FILE'][0]
64
+ else:
65
+ rank = dist.get_rank()
66
+ node_idx = rank // torch.cuda.device_count()
67
+ num_token_files = len(cfg['DATASET']['TOKEN_FILE'])
68
+ sas_token_file = cfg['DATASET']['TOKEN_FILE'][node_idx % num_token_files]
69
+ else:
70
+ sas_token_file = cfg['DATASET']['TOKEN_FILE']
71
+
72
+ sas_token_file = os.path.join(cfg['DATASET']['ROOT'], sas_token_file)
73
+
74
+ if (
75
+ cfg['DATASET']['LOADER'] == 'blobfuse'
76
+ or not os.path.isfile(sas_token_file)
77
+ ):
78
+ sas_token_file = None
79
+
80
+ return sas_token_file
81
+
82
+
83
+ def _build_pairs_dataset_v2(cfg, is_train):
84
+ transforms = build_transforms(cfg, is_train)
85
+ logger.info('transforms: {}'.format(transforms))
86
+
87
+ dataset_name = cfg['DATASET']['TRAIN_SET'] \
88
+ if is_train else cfg['DATASET']['TEST_SET']
89
+
90
+ tokenobj = build_tokenizer(cfg['LANG_ENCODER'])
91
+
92
+ if cfg['DATASET']['DATA_FORMAT'] != 'tsv':
93
+ raise ValueError('Only support tsv format for pairs dataset v2')
94
+
95
+ tsv_list = _get_tsv_list(cfg, is_train)
96
+
97
+ if len(tsv_list) > 0:
98
+ tsv_filenames = sorted(
99
+ [
100
+ os.path.join(cfg['DATASET']['ROOT'], dataset_name, f)
101
+ for f in tsv_list
102
+ ]
103
+ )
104
+ else:
105
+ dataset_path = os.path.join(cfg['DATASET']['ROOT'], dataset_name)
106
+ tsv_files = Path(dataset_path).glob('**/*.tsv')
107
+
108
+ tsv_filenames = sorted(
109
+ [
110
+ str(path)
111
+ for path in tsv_files
112
+ ]
113
+ )
114
+
115
+ image_tsv_files = [
116
+ filename
117
+ for filename in tsv_filenames
118
+ if (
119
+ 'image-' in basename(filename)
120
+ or 'image_' in basename(filename)
121
+ or '_image' in basename(filename)
122
+ or '-image' in basename(filename)
123
+ or 'images-' in basename(filename)
124
+ )
125
+ ]
126
+ text_tsv_files = [
127
+ filename
128
+ for filename in tsv_filenames
129
+ if (
130
+ 'text-' in basename(filename)
131
+ or 'text_' in basename(filename)
132
+ or '_text' in basename(filename)
133
+ or '-text' in basename(filename)
134
+ or 'texts-' in basename(filename)
135
+ )
136
+ ]
137
+
138
+ logger.info(
139
+ "=> found %d/%d tsv file(s) to load.",
140
+ len(image_tsv_files), len(text_tsv_files)
141
+ )
142
+
143
+ num_captions = 1 \
144
+ if is_train else cfg['DATASET'].get('NUM_CAPTIONS', 1)
145
+ text_format = cfg['DATASET'].get('TEXT_FORMAT', 'json')
146
+
147
+ sas_token_file = _get_token_file(cfg)
148
+ logger.info("=> SAS token path: %s", sas_token_file)
149
+
150
+ metas = []
151
+ cfg_data = cfg['DATASET']
152
+ if 'CLASSIFICATION_SETS' in cfg_data and 'NUM_CLASSES' in cfg_data:
153
+ for source, num_classes in zip(cfg_data['CLASSIFICATION_SETS'], cfg_data['NUM_CLASSES']):
154
+ metas.append(
155
+ TSVMeta(
156
+ source=source,
157
+ num_classes=num_classes,
158
+ task='classification'
159
+ )
160
+ )
161
+ logger.info('=> add meta: {}'.format(metas[-1]))
162
+
163
+ if 'coco-caption' in dataset_name:
164
+ logger.info('=> coco caption data is used')
165
+ logger.info('=> update num_captions: 5, text_format: json')
166
+ logger.warning('=> set sas token to None for coco evaluation')
167
+ sas_token_file = None
168
+ num_captions = 5
169
+ text_format = 'json'
170
+
171
+ dataset = TSVImageTextDatasetV2(
172
+ image_tsv_files, text_tsv_files,
173
+ transform=transforms,
174
+ tokenize=tokenobj,
175
+ context_length=cfg['LANG_ENCODER']['CONTEXT_LENGTH'],
176
+ num_captions=num_captions,
177
+ text_format=text_format,
178
+ is_train=is_train,
179
+ sas_token_path=sas_token_file,
180
+ metas=metas,
181
+ prompt_engineering=cfg['DATASET'].get('PROMPT_ENGINEERING', True),
182
+ concat_queries=cfg['DATASET'].get('CONCAT_QUERIES', False)
183
+ )
184
+
185
+ logger.info(
186
+ "=> %s set size: %d", 'train'
187
+ if is_train else 'val', len(dataset)
188
+ )
189
+
190
+ return dataset
191
+
192
+
193
+ def build_dataloader(cfg, is_train=True, distributed=False):
194
+ dataset = build_dataset(cfg, is_train)
195
+
196
+ if (
197
+ is_train
198
+ and 'TIMM_AUG' in cfg['AUG']
199
+ and cfg['AUG']['TIMM_AUG']['USE_LOADER']
200
+ ):
201
+ logger.info('=> use timm loader for training')
202
+ timm_cfg = CN(init_dict=cfg['AUG']['TIMM_AUG'])
203
+ data_loader = create_loader(
204
+ dataset,
205
+ input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
206
+ batch_size=cfg['TRAIN']['BATCH_SIZE_PER_GPU'],
207
+ is_training=True,
208
+ use_prefetcher=True,
209
+ no_aug=False,
210
+ re_prob=timm_cfg.RE_PROB,
211
+ re_mode=timm_cfg.RE_MODE,
212
+ re_count=timm_cfg.RE_COUNT,
213
+ re_split=timm_cfg.RE_SPLIT,
214
+ scale=cfg['AUG']['SCALE'],
215
+ ratio=cfg['AUG']['RATIO'],
216
+ hflip=timm_cfg.HFLIP,
217
+ vflip=timm_cfg.VFLIP,
218
+ color_jitter=timm_cfg.COLOR_JITTER,
219
+ auto_augment=timm_cfg.AUTO_AUGMENT,
220
+ num_aug_splits=0,
221
+ interpolation=cfg['AUG']['INTERPOLATION'],
222
+ mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
223
+ std=cfg['IMAGE_ENCODER']['IMAGE_STD'],
224
+ num_workers=cfg['WORKERS'],
225
+ distributed=distributed,
226
+ collate_fn=None,
227
+ pin_memory=cfg['PIN_MEMORY'],
228
+ use_multi_epochs_loader=True
229
+ )
230
+ else:
231
+ if is_train:
232
+ batch_size_per_gpu = cfg['TRAIN']['BATCH_SIZE_PER_GPU']
233
+ shuffle = cfg['TRAIN'].get('SHUFFLE', True)
234
+ else:
235
+ batch_size_per_gpu = cfg['TEST']['BATCH_SIZE_PER_GPU']
236
+ shuffle = cfg['TEST'].get('SHUFFLE', False)
237
+
238
+ if distributed or cfg.get('ALWAYS_ENABLE_SAMPLER', False):
239
+ # sampler = build_sampler(cfg, dataset, is_train, shuffle)
240
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
241
+ shuffle = False
242
+ else:
243
+ sampler = None
244
+
245
+ data_loader = torch.utils.data.DataLoader(
246
+ dataset,
247
+ batch_size=batch_size_per_gpu,
248
+ shuffle=shuffle,
249
+ num_workers=cfg['WORKERS'],
250
+ pin_memory=cfg['PIN_MEMORY'],
251
+ sampler=sampler,
252
+ drop_last=True if is_train else False,
253
+ prefetch_factor=cfg.get('PREFETCH_FACTOR', 2)
254
+ )
255
+
256
+ return data_loader
257
+
258
+
259
+
260
+
MedImageInsight/ImageDataLoader/constants.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IMAGENET_CLASSES = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "dark glasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
2
+
3
+ IMAGENET_DEFAULT_TEMPLATES = [
4
+ '{}.',
5
+ 'a bad photo of a {}.',
6
+ 'a photo of many {}.',
7
+ 'a sculpture of a {}.',
8
+ 'a photo of the hard to see {}.',
9
+ 'a low resolution photo of the {}.',
10
+ 'a rendering of a {}.',
11
+ 'graffiti of a {}.',
12
+ 'a bad photo of the {}.',
13
+ 'a cropped photo of the {}.',
14
+ 'a tattoo of a {}.',
15
+ 'the embroidered {}.',
16
+ 'a photo of a hard to see {}.',
17
+ 'a bright photo of a {}.',
18
+ 'a photo of a clean {}.',
19
+ 'a photo of a dirty {}.',
20
+ 'a dark photo of the {}.',
21
+ 'a drawing of a {}.',
22
+ 'a photo of my {}.',
23
+ 'the plastic {}.',
24
+ 'a photo of the cool {}.',
25
+ 'a close-up photo of a {}.',
26
+ 'a black and white photo of the {}.',
27
+ 'a painting of the {}.',
28
+ 'a painting of a {}.',
29
+ 'a pixelated photo of the {}.',
30
+ 'a sculpture of the {}.',
31
+ 'a bright photo of the {}.',
32
+ 'a cropped photo of a {}.',
33
+ 'a plastic {}.',
34
+ 'a photo of the dirty {}.',
35
+ 'a jpeg corrupted photo of a {}.',
36
+ 'a blurry photo of the {}.',
37
+ 'a photo of the {}.',
38
+ 'a good photo of the {}.',
39
+ 'a rendering of the {}.',
40
+ 'a {} in a video game.',
41
+ 'a photo of one {}.',
42
+ 'a doodle of a {}.',
43
+ 'a close-up photo of the {}.',
44
+ 'a photo of a {}.',
45
+ 'the origami {}.',
46
+ 'the {} in a video game.',
47
+ 'a sketch of a {}.',
48
+ 'a doodle of the {}.',
49
+ 'a origami {}.',
50
+ 'a low resolution photo of a {}.',
51
+ 'the toy {}.',
52
+ 'a rendition of the {}.',
53
+ 'a photo of the clean {}.',
54
+ 'a photo of a large {}.',
55
+ 'a rendition of a {}.',
56
+ 'a photo of a nice {}.',
57
+ 'a photo of a weird {}.',
58
+ 'a blurry photo of a {}.',
59
+ 'a cartoon {}.',
60
+ 'art of a {}.',
61
+ 'a sketch of the {}.',
62
+ 'a embroidered {}.',
63
+ 'a pixelated photo of a {}.',
64
+ 'itap of the {}.',
65
+ 'a jpeg corrupted photo of the {}.',
66
+ 'a good photo of a {}.',
67
+ 'a plushie {}.',
68
+ 'a photo of the nice {}.',
69
+ 'a photo of the small {}.',
70
+ 'a photo of the weird {}.',
71
+ 'the cartoon {}.',
72
+ 'art of the {}.',
73
+ 'a drawing of the {}.',
74
+ 'a photo of the large {}.',
75
+ 'a black and white photo of a {}.',
76
+ 'the plushie {}.',
77
+ 'a dark photo of a {}.',
78
+ 'itap of a {}.',
79
+ 'graffiti of the {}.',
80
+ 'a toy {}.',
81
+ 'itap of my {}.',
82
+ 'a photo of a cool {}.',
83
+ 'a photo of a small {}.',
84
+ 'a tattoo of the {}.',
85
+ ]
MedImageInsight/ImageDataLoader/languages/__init__.py ADDED
File without changes
MedImageInsight/ImageDataLoader/languages/prompt_engineering.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+
4
+
5
+ def get_prompt_templates():
6
+ prompt_templates = [
7
+ '{}.',
8
+ 'a photo of a {}.',
9
+ 'a bad photo of a {}.',
10
+ 'a photo of many {}.',
11
+ 'a sculpture of a {}.',
12
+ 'a photo of the hard to see {}.',
13
+ 'a low resolution photo of the {}.',
14
+ 'a rendering of a {}.',
15
+ 'graffiti of a {}.',
16
+ 'a bad photo of the {}.',
17
+ 'a cropped photo of the {}.',
18
+ 'a tattoo of a {}.',
19
+ 'the embroidered {}.',
20
+ 'a photo of a hard to see {}.',
21
+ 'a bright photo of a {}.',
22
+ 'a photo of a clean {}.',
23
+ 'a photo of a dirty {}.',
24
+ 'a dark photo of the {}.',
25
+ 'a drawing of a {}.',
26
+ 'a photo of my {}.',
27
+ 'the plastic {}.',
28
+ 'a photo of the cool {}.',
29
+ 'a close-up photo of a {}.',
30
+ 'a black and white photo of the {}.',
31
+ 'a painting of the {}.',
32
+ 'a painting of a {}.',
33
+ 'a pixelated photo of the {}.',
34
+ 'a sculpture of the {}.',
35
+ 'a bright photo of the {}.',
36
+ 'a cropped photo of a {}.',
37
+ 'a plastic {}.',
38
+ 'a photo of the dirty {}.',
39
+ 'a jpeg corrupted photo of a {}.',
40
+ 'a blurry photo of the {}.',
41
+ 'a photo of the {}.',
42
+ 'a good photo of the {}.',
43
+ 'a rendering of the {}.',
44
+ 'a {} in a video game.',
45
+ 'a photo of one {}.',
46
+ 'a doodle of a {}.',
47
+ 'a close-up photo of the {}.',
48
+ 'the origami {}.',
49
+ 'the {} in a video game.',
50
+ 'a sketch of a {}.',
51
+ 'a doodle of the {}.',
52
+ 'a origami {}.',
53
+ 'a low resolution photo of a {}.',
54
+ 'the toy {}.',
55
+ 'a rendition of the {}.',
56
+ 'a photo of the clean {}.',
57
+ 'a photo of a large {}.',
58
+ 'a rendition of a {}.',
59
+ 'a photo of a nice {}.',
60
+ 'a photo of a weird {}.',
61
+ 'a blurry photo of a {}.',
62
+ 'a cartoon {}.',
63
+ 'art of a {}.',
64
+ 'a sketch of the {}.',
65
+ 'a embroidered {}.',
66
+ 'a pixelated photo of a {}.',
67
+ 'itap of the {}.',
68
+ 'a jpeg corrupted photo of the {}.',
69
+ 'a good photo of a {}.',
70
+ 'a plushie {}.',
71
+ 'a photo of the nice {}.',
72
+ 'a photo of the small {}.',
73
+ 'a photo of the weird {}.',
74
+ 'the cartoon {}.',
75
+ 'art of the {}.',
76
+ 'a drawing of the {}.',
77
+ 'a photo of the large {}.',
78
+ 'a black and white photo of a {}.',
79
+ 'the plushie {}.',
80
+ 'a dark photo of a {}.',
81
+ 'itap of a {}.',
82
+ 'graffiti of the {}.',
83
+ 'a toy {}.',
84
+ 'itap of my {}.',
85
+ 'a photo of a cool {}.',
86
+ 'a photo of a small {}.',
87
+ 'a tattoo of the {}.',
88
+ ]
89
+ return prompt_templates
90
+
91
+
92
+ def prompt_engineering(classnames):
93
+ prompt_templates = get_prompt_templates()
94
+ temp_idx = np.random.randint(len(prompt_templates))
95
+
96
+ if isinstance(classnames, list):
97
+ classname = random.choice(classnames)
98
+ else:
99
+ classname = classnames
100
+
101
+ return prompt_templates[temp_idx].replace('{}', classname.replace(',', '').replace('+', ' '))
MedImageInsight/ImageDataLoader/transforms/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build import build_transforms
MedImageInsight/ImageDataLoader/transforms/autoaugment.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from enum import Enum
3
+ from typing import List, Tuple, Optional, Dict
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from torchvision.transforms import functional as F
9
+ from torchvision.transforms.functional import InterpolationMode
10
+
11
+ __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
12
+
13
+
14
+ def _apply_op(
15
+ img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]]
16
+ ):
17
+ if op_name == "ShearX":
18
+ img = F.affine(
19
+ img,
20
+ angle=0.0,
21
+ translate=[0, 0],
22
+ scale=1.0,
23
+ shear=[math.degrees(magnitude), 0.0],
24
+ interpolation=interpolation,
25
+ fill=fill,
26
+ )
27
+ elif op_name == "ShearY":
28
+ img = F.affine(
29
+ img,
30
+ angle=0.0,
31
+ translate=[0, 0],
32
+ scale=1.0,
33
+ shear=[0.0, math.degrees(magnitude)],
34
+ interpolation=interpolation,
35
+ fill=fill,
36
+ )
37
+ elif op_name == "TranslateX":
38
+ img = F.affine(
39
+ img,
40
+ angle=0.0,
41
+ translate=[int(magnitude), 0],
42
+ scale=1.0,
43
+ interpolation=interpolation,
44
+ shear=[0.0, 0.0],
45
+ fill=fill,
46
+ )
47
+ elif op_name == "TranslateY":
48
+ img = F.affine(
49
+ img,
50
+ angle=0.0,
51
+ translate=[0, int(magnitude)],
52
+ scale=1.0,
53
+ interpolation=interpolation,
54
+ shear=[0.0, 0.0],
55
+ fill=fill,
56
+ )
57
+ elif op_name == "Rotate":
58
+ img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
59
+ elif op_name == "Brightness":
60
+ img = F.adjust_brightness(img, 1.0 + magnitude)
61
+ elif op_name == "Color":
62
+ img = F.adjust_saturation(img, 1.0 + magnitude)
63
+ elif op_name == "Contrast":
64
+ img = F.adjust_contrast(img, 1.0 + magnitude)
65
+ elif op_name == "Sharpness":
66
+ img = F.adjust_sharpness(img, 1.0 + magnitude)
67
+ elif op_name == "Posterize":
68
+ img = F.posterize(img, int(magnitude))
69
+ elif op_name == "Solarize":
70
+ img = F.solarize(img, magnitude)
71
+ elif op_name == "AutoContrast":
72
+ img = F.autocontrast(img)
73
+ elif op_name == "Equalize":
74
+ img = F.equalize(img)
75
+ elif op_name == "Invert":
76
+ img = F.invert(img)
77
+ elif op_name == "Identity":
78
+ pass
79
+ else:
80
+ raise ValueError(f"The provided operator {op_name} is not recognized.")
81
+ return img
82
+
83
+
84
+ class AutoAugmentPolicy(Enum):
85
+ """AutoAugment policies learned on different datasets.
86
+ Available policies are IMAGENET, CIFAR10 and SVHN.
87
+ """
88
+
89
+ IMAGENET = "imagenet"
90
+ CIFAR10 = "cifar10"
91
+ SVHN = "svhn"
92
+
93
+
94
+ # FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
95
+ class AutoAugment(torch.nn.Module):
96
+ r"""AutoAugment data augmentation method based on
97
+ `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
98
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
99
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
100
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
101
+
102
+ Args:
103
+ policy (AutoAugmentPolicy): Desired policy enum defined by
104
+ :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
105
+ interpolation (InterpolationMode): Desired interpolation enum defined by
106
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
107
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
108
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
109
+ image. If given a number, the value is used for all bands respectively.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
115
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
116
+ fill: Optional[List[float]] = None,
117
+ ) -> None:
118
+ super().__init__()
119
+ self.policy = policy
120
+ self.interpolation = interpolation
121
+ self.fill = fill
122
+ self.policies = self._get_policies(policy)
123
+
124
+ def _get_policies(
125
+ self, policy: AutoAugmentPolicy
126
+ ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
127
+ if policy == AutoAugmentPolicy.IMAGENET:
128
+ return [
129
+ (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
130
+ (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
131
+ (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
132
+ (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
133
+ (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
134
+ (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
135
+ (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
136
+ (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
137
+ (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
138
+ (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
139
+ (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
140
+ (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
141
+ (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
142
+ (("Invert", 0.6, None), ("Equalize", 1.0, None)),
143
+ (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
144
+ (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
145
+ (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
146
+ (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
147
+ (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
148
+ (("Color", 0.4, 0), ("Equalize", 0.6, None)),
149
+ (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
150
+ (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
151
+ (("Invert", 0.6, None), ("Equalize", 1.0, None)),
152
+ (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
153
+ (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
154
+ ]
155
+ elif policy == AutoAugmentPolicy.CIFAR10:
156
+ return [
157
+ (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
158
+ (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
159
+ (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
160
+ (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
161
+ (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
162
+ (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
163
+ (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
164
+ (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
165
+ (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
166
+ (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
167
+ (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
168
+ (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
169
+ (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
170
+ (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
171
+ (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
172
+ (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
173
+ (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
174
+ (("Color", 0.9, 9), ("Equalize", 0.6, None)),
175
+ (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
176
+ (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
177
+ (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
178
+ (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
179
+ (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
180
+ (("Equalize", 0.8, None), ("Invert", 0.1, None)),
181
+ (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
182
+ ]
183
+ elif policy == AutoAugmentPolicy.SVHN:
184
+ return [
185
+ (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
186
+ (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
187
+ (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
188
+ (("Invert", 0.9, None), ("Equalize", 0.6, None)),
189
+ (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
190
+ (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
191
+ (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
192
+ (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
193
+ (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
194
+ (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
195
+ (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
196
+ (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
197
+ (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
198
+ (("Invert", 0.9, None), ("Equalize", 0.6, None)),
199
+ (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
200
+ (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
201
+ (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
202
+ (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
203
+ (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
204
+ (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
205
+ (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
206
+ (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
207
+ (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
208
+ (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
209
+ (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
210
+ ]
211
+ else:
212
+ raise ValueError(f"The provided policy {policy} is not recognized.")
213
+
214
+ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
215
+ return {
216
+ # op_name: (magnitudes, signed)
217
+ "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
218
+ "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
219
+ "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
220
+ "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
221
+ "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
222
+ "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
223
+ "Color": (torch.linspace(0.0, 0.9, num_bins), True),
224
+ "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
225
+ "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
226
+ "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
227
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
228
+ "AutoContrast": (torch.tensor(0.0), False),
229
+ "Equalize": (torch.tensor(0.0), False),
230
+ "Invert": (torch.tensor(0.0), False),
231
+ }
232
+
233
+ @staticmethod
234
+ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
235
+ """Get parameters for autoaugment transformation
236
+
237
+ Returns:
238
+ params required by the autoaugment transformation
239
+ """
240
+ policy_id = int(torch.randint(transform_num, (1,)).item())
241
+ probs = torch.rand((2,))
242
+ signs = torch.randint(2, (2,))
243
+
244
+ return policy_id, probs, signs
245
+
246
+ def forward(self, img: Tensor) -> Tensor:
247
+ """
248
+ img (PIL Image or Tensor): Image to be transformed.
249
+
250
+ Returns:
251
+ PIL Image or Tensor: AutoAugmented image.
252
+ """
253
+ fill = self.fill
254
+ if isinstance(img, Tensor):
255
+ if isinstance(fill, (int, float)):
256
+ fill = [float(fill)] * F.get_image_num_channels(img)
257
+ elif fill is not None:
258
+ fill = [float(f) for f in fill]
259
+
260
+ transform_id, probs, signs = self.get_params(len(self.policies))
261
+
262
+ for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
263
+ if probs[i] <= p:
264
+ op_meta = self._augmentation_space(10, F.get_image_size(img))
265
+ magnitudes, signed = op_meta[op_name]
266
+ magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
267
+ if signed and signs[i] == 0:
268
+ magnitude *= -1.0
269
+ img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
270
+
271
+ return img
272
+
273
+ def __repr__(self) -> str:
274
+ return self.__class__.__name__ + f"(policy={self.policy}, fill={self.fill})"
275
+
276
+
277
+ class RandAugment(torch.nn.Module):
278
+ r"""RandAugment data augmentation method based on
279
+ `"RandAugment: Practical automated data augmentation with a reduced search space"
280
+ <https://arxiv.org/abs/1909.13719>`_.
281
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
282
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
283
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
284
+
285
+ Args:
286
+ num_ops (int): Number of augmentation transformations to apply sequentially.
287
+ magnitude (int): Magnitude for all the transformations.
288
+ num_magnitude_bins (int): The number of different magnitude values.
289
+ interpolation (InterpolationMode): Desired interpolation enum defined by
290
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
291
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
292
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
293
+ image. If given a number, the value is used for all bands respectively.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ num_ops: int = 2,
299
+ magnitude: int = 9,
300
+ num_magnitude_bins: int = 31,
301
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
302
+ fill: Optional[List[float]] = None,
303
+ ) -> None:
304
+ super().__init__()
305
+ self.num_ops = num_ops
306
+ self.magnitude = magnitude
307
+ self.num_magnitude_bins = num_magnitude_bins
308
+ self.interpolation = interpolation
309
+ self.fill = fill
310
+
311
+ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
312
+ return {
313
+ # op_name: (magnitudes, signed)
314
+ "Identity": (torch.tensor(0.0), False),
315
+ "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
316
+ "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
317
+ "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
318
+ "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
319
+ "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
320
+ "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
321
+ "Color": (torch.linspace(0.0, 0.9, num_bins), True),
322
+ "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
323
+ "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
324
+ "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
325
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
326
+ "AutoContrast": (torch.tensor(0.0), False),
327
+ "Equalize": (torch.tensor(0.0), False),
328
+ }
329
+
330
+ def forward(self, img: Tensor) -> Tensor:
331
+ """
332
+ img (PIL Image or Tensor): Image to be transformed.
333
+
334
+ Returns:
335
+ PIL Image or Tensor: Transformed image.
336
+ """
337
+ fill = self.fill
338
+ if isinstance(img, Tensor):
339
+ if isinstance(fill, (int, float)):
340
+ fill = [float(fill)] * F.get_image_num_channels(img)
341
+ elif fill is not None:
342
+ fill = [float(f) for f in fill]
343
+
344
+ for _ in range(self.num_ops):
345
+ op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
346
+ op_index = int(torch.randint(len(op_meta), (1,)).item())
347
+ op_name = list(op_meta.keys())[op_index]
348
+ magnitudes, signed = op_meta[op_name]
349
+ magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
350
+ if signed and torch.randint(2, (1,)):
351
+ magnitude *= -1.0
352
+ img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
353
+
354
+ return img
355
+
356
+ def __repr__(self) -> str:
357
+ s = self.__class__.__name__ + "("
358
+ s += "num_ops={num_ops}"
359
+ s += ", magnitude={magnitude}"
360
+ s += ", num_magnitude_bins={num_magnitude_bins}"
361
+ s += ", interpolation={interpolation}"
362
+ s += ", fill={fill}"
363
+ s += ")"
364
+ return s.format(**self.__dict__)
365
+
366
+
367
+ class TrivialAugmentWide(torch.nn.Module):
368
+ r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
369
+ `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
370
+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
371
+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
372
+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
373
+
374
+ Args:
375
+ num_magnitude_bins (int): The number of different magnitude values.
376
+ interpolation (InterpolationMode): Desired interpolation enum defined by
377
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
378
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
379
+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
380
+ image. If given a number, the value is used for all bands respectively.
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ num_magnitude_bins: int = 31,
386
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
387
+ fill: Optional[List[float]] = None,
388
+ ) -> None:
389
+ super().__init__()
390
+ self.num_magnitude_bins = num_magnitude_bins
391
+ self.interpolation = interpolation
392
+ self.fill = fill
393
+
394
+ def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
395
+ return {
396
+ # op_name: (magnitudes, signed)
397
+ "Identity": (torch.tensor(0.0), False),
398
+ "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
399
+ "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
400
+ "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
401
+ "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
402
+ "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
403
+ "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
404
+ "Color": (torch.linspace(0.0, 0.99, num_bins), True),
405
+ "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
406
+ "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
407
+ "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
408
+ "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
409
+ "AutoContrast": (torch.tensor(0.0), False),
410
+ "Equalize": (torch.tensor(0.0), False),
411
+ }
412
+
413
+ def forward(self, img: Tensor) -> Tensor:
414
+ """
415
+ img (PIL Image or Tensor): Image to be transformed.
416
+
417
+ Returns:
418
+ PIL Image or Tensor: Transformed image.
419
+ """
420
+ fill = self.fill
421
+ if isinstance(img, Tensor):
422
+ if isinstance(fill, (int, float)):
423
+ fill = [float(fill)] * F.get_image_num_channels(img)
424
+ elif fill is not None:
425
+ fill = [float(f) for f in fill]
426
+
427
+ op_meta = self._augmentation_space(self.num_magnitude_bins)
428
+ op_index = int(torch.randint(len(op_meta), (1,)).item())
429
+ op_name = list(op_meta.keys())[op_index]
430
+ magnitudes, signed = op_meta[op_name]
431
+ magnitude = (
432
+ float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item())
433
+ if magnitudes.ndim > 0
434
+ else 0.0
435
+ )
436
+ if signed and torch.randint(2, (1,)):
437
+ magnitude *= -1.0
438
+
439
+ return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
440
+
441
+ def __repr__(self) -> str:
442
+ s = self.__class__.__name__ + "("
443
+ s += "num_magnitude_bins={num_magnitude_bins}"
444
+ s += ", interpolation={interpolation}"
445
+ s += ", fill={fill}"
446
+ s += ")"
447
+ return s.format(**self.__dict__)
MedImageInsight/ImageDataLoader/transforms/build.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import timm
6
+ from timm.data import create_transform
7
+
8
+ from yacs.config import CfgNode as CN
9
+ from PIL import ImageFilter
10
+ import logging
11
+ import random
12
+
13
+ import torch
14
+ import torchvision.transforms as T
15
+
16
+
17
+ from .autoaugment import AutoAugmentPolicy
18
+ from .autoaugment import AutoAugment
19
+ from .autoaugment import RandAugment
20
+ from .autoaugment import TrivialAugmentWide
21
+ from .threeaugment import deitIII_Solarization
22
+ from .threeaugment import deitIII_gray_scale
23
+ from .threeaugment import deitIII_GaussianBlur
24
+
25
+ from PIL import ImageOps
26
+ from timm.data.transforms import RandomResizedCropAndInterpolation
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class GaussianBlur(object):
32
+ """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
33
+
34
+ def __init__(self, sigma=[.1, 2.]):
35
+ self.sigma = sigma
36
+
37
+ def __call__(self, x):
38
+ sigma = random.uniform(self.sigma[0], self.sigma[1])
39
+ x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
40
+ return x
41
+
42
+
43
+ def get_resolution(original_resolution):
44
+ """Takes (H,W) and returns (precrop, crop)."""
45
+ area = original_resolution[0] * original_resolution[1]
46
+ return (160, 128) if area < 96*96 else (512, 480)
47
+
48
+
49
+ INTERPOLATION_MODES = {
50
+ 'bilinear': T.InterpolationMode.BILINEAR,
51
+ 'bicubic': T.InterpolationMode.BICUBIC,
52
+ 'nearest': T.InterpolationMode.NEAREST,
53
+ }
54
+
55
+
56
+ def build_transforms(cfg, is_train=True):
57
+ # assert isinstance(cfg.DATASET.OUTPUT_SIZE, (list, tuple)), 'DATASET.OUTPUT_SIZE should be list or tuple'
58
+ normalize = T.Normalize(
59
+ mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
60
+ std=cfg['IMAGE_ENCODER']['IMAGE_STD']
61
+ )
62
+
63
+ transforms = None
64
+ if is_train:
65
+ if 'THREE_AUG' in cfg['AUG']:
66
+ img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE']
67
+ remove_random_resized_crop = cfg['AUG']['THREE_AUG']['SRC']
68
+ mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
69
+ primary_tfl = []
70
+ scale=(0.08, 1.0)
71
+ interpolation='bicubic'
72
+ if remove_random_resized_crop:
73
+ primary_tfl = [
74
+ T.Resize(img_size, interpolation=3),
75
+ T.RandomCrop(img_size, padding=4,padding_mode='reflect'),
76
+ T.RandomHorizontalFlip()
77
+ ]
78
+ else:
79
+ primary_tfl = [
80
+ RandomResizedCropAndInterpolation(
81
+ img_size, scale=scale, interpolation=interpolation),
82
+ T.RandomHorizontalFlip()
83
+ ]
84
+ secondary_tfl = [T.RandomChoice([gray_scale(p=1.0),
85
+ Solarization(p=1.0),
86
+ GaussianBlurDeiTv3(p=1.0)])]
87
+ color_jitter = cfg['AUG']['THREE_AUG']['COLOR_JITTER']
88
+ if color_jitter is not None and not color_jitter==0:
89
+ secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter))
90
+ final_tfl = [
91
+ T.ToTensor(),
92
+ T.Normalize(
93
+ mean=torch.tensor(mean),
94
+ std=torch.tensor(std))
95
+ ]
96
+ return T.Compose(primary_tfl+secondary_tfl+final_tfl)
97
+ elif 'TIMM_AUG' in cfg['AUG'] and cfg['AUG']['TIMM_AUG']['USE_TRANSFORM']:
98
+ logger.info('=> use timm transform for training')
99
+ timm_cfg = cfg['AUG']['TIMM_AUG']
100
+ transforms = create_transform(
101
+ input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
102
+ is_training=True,
103
+ use_prefetcher=False,
104
+ no_aug=False,
105
+ re_prob=timm_cfg.get('RE_PROB', 0.),
106
+ re_mode=timm_cfg.get('RE_MODE', 'const'),
107
+ re_count=timm_cfg.get('RE_COUNT', 1),
108
+ re_num_splits= 0 if not timm_cfg.get('RE_SPLITS', False) else timm_cfg['RE_SPLITS'], # if false or 0, return 0
109
+ scale=cfg['AUG'].get('SCALE', None),
110
+ ratio=cfg['AUG'].get('RATIO', None),
111
+ hflip=timm_cfg.get('HFLIP', 0.5),
112
+ vflip=timm_cfg.get('VFLIP', 0.),
113
+ color_jitter=timm_cfg.get('COLOR_JITTER', 0.4),
114
+ auto_augment=timm_cfg.get('AUTO_AUGMENT', None),
115
+ interpolation=cfg['AUG']['INTERPOLATION'],
116
+ mean=cfg['IMAGE_ENCODER']['IMAGE_MEAN'],
117
+ std=cfg['IMAGE_ENCODER']['IMAGE_STD'],
118
+ )
119
+ elif 'TORCHVISION_AUG' in cfg['AUG']:
120
+ logger.info('=> use torchvision transform fro training')
121
+ crop_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
122
+ interpolation = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
123
+ trans = [
124
+ T.RandomResizedCrop(
125
+ crop_size, scale=cfg['AUG']['SCALE'], ratio=cfg['AUG']['RATIO'],
126
+ interpolation=interpolation
127
+ )
128
+ ]
129
+ hflip_prob = cfg['AUG']['TORCHVISION_AUG']['HFLIP']
130
+ auto_augment_policy = cfg['AUG']['TORCHVISION_AUG'].get('AUTO_AUGMENT', None)
131
+ if hflip_prob > 0:
132
+ trans.append(T.RandomHorizontalFlip(hflip_prob))
133
+ if auto_augment_policy is not None:
134
+ if auto_augment_policy == "ra":
135
+ trans.append(RandAugment(interpolation=interpolation))
136
+ elif auto_augment_policy == "ta_wide":
137
+ trans.append(TrivialAugmentWide(interpolation=interpolation))
138
+ else:
139
+ aa_policy = AutoAugmentPolicy(auto_augment_policy)
140
+ trans.append(AutoAugment(policy=aa_policy, interpolation=interpolation))
141
+ trans.extend(
142
+ [
143
+ T.ToTensor(),
144
+ normalize,
145
+ ]
146
+ )
147
+ random_erase_prob = cfg['AUG']['TORCHVISION_AUG']['RE_PROB']
148
+ random_erase_scale = cfg['AUG']['TORCHVISION_AUG'].get('RE_SCALE', 0.33)
149
+ if random_erase_prob > 0:
150
+ # NCFC (4/26/2023): Added scale parameter to random erasing for medical imaging
151
+ trans.append(T.RandomErasing(p=random_erase_prob, scale = (0.02, random_erase_scale)))
152
+
153
+ from torchvision.transforms import InterpolationMode
154
+ rotation = cfg['AUG']['TORCHVISION_AUG'].get('ROTATION', 0.0)
155
+ if (rotation > 0.0):
156
+ trans.append(T.RandomRotation(rotation, interpolation=InterpolationMode.BILINEAR))
157
+ logger.info(" TORCH AUG: Rotation: " + str(rotation))
158
+
159
+ transforms = T.Compose(trans)
160
+ elif cfg['AUG'].get('RANDOM_CENTER_CROP', False):
161
+ logger.info('=> use random center crop data augmenation')
162
+ # precrop, crop = get_resolution(cfg.TRAIN.IMAGE_SIZE)
163
+ crop = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
164
+ padding = cfg['AUG'].get('RANDOM_CENTER_CROP_PADDING', 32)
165
+ precrop = crop + padding
166
+ mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
167
+ transforms = T.Compose([
168
+ T.Resize(
169
+ (precrop, precrop),
170
+ interpolation=mode
171
+ ),
172
+ T.RandomCrop((crop, crop)),
173
+ T.RandomHorizontalFlip(),
174
+ T.ToTensor(),
175
+ normalize,
176
+ ])
177
+ elif cfg['AUG'].get('MAE_FINETUNE_AUG', False):
178
+ mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
179
+ std = cfg['IMAGE_ENCODER']['IMAGE_STD']
180
+ transforms = create_transform(
181
+ input_size=cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0],
182
+ is_training=True,
183
+ color_jitter=cfg['AUG'].get('COLOR_JITTER', None),
184
+ auto_augment=cfg['AUG'].get('AUTO_AUGMENT', 'rand-m9-mstd0.5-inc1'),
185
+ interpolation='bicubic',
186
+ re_prob=cfg['AUG'].get('RE_PROB', 0.25),
187
+ re_mode=cfg['AUG'].get('RE_MODE', "pixel"),
188
+ re_count=cfg['AUG'].get('RE_COUNT', 1),
189
+ mean=mean,
190
+ std=std,
191
+ )
192
+ elif cfg['AUG'].get('MAE_PRETRAIN_AUG', False):
193
+ mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
194
+ std = cfg['IMAGE_ENCODER']['IMAGE_STD']
195
+ transforms = T.Compose([
196
+ T.RandomResizedCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0], scale=tuple(cfg['AUG']['SCALE']), interpolation=INTERPOLATION_MODES["bicubic"]), # 3 is bicubic
197
+ T.RandomHorizontalFlip(),
198
+ T.ToTensor(),
199
+ T.Normalize(mean=mean, std=std)])
200
+ elif cfg['AUG'].get('ThreeAugment', False): # from DeiT III
201
+ mean = cfg['IMAGE_ENCODER']['IMAGE_MEAN']
202
+ std = cfg['IMAGE_ENCODER']['IMAGE_STD']
203
+ img_size = cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]
204
+ remove_random_resized_crop = cfg['AUG'].get('src', False)
205
+ mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
206
+ primary_tfl = []
207
+ scale=(0.08, 1.0)
208
+ interpolation='bicubic'
209
+ if remove_random_resized_crop:
210
+ primary_tfl = [
211
+ T.Resize(img_size, interpolation=3), # bicubic
212
+ T.RandomCrop(img_size, padding=4,padding_mode='reflect'),
213
+ T.RandomHorizontalFlip()
214
+ ]
215
+ else:
216
+ primary_tfl = [
217
+ timm.data.transforms.RandomResizedCropAndInterpolation(
218
+ img_size, scale=scale, interpolation=interpolation),
219
+ T.RandomHorizontalFlip()
220
+ ]
221
+
222
+ secondary_tfl = [T.RandomChoice([deitIII_gray_scale(p=1.0),
223
+ deitIII_Solarization(p=1.0),
224
+ deitIII_GaussianBlur(p=1.0)])]
225
+ color_jitter = cfg['AUG']['COLOR_JITTER']
226
+ secondary_tfl.append(T.ColorJitter(color_jitter, color_jitter, color_jitter))
227
+ final_tfl = [
228
+ T.ToTensor(),
229
+ T.Normalize(
230
+ mean=torch.tensor(mean),
231
+ std=torch.tensor(std))
232
+ ]
233
+ transforms = T.Compose(primary_tfl+secondary_tfl+final_tfl)
234
+ logger.info('=> training transformers: {}'.format(transforms))
235
+ else:
236
+ mode = INTERPOLATION_MODES[cfg['AUG']['INTERPOLATION']]
237
+ if cfg['TEST']['CENTER_CROP']:
238
+ transforms = T.Compose([
239
+ T.Resize(
240
+ int(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0] / 0.875),
241
+ # the same behavior as in deit: size = int((256 / 224) * args.input_size)
242
+ # 224 / 256 = 0.875
243
+ interpolation=mode
244
+ ),
245
+ T.CenterCrop(cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]),
246
+ T.ToTensor(),
247
+ normalize,
248
+ ])
249
+ else:
250
+ transforms = T.Compose([
251
+ T.Resize(
252
+ (cfg['IMAGE_ENCODER']['IMAGE_SIZE'][1], cfg['IMAGE_ENCODER']['IMAGE_SIZE'][0]),
253
+ interpolation=mode
254
+ ),
255
+ T.ToTensor(),
256
+ normalize,
257
+ ])
258
+ logger.info('=> testing transformers: {}'.format(transforms))
259
+
260
+ return transforms
261
+
MedImageInsight/ImageDataLoader/transforms/threeaugment.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from PIL import ImageFilter, ImageOps
3
+ from torchvision import transforms
4
+
5
+
6
+ class deitIII_GaussianBlur(object):
7
+ """
8
+ Apply Gaussian Blur to the PIL image.
9
+ """
10
+ def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
11
+ self.prob = p
12
+ self.radius_min = radius_min
13
+ self.radius_max = radius_max
14
+
15
+ def __call__(self, img):
16
+ do_it = random.random() <= self.prob
17
+ if not do_it:
18
+ return img
19
+
20
+ img = img.filter(
21
+ ImageFilter.GaussianBlur(
22
+ radius=random.uniform(self.radius_min, self.radius_max)
23
+ )
24
+ )
25
+ return img
26
+
27
+
28
+ class deitIII_Solarization(object):
29
+ """
30
+ Apply Solarization to the PIL image.
31
+ """
32
+ def __init__(self, p=0.2):
33
+ self.p = p
34
+
35
+ def __call__(self, img):
36
+ if random.random() < self.p:
37
+ return ImageOps.solarize(img)
38
+ else:
39
+ return img
40
+
41
+
42
+ class deitIII_gray_scale(object):
43
+ """
44
+ Apply Solarization to the PIL image.
45
+ """
46
+ def __init__(self, p=0.2):
47
+ self.p = p
48
+ self.transf = transforms.Grayscale(3)
49
+
50
+ def __call__(self, img):
51
+ if random.random() < self.p:
52
+ return self.transf(img)
53
+ else:
54
+ return img
MedImageInsight/ImageDataLoader/tsv.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import os
6
+ from io import BytesIO
7
+ import json
8
+ import logging
9
+ import base64
10
+ import random
11
+ from typing import Callable, List, Tuple, Union, NamedTuple
12
+ from PIL import Image
13
+ from PIL import ImageFile
14
+ import torch.utils.data as data
15
+ from .languages.prompt_engineering import prompt_engineering
16
+ from .tsv_file import TSVFile, CompositeTSVFile
17
+
18
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class TSVDataset(data.Dataset):
24
+
25
+ def __init__(self,
26
+ tsv_file: Union[str, List[str]],
27
+ transform: Callable = None,
28
+ map_file: str = None,
29
+ token_file: str = None,
30
+ is_train: bool = True,
31
+ azcopy_path: str = None):
32
+ self.transform = transform
33
+ self._chunk_sizes = None
34
+ self.label2idx = self._load_map(map_file)
35
+ self.class_selector = list(self.label2idx.keys()) if self.label2idx else None
36
+
37
+ if isinstance(tsv_file, str):
38
+ if os.path.splitext(tsv_file)[1] == '.tsv':
39
+ self.tsv_file = TSVFile(
40
+ tsv_file, class_selector=self.class_selector
41
+ )
42
+ else:
43
+ self.tsv_file = CompositeTSVFile(
44
+ tsv_file,
45
+ class_selector=self.class_selector,
46
+ is_train=is_train,
47
+ sas_token_path=token_file,
48
+ azcopy_path=azcopy_path
49
+ )
50
+ self._chunk_sizes = self.tsv_file.get_chunk_size()
51
+ elif isinstance(tsv_file, list):
52
+ self.tsv_file = CompositeTSVFile(
53
+ tsv_file,
54
+ class_selector=self.class_selector,
55
+ is_train=is_train,
56
+ sas_token_path=token_file,
57
+ azcopy_path=azcopy_path
58
+ )
59
+ self._chunk_sizes = self.tsv_file.get_chunk_size()
60
+ else:
61
+ raise ValueError("Invalid input! Please check the tsv filenames")
62
+
63
+ logger.debug('=> {}\titems: {}'.format(tsv_file, len(self.tsv_file)))
64
+
65
+ def fetch_blob(self, idx):
66
+ image_tsv = self.tsv_file.file_list[idx]
67
+ self.tsv_file.blob_storage.fetch_blob(image_tsv)
68
+
69
+ def num_classes(self):
70
+ return len(self.class_selector)
71
+
72
+ def get_chunk_sizes(self):
73
+ return self._chunk_sizes
74
+
75
+ def get_class_boundaries(self):
76
+ # The samples of each class are organized class-by-class.
77
+ # _class_boundaries stores the lower- and upper-bound of each class.
78
+ return self.tsv_file.get_class_boundaries()
79
+
80
+ def get_filenames(self):
81
+ filenames = [
82
+ self.tsv_file.get_key(i)
83
+ for i in range(self.tsv_file.num_rows())
84
+ ]
85
+
86
+ return filenames
87
+
88
+ def _load_map(self, map_file: str):
89
+ if not map_file:
90
+ return None
91
+
92
+ label2idx = {}
93
+ with open(map_file) as f:
94
+ for line in f:
95
+ items = line.strip().split('\t')
96
+ label2idx[items[0]] = int(items[1])
97
+
98
+ return label2idx
99
+
100
+ def __getitem__(self, index: Union[int, Tuple[int, int]]):
101
+ items = self.tsv_file[index]
102
+ _, target, img = self._decode_data(items)
103
+
104
+ if self.transform:
105
+ img = self.transform(img)
106
+
107
+ return img, target
108
+
109
+ def _decode_data(self, items: Tuple[str, str, str]):
110
+ key = items[0]
111
+ label = self._get_label(items[1])
112
+ image = Image.open(BytesIO(base64.b64decode(items[2]))).convert('RGB')
113
+
114
+ return key, label, image
115
+
116
+ def _get_label(self, item: str):
117
+ if not self.label2idx:
118
+ return int(item)
119
+
120
+ js = json.loads(item)
121
+ return self.label2idx[js[0]['class']]
122
+
123
+ def __len__(self):
124
+ return len(self.tsv_file)
125
+
126
+
127
+ class TSVMeta(NamedTuple):
128
+ source: str
129
+ num_classes: int
130
+ task: str
131
+
132
+
133
+ class TSVImageTextDatasetV2(data.Dataset):
134
+ """
135
+ This class is intended for encapsulating Image/Text pair data for contrastive learning described in
136
+ the following paper,
137
+ "Learning Transferable Visual Models From Natural Language Supervision" (a.k.a CLIP)
138
+ V2: support image text pairs and supervised classification data
139
+ """
140
+
141
+ def __init__(self,
142
+ image_tsv_file: Union[str, List[str]],
143
+ text_tsv_file: Union[str, List[str]],
144
+ transform: Callable = None,
145
+ tokenize: Callable = None,
146
+ context_length: int = 77,
147
+ num_captions: int = 1,
148
+ text_format: str = 'txt',
149
+ is_train: bool = True,
150
+ sas_token_path: str = None,
151
+ azcopy_path: str = None,
152
+ metas: List[NamedTuple] = None,
153
+ prompt_engineering=True,
154
+ concat_queries=False):
155
+ self.transform = transform
156
+ self.tokenize = tokenize
157
+ self._chunk_sizes = None
158
+ self.context_length = context_length
159
+ self.num_captions = num_captions
160
+ self.text_format = text_format
161
+ self.tsv_file_list = []
162
+ self.metas = metas
163
+ self.label_offsets = self.build_label_offsets()
164
+ self.prompt_engineering = prompt_engineering
165
+ self.concat_queries = concat_queries
166
+
167
+ if isinstance(image_tsv_file, str) and isinstance(text_tsv_file, str):
168
+ # single tsv file
169
+ if (
170
+ os.path.splitext(image_tsv_file)[1].lower() == '.tsv'
171
+ and os.path.splitext(text_tsv_file)[1].lower() == '.tsv'
172
+ ):
173
+ self.tsv_file_list.append((image_tsv_file, text_tsv_file))
174
+ self.image_tsv_file = TSVFile(
175
+ image_tsv_file, if_generate_lineidx=True
176
+ )
177
+ self.text_tsv_file = TSVFile(
178
+ text_tsv_file, if_generate_lineidx=True
179
+ )
180
+ else:
181
+ raise ValueError("Invalid input! Please check the tsv filenames.")
182
+ # multiple tsv files specified in a list
183
+ elif (
184
+ isinstance(image_tsv_file, list)
185
+ and isinstance(text_tsv_file, list)
186
+ ):
187
+ assert len(image_tsv_file) == len(text_tsv_file), \
188
+ "Inconsistent number of Image/Text tsv files!"
189
+ self.tsv_file_list = [
190
+ (txt, img)
191
+ for img, txt in zip(image_tsv_file, text_tsv_file)
192
+ ]
193
+ self.image_tsv_file = CompositeTSVFile(
194
+ image_tsv_file,
195
+ is_train=is_train,
196
+ sas_token_path=sas_token_path,
197
+ azcopy_path=azcopy_path
198
+ )
199
+ self.text_tsv_file = CompositeTSVFile(
200
+ text_tsv_file,
201
+ is_train=is_train,
202
+ sas_token_path=sas_token_path,
203
+ azcopy_path=azcopy_path
204
+ )
205
+ self._chunk_sizes = self.image_tsv_file.get_chunk_size()
206
+ else:
207
+ raise ValueError("Invalid input! Please check the tsv filenames.")
208
+
209
+ assert len(self.image_tsv_file) == len(self.text_tsv_file), \
210
+ "Inconsistent size of Image/Text ({}/{}) data!".format(
211
+ len(self.image_tsv_file), len(self.text_tsv_file)
212
+ )
213
+
214
+ def build_label_offsets(self):
215
+ if self.metas is None:
216
+ return None
217
+
218
+ label_offsets = {}
219
+ offset = 1
220
+ for meta in self.metas:
221
+ print(meta)
222
+ print(label_offsets)
223
+ label_offsets[meta.source] = offset
224
+ offset += meta.num_classes
225
+
226
+ return label_offsets
227
+
228
+ def fetch_blob(self, idx):
229
+ # image_tsv, text_tsv = self.tsv_file_list[idx]
230
+ image_tsv = self.image_tsv_file.file_list[idx]
231
+ text_tsv = self.text_tsv_file.file_list[idx]
232
+ self.image_tsv_file.blob_storage.fetch_blob(image_tsv)
233
+ self.text_tsv_file.blob_storage.fetch_blob(text_tsv)
234
+
235
+ def get_chunk_sizes(self):
236
+ return self._chunk_sizes
237
+
238
+ def __getitem__(self, index: Union[int, Tuple[int, int]]):
239
+ if index is None:
240
+ import torch
241
+ return torch.tensor([], dtype=torch.float32), \
242
+ torch.tensor([], dtype=torch.int64), \
243
+ torch.tensor([], dtype=torch.int64)
244
+
245
+ items_image = self.image_tsv_file[index]
246
+ items_text = self.text_tsv_file[index]
247
+
248
+ assert items_text[0] == items_image[0], \
249
+ 'keys do not match for image and text {} vs {}'.format(
250
+ items_text[0], items_image[0]
251
+ )
252
+
253
+ _, img = self._decode_image(items_image)
254
+ _, txt, label = self._decode_text(items_text)
255
+
256
+ if self.transform:
257
+ img = self.transform(img)
258
+
259
+ tokens = self.tokenize(
260
+ txt, padding='max_length', truncation=True, max_length=self.context_length,
261
+ return_tensors='pt'
262
+ ) if self.tokenize else txt
263
+
264
+ tokens['input_ids'].squeeze_()
265
+ tokens['attention_mask'].squeeze_()
266
+
267
+ return img, tokens, label
268
+
269
+ def _decode_image(self, items: Tuple[str, str]):
270
+ key = items[0]
271
+ image = Image.open(BytesIO(base64.b64decode(items[1]))).convert('RGB')
272
+
273
+ return key, image
274
+
275
+ def _decode_text(self, items: Tuple[str, Union[str, dict]]):
276
+ key = items[0]
277
+ text = ''
278
+
279
+ if self.text_format != 'json':
280
+ raise ValueError('Only support json format')
281
+
282
+ # Do some reasonable handing of occasionally bad data.
283
+ try:
284
+ js = json.loads(items[1])
285
+ except Exception as e:
286
+
287
+ # empty dictionary
288
+ js = {}
289
+
290
+ # Record the data error in the log.
291
+ logger.info("JSON parsing error on: " + items[1])
292
+ logger.info(str(e))
293
+
294
+ # do not raise the exception
295
+ # raise e
296
+
297
+ # put some text in and continue processing data (do not kill job)
298
+ sstr = items[1].find("\"")
299
+ if (sstr < 0):
300
+ sstr = 0
301
+
302
+ estr = items[1][sstr:].find("\"")
303
+ if (estr < 0):
304
+ estr = len(items[1])
305
+
306
+ text = items[1][sstr:estr]
307
+ if (len(text) < 2):
308
+ text = "A picture showing some content."
309
+
310
+ label = 0
311
+
312
+ if 'captions' in js:
313
+ captions = js['captions']
314
+ if isinstance(captions, list):
315
+ if self.num_captions == 1:
316
+ text = random.choice(captions)
317
+ else:
318
+ text = captions
319
+ if len(captions) > self.num_captions:
320
+ text = captions[:self.num_captions]
321
+ elif isinstance(captions, str):
322
+ text = captions
323
+ else:
324
+ raise ValueError('captions should be str or list')
325
+ label = 0
326
+ elif 'tags' in js:
327
+ text = prompt_engineering(js['tags'])
328
+ label = 0
329
+ elif 'task' in js and js['task'] == 'classification':
330
+ if (self.prompt_engineering):
331
+ text = prompt_engineering(js['class_name'])
332
+ else:
333
+ text = js['class_name']
334
+ label = js['class_id']
335
+
336
+ if (self.label_offsets is not None):
337
+ if (js['source'] in self.label_offsets):
338
+ label += self.label_offsets[js['source']]
339
+
340
+ if (self.concat_queries):
341
+ if ('queries' in js) and (len(js['queries']) > 0):
342
+ q = ''
343
+ for item in js['queries']:
344
+ q = q + item + ' '
345
+
346
+ text = q + ', ' + text
347
+
348
+ return key, text, label
349
+
350
+ def __len__(self):
351
+ return len(self.image_tsv_file)
MedImageInsight/ImageDataLoader/tsv_file.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gc
3
+ import os
4
+ import os.path as op
5
+ import json
6
+ from typing import List
7
+ from .blob_storage import BlobStorage, disk_usage
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def generate_lineidx(filein: str, idxout: str) -> None:
13
+ idxout_tmp = idxout + '.tmp'
14
+ with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout:
15
+ fsize = os.fstat(tsvin.fileno()).st_size
16
+ fpos = 0
17
+ while fpos != fsize:
18
+ tsvout.write(str(fpos) + "\n")
19
+ tsvin.readline()
20
+ fpos = tsvin.tell()
21
+ os.rename(idxout_tmp, idxout)
22
+
23
+
24
+ def read_to_character(fp, c):
25
+ result = []
26
+ while True:
27
+ s = fp.read(32)
28
+ assert s != ''
29
+ if c in s:
30
+ result.append(s[: s.index(c)])
31
+ break
32
+ else:
33
+ result.append(s)
34
+ return ''.join(result)
35
+
36
+
37
+ class TSVFile(object):
38
+ def __init__(self,
39
+ tsv_file: str,
40
+ if_generate_lineidx: bool = True,
41
+ lineidx: str = None,
42
+ class_selector: List[str] = None,
43
+ blob_storage: BlobStorage = None):
44
+ self.tsv_file = tsv_file
45
+ self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' \
46
+ if not lineidx else lineidx
47
+ self.linelist = op.splitext(tsv_file)[0] + '.linelist'
48
+ self.chunks = op.splitext(tsv_file)[0] + '.chunks'
49
+ self._fp = None
50
+ self._lineidx = None
51
+ self._sample_indices = None
52
+ self._class_boundaries = None
53
+ self._class_selector = class_selector
54
+ self._blob_storage = blob_storage
55
+ self._len = None
56
+ # the process always keeps the process which opens the file.
57
+ # If the pid is not equal to the currrent pid, we will re-open the file.
58
+ self.pid = None
59
+ # generate lineidx if not exist
60
+ if not op.isfile(self.lineidx) and if_generate_lineidx:
61
+ generate_lineidx(self.tsv_file, self.lineidx)
62
+
63
+ def __del__(self):
64
+ self.gcidx()
65
+ if self._fp:
66
+ self._fp.close()
67
+ # physically remove the tsv file if it is retrieved by BlobStorage
68
+ if self._blob_storage and 'azcopy' in self.tsv_file and os.path.exists(self.tsv_file):
69
+ try:
70
+ original_usage = disk_usage('/')
71
+ os.remove(self.tsv_file)
72
+ logger.info("Purged %s (disk usage: %.2f%% => %.2f%%)" %
73
+ (self.tsv_file, original_usage, disk_usage('/') * 100))
74
+ except:
75
+ # Known issue: multiple threads attempting to delete the file will raise a FileNotFound error.
76
+ # TODO: try Threadling.Lock to better handle the race condition
77
+ pass
78
+
79
+ def __str__(self):
80
+ return "TSVFile(tsv_file='{}')".format(self.tsv_file)
81
+
82
+ def __repr__(self):
83
+ return str(self)
84
+
85
+ def gcidx(self):
86
+ logger.debug('Run gc collect')
87
+ self._lineidx = None
88
+ self._sample_indices = None
89
+ #self._class_boundaries = None
90
+ return gc.collect()
91
+
92
+ def get_class_boundaries(self):
93
+ return self._class_boundaries
94
+
95
+ def num_rows(self, gcf=False):
96
+ if (self._len is None):
97
+ self._ensure_lineidx_loaded()
98
+ retval = len(self._sample_indices)
99
+
100
+ if (gcf):
101
+ self.gcidx()
102
+
103
+ self._len = retval
104
+
105
+ return self._len
106
+
107
+ def seek(self, idx: int):
108
+ self._ensure_tsv_opened()
109
+ self._ensure_lineidx_loaded()
110
+ try:
111
+ pos = self._lineidx[self._sample_indices[idx]]
112
+ except:
113
+ logger.info('=> {}-{}'.format(self.tsv_file, idx))
114
+ raise
115
+ self._fp.seek(pos)
116
+ return [s.strip() for s in self._fp.readline().split('\t')]
117
+
118
+ def seek_first_column(self, idx: int):
119
+ self._ensure_tsv_opened()
120
+ self._ensure_lineidx_loaded()
121
+ pos = self._lineidx[idx]
122
+ self._fp.seek(pos)
123
+ return read_to_character(self._fp, '\t')
124
+
125
+ def get_key(self, idx: int):
126
+ return self.seek_first_column(idx)
127
+
128
+ def __getitem__(self, index: int):
129
+ return self.seek(index)
130
+
131
+ def __len__(self):
132
+ return self.num_rows()
133
+
134
+ def _ensure_lineidx_loaded(self):
135
+ if self._lineidx is None:
136
+ logger.debug('=> loading lineidx: {}'.format(self.lineidx))
137
+ with open(self.lineidx, 'r') as fp:
138
+ lines = fp.readlines()
139
+ lines = [line.strip() for line in lines]
140
+ self._lineidx = [int(line) for line in lines]
141
+
142
+ # read the line list if exists
143
+ linelist = None
144
+ if op.isfile(self.linelist):
145
+ with open(self.linelist, 'r') as fp:
146
+ linelist = sorted(
147
+ [
148
+ int(line.strip())
149
+ for line in fp.readlines()
150
+ ]
151
+ )
152
+
153
+ if op.isfile(self.chunks):
154
+ self._sample_indices = []
155
+ self._class_boundaries = []
156
+ class_boundaries = json.load(open(self.chunks, 'r'))
157
+ for class_name, boundary in class_boundaries.items():
158
+ start = len(self._sample_indices)
159
+ if class_name in self._class_selector:
160
+ for idx in range(boundary[0], boundary[1] + 1):
161
+ # NOTE: potentially slow when linelist is long, try to speed it up
162
+ if linelist and idx not in linelist:
163
+ continue
164
+ self._sample_indices.append(idx)
165
+ end = len(self._sample_indices)
166
+ self._class_boundaries.append((start, end))
167
+ else:
168
+ if linelist:
169
+ self._sample_indices = linelist
170
+ else:
171
+ self._sample_indices = list(range(len(self._lineidx)))
172
+
173
+ def _ensure_tsv_opened(self):
174
+ if self._fp is None:
175
+ if self._blob_storage:
176
+ self._fp = self._blob_storage.open(self.tsv_file)
177
+ else:
178
+ self._fp = open(self.tsv_file, 'r')
179
+ self.pid = os.getpid()
180
+
181
+ if self.pid != os.getpid():
182
+ logger.debug('=> re-open {} because the process id changed'.format(self.tsv_file))
183
+ self._fp = open(self.tsv_file, 'r')
184
+ self.pid = os.getpid()
185
+
186
+
187
+ class CompositeTSVFile:
188
+ def __init__(self,
189
+ file_list: List[str],
190
+ root: str = '.',
191
+ class_selector: List[str] = None,
192
+ is_train: bool = True,
193
+ sas_token_path: str = None,
194
+ azcopy_path: str = None):
195
+ self.root = root
196
+ self.tsvs = None
197
+ self.chunk_sizes = None
198
+ self.accum_chunk_sizes = None
199
+ self._class_selector = class_selector
200
+ self._class_boundaries = None
201
+ self.initialized = False
202
+ assert isinstance(file_list, list)
203
+ self.blob_storage = BlobStorage(is_train, sas_token_path, azcopy_path)
204
+ self.file_list = self.blob_storage.register_local_tsv_paths(file_list)
205
+ logger.info('=> Init CompositeTSVFile...')
206
+ self.initialize()
207
+ logger.info('=> Init CompositeTSVFile Done...')
208
+
209
+ def get_key(self, index: int):
210
+ idx_source, idx_row = self._calc_chunk_idx_row(index)
211
+ k = self.tsvs[idx_source].get_key(idx_row)
212
+ return '_'.join([self.file_list[idx_source], k])
213
+
214
+ def get_class_boundaries(self):
215
+ return self._class_boundaries
216
+
217
+ def get_chunk_size(self):
218
+ return self.chunk_sizes
219
+
220
+ def num_rows(self):
221
+ return sum(self.chunk_sizes)
222
+
223
+ def _calc_chunk_idx_row(self, index: int):
224
+ idx_chunk = 0
225
+ idx_row = index
226
+ while index >= self.accum_chunk_sizes[idx_chunk]:
227
+ idx_chunk += 1
228
+ idx_row = index - self.accum_chunk_sizes[idx_chunk-1]
229
+ return idx_chunk, idx_row
230
+
231
+ def __getitem__(self, index: int):
232
+ idx_source, idx_row = self._calc_chunk_idx_row(index)
233
+ if idx_source not in self.blob_storage:
234
+ self.blob_storage[idx_source] = TSVFile(
235
+ op.join(self.root, self.file_list[idx_source]),
236
+ class_selector=self._class_selector,
237
+ blob_storage=self.blob_storage,
238
+ if_generate_lineidx=True
239
+ )
240
+ return self.blob_storage[idx_source].seek(idx_row)
241
+
242
+ def __len__(self):
243
+ return sum(self.chunk_sizes)
244
+
245
+ def initialize(self):
246
+ """
247
+ this function has to be called in init function if cache_policy is
248
+ enabled. Thus, let's always call it in init funciton to make it simple.
249
+ """
250
+ if self.initialized:
251
+ return
252
+ self.tsvs = [
253
+ TSVFile(
254
+ op.join(self.root, f),
255
+ class_selector=self._class_selector
256
+ ) for f in self.file_list
257
+ ]
258
+ logger.debug("=> Calculating chunk sizes ...")
259
+ self.chunk_sizes = [tsv.num_rows(gcf=True) for tsv in self.tsvs]
260
+
261
+ self.accum_chunk_sizes = [0]
262
+ for size in self.chunk_sizes:
263
+ self.accum_chunk_sizes += [self.accum_chunk_sizes[-1] + size]
264
+ self.accum_chunk_sizes = self.accum_chunk_sizes[1:]
265
+
266
+ if (
267
+ self._class_selector
268
+ and all([tsv.get_class_boundaries() for tsv in self.tsvs])
269
+ ):
270
+ """
271
+ Note: When using CompositeTSVFile, make sure that the classes contained in each
272
+ tsv file do not overlap. Otherwise, the class boundaries won't be correct.
273
+ """
274
+ self._class_boundaries = []
275
+ offset = 0
276
+ for tsv in self.tsvs:
277
+ boundaries = tsv.get_class_boundaries()
278
+ for bound in boundaries:
279
+ self._class_boundaries.append((bound[0] + offset, bound[1] + offset))
280
+ offset += len(tsv)
281
+ self.initialized = True
282
+
283
+
284
+ def load_list_file(fname: str) -> List[str]:
285
+ with open(fname, 'r') as fp:
286
+ lines = fp.readlines()
287
+ result = [line.strip() for line in lines]
288
+ if len(result) > 0 and result[-1] == '':
289
+ result = result[:-1]
290
+ return result
MedImageInsight/ImageDataLoader/zipdata.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as op
2
+ from zipfile import ZipFile, BadZipFile
3
+ import torch.utils.data as data
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import multiprocessing
7
+
8
+ _VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
9
+
10
+
11
+ class ZipData(data.Dataset):
12
+ _IGNORE_ATTRS = {'_zip_file'}
13
+
14
+ def __init__(self, path, map_file,
15
+ transform=None, target_transform=None,
16
+ extensions=None):
17
+ self._path = path
18
+ if not extensions:
19
+ extensions = _VALID_IMAGE_TYPES
20
+ self._zip_file = ZipFile(path)
21
+ self.zip_dict = {}
22
+ self.samples = []
23
+ self.transform = transform
24
+ self.target_transform = target_transform
25
+ self.class_to_idx = {}
26
+ with open(map_file, 'r') as f:
27
+ for line in iter(f.readline, ""):
28
+ line = line.strip()
29
+ if not line:
30
+ continue
31
+ cls_idx = [l for l in line.split('\t') if l]
32
+ if not cls_idx:
33
+ continue
34
+ if (len(cls_idx) < 2):
35
+ cls_idx = [l for l in line.split(' ') if l]
36
+ if not cls_idx:
37
+ continue
38
+ assert len(cls_idx) >= 2, "invalid line: {}".format(line)
39
+ idx = int(cls_idx[1])
40
+ cls = cls_idx[0]
41
+ del cls_idx
42
+ at_idx = cls.find('@')
43
+ assert at_idx >= 0, "invalid class: {}".format(cls)
44
+ cls = cls[at_idx + 1:]
45
+ if cls.startswith('/'):
46
+ # Python ZipFile expects no root
47
+ cls = cls[1:]
48
+ assert cls, "invalid class in line {}".format(line)
49
+ prev_idx = self.class_to_idx.get(cls)
50
+ assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format(
51
+ cls, idx, prev_idx
52
+ )
53
+ self.class_to_idx[cls] = idx
54
+
55
+ for fst in self._zip_file.infolist():
56
+ fname = fst.filename
57
+ target = self.class_to_idx.get(fname)
58
+ if target is None:
59
+ continue
60
+ if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
61
+ continue
62
+ ext = op.splitext(fname)[1].lower()
63
+ if ext in extensions:
64
+ self.samples.append((fname, target))
65
+ assert len(self), "No images found in: {} with map: {}".format(self._path, map_file)
66
+
67
+ def __repr__(self):
68
+ return 'ZipData({}, size={})'.format(self._path, len(self))
69
+
70
+ def __getstate__(self):
71
+ return {
72
+ key: val if key not in self._IGNORE_ATTRS else None
73
+ for key, val in self.__dict__.iteritems()
74
+ }
75
+
76
+ def __getitem__(self, index):
77
+ proc = multiprocessing.current_process()
78
+ pid = proc.pid # get pid of this process.
79
+ if pid not in self.zip_dict:
80
+ self.zip_dict[pid] = ZipFile(self._path)
81
+ zip_file = self.zip_dict[pid]
82
+
83
+ if index >= len(self) or index < 0:
84
+ raise KeyError("{} is invalid".format(index))
85
+ path, target = self.samples[index]
86
+ try:
87
+ sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB')
88
+ except BadZipFile:
89
+ print("bad zip file")
90
+ return None, None
91
+ if self.transform is not None:
92
+ sample = self.transform(sample)
93
+ if self.target_transform is not None:
94
+ target = self.target_transform(target)
95
+ return sample, target
96
+
97
+ def __len__(self):
98
+ return len(self.samples)
MedImageInsight/ImageEncoder/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ from .build import build_image_encoder
6
+
7
+ from .coswin import *
8
+ from .davit_v1 import *
MedImageInsight/ImageEncoder/build.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .registry import image_encoders
2
+ from .registry import is_image_encoder
3
+
4
+
5
+ def build_image_encoder(config_encoder, verbose, **kwargs):
6
+ model_name = config_encoder['NAME']
7
+ if model_name.startswith('cls_'):
8
+ model_name = model_name[4:]
9
+
10
+ if not is_image_encoder(model_name):
11
+ raise ValueError(f'Unkown model: {model_name}')
12
+
13
+ return image_encoders(model_name)(config_encoder, verbose, **kwargs)
MedImageInsight/ImageEncoder/coswin.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # CoSwin: Convolutional Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # Modified by Bin Xiao
7
+ # --------------------------------------------------------
8
+
9
+ import logging
10
+ import os
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.utils.checkpoint as checkpoint
14
+ import numpy as np
15
+ from einops import rearrange, repeat
16
+ from einops.layers.torch import Rearrange
17
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
18
+
19
+ from .registry import register_image_encoder
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+
25
+ class Mlp(nn.Module):
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ def window_partition(x, window_size):
45
+ """
46
+ Args:
47
+ x: (B, H, W, C)
48
+ window_size (int): window size
49
+
50
+ Returns:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ """
53
+ B, H, W, C = x.shape
54
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
55
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
56
+ return windows
57
+
58
+
59
+ def window_reverse(windows, window_size, H, W):
60
+ """
61
+ Args:
62
+ windows: (num_windows*B, window_size, window_size, C)
63
+ window_size (int): Window size
64
+ H (int): Height of image
65
+ W (int): Width of image
66
+
67
+ Returns:
68
+ x: (B, H, W, C)
69
+ """
70
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
71
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
72
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
73
+ return x
74
+
75
+
76
+ class WindowAttention(nn.Module):
77
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
78
+ It supports both of shifted and non-shifted window.
79
+
80
+ Args:
81
+ dim (int): Number of input channels.
82
+ window_size (tuple[int]): The height and width of the window.
83
+ num_heads (int): Number of attention heads.
84
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
85
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
86
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
87
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
88
+ """
89
+
90
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
91
+
92
+ super().__init__()
93
+ self.dim = dim
94
+ self.window_size = window_size # Wh, Ww
95
+ self.num_heads = num_heads
96
+ head_dim = dim // num_heads
97
+ self.scale = qk_scale or head_dim ** -0.5
98
+
99
+ # define a parameter table of relative position bias
100
+ self.relative_position_bias_table = nn.Parameter(
101
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
102
+
103
+ # get pair-wise relative position index for each token inside the window
104
+ coords_h = torch.arange(self.window_size[0])
105
+ coords_w = torch.arange(self.window_size[1])
106
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
107
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
108
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
109
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
110
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
111
+ relative_coords[:, :, 1] += self.window_size[1] - 1
112
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
113
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
114
+ self.register_buffer("relative_position_index", relative_position_index)
115
+
116
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
117
+ self.attn_drop = nn.Dropout(attn_drop)
118
+ self.proj = nn.Linear(dim, dim)
119
+ self.proj_drop = nn.Dropout(proj_drop)
120
+
121
+ trunc_normal_(self.relative_position_bias_table, std=.02)
122
+ self.softmax = nn.Softmax(dim=-1)
123
+
124
+ def forward(self, x, mask=None):
125
+ """
126
+ Args:
127
+ x: input features with shape of (num_windows*B, N, C)
128
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
129
+ """
130
+ B_, N, C = x.shape
131
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
132
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
133
+
134
+ q = q * self.scale
135
+ attn = (q @ k.transpose(-2, -1))
136
+
137
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
138
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
139
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
140
+ attn = attn + relative_position_bias.unsqueeze(0)
141
+
142
+ if mask is not None:
143
+ nW = mask.shape[0]
144
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
145
+ attn = attn.view(-1, self.num_heads, N, N)
146
+ attn = self.softmax(attn)
147
+ else:
148
+ attn = self.softmax(attn)
149
+
150
+ attn = self.attn_drop(attn)
151
+
152
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
153
+ x = self.proj(x)
154
+ x = self.proj_drop(x)
155
+ return x
156
+
157
+ def extra_repr(self) -> str:
158
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
159
+
160
+ def flops(self, N):
161
+ # calculate flops for 1 window with token length of N
162
+ flops = 0
163
+ # qkv = self.qkv(x)
164
+ flops += N * self.dim * 3 * self.dim
165
+ # attn = (q @ k.transpose(-2, -1))
166
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
167
+ # x = (attn @ v)
168
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
169
+ # x = self.proj(x)
170
+ flops += N * self.dim * self.dim
171
+ return flops
172
+
173
+
174
+ class SwinTransformerBlock(nn.Module):
175
+ r""" Swin Transformer Block.
176
+
177
+ Args:
178
+ dim (int): Number of input channels.
179
+ input_resolution (tuple[int]): Input resulotion.
180
+ num_heads (int): Number of attention heads.
181
+ window_size (int): Window size.
182
+ shift_size (int): Shift size for SW-MSA.
183
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
184
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
185
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
186
+ drop (float, optional): Dropout rate. Default: 0.0
187
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
188
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
189
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
190
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
191
+ """
192
+
193
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
194
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
195
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False):
196
+ super().__init__()
197
+ self.dim = dim
198
+ self.input_resolution = input_resolution
199
+ self.num_heads = num_heads
200
+ self.window_size = window_size
201
+ self.shift_size = shift_size
202
+ self.mlp_ratio = mlp_ratio
203
+ if min(self.input_resolution) <= self.window_size:
204
+ # if window size is larger than input resolution, we don't partition windows
205
+ self.shift_size = 0
206
+ self.window_size = min(self.input_resolution)
207
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
208
+
209
+ self.norm1 = norm_layer(dim)
210
+ self.attn = WindowAttention(
211
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
212
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
213
+
214
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
215
+ self.norm2 = norm_layer(dim)
216
+ mlp_hidden_dim = int(dim * mlp_ratio)
217
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
218
+
219
+ if self.shift_size > 0:
220
+ # calculate attention mask for SW-MSA
221
+ H, W = self.input_resolution
222
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
223
+ h_slices = (slice(0, -self.window_size),
224
+ slice(-self.window_size, -self.shift_size),
225
+ slice(-self.shift_size, None))
226
+ w_slices = (slice(0, -self.window_size),
227
+ slice(-self.window_size, -self.shift_size),
228
+ slice(-self.shift_size, None))
229
+ cnt = 0
230
+ for h in h_slices:
231
+ for w in w_slices:
232
+ img_mask[:, h, w, :] = cnt
233
+ cnt += 1
234
+
235
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
236
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
237
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
238
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
239
+ else:
240
+ attn_mask = None
241
+
242
+ self.gamma = 1.0
243
+ if layer_scale:
244
+ logger.info('=> enable layer scale')
245
+ self.gamma = nn.Parameter(
246
+ 1e-4*torch.ones(dim), requires_grad=True
247
+ )
248
+
249
+ self.register_buffer("attn_mask", attn_mask)
250
+
251
+ def forward(self, x):
252
+ H, W = self.input_resolution
253
+ B, L, C = x.shape
254
+ assert L == H * W, "input feature has wrong size"
255
+
256
+ shortcut = x
257
+ x = self.norm1(x)
258
+ x = x.view(B, H, W, C)
259
+
260
+ # cyclic shift
261
+ if self.shift_size > 0:
262
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
263
+ else:
264
+ shifted_x = x
265
+
266
+ # partition windows
267
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
268
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
269
+
270
+ # W-MSA/SW-MSA
271
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
272
+
273
+ # merge windows
274
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
275
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
276
+
277
+ # reverse cyclic shift
278
+ if self.shift_size > 0:
279
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
280
+ else:
281
+ x = shifted_x
282
+ x = x.view(B, H * W, C)
283
+
284
+ # FFN
285
+ x = shortcut + self.drop_path(self.gamma*x)
286
+ x = x + self.drop_path(self.gamma*self.mlp(self.norm2(x)))
287
+
288
+ return x
289
+
290
+ def extra_repr(self) -> str:
291
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
292
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
293
+
294
+ def flops(self):
295
+ flops = 0
296
+ H, W = self.input_resolution
297
+ # norm1
298
+ flops += self.dim * H * W
299
+ # W-MSA/SW-MSA
300
+ nW = H * W / self.window_size / self.window_size
301
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
302
+ # mlp
303
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
304
+ # norm2
305
+ flops += self.dim * H * W
306
+ return flops
307
+
308
+
309
+ class PatchMerging(nn.Module):
310
+ r""" Patch Merging Layer.
311
+
312
+ Args:
313
+ input_resolution (tuple[int]): Resolution of input feature.
314
+ dim (int): Number of input channels.
315
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
316
+ """
317
+
318
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
319
+ super().__init__()
320
+ self.input_resolution = input_resolution
321
+ self.dim = dim
322
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
323
+ self.norm = norm_layer(4 * dim)
324
+
325
+ def forward(self, x):
326
+ """
327
+ x: B, H*W, C
328
+ """
329
+ H, W = self.input_resolution
330
+ B, L, C = x.shape
331
+ assert L == H * W, "input feature has wrong size"
332
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
333
+
334
+ x = x.view(B, H, W, C)
335
+
336
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
337
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
338
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
339
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
340
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
341
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
342
+
343
+ x = self.norm(x)
344
+ x = self.reduction(x)
345
+
346
+ return x
347
+
348
+ def extra_repr(self) -> str:
349
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
350
+
351
+ def flops(self):
352
+ H, W = self.input_resolution
353
+ flops = H * W * self.dim
354
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
355
+ return flops
356
+
357
+
358
+ class BasicLayer(nn.Module):
359
+ """ A basic Swin Transformer layer for one stage.
360
+
361
+ Args:
362
+ dim (int): Number of input channels.
363
+ input_resolution (tuple[int]): Input resolution.
364
+ depth (int): Number of blocks.
365
+ num_heads (int): Number of attention heads.
366
+ window_size (int): Local window size.
367
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
368
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
369
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
370
+ drop (float, optional): Dropout rate. Default: 0.0
371
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
372
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
373
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
374
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
375
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
376
+ """
377
+
378
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
379
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
380
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
381
+ use_checkpoint=False, layer_scale=False):
382
+
383
+ super().__init__()
384
+ self.dim = dim
385
+ self.input_resolution = input_resolution
386
+ self.depth = depth
387
+ self.use_checkpoint = use_checkpoint
388
+
389
+ # build blocks
390
+ self.blocks = nn.ModuleList([
391
+ SwinTransformerBlock(
392
+ dim=dim, input_resolution=input_resolution,
393
+ num_heads=num_heads, window_size=window_size,
394
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
395
+ mlp_ratio=mlp_ratio,
396
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
397
+ drop=drop, attn_drop=attn_drop,
398
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
399
+ norm_layer=norm_layer,
400
+ layer_scale=layer_scale
401
+ )
402
+ for i in range(depth)])
403
+
404
+ # patch merging layer
405
+ if downsample is not None:
406
+ # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
407
+ self.downsample = downsample(
408
+ input_resolution=input_resolution, patch_size=3, in_chans=dim, embed_dim=dim*2,
409
+ stride=2, padding=1, norm_layer=norm_layer
410
+ )
411
+ else:
412
+ self.downsample = None
413
+
414
+ def forward(self, x):
415
+ for blk in self.blocks:
416
+ if self.use_checkpoint:
417
+ x = checkpoint.checkpoint(blk, x)
418
+ else:
419
+ x = blk(x)
420
+ if self.downsample is not None:
421
+ x = self.downsample(x)
422
+ return x
423
+
424
+ def extra_repr(self) -> str:
425
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
426
+
427
+ def flops(self):
428
+ flops = 0
429
+ for blk in self.blocks:
430
+ flops += blk.flops()
431
+ if self.downsample is not None:
432
+ flops += self.downsample.flops()
433
+ return flops
434
+
435
+
436
+ class PatchEmbed(nn.Module):
437
+ r""" Image to Patch Embedding
438
+
439
+ Args:
440
+ img_size (int): Image size. Default: 224.
441
+ patch_size (int): Patch token size. Default: 4.
442
+ in_chans (int): Number of input image channels. Default: 3.
443
+ embed_dim (int): Number of linear projection output channels. Default: 96.
444
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
445
+ """
446
+
447
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
448
+ super().__init__()
449
+ img_size = to_2tuple(img_size)
450
+ patch_size = to_2tuple(patch_size)
451
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
452
+ self.img_size = img_size
453
+ self.patch_size = patch_size
454
+ self.patches_resolution = patches_resolution
455
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
456
+
457
+ self.in_chans = in_chans
458
+ self.embed_dim = embed_dim
459
+
460
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
461
+ if norm_layer is not None:
462
+ self.norm = norm_layer(embed_dim)
463
+ else:
464
+ self.norm = None
465
+
466
+ def forward(self, x):
467
+ B, C, H, W = x.shape
468
+ # FIXME look at relaxing size constraints
469
+ assert H == self.img_size[0] and W == self.img_size[1], \
470
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
471
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
472
+ if self.norm is not None:
473
+ x = self.norm(x)
474
+ return x
475
+
476
+ def flops(self):
477
+ Ho, Wo = self.patches_resolution
478
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
479
+ if self.norm is not None:
480
+ flops += Ho * Wo * self.embed_dim
481
+ return flops
482
+
483
+
484
+ class ConvEmbed(nn.Module):
485
+ """ Image to Patch Embedding
486
+ """
487
+
488
+ def __init__(
489
+ self,
490
+ input_resolution=(224,224),
491
+ patch_size=7,
492
+ in_chans=3,
493
+ embed_dim=64,
494
+ stride=4,
495
+ padding=2,
496
+ norm_layer=None
497
+ ):
498
+ super().__init__()
499
+ self.patch_size = patch_size
500
+ self.input_resolution = input_resolution
501
+
502
+ self.proj = nn.Conv2d(
503
+ in_chans, embed_dim,
504
+ kernel_size=patch_size,
505
+ stride=stride,
506
+ padding=padding
507
+ )
508
+ self.norm = norm_layer(embed_dim) if norm_layer else None
509
+
510
+ def forward(self, x):
511
+ if len(x.size()) == 3:
512
+ x = rearrange(
513
+ x, 'b (h w) c -> b c h w',
514
+ h=self.input_resolution[0],
515
+ w=self.input_resolution[1]
516
+ )
517
+
518
+ x = self.proj(x)
519
+
520
+ B, C, H, W = x.shape
521
+ x = rearrange(x, 'b c h w -> b (h w) c')
522
+ if self.norm:
523
+ x = self.norm(x)
524
+
525
+ return x
526
+
527
+
528
+ class SwinTransformer(nn.Module):
529
+ r""" Swin Transformer
530
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
531
+ https://arxiv.org/pdf/2103.14030
532
+
533
+ Args:
534
+ img_size (int | tuple(int)): Input image size. Default 224
535
+ patch_size (int | tuple(int)): Patch size. Default: 4
536
+ in_chans (int): Number of input image channels. Default: 3
537
+ num_classes (int): Number of classes for classification head. Default: 1000
538
+ embed_dim (int): Patch embedding dimension. Default: 96
539
+ depths (tuple(int)): Depth of each Swin Transformer layer.
540
+ num_heads (tuple(int)): Number of attention heads in different layers.
541
+ window_size (int): Window size. Default: 7
542
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
543
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
544
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
545
+ drop_rate (float): Dropout rate. Default: 0
546
+ attn_drop_rate (float): Attention dropout rate. Default: 0
547
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
548
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
549
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
550
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
551
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
552
+ """
553
+
554
+ def __init__(self, img_size=224, patch_size=7, patch_padding=2, patch_stride=4, in_chans=3,
555
+ num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
556
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
557
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
558
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
559
+ use_checkpoint=False, layer_scale=False, **kwargs):
560
+ super().__init__()
561
+
562
+ self.num_classes = num_classes
563
+ self.num_layers = len(depths)
564
+ self.embed_dim = embed_dim
565
+ self.ape = ape
566
+ self.patch_norm = patch_norm
567
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
568
+ self.mlp_ratio = mlp_ratio
569
+
570
+ # split image into non-overlapping patches
571
+ # self.patch_embed = PatchEmbed(
572
+ # img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
573
+ # norm_layer=norm_layer if self.patch_norm else None)
574
+
575
+ self.patch_embed = ConvEmbed(
576
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, padding=patch_padding,
577
+ norm_layer=norm_layer if self.patch_norm else None
578
+ )
579
+
580
+ img_size = to_2tuple(img_size)
581
+ patches_resolution = (
582
+ int(np.floor(float(img_size[0]+2*patch_padding-patch_size)/patch_stride+1)),
583
+ int(np.floor(float(img_size[0]+2*patch_padding-patch_size)/patch_stride+1))
584
+ )
585
+ num_patches = patches_resolution[0] * patches_resolution[1]
586
+ # num_patches = self.patch_embed.num_patches
587
+ # patches_resolution = self.patch_embed.patches_resolution
588
+ self.patches_resolution = patches_resolution
589
+
590
+ # absolute position embedding
591
+ if self.ape:
592
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
593
+ trunc_normal_(self.absolute_pos_embed, std=.02)
594
+
595
+ self.pos_drop = nn.Dropout(p=drop_rate)
596
+
597
+ # stochastic depth
598
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
599
+
600
+ # build layers
601
+ self.layers = nn.ModuleList()
602
+ for i_layer in range(self.num_layers):
603
+ layer = BasicLayer(
604
+ dim=int(embed_dim * 2 ** i_layer),
605
+ input_resolution=(
606
+ patches_resolution[0] // (2 ** i_layer),
607
+ patches_resolution[1] // (2 ** i_layer)
608
+ ),
609
+ depth=depths[i_layer],
610
+ num_heads=num_heads[i_layer],
611
+ window_size=window_size,
612
+ mlp_ratio=self.mlp_ratio,
613
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
614
+ drop=drop_rate, attn_drop=attn_drop_rate,
615
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
616
+ norm_layer=norm_layer,
617
+ # downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
618
+ downsample=ConvEmbed if (i_layer < self.num_layers - 1) else None,
619
+ use_checkpoint=use_checkpoint,
620
+ layer_scale=layer_scale
621
+ )
622
+ self.layers.append(layer)
623
+
624
+ self.norm = norm_layer(self.num_features)
625
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
626
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
627
+
628
+ self.apply(self._init_weights)
629
+
630
+ @property
631
+ def dim_out(self):
632
+ return self.num_features
633
+
634
+ def _init_weights(self, m):
635
+ if isinstance(m, nn.Linear):
636
+ trunc_normal_(m.weight, std=.02)
637
+ if isinstance(m, nn.Linear) and m.bias is not None:
638
+ nn.init.constant_(m.bias, 0)
639
+ elif isinstance(m, nn.LayerNorm):
640
+ nn.init.constant_(m.bias, 0)
641
+ nn.init.constant_(m.weight, 1.0)
642
+
643
+ def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
644
+ if os.path.isfile(pretrained):
645
+ logging.info(f'=> loading pretrained model {pretrained}')
646
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
647
+
648
+ self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
649
+
650
+ def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
651
+ model_dict = self.state_dict()
652
+ stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
653
+
654
+ pretrained_dict = {
655
+ stripped_key(k): v for k, v in pretrained_dict.items()
656
+ if stripped_key(k) in model_dict.keys()
657
+ }
658
+ need_init_state_dict = {}
659
+ for k, v in pretrained_dict.items():
660
+ need_init = (
661
+ (
662
+ k.split('.')[0] in pretrained_layers
663
+ or pretrained_layers[0] == '*'
664
+ )
665
+ and 'relative_position_index' not in k
666
+ and 'attn_mask' not in k
667
+ )
668
+
669
+ if need_init:
670
+ if verbose:
671
+ logger.info(f'=> init {k} from pretrained state dict')
672
+
673
+ if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():
674
+ relative_position_bias_table_pretrained = v
675
+ relative_position_bias_table_current = model_dict[k]
676
+ L1, nH1 = relative_position_bias_table_pretrained.size()
677
+ L2, nH2 = relative_position_bias_table_current.size()
678
+ if nH1 != nH2:
679
+ logger.info(f"Error in loading {k}, passing")
680
+ else:
681
+ if L1 != L2:
682
+ logger.info(
683
+ '=> load_pretrained: resized variant: {} to {}'
684
+ .format((L1, nH1), (L2, nH2))
685
+ )
686
+ S1 = int(L1 ** 0.5)
687
+ S2 = int(L2 ** 0.5)
688
+ relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
689
+ relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
690
+ size=(S2, S2),
691
+ mode='bicubic')
692
+ v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
693
+
694
+ if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():
695
+ absolute_pos_embed_pretrained = v
696
+ absolute_pos_embed_current = model_dict[k]
697
+ _, L1, C1 = absolute_pos_embed_pretrained.size()
698
+ _, L2, C2 = absolute_pos_embed_current.size()
699
+ if C1 != C1:
700
+ logger.info(f"Error in loading {k}, passing")
701
+ else:
702
+ if L1 != L2:
703
+ logger.info(
704
+ '=> load_pretrained: resized variant: {} to {}'
705
+ .format((1, L1, C1), (1, L2, C2))
706
+ )
707
+ S1 = int(L1 ** 0.5)
708
+ S2 = int(L2 ** 0.5)
709
+ absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
710
+ absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
711
+ absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
712
+ absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
713
+ v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)
714
+
715
+ need_init_state_dict[k] = v
716
+ self.load_state_dict(need_init_state_dict, strict=False)
717
+
718
+ @torch.jit.ignore
719
+ def no_weight_decay(self):
720
+ return {'absolute_pos_embed'}
721
+
722
+ @torch.jit.ignore
723
+ def no_weight_decay_keywords(self):
724
+ return {'relative_position_bias_table'}
725
+
726
+ def forward_features(self, x):
727
+ x = self.patch_embed(x)
728
+ if self.ape:
729
+ x = x + self.absolute_pos_embed
730
+ x = self.pos_drop(x)
731
+
732
+ for layer in self.layers:
733
+ x = layer(x)
734
+
735
+ x = self.norm(x) # B L C
736
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
737
+ x = torch.flatten(x, 1)
738
+ return x
739
+
740
+ def forward(self, x):
741
+ x = self.forward_features(x)
742
+ x = self.head(x)
743
+ return x
744
+
745
+
746
+ @register_image_encoder
747
+ def image_encoder(config_encoder, verbose, **kwargs):
748
+ spec = config_encoder['SPEC']
749
+
750
+ coswin = SwinTransformer(
751
+ img_size=config_encoder['IMAGE_SIZE'],
752
+ patch_size=spec['PATCH_SIZE'],
753
+ patch_padding=spec['PATCH_PADDING'],
754
+ patch_stride=spec['PATCH_STRIDE'],
755
+ in_chans=spec['IN_CHANS'],
756
+ num_classes=0,
757
+ embed_dim=spec['EMBED_DIM'],
758
+ depths=spec['DEPTHS'],
759
+ num_heads=spec['NUM_HEADS'],
760
+ window_size=spec['WINDOW_SIZE'],
761
+ mlp_ratio=spec['MLP_RATIO'],
762
+ qkv_bias=spec['QKV_BIAS'],
763
+ qk_scale=spec.get('QK_SCALE', None),
764
+ drop_rate=spec['DROP_RATE'],
765
+ drop_path_rate=spec['DROP_PATH_RATE'],
766
+ ape=spec['APE'],
767
+ patch_norm=spec['PATCH_NORM'],
768
+ layer_scale=spec.get('LAYER_SCALE', False),
769
+ use_checkpoint=spec.get('ENABLE_CHECKPOINT', False)
770
+ )
771
+
772
+ if config_encoder['LOAD_PRETRAINED']:
773
+ coswin.from_pretrained(
774
+ config_encoder['PRETRAINED'],
775
+ config_encoder['PRETRAINED_LAYERS'],
776
+ verbose
777
+ )
778
+
779
+ return coswin
MedImageInsight/ImageEncoder/davit_v1.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import copy
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+ from collections import OrderedDict
10
+
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath, trunc_normal_
13
+
14
+ # helper methods
15
+ from .registry import register_image_encoder
16
+
17
+ import mup.init
18
+ from mup import MuReadout, set_base_shapes
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class MySequential(nn.Sequential):
24
+ def forward(self, *inputs):
25
+ for module in self._modules.values():
26
+ if type(inputs) == tuple:
27
+ inputs = module(*inputs)
28
+ else:
29
+ inputs = module(inputs)
30
+ return inputs
31
+
32
+
33
+ class PreNorm(nn.Module):
34
+ def __init__(self, norm, fn, drop_path=None):
35
+ super().__init__()
36
+ self.norm = norm
37
+ self.fn = fn
38
+ self.drop_path = drop_path
39
+
40
+ def forward(self, x, *args, **kwargs):
41
+ shortcut = x
42
+ if self.norm != None:
43
+ x, size = self.fn(self.norm(x), *args, **kwargs)
44
+ else:
45
+ x, size = self.fn(x, *args, **kwargs)
46
+
47
+ if self.drop_path:
48
+ x = self.drop_path(x)
49
+
50
+ x = shortcut + x
51
+
52
+ return x, size
53
+
54
+
55
+ class Mlp(nn.Module):
56
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ in_features,
62
+ hidden_features=None,
63
+ out_features=None,
64
+ act_layer=nn.GELU,
65
+ ):
66
+ super().__init__()
67
+ out_features = out_features or in_features
68
+ hidden_features = hidden_features or in_features
69
+ self.net = nn.Sequential(OrderedDict([
70
+ ("fc1", nn.Linear(in_features, hidden_features)),
71
+ ("act", act_layer()),
72
+ ("fc2", nn.Linear(hidden_features, out_features))
73
+ ]))
74
+
75
+ def forward(self, x, size):
76
+ return self.net(x), size
77
+
78
+
79
+ class DepthWiseConv2d(nn.Module):
80
+ def __init__(
81
+ self,
82
+ dim_in,
83
+ kernel_size,
84
+ padding,
85
+ stride,
86
+ bias=True,
87
+ ):
88
+ super().__init__()
89
+ self.dw = nn.Conv2d(
90
+ dim_in, dim_in,
91
+ kernel_size=kernel_size,
92
+ padding=padding,
93
+ groups=dim_in,
94
+ stride=stride,
95
+ bias=bias
96
+ )
97
+
98
+ def forward(self, x, size):
99
+ B, N, C = x.shape
100
+ H, W = size
101
+ assert N == H * W
102
+
103
+ x = self.dw(x.transpose(1, 2).view(B, C, H, W))
104
+ size = (x.size(-2), x.size(-1))
105
+ x = x.flatten(2).transpose(1, 2)
106
+ return x, size
107
+
108
+
109
+ class ConvEmbed(nn.Module):
110
+ """ Image to Patch Embedding
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ patch_size=7,
116
+ in_chans=3,
117
+ embed_dim=64,
118
+ stride=4,
119
+ padding=2,
120
+ norm_layer=None,
121
+ pre_norm=True
122
+ ):
123
+ super().__init__()
124
+ self.patch_size = patch_size
125
+
126
+ self.proj = nn.Conv2d(
127
+ in_chans, embed_dim,
128
+ kernel_size=patch_size,
129
+ stride=stride,
130
+ padding=padding
131
+ )
132
+
133
+ dim_norm = in_chans if pre_norm else embed_dim
134
+ self.norm = norm_layer(dim_norm) if norm_layer else None
135
+
136
+ self.pre_norm = pre_norm
137
+
138
+ def forward(self, x, size):
139
+ H, W = size
140
+ if len(x.size()) == 3:
141
+ if self.norm and self.pre_norm:
142
+ x = self.norm(x)
143
+ x = rearrange(
144
+ x, 'b (h w) c -> b c h w',
145
+ h=H, w=W
146
+ )
147
+
148
+ x = self.proj(x)
149
+
150
+ _, _, H, W = x.shape
151
+ x = rearrange(x, 'b c h w -> b (h w) c')
152
+ if self.norm and not self.pre_norm:
153
+ x = self.norm(x)
154
+
155
+ return x, (H, W)
156
+
157
+
158
+ class ChannelAttention(nn.Module):
159
+
160
+ def __init__(self, dim, base_dim, groups=8, base_groups=8, qkv_bias=True, dynamic_scale=True, standparam=True):
161
+ super().__init__()
162
+
163
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
164
+ self.proj = nn.Linear(dim, dim)
165
+ self.dynamic_scale = dynamic_scale
166
+
167
+ self.dim = dim
168
+ self.groups = groups
169
+ self.group_dim = dim // groups
170
+
171
+ self.base_dim = base_dim
172
+ self.base_groups = base_groups
173
+ self.base_group_dim = base_dim // base_groups
174
+
175
+ self.group_wm = self.group_dim / self.base_group_dim # Width multiplier for each group.
176
+ self.standparam = standparam
177
+
178
+ def forward(self, x, size):
179
+ B, N, C = x.shape
180
+ assert C == self.dim
181
+
182
+ qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
183
+ q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, groups, N, group_dim].
184
+
185
+ scale = N ** -0.5 if self.dynamic_scale else self.dim ** -0.5
186
+
187
+ # Change the scaling factor.
188
+ # Ref: examples/Transformer/model.py in muP.
189
+ # Note: We consider backward compatiblity and follow https://github.com/microsoft/mup/issues/18.
190
+ if self.standparam:
191
+ scale = N ** -0.5 if self.dynamic_scale else self.dim ** -0.5
192
+ else:
193
+ assert self.dynamic_scale # Currently only support dynamic scale.
194
+ scale = N ** -0.5
195
+
196
+ q = q * scale
197
+ attention = q.transpose(-1, -2) @ k
198
+ attention = attention.softmax(dim=-1)
199
+
200
+ if not self.standparam:
201
+ # Follow https://github.com/microsoft/mup/issues/18.
202
+ attention = attention / self.group_wm
203
+
204
+ x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
205
+ x = x.transpose(1, 2).reshape(B, N, C)
206
+ x = self.proj(x)
207
+ return x, size
208
+
209
+
210
+ class ChannelBlock(nn.Module):
211
+
212
+ def __init__(self, dim, base_dim, groups, base_groups, mlp_ratio=4., qkv_bias=True,
213
+ drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
214
+ conv_at_attn=True, conv_at_ffn=True, dynamic_scale=True, standparam=True):
215
+ super().__init__()
216
+
217
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
218
+
219
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
220
+ self.channel_attn = PreNorm(
221
+ norm_layer(dim),
222
+ ChannelAttention(dim, base_dim, groups=groups, base_groups=base_groups, qkv_bias=qkv_bias,
223
+ dynamic_scale=dynamic_scale, standparam=standparam),
224
+ drop_path
225
+ )
226
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
227
+ self.ffn = PreNorm(
228
+ norm_layer(dim),
229
+ Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
230
+ drop_path
231
+ )
232
+
233
+ def forward(self, x, size):
234
+ if self.conv1:
235
+ x, size = self.conv1(x, size)
236
+ x, size = self.channel_attn(x, size)
237
+
238
+ if self.conv2:
239
+ x, size = self.conv2(x, size)
240
+ x, size = self.ffn(x, size)
241
+
242
+ return x, size
243
+
244
+
245
+ def window_partition(x, window_size: int):
246
+ B, H, W, C = x.shape
247
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
248
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
249
+ return windows
250
+
251
+
252
+ def window_reverse(windows, window_size: int, H: int, W: int):
253
+ B = windows.shape[0] // (H * W // window_size // window_size)
254
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
255
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
256
+ return x
257
+
258
+
259
+ class WindowAttention(nn.Module):
260
+
261
+ def __init__(self, dim, base_dim, num_heads, base_num_heads, window_size, qkv_bias=True, standparam=True):
262
+
263
+ super().__init__()
264
+
265
+ self.window_size = window_size
266
+
267
+ self.dim = dim
268
+ self.num_heads = num_heads
269
+ head_dim = dim // num_heads
270
+
271
+ self.base_dim = base_dim
272
+ self.base_num_heads = base_num_heads
273
+ base_head_dim = base_dim // base_num_heads
274
+
275
+ # Change the scaling factor.
276
+ # Ref: examples/Transformer/model.py in muP.
277
+ # Note: We consider backward compatiblity and follow https://github.com/microsoft/mup/issues/17.
278
+ if standparam:
279
+ scale = float(head_dim) ** -0.5
280
+ else:
281
+ # TODO: Here we ensure backward compatibility, which may not be optimal.
282
+ # We may add an argument called backward_comp. If it is set as False, we use
283
+ # float(head_dim) ** -1 * math.sqrt(attn_mult)
284
+ # as in the Transformer example in muP.
285
+ base_scale = float(base_head_dim) ** -0.5 # The same as scaling in standard parametrization.
286
+ head_wm = head_dim / base_head_dim # Width multiplier for each head.
287
+ scale = base_scale / head_wm
288
+ # scale_1 = (float(base_head_dim) ** 0.5) * (float(head_dim) ** -1) # Equivalent implementation as shown in the muP paper.
289
+ # assert np.isclose(scale, scale_1)
290
+ self.scale = scale
291
+
292
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
293
+ self.proj = nn.Linear(dim, dim)
294
+
295
+ self.softmax = nn.Softmax(dim=-1)
296
+
297
+ def forward(self, x, size):
298
+
299
+ H, W = size
300
+ B, L, C = x.shape
301
+ assert L == H * W, "input feature has wrong size"
302
+
303
+ x = x.view(B, H, W, C)
304
+
305
+ pad_l = pad_t = 0
306
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
307
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
308
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
309
+ _, Hp, Wp, _ = x.shape
310
+
311
+ x = window_partition(x, self.window_size)
312
+ x = x.view(-1, self.window_size * self.window_size, C)
313
+
314
+ B_, N, C = x.shape
315
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
316
+ q, k, v = qkv[0], qkv[1], qkv[2]
317
+
318
+ q = q * self.scale
319
+ attn = (q @ k.transpose(-2, -1))
320
+ attn = self.softmax(attn)
321
+
322
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
323
+ x = self.proj(x)
324
+
325
+ # merge windows
326
+ x = x.view(
327
+ -1, self.window_size, self.window_size, C
328
+ )
329
+ x = window_reverse(x, self.window_size, Hp, Wp)
330
+
331
+ if pad_r > 0 or pad_b > 0:
332
+ x = x[:, :H, :W, :].contiguous()
333
+
334
+ x = x.view(B, H * W, C)
335
+
336
+ return x, size
337
+
338
+
339
+ class SpatialBlock(nn.Module):
340
+
341
+ def __init__(self, dim, base_dim, num_heads, base_num_heads, window_size,
342
+ mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
343
+ norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True, standparam=True):
344
+ super().__init__()
345
+
346
+ drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
347
+
348
+ self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
349
+ self.window_attn = PreNorm(
350
+ norm_layer(dim),
351
+ WindowAttention(dim, base_dim, num_heads, base_num_heads, window_size, qkv_bias=qkv_bias,
352
+ standparam=standparam),
353
+ drop_path
354
+ )
355
+ self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
356
+ self.ffn = PreNorm(
357
+ norm_layer(dim),
358
+ Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
359
+ drop_path
360
+ )
361
+
362
+ def forward(self, x, size):
363
+ if self.conv1:
364
+ x, size = self.conv1(x, size)
365
+ x, size = self.window_attn(x, size)
366
+
367
+ if self.conv2:
368
+ x, size = self.conv2(x, size)
369
+ x, size = self.ffn(x, size)
370
+ return x, size
371
+
372
+
373
+ class DaViT(nn.Module):
374
+ """ DaViT: Dual-Attention Transformer
375
+
376
+ Args:
377
+ img_size (int | tuple(int)): Input image size. Default: 224
378
+ patch_size (int | tuple(int)): Patch size. Default: 4
379
+ in_chans (int): Number of input image channels. Default: 3
380
+ num_classes (int): Number of classes for classification head. Default: 1000
381
+ depths (tuple(int)): Number of spatial and channel blocks in different stages. Default: (1, 1, 3, 1)
382
+ patch_size (tuple(int)): Patch sizes in different stages. Default: (7, 2, 2, 2)
383
+ patch_stride (tuple(int)): Patch strides in different stages. Default: (4, 2, 2, 2)
384
+ patch_padding (tuple(int)): Patch padding sizes in different stages. Default: (3, 0, 0, 0)
385
+ patch_prenorm (tuple(bool)): Use pre-normalization or not in different stages. Default: (False, False, False, False)
386
+ embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256)
387
+ base_embed_dims (tuple(int)): Patch embedding dimension (base case for muP). Default: (64, 128, 192, 256)
388
+ num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16)
389
+ base_num_heads (tuple(int)): Number of attention heads in different layers (base case for muP). Default: (4, 8, 12, 16)
390
+ num_groups (tuple(int)): Number of groups in channel attention in different layers. Default: (3, 6, 12, 24)
391
+ base_num_groups (tuple(int)): Number of groups in channel attention in different layers (base case for muP). Default: (3, 6, 12, 24)
392
+ window_size (int): Window size. Default: 7
393
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
394
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
395
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
396
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
397
+ enable_checkpoint (bool): If True, enabling checkpoint. Default: False
398
+ conv_at_attn (bool): If True, add convolution layer before attention. Default: True
399
+ conv_at_ffn (bool): If True, add convolution layer before ffn. Default: True
400
+ dynamic_scale (bool): If True, scale of channel attention is respect to the number of tokens. Default: True
401
+ standparam (bool): Use standard parametrization or mu-parametrization. Default: True (i.e., use standard paramerization)
402
+ """
403
+
404
+ def __init__(
405
+ self,
406
+ img_size=224,
407
+ in_chans=3,
408
+ num_classes=1000,
409
+ depths=(1, 1, 3, 1),
410
+ patch_size=(7, 2, 2, 2),
411
+ patch_stride=(4, 2, 2, 2),
412
+ patch_padding=(3, 0, 0, 0),
413
+ patch_prenorm=(False, False, False, False),
414
+ embed_dims=(64, 128, 192, 256),
415
+ base_embed_dims=(64, 128, 192, 256),
416
+ num_heads=(3, 6, 12, 24),
417
+ base_num_heads=(3, 6, 12, 24),
418
+ num_groups=(3, 6, 12, 24),
419
+ base_num_groups=(3, 6, 12, 24),
420
+ window_size=7,
421
+ mlp_ratio=4.,
422
+ qkv_bias=True,
423
+ drop_path_rate=0.1,
424
+ norm_layer=nn.LayerNorm,
425
+ enable_checkpoint=False,
426
+ conv_at_attn=True,
427
+ conv_at_ffn=True,
428
+ dynamic_scale=True,
429
+ standparam=True
430
+ ):
431
+ super().__init__()
432
+
433
+ self.num_classes = num_classes
434
+ self.embed_dims = embed_dims
435
+ self.num_heads = num_heads
436
+ self.num_groups = num_groups
437
+ self.num_stages = len(self.embed_dims)
438
+ self.enable_checkpoint = enable_checkpoint
439
+ assert self.num_stages == len(self.num_heads) == len(self.num_groups)
440
+
441
+ num_stages = len(embed_dims)
442
+ self.img_size = img_size
443
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths) * 2)]
444
+
445
+ depth_offset = 0
446
+ convs = []
447
+ blocks = []
448
+ for i in range(num_stages):
449
+ conv_embed = ConvEmbed(
450
+ patch_size=patch_size[i],
451
+ stride=patch_stride[i],
452
+ padding=patch_padding[i],
453
+ in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
454
+ embed_dim=self.embed_dims[i],
455
+ norm_layer=norm_layer,
456
+ pre_norm=patch_prenorm[i]
457
+ )
458
+ convs.append(conv_embed)
459
+
460
+ logger.info(f'=> Depth offset in stage {i}: {depth_offset}')
461
+ block = MySequential(
462
+ *[
463
+ MySequential(OrderedDict([
464
+ (
465
+ 'spatial_block', SpatialBlock(
466
+ embed_dims[i],
467
+ base_embed_dims[i],
468
+ num_heads[i],
469
+ base_num_heads[i],
470
+ window_size,
471
+ drop_path_rate=dpr[depth_offset + j * 2],
472
+ qkv_bias=qkv_bias,
473
+ mlp_ratio=mlp_ratio,
474
+ conv_at_attn=conv_at_attn,
475
+ conv_at_ffn=conv_at_ffn,
476
+ standparam=standparam
477
+ )
478
+ ),
479
+ (
480
+ 'channel_block', ChannelBlock(
481
+ embed_dims[i],
482
+ base_embed_dims[i],
483
+ num_groups[i],
484
+ base_num_groups[i],
485
+ drop_path_rate=dpr[depth_offset + j * 2 + 1],
486
+ qkv_bias=qkv_bias,
487
+ mlp_ratio=mlp_ratio,
488
+ conv_at_attn=conv_at_attn,
489
+ conv_at_ffn=conv_at_ffn,
490
+ dynamic_scale=dynamic_scale,
491
+ standparam=standparam
492
+ )
493
+ )
494
+ ])) for j in range(depths[i])
495
+ ]
496
+ )
497
+ blocks.append(block)
498
+ depth_offset += depths[i] * 2
499
+
500
+ self.convs = nn.ModuleList(convs)
501
+ self.blocks = nn.ModuleList(blocks)
502
+
503
+ self.norms = norm_layer(self.embed_dims[-1])
504
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
505
+
506
+ if standparam:
507
+ self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
508
+ else:
509
+ self.head = MuReadout(self.embed_dims[-1], num_classes,
510
+ readout_zero_init=True) # Follow examples/ResNet/resnet.py in muP.
511
+
512
+ if torch.cuda.is_available():
513
+ self.device = torch.device(type="cuda", index=0)
514
+ else:
515
+ self.device = torch.device(type="cpu")
516
+
517
+ def custom_init_weights(self, use_original_init=True):
518
+ self.use_original_init = use_original_init
519
+ logger.info('Custom init: {}'.format('original init' if self.use_original_init else 'muP init'))
520
+ self.apply(self._custom_init_weights)
521
+
522
+ @property
523
+ def dim_out(self):
524
+ return self.embed_dims[-1]
525
+
526
+ def _custom_init_weights(self, m):
527
+ # Customized initialization for weights.
528
+ if self.use_original_init:
529
+ # Original initialization.
530
+ # Note: This is not SP init. We do not implement SP init here.
531
+ custom_trunc_normal_ = trunc_normal_
532
+ custom_normal_ = nn.init.normal_
533
+ else:
534
+ # muP.
535
+ custom_trunc_normal_ = mup.init.trunc_normal_
536
+ custom_normal_ = mup.init.normal_
537
+
538
+ # These initializations will overwrite the existing inializations from the modules and adjusted by set_base_shapes().
539
+ if isinstance(m, MuReadout):
540
+ pass # Note: MuReadout is already zero initialized due to readout_zero_init=True.
541
+ elif isinstance(m, nn.Linear):
542
+ custom_trunc_normal_(m.weight, std=0.02)
543
+ if m.bias is not None:
544
+ nn.init.constant_(m.bias, 0)
545
+ elif isinstance(m, nn.Conv2d):
546
+ custom_normal_(m.weight, std=0.02)
547
+ for name, _ in m.named_parameters():
548
+ if name in ['bias']:
549
+ nn.init.constant_(m.bias, 0)
550
+ elif isinstance(m, nn.LayerNorm): # Follow P24 Layernorm Weights and Biases.
551
+ nn.init.constant_(m.weight, 1.0)
552
+ nn.init.constant_(m.bias, 0)
553
+ elif isinstance(m, nn.BatchNorm2d): # Follow P24 Layernorm Weights and Biases.
554
+ nn.init.constant_(m.weight, 1.0)
555
+ nn.init.constant_(m.bias, 0)
556
+
557
+ def _try_remap_keys(self, pretrained_dict):
558
+ remap_keys = {
559
+ "conv_embeds": "convs",
560
+ "main_blocks": "blocks",
561
+ "0.cpe.0.proj": "spatial_block.conv1.fn.dw",
562
+ "0.attn": "spatial_block.window_attn.fn",
563
+ "0.cpe.1.proj": "spatial_block.conv2.fn.dw",
564
+ "0.mlp": "spatial_block.ffn.fn.net",
565
+ "1.cpe.0.proj": "channel_block.conv1.fn.dw",
566
+ "1.attn": "channel_block.channel_attn.fn",
567
+ "1.cpe.1.proj": "channel_block.conv2.fn.dw",
568
+ "1.mlp": "channel_block.ffn.fn.net",
569
+ "0.norm1": "spatial_block.window_attn.norm",
570
+ "0.norm2": "spatial_block.ffn.norm",
571
+ "1.norm1": "channel_block.channel_attn.norm",
572
+ "1.norm2": "channel_block.ffn.norm"
573
+ }
574
+
575
+ full_key_mappings = {}
576
+ for k in pretrained_dict.keys():
577
+ old_k = k
578
+ for remap_key in remap_keys.keys():
579
+ if remap_key in k:
580
+ logger.info(f'=> Repace {remap_key} with {remap_keys[remap_key]}')
581
+ k = k.replace(remap_key, remap_keys[remap_key])
582
+
583
+ full_key_mappings[old_k] = k
584
+
585
+ return full_key_mappings
586
+
587
+ def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
588
+ model_dict = self.state_dict()
589
+ stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
590
+ full_key_mappings = self._try_remap_keys(pretrained_dict)
591
+
592
+ pretrained_dict = {
593
+ stripped_key(full_key_mappings[k]): v.to(self.device) for k, v in pretrained_dict.items()
594
+ if stripped_key(full_key_mappings[k]) in model_dict.keys()
595
+ }
596
+ need_init_state_dict = {}
597
+ for k, v in pretrained_dict.items():
598
+ need_init = (
599
+ k.split('.')[0] in pretrained_layers
600
+ or pretrained_layers[0] == '*'
601
+ )
602
+ if need_init:
603
+ if verbose:
604
+ logger.info(f'=> init {k} from pretrained state dict')
605
+
606
+ need_init_state_dict[k] = v.to(self.device)
607
+ self.load_state_dict(need_init_state_dict, strict=False)
608
+
609
+ def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
610
+ if os.path.isfile(pretrained):
611
+ logger.info(f'=> loading pretrained model {pretrained}')
612
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
613
+
614
+ self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
615
+
616
+ def forward_features(self, x):
617
+ input_size = (x.size(2), x.size(3))
618
+ for conv, block in zip(self.convs, self.blocks):
619
+ x, input_size = conv(x, input_size)
620
+ if self.enable_checkpoint:
621
+ x, input_size = checkpoint.checkpoint(block, x, input_size)
622
+ else:
623
+ x, input_size = block(x, input_size)
624
+
625
+ x = self.avgpool(x.transpose(1, 2))
626
+ x = torch.flatten(x, 1)
627
+ x = self.norms(x)
628
+
629
+ return x
630
+
631
+ def forward(self, x):
632
+ x = self.forward_features(x)
633
+ x = self.head(x)
634
+ return x
635
+
636
+
637
+ def create_encoder(config_encoder):
638
+ spec = config_encoder['SPEC']
639
+ standparam = spec.get('STANDPARAM', True)
640
+
641
+ if standparam:
642
+ # Dummy values for muP parameters.
643
+ base_embed_dims = spec['DIM_EMBED']
644
+ base_num_heads = spec['NUM_HEADS']
645
+ base_num_groups = spec['NUM_GROUPS']
646
+ else:
647
+ base_embed_dims = spec['BASE_DIM_EMBED']
648
+ base_num_heads = spec['BASE_NUM_HEADS']
649
+ base_num_groups = spec['BASE_NUM_GROUPS']
650
+
651
+ davit = DaViT(
652
+ num_classes=config_encoder['NUM_CLASSES'],
653
+ depths=spec['DEPTHS'],
654
+ embed_dims=spec['DIM_EMBED'],
655
+ base_embed_dims=base_embed_dims,
656
+ num_heads=spec['NUM_HEADS'],
657
+ base_num_heads=base_num_heads,
658
+ num_groups=spec['NUM_GROUPS'],
659
+ base_num_groups=base_num_groups,
660
+ patch_size=spec['PATCH_SIZE'],
661
+ patch_stride=spec['PATCH_STRIDE'],
662
+ patch_padding=spec['PATCH_PADDING'],
663
+ patch_prenorm=spec['PATCH_PRENORM'],
664
+ drop_path_rate=spec['DROP_PATH_RATE'],
665
+ img_size=config_encoder['IMAGE_SIZE'],
666
+ window_size=spec.get('WINDOW_SIZE', 7),
667
+ enable_checkpoint=spec.get('ENABLE_CHECKPOINT', False),
668
+ conv_at_attn=spec.get('CONV_AT_ATTN', True),
669
+ conv_at_ffn=spec.get('CONV_AT_FFN', True),
670
+ dynamic_scale=spec.get('DYNAMIC_SCALE', True),
671
+ standparam=standparam,
672
+ )
673
+ return davit
674
+
675
+
676
+ def create_mup_encoder(config_encoder):
677
+ def gen_config(config, wm):
678
+ new_config = copy.deepcopy(config)
679
+ for name in ['DIM_EMBED', 'NUM_HEADS', 'NUM_GROUPS']:
680
+ base_name = 'BASE_' + name
681
+ new_values = [round(base_value * wm) for base_value in
682
+ config['SPEC'][base_name]] # New value = base value * width multiplier.
683
+ logger.info(f'config["SPEC"]["{name}"]: {new_config["SPEC"][name]} -> {new_values}')
684
+ new_config['SPEC'][name] = new_values
685
+ return new_config
686
+
687
+ logger.info('muP: Create models and set base shapes')
688
+ logger.info('=> Create model')
689
+ model = create_encoder(config_encoder)
690
+
691
+ logger.info('=> Create base model')
692
+ base_config = gen_config(config_encoder, wm=1.0)
693
+ base_model = create_encoder(base_config)
694
+
695
+ logger.info('=> Create delta model')
696
+ delta_config = gen_config(config_encoder, wm=2.0)
697
+ delta_model = create_encoder(delta_config)
698
+
699
+ logger.info('=> Set base shapes in model for training')
700
+ set_base_shapes(model, base=base_model, delta=delta_model)
701
+
702
+ return model
703
+
704
+
705
+ @register_image_encoder
706
+ def image_encoder(config_encoder, verbose, **kwargs):
707
+ spec = config_encoder['SPEC']
708
+ standparam = spec.get('STANDPARAM', True)
709
+
710
+ if standparam:
711
+ logger.info('Create model with standard parameterization')
712
+ model = create_encoder(config_encoder)
713
+ model.custom_init_weights(use_original_init=True)
714
+ else:
715
+ logger.info('Create model with mu parameterization')
716
+ model = create_mup_encoder(config_encoder)
717
+ model.custom_init_weights(use_original_init=False)
718
+
719
+ logger.info('Load model from pretrained checkpoint')
720
+ if config_encoder['LOAD_PRETRAINED']:
721
+ model.from_pretrained(
722
+ config_encoder['PRETRAINED'],
723
+ config_encoder['PRETRAINED_LAYERS'],
724
+ verbose
725
+ )
726
+
727
+ return model
MedImageInsight/ImageEncoder/registry.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _image_encoders = {}
2
+
3
+
4
+ def register_image_encoder(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+
8
+ _image_encoders[model_name] = fn
9
+
10
+ return fn
11
+
12
+
13
+ def image_encoders(model_name):
14
+ return _image_encoders[model_name]
15
+
16
+
17
+ def is_image_encoder(model_name):
18
+ return model_name in _image_encoders
MedImageInsight/LangEncoder/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ from .build import build_lang_encoder
6
+ from .build import build_tokenizer
7
+
8
+ from .transformer import *
9
+ # from .hf_model import *
10
+ # from .zcode import *
11
+ # from .pretrain import *
12
+ # from .tulrv6 import *
13
+ # from .t5 import *
MedImageInsight/LangEncoder/build.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from transformers import CLIPTokenizer, CLIPTokenizerFast
5
+ from transformers import AutoTokenizer
6
+
7
+ from .registry import lang_encoders
8
+ from .registry import is_lang_encoder
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
14
+ model_name = config_encoder['NAME']
15
+
16
+ if model_name.endswith('pretrain'):
17
+ model_name = 'pretrain'
18
+
19
+ if not is_lang_encoder(model_name):
20
+ raise ValueError(f'Unknown model: {model_name}')
21
+
22
+ return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
23
+
24
+
25
+ def post_process_clip(text):
26
+ text['input_ids'].squeeze_() # torch.Size([1, 77])
27
+ text['attention_mask'].squeeze_() # torch.Size([1, 77])
28
+ return text
29
+
30
+
31
+ def build_tokenizer(config_encoder):
32
+ tokenizer = None
33
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false' # 'true', avoid hanging
34
+
35
+ if config_encoder['TOKENIZER'] == 'clip':
36
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
37
+ pretrained_tokenizer = config_encoder.get(
38
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
39
+ )
40
+ # print(pretrained_tokenizer)
41
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
42
+ tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
43
+ tokenizer.post_process = post_process_clip
44
+ elif config_encoder['TOKENIZER'] == 'clip-fast':
45
+ pretrained_tokenizer = config_encoder.get(
46
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
47
+ )
48
+ tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True)
49
+ tokenizer.post_process = post_process_clip
50
+ elif config_encoder['TOKENIZER'] == 'zcodepp':
51
+ from .zcodepp import ZCodeppTokenizer
52
+ tokenizer = ZCodeppTokenizer(config_encoder)
53
+ tokenizer.post_process = lambda x: x
54
+ elif config_encoder['TOKENIZER'] == 'zcode':
55
+ from transformers import XLMRobertaTokenizer
56
+ tokenizer = XLMRobertaTokenizer.from_pretrained(config_encoder['PRETRAINED_TOKENIZER'])
57
+ elif config_encoder['TOKENIZER'] == 'tulrv6':
58
+ from .modeling_tulrv6 import TULRv6Tokenizer
59
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
60
+ pretrained_tokenizer = config_encoder.get(
61
+ 'PRETRAINED_TOKENIZER', 'tulrv6-base'
62
+ )
63
+ tokenizer = TULRv6Tokenizer.from_pretrained(pretrained_tokenizer)
64
+ # tokenizer.post_process = post_process_clip
65
+ else:
66
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
67
+ pretrained_tokenizer = config_encoder.get('PRETRAINED_TOKENIZER', '')
68
+ tokenizer = AutoTokenizer.from_pretrained(
69
+ pretrained_tokenizer
70
+ if pretrained_tokenizer else config_encoder['TOKENIZER']
71
+ )
72
+ tokenizer.post_process = post_process_clip
73
+
74
+ # Extra configurations.
75
+ if 'TOKENIZER_CONF' in config_encoder:
76
+ tokenizer_conf = config_encoder['TOKENIZER_CONF']
77
+
78
+ num_pretrained_tokens = len(tokenizer)
79
+
80
+ addition_special_tokens_config = tokenizer_conf.get('ADDITIONAL_SPECIAL_TOKENS', None)
81
+ if addition_special_tokens_config == 'od+cap':
82
+ # Note: We still keep the additional special tokens from original tokenizer when we add new special tokens.
83
+ # This is to make sure tokenizer.additional_special_tokens afterwards includes original additional special tokens.
84
+ special_tokens_dict = {
85
+ 'additional_special_tokens': \
86
+ tokenizer.additional_special_tokens + \
87
+ ['<od>','</od>','<cap>','</cap>'] + \
88
+ [f'<loc_{x}>' for x in range(tokenizer_conf.get('NUM_LOCATION_TOKENS', 0))]
89
+ }
90
+ tokenizer.add_special_tokens(special_tokens_dict)
91
+ elif isinstance(addition_special_tokens_config, list):
92
+ special_tokens_dict = {
93
+ 'additional_special_tokens': \
94
+ tokenizer.additional_special_tokens + \
95
+ addition_special_tokens_config + \
96
+ [f'<loc_{x}>' for x in range(tokenizer_conf.get('NUM_LOCATION_TOKENS', 0))]+
97
+ [f'<time_{x}>' for x in range(
98
+ tokenizer_conf.get('NUM_TIME_TOKENS', 0))]
99
+ }
100
+ tokenizer.add_special_tokens(special_tokens_dict)
101
+ elif addition_special_tokens_config is not None:
102
+ raise ValueError('ADDITIONAL_SPECIAL_TOKENS type error')
103
+
104
+ num_current_tokens = len(tokenizer)
105
+ logger.info(f'{num_pretrained_tokens} tokens in pretrained tokenizer => {num_current_tokens} in current tokenizer')
106
+ logger.info(f'All special tokens in tokenizer: {tokenizer.additional_special_tokens}')
107
+
108
+ return tokenizer
MedImageInsight/LangEncoder/registry.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _lang_encoders = {}
2
+
3
+
4
+ def register_lang_encoder(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+
8
+ _lang_encoders[model_name] = fn
9
+
10
+ return fn
11
+
12
+
13
+ def lang_encoders(model_name):
14
+ return _lang_encoders[model_name]
15
+
16
+
17
+ def is_lang_encoder(model_name):
18
+ return model_name in _lang_encoders
MedImageInsight/LangEncoder/transformer.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+ import logging
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from timm.models.layers import DropPath, trunc_normal_
12
+
13
+ from .registry import register_lang_encoder
14
+ from ..Utils import is_main_process
15
+ from ..Utils import register_norm_module
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ @register_norm_module
21
+ class LayerNorm(nn.Module):
22
+ def __init__(self, hidden_size, eps=1e-12):
23
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
24
+ """
25
+ super(LayerNorm, self).__init__()
26
+ self.weight = nn.Parameter(torch.ones(hidden_size))
27
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
28
+ self.variance_epsilon = eps
29
+
30
+ def forward(self, x):
31
+ pdtype = x.dtype
32
+ x = x.float()
33
+ u = x.mean(-1, keepdim=True)
34
+ s = (x - u).pow(2).mean(-1, keepdim=True)
35
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
36
+ return self.weight * x.to(pdtype) + self.bias
37
+
38
+
39
+ class QuickGELU(nn.Module):
40
+ def forward(self, x: torch.Tensor):
41
+ return x * torch.sigmoid(1.702 * x)
42
+
43
+
44
+ class ResidualAttentionBlock(nn.Module):
45
+ def __init__(self,
46
+ d_model: int,
47
+ n_head: int,
48
+ attn_mask: torch.Tensor = None,
49
+ drop_path: float = 0.0):
50
+ super().__init__()
51
+
52
+ self.attn = nn.MultiheadAttention(d_model, n_head)
53
+ self.ln_1 = LayerNorm(d_model)
54
+ self.mlp = nn.Sequential(OrderedDict([
55
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
56
+ ("gelu", QuickGELU()),
57
+ ("c_proj", nn.Linear(d_model * 4, d_model))
58
+ ]))
59
+ self.ln_2 = LayerNorm(d_model)
60
+ self.attn_mask = attn_mask
61
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
62
+
63
+ def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
64
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
65
+ if self.attn_mask is not None else None
66
+
67
+
68
+ return self.attn(
69
+ x, x, x,
70
+ key_padding_mask=key_padding_mask,
71
+ need_weights=False,
72
+ attn_mask=self.attn_mask
73
+ )[0]
74
+
75
+ def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
76
+ x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
77
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
78
+ return x
79
+
80
+
81
+ class Transformer(nn.Module):
82
+ def __init__(self,
83
+ context_length: int,
84
+ vocab_size: int,
85
+ width: int,
86
+ layers: int,
87
+ heads: int,
88
+ drop_path: float = 0.0,
89
+ autogressive: bool =True,
90
+ key_padding_token: int = 0,
91
+ ):
92
+ super().__init__()
93
+
94
+ self.token_embedding = nn.Embedding(vocab_size, width)
95
+ self.key_padding_token = key_padding_token
96
+
97
+ self.context_length = context_length
98
+ self.positional_embedding = nn.Parameter(
99
+ torch.empty(self.context_length, width)
100
+ )
101
+
102
+ self.width = width
103
+ self.layers = layers
104
+ self.autogressive = autogressive
105
+ attn_mask = self.build_attention_mask() if autogressive else None
106
+ dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule
107
+ self.resblocks = nn.ModuleList(
108
+ [
109
+ ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
110
+ for i in range(layers)
111
+ ]
112
+ )
113
+
114
+ self.ln_final = LayerNorm(width)
115
+
116
+ trunc_normal_(self.positional_embedding, std=.02)
117
+ # nn.init.normal_(self.token_embedding, std=.02)
118
+ trunc_normal_(self.token_embedding.weight, std=.02)
119
+ self.apply(self._init_weights)
120
+
121
+ @property
122
+ def dim_out(self):
123
+ return self.width
124
+
125
+ def build_attention_mask(self):
126
+ # lazily create causal attention mask, with full attention between the vision tokens
127
+ # pytorch uses additive attention mask; fill with -inf
128
+ mask = torch.empty(self.context_length, self.context_length)
129
+ mask.fill_(float("-inf"))
130
+ mask.triu_(1) # zero out the lower diagonal
131
+ return mask
132
+
133
+ def _init_weights(self, m):
134
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
135
+ if is_main_process():
136
+ logger.info('=> init weight of Linear/Conv2d from trunc norm')
137
+ trunc_normal_(m.weight, std=0.02)
138
+ if m.bias is not None:
139
+ if is_main_process():
140
+ logger.info('=> init bias of Linear/Conv2d to zeros')
141
+ nn.init.constant_(m.bias, 0)
142
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
143
+ nn.init.constant_(m.bias, 0)
144
+
145
+ def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
146
+ if os.path.isfile(pretrained):
147
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
148
+ logging.info(f'=> loading pretrained model {pretrained}')
149
+ model_dict = self.state_dict()
150
+ pretrained_dict = {
151
+ k: v for k, v in pretrained_dict.items()
152
+ if k in model_dict.keys()
153
+ }
154
+ need_init_state_dict = {}
155
+ for k, v in pretrained_dict.items():
156
+ need_init = (
157
+ k.split('.')[0] in pretrained_layers
158
+ or pretrained_layers[0] == '*'
159
+ )
160
+ if need_init:
161
+ if verbose:
162
+ logging.info(f'=> init {k} from {pretrained}')
163
+
164
+ need_init_state_dict[k] = v
165
+ self.load_state_dict(need_init_state_dict, strict=False)
166
+
167
+
168
+ @torch.jit.ignore
169
+ def no_weight_decay(self):
170
+ return {
171
+ 'positional_embedding',
172
+ 'token_embedding',
173
+ }
174
+
175
+ def forward(self, input_ids, attention_mask=None):
176
+ input_ids = input_ids.to(self.positional_embedding.device, non_blocking=True)
177
+ # Here we generate key_padding_mask using attention_mask instead of using
178
+ # a predefined key_padding_token (e.g., 0). This is to solve a discrepancy
179
+ # between Transformer 4.16.2 and 4.25.1, since Transformers 4.16.2 uses token id 0
180
+ # for padding but 4.25.1 uses EOS token (token id 49407) for padding.
181
+ key_padding_mask = (attention_mask == 0) if not self.autogressive else None
182
+ # a True value indicates that the corresponding key value will be ignored for the purpose of attention
183
+ x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
184
+ x = x + self.positional_embedding
185
+ x = x.permute(1, 0, 2) # NLD -> LND
186
+ for block in self.resblocks:
187
+ x = block(x, key_padding_mask)
188
+ x = x.permute(1, 0, 2) # LND -> NLD
189
+
190
+ x = self.ln_final(x)
191
+
192
+ return {'last_hidden_state': x}
193
+
194
+
195
+ @register_lang_encoder
196
+ def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
197
+ transformer = Transformer(
198
+ context_length=config_encoder['CONTEXT_LENGTH'],
199
+ vocab_size=tokenizer.vocab_size,
200
+ width=config_encoder['WIDTH'],
201
+ layers=config_encoder['LAYERS'],
202
+ heads=config_encoder['HEADS'],
203
+ autogressive=config_encoder.get('AUTOGRESSIVE', True),
204
+ key_padding_token=config_encoder.get('KEY_PADDING_TOKEN', 0),
205
+ )
206
+
207
+ if config_encoder['LOAD_PRETRAINED']:
208
+ transformer.load_pretrained()
209
+
210
+ return transformer
MedImageInsight/UniCLModel.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import tempfile
3
+ import logging
4
+ import os
5
+ import copy
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from timm.models.layers import trunc_normal_
11
+
12
+ from .ImageEncoder import build_image_encoder
13
+ from .LangEncoder import build_lang_encoder
14
+ from .LangEncoder import build_tokenizer
15
+
16
+ import mup.init
17
+ from mup import set_base_shapes
18
+
19
+ from safetensors.torch import load_file
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class UniCLModel(nn.Module):
26
+ def __init__(self, config: dict):
27
+ super().__init__()
28
+
29
+ self.conf_lang_encoder = config['LANG_ENCODER']
30
+ self.tokenizer = build_tokenizer(self.conf_lang_encoder)
31
+
32
+ self.lang_encoder = build_lang_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE'])
33
+
34
+ dim_projection = config['UNICL_MODEL']['DIM_PROJECTION']
35
+ if hasattr(self.lang_encoder, 'dim_out'):
36
+ dim_out = self.lang_encoder.dim_out
37
+ else:
38
+ with torch.no_grad():
39
+ dim_out = self.lang_encoder(
40
+ torch.zeros(1,1).type(torch.LongTensor)
41
+ )['last_hidden_state'].size(2)
42
+
43
+ self.lang_projection = nn.Parameter(torch.empty(dim_out, dim_projection))
44
+
45
+ self.conf_image_encoder = config['IMAGE_ENCODER']
46
+ self.image_encoder = build_image_encoder(self.conf_image_encoder, config['VERBOSE'])
47
+
48
+ self.image_projection = nn.Parameter(
49
+ torch.empty(self.image_encoder.dim_out, dim_projection)
50
+ )
51
+
52
+ self.logit_scale = nn.Parameter(torch.ones([]))
53
+
54
+ if torch.cuda.is_available():
55
+ self.device = torch.device(type="cuda", index=0)
56
+ else:
57
+ self.device = torch.device(type="cpu")
58
+
59
+ def custom_init_weights(self, use_original_init=True):
60
+ self.use_original_init = use_original_init
61
+ logger.info('Custom init: {}'.format('original init' if self.use_original_init else 'muP init'))
62
+
63
+ if self.use_original_init:
64
+ # Original initialization.
65
+ # Note: This is not SP init. We do not implement SP init here.
66
+ custom_trunc_normal_ = trunc_normal_ # Note: This should be the same as torch.nn.init.trunc_normal_
67
+ else:
68
+ # muP.
69
+ custom_trunc_normal_ = mup.init.trunc_normal_
70
+
71
+ custom_trunc_normal_(self.lang_projection, std=.02)
72
+ custom_trunc_normal_(self.image_projection, std=.02)
73
+
74
+ def _convert_old_weights(self, model_dict):
75
+ model_dict_updated = {}
76
+ for k, v in model_dict.items():
77
+ if k.startswith('visual.'):
78
+ model_dict_updated['image_encoder.'+k[7:]] = v
79
+ elif k.startswith('text.'):
80
+ model_dict_updated['lang_encoder.'+k[5:]] = v
81
+ elif k == 'vision_projection':
82
+ model_dict_updated['image_projection'] = v
83
+ elif k == 'text_projection':
84
+ model_dict_updated['lang_projection'] = v
85
+ else:
86
+ model_dict_updated[k] = v
87
+
88
+ return model_dict_updated
89
+
90
+ def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
91
+ if not os.path.isfile(pretrained):
92
+ logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight')
93
+ return
94
+
95
+ ## Load SafeTensors Version of Pretrained Model
96
+ pretrained_dict = load_file(pretrained)
97
+ logger.info(f'=> Loading pretrained model {pretrained}')
98
+ model_dict = self.state_dict()
99
+ pretrained_dict = self._convert_old_weights(pretrained_dict)
100
+ ## To ensure cuda is mapped to all weights in the SafeTensors version model
101
+ pretrained_dict = {
102
+ k: v.to(self.device) for k, v in pretrained_dict.items()
103
+ }
104
+ need_init_state_dict = {}
105
+ image_encoder_state_dict = {}
106
+ for k, v in pretrained_dict.items():
107
+ need_init = (
108
+ k.split('.')[0] in pretrained_layers
109
+ or pretrained_layers[0] == '*'
110
+ )
111
+
112
+ if need_init:
113
+ if k.startswith('image_encoder.'):
114
+ image_encoder_state_dict[k] = v.to(self.device)
115
+ else:
116
+ if verbose:
117
+ logger.info(f'=> init {k} from {pretrained}')
118
+
119
+ if 'positional_embedding' in k and v.size() != model_dict[k].size():
120
+ positional_embedding_pretrained = v
121
+ positional_embedding_current = model_dict[k]
122
+ L1, nH1 = positional_embedding_pretrained.size()
123
+ L2, nH2 = positional_embedding_current.size()
124
+ if nH1 != nH2:
125
+ logger.info(f"Error in loading {k}, passing")
126
+ else:
127
+ if L1 != L2:
128
+ logger.info(
129
+ '=> load_pretrained: resized variant: {} to {}'
130
+ .format((L1, nH1), (L2, nH2))
131
+ )
132
+
133
+ posemb = positional_embedding_pretrained.float()
134
+ posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1)
135
+ posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear')
136
+ posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0)
137
+ v = posemb_grid
138
+
139
+ need_init_state_dict[k] = v.to(self.device)
140
+ self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose)
141
+ self.load_state_dict(need_init_state_dict, strict=False)
142
+
143
+ @torch.jit.ignore
144
+ def no_weight_decay(self):
145
+ no_weight_decay = {'logit_scale'}
146
+ if hasattr(self.lang_encoder, 'no_weight_decay'):
147
+ for k in self.lang_encoder.no_weight_decay():
148
+ no_weight_decay.add('lang_encoder.'+k)
149
+
150
+ if hasattr(self.image_encoder, 'no_weight_decay'):
151
+ for k in self.visual.no_weight_decay():
152
+ no_weight_decay.add('image_encoder.'+k)
153
+
154
+ return no_weight_decay
155
+
156
+ @property
157
+ def dtype(self):
158
+ return self.logit_scale.dtype
159
+
160
+ def encode_image(self, image, norm=True):
161
+ x = self.image_encoder.forward_features(image)
162
+ x = x @ self.image_projection
163
+
164
+ if norm:
165
+ x = x / x.norm(dim=-1, keepdim=True)
166
+
167
+ return x
168
+
169
+ def encode_text(self, text, norm=True):
170
+ x = self.lang_encoder(**text)
171
+ x = x['last_hidden_state']
172
+
173
+ if self.conf_lang_encoder['TOKENIZER'] == 'clip':
174
+ x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)]
175
+ else:
176
+ x = x[:, 0]
177
+
178
+ x = x @ self.lang_projection
179
+
180
+ if norm:
181
+ x = x / x.norm(dim=-1, keepdim=True)
182
+
183
+ return x
184
+
185
+ def forward(self, image, text):
186
+ features_image = self.encode_image(image)
187
+ features_text = self.encode_text(text)
188
+
189
+ # cosine similarity as logits
190
+ T = self.logit_scale.exp()
191
+
192
+ return features_image, features_text, T
193
+
194
+
195
+ def create_model(config):
196
+ model = UniCLModel(config)
197
+ return model
198
+
199
+
200
+ def create_mup_model(config):
201
+ def gen_config(config, wm):
202
+ # TODO: Currently only support the case that all UniCL, lang encoder, and image encoder use
203
+ # mu parameterization. This requirement can be relaxed.
204
+ assert (not config['UNICL_MODEL']['STANDPARAM']) and \
205
+ (not config['LANG_ENCODER']['STANDPARAM']) and \
206
+ (not config['IMAGE_ENCODER']['SPEC']['STANDPARAM'])
207
+ new_config = copy.deepcopy(config)
208
+ logger.info(f'Generate config with width mult = {wm}:')
209
+
210
+ # Generate config for UniCL head.
211
+ new_config_section = new_config['UNICL_MODEL']
212
+ new_config_section['STANDPARAM'] = True # Use standard parameterization when determining base shapes.
213
+ for name in ['DIM_PROJECTION']:
214
+ base_name = 'BASE_' + name
215
+ new_values = round(new_config_section[base_name] * wm) # New value = base value * width multiplier.
216
+ logger.info(f'config["UNICL_MODEL"]["{name}"]: {new_config_section[name]} -> {new_values}')
217
+ new_config_section[name] = new_values
218
+
219
+ # Generate config for lang encoder.
220
+ new_config_section = new_config['LANG_ENCODER']
221
+ new_config_section['STANDPARAM'] = True
222
+ for name in ['WIDTH', 'HEADS']:
223
+ base_name = 'BASE_' + name
224
+ new_values = round(new_config_section[base_name] * wm) # New value = base value * width multiplier.
225
+ logger.info(f'config["LANG_ENCODER"]["{name}"]: {new_config_section[name]} -> {new_values}')
226
+ new_config_section[name] = new_values
227
+
228
+ # Generate config for image encoder.
229
+ new_config_section = new_config['IMAGE_ENCODER']['SPEC']
230
+ new_config_section['STANDPARAM'] = True
231
+ for name in ['DIM_EMBED', 'NUM_HEADS', 'NUM_GROUPS']:
232
+ base_name = 'BASE_' + name
233
+ new_values = [round(base_value * wm) for base_value in new_config_section[base_name]] # New value = base value * width multiplier.
234
+ logger.info(f'config["IMAGE_ENCODER"]["SPEC"]["{name}"]: {new_config_section[name]} -> {new_values}')
235
+ new_config_section[name] = new_values
236
+
237
+ return new_config
238
+
239
+ logger.info('muP: Create models and set base shapes')
240
+ logger.info('=> Create model')
241
+ model = create_model(config)
242
+ # Temporarily remove the lang and image encoders from model to prevent from
243
+ # setting the base shape for these encoders again.
244
+ lang_encoder, image_encoder = model.lang_encoder, model.image_encoder
245
+ model.lang_encoder, model.image_encoder = None, None
246
+
247
+ logger.info('=> Create base model')
248
+ base_config = gen_config(config, wm=1.0)
249
+ base_model = create_model(base_config)
250
+ del base_model.lang_encoder, base_model.image_encoder
251
+
252
+ logger.info('=> Create delta model')
253
+ delta_config = gen_config(config, wm=2.0)
254
+ delta_model = create_model(delta_config)
255
+ del delta_model.lang_encoder, delta_model.image_encoder
256
+
257
+ logger.info('=> Set base shapes in model for training')
258
+ set_base_shapes(model, base=base_model, delta=delta_model)
259
+
260
+ # Restore the lang and image encoders in the model.
261
+ model.lang_encoder, model.image_encoder = lang_encoder, image_encoder
262
+
263
+ return model
264
+
265
+
266
+ def build_unicl_model(config, **kwargs):
267
+ standparam = config['UNICL_MODEL'].get('STANDPARAM', True)
268
+
269
+ if standparam:
270
+ logger.info('Create model with standard parameterization')
271
+ model = create_model(config)
272
+
273
+ use_original_init = True
274
+ else:
275
+ logger.info('Create model with mu parameterization')
276
+ model = create_mup_model(config)
277
+ use_original_init = False
278
+
279
+ # Initialize other parameters.
280
+ model.custom_init_weights(use_original_init=use_original_init)
281
+
282
+ if config['UNICL_MODEL']['LOAD_PRETRAINED']:
283
+ pretrained_path = config['UNICL_MODEL']['PRETRAINED']
284
+ from .Distributed.Utils import is_valid_url, download_file
285
+ if is_valid_url(pretrained_path):
286
+ with tempfile.TemporaryDirectory() as tmp_path:
287
+ file_local_path = pathlib.Path(tmp_path) / 'base_model.pt'
288
+ download_file(pretrained_path, file_local_path)
289
+ model.from_pretrained(str(file_local_path), config['UNICL_MODEL']['PRETRAINED_LAYERS'], config['VERBOSE'])
290
+ else:
291
+ model.from_pretrained(pretrained_path, config['UNICL_MODEL']['PRETRAINED_LAYERS'], config['VERBOSE'])
292
+
293
+ return model
MedImageInsight/Utils/Arguments.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ import yaml
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def add_env_parser_to_yaml():
12
+ """
13
+ Adding ability of resolving environment variables to the yaml SafeLoader.
14
+ Environment variables in the form of "${<env_var_name>}" can be resolved as strings.
15
+ If the <env_var_name> is not in the env, <env_var_name> itself would be used.
16
+
17
+ E.g.:
18
+ config:
19
+ username: admin
20
+ password: ${SERVICE_PASSWORD}
21
+ service: https://${SERVICE_HOST}/service
22
+ """
23
+ loader = yaml.SafeLoader
24
+ env_pattern = re.compile(r".*?\${(.*?)}.*?")
25
+
26
+ def env_constructor(loader, node):
27
+ value = loader.construct_scalar(node)
28
+ for group in env_pattern.findall(value):
29
+ value = value.replace(f"${{{group}}}", os.environ.get(group, group))
30
+ return value
31
+
32
+ yaml.add_implicit_resolver("!ENV", env_pattern, Loader=loader)
33
+ yaml.add_constructor("!ENV", env_constructor, Loader=loader)
34
+
35
+
36
+ def load_config_dict_to_opt(opt, config_dict, splitter='.', log_new=False):
37
+ """
38
+ Load the key, value pairs from config_dict to opt, overriding existing values in opt
39
+ if there is any.
40
+ """
41
+ if not isinstance(config_dict, dict):
42
+ raise TypeError("Config must be a Python dictionary")
43
+ for k, v in config_dict.items():
44
+ k_parts = k.split(splitter)
45
+ pointer = opt
46
+ for k_part in k_parts[:-1]:
47
+ if '[' in k_part and ']' in k_part:
48
+ # for the format "a.b[0][1].c: d"
49
+ k_part_splits = k_part.split('[')
50
+ k_part = k_part_splits.pop(0)
51
+ pointer = pointer[k_part]
52
+ for i in k_part_splits:
53
+ assert i[-1] == ']'
54
+ pointer = pointer[int(i[:-1])]
55
+ else:
56
+ if k_part not in pointer:
57
+ pointer[k_part] = {}
58
+ pointer = pointer[k_part]
59
+ assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
60
+ if '[' in k_parts[-1] and ']' in k_parts[-1]:
61
+ k_part_splits = k_parts[-1].split('[')
62
+ k_part = k_part_splits.pop(0)
63
+ pointer = pointer[k_part]
64
+ for i in k_part_splits[:-1]:
65
+ assert i[-1] == ']'
66
+ pointer = pointer[int(i[:-1])]
67
+ assert k_part_splits[-1][-1] == ']'
68
+ ori_value = pointer[int(k_part_splits[-1][:-1])]
69
+ pointer[int(k_part_splits[-1][:-1])] = v
70
+ else:
71
+ ori_value = pointer.get(k_parts[-1])
72
+ pointer[k_parts[-1]] = v
73
+ if ori_value:
74
+ logger.warning(f"Overrided {k} from {ori_value} to {v}")
75
+ elif log_new:
76
+ logger.warning(f"Added {k}: {v}")
77
+
78
+
79
+ def load_opt_from_config_files(conf_files):
80
+ """
81
+ Load opt from the config files, settings in later files can override those in previous files.
82
+
83
+ Args:
84
+ conf_files (list): a list of config file paths
85
+
86
+ Returns:
87
+ dict: a dictionary of opt settings
88
+ """
89
+ opt = {}
90
+ for conf_file in conf_files:
91
+ with open(conf_file, encoding='utf-8') as f:
92
+ # config_dict = yaml.safe_load(f)
93
+ config_dict = yaml.unsafe_load(f)
94
+
95
+ load_config_dict_to_opt(opt, config_dict)
96
+
97
+ return opt
98
+
99
+
100
+ def load_opt_command(args):
101
+ parser = argparse.ArgumentParser(description='MainzTrain: Pretrain or fine-tune models for NLP tasks.')
102
+ parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
103
+ parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the MainzTrain config file(s).')
104
+ parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')
105
+ parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
106
+
107
+ cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
108
+
109
+ add_env_parser_to_yaml()
110
+ opt = load_opt_from_config_files(cmdline_args.conf_files)
111
+
112
+ if cmdline_args.config_overrides:
113
+ config_overrides_string = ' '.join(cmdline_args.config_overrides)
114
+ config_overrides_string = os.path.expandvars(config_overrides_string)
115
+ logger.warning(f"Command line config overrides: {config_overrides_string}")
116
+ config_dict = yaml.safe_load(config_overrides_string)
117
+ load_config_dict_to_opt(opt, config_dict)
118
+
119
+ # combine cmdline_args into opt dictionary
120
+ for key, val in cmdline_args.__dict__.items():
121
+ if val is not None:
122
+ opt[key] = val
123
+
124
+ return opt, cmdline_args
125
+
126
+
127
+ def save_opt_to_json(opt, conf_file):
128
+ with open(conf_file, 'w', encoding='utf-8') as f:
129
+ json.dump(opt, f, indent=4)
130
+
131
+
132
+ def save_opt_to_yaml(opt, conf_file):
133
+ with open(conf_file, 'w', encoding='utf-8') as f:
134
+ yaml.dump(opt, f)
MedImageInsight/Utils/GeneraUtils.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ import copy
4
+ import itertools
5
+ import random
6
+ from collections.abc import Iterable, Iterator
7
+ import torch
8
+ from torch._C import default_generator
9
+ import torch.distributed as dist
10
+ import time
11
+ from functools import wraps, partial
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class ObjectView(object):
17
+ def __init__(self, d):
18
+ self.__dict__ = d
19
+
20
+
21
+ class AverageMeter(object):
22
+ """Computes and stores the average and current value."""
23
+ def __init__(self):
24
+ self.reset()
25
+
26
+ def reset(self):
27
+ self.val = 0
28
+ self.avg = 0
29
+ self.sum = 0
30
+ self.count = 0
31
+
32
+ def update(self, val, n=1, decay=0):
33
+ self.val = val
34
+ if decay:
35
+ alpha = math.exp(-n / decay) # exponential decay over 100 updates
36
+ self.sum = alpha * self.sum + (1 - alpha) * val * n
37
+ self.count = alpha * self.count + (1 - alpha) * n
38
+ else:
39
+ self.sum += val * n
40
+ self.count += n
41
+ self.avg = self.sum / self.count
42
+
43
+ def getstate(self):
44
+ return {'val': self.val,
45
+ 'avg': self.avg,
46
+ 'sum': self.sum,
47
+ 'count': self.count}
48
+
49
+ def setstate(self, state):
50
+ self.val = state['val']
51
+ self.avg = state['avg']
52
+ self.sum = state['sum']
53
+ self.count = state['count']
54
+
55
+
56
+ def move_batch_to_device(batch, device):
57
+ """
58
+ Move the batch to the device.
59
+ It should be called before feeding the batch to the model.
60
+
61
+ Args:
62
+ batch (torch.tensor or container of torch.tensor): input batch
63
+ device (torch.device): device to move the batch to
64
+ Returns:
65
+ return_batch: same type as the input batch with internal tensors moved to device
66
+ """
67
+ if torch.is_tensor(batch):
68
+ return_batch = batch.to(device)
69
+ elif isinstance(batch, list):
70
+ return_batch = [move_batch_to_device(t, device) for t in batch]
71
+ elif isinstance(batch, tuple):
72
+ return_batch = tuple(move_batch_to_device(t, device) for t in batch)
73
+ elif isinstance(batch, dict):
74
+ return_batch = {}
75
+ for k in batch:
76
+ return_batch[k] = move_batch_to_device(batch[k], device)
77
+ else:
78
+ logger.debug(f"Can not move type {type(batch)} to device. Skipping it in the batch.")
79
+ return_batch = batch
80
+
81
+ return return_batch
82
+
83
+
84
+ def cast_batch_to_dtype(batch, dtype):
85
+ """
86
+ Cast the float32 tensors in a batch to a specified torch dtype.
87
+ It should be called before feeding the batch to the FP16 DeepSpeed model.
88
+
89
+ Args:
90
+ batch (torch.tensor or container of torch.tensor): input batch
91
+ Returns:
92
+ return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype.
93
+ """
94
+ if torch.is_tensor(batch):
95
+ if torch.is_floating_point(batch):
96
+ return_batch = batch.to(dtype)
97
+ else:
98
+ return_batch = batch
99
+ elif isinstance(batch, list):
100
+ return_batch = [cast_batch_to_dtype(t, dtype) for t in batch]
101
+ elif isinstance(batch, tuple):
102
+ return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch)
103
+ elif isinstance(batch, dict):
104
+ return_batch = {}
105
+ for k in batch:
106
+ return_batch[k] = cast_batch_to_dtype(batch[k], dtype)
107
+ else:
108
+ logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.")
109
+ return_batch = batch
110
+
111
+ return return_batch
112
+
113
+
114
+ def cast_batch_to_half(batch):
115
+ """
116
+ Cast the float32 tensors in a batch to float16.
117
+ It should be called before feeding the batch to the FP16 DeepSpeed model.
118
+
119
+ Args:
120
+ batch (torch.tensor or container of torch.tensor): input batch
121
+ Returns:
122
+ return_batch: same type as the input batch with internal float32 tensors casted to float16
123
+ """
124
+ return cast_batch_to_dtype(batch, torch.float16)
125
+
126
+
127
+ def cast_batch_to_bf16(batch):
128
+ """
129
+ Cast the float32 tensors in a batch to bfloat16.
130
+ It should be called before feeding the batch to the FP16 DeepSpeed model.
131
+
132
+ Args:
133
+ batch (torch.tensor or container of torch.tensor): input batch
134
+ Returns:
135
+ return_batch: same type as the input batch with internal float32 tensors casted to bfloat16
136
+ """
137
+ return cast_batch_to_dtype(batch, torch.bfloat16)
138
+
139
+
140
+ # copied from MainzSpeech/moe_tools
141
+ def peek_first_item_from_iterator(it):
142
+ # extract first item from iterator
143
+ first_item = next(it)
144
+ # create iterator with the first item added back in
145
+ new_it = itertools.chain([copy.deepcopy(first_item)], it)
146
+ return first_item, new_it
147
+
148
+
149
+ # copied from MainzSpeech/moe_tools
150
+ def generate_dummy_batch(it):
151
+ """
152
+ Generates a dummy batch by peeking at given iterable or iterator on rank 0,
153
+ then broadcast dummy_batch to all other ranks.
154
+ """
155
+ from mpi4py import MPI
156
+ assert isinstance(it, Iterable) or isinstance(it, Iterator)
157
+ if isinstance(it, Iterable):
158
+ it = iter(it)
159
+ if MPI.COMM_WORLD.Get_rank() == 0:
160
+ dummy_batch, it = peek_first_item_from_iterator(it)
161
+ else:
162
+ dummy_batch = None
163
+ dummy_batch = MPI.COMM_WORLD.bcast(dummy_batch, root=0)
164
+ assert dummy_batch is not None
165
+ return dummy_batch, it
166
+
167
+
168
+ def retry_on_failure(func=None, *, max_retries=3, on_error_func=None, on_retry_func=None, raise_err_func=None, sleep_time=30, error_types=(Exception,)):
169
+ """
170
+ Decorator utility to retry a function, this decorator must be used without arguments (@retry_on_failure) or with all named arguments (@retry_on_failure(max_retries=10)).
171
+ Args:
172
+ max_retries (int): The number of retries to perform, in addition to the initial retry. Defaults to 3.
173
+ sleep_time (int): The time in seconds to wait before the next retry. Defaults to 30.
174
+ error_types (Tuple[type]): a tuple of exception types which are used to except any error being retried, if the exception that is thrown is not an instance of one of these types, the function is not retried. Defaults to (Exception,) which covers all exceptions.
175
+ on_retry_func (callable(num_retries)): A function with a single argument, the number of retries done so far. This function is called just before any retry. Defaults to a function logging `num_retries`.
176
+ on_error_func (callable(num_retries)): A function with a single argument, the number of retries done in total. This function is called after `max_retries` has been tried. Defaults to a function logging `num_retries`.
177
+ raise_err_func (callable(err)): A function with a single argument, the exception that was thrown. This function is called after `max_retries` has been tried. Defaults to raising the error.
178
+ """
179
+ if on_error_func is None:
180
+ def on_error_func(retried_times):
181
+ logger.warning(f"Failed after retrying {retried_times} times")
182
+
183
+ if on_retry_func is None:
184
+ def on_retry_func(idx):
185
+ logger.warning(f"Retrying on failure {idx}")
186
+
187
+ if raise_err_func is None:
188
+ def raise_err_func(err):
189
+ raise err
190
+
191
+ if func is None:
192
+ return partial(
193
+ retry_on_failure,
194
+ max_retries=max_retries,
195
+ on_error_func=on_error_func,
196
+ on_retry_func=on_retry_func,
197
+ raise_err_func=raise_err_func,
198
+ sleep_time=sleep_time,
199
+ error_types=error_types,
200
+ )
201
+
202
+ @wraps(func)
203
+ def decorator(*args, **kwargs):
204
+ num_retries = 0
205
+ while True:
206
+ try:
207
+ return func(*args, **kwargs)
208
+ except error_types as err:
209
+ num_retries += 1
210
+ on_retry_func(num_retries)
211
+ if num_retries > max_retries:
212
+ on_error_func(num_retries)
213
+ raise_err_func(err)
214
+ time.sleep(sleep_time)
215
+
216
+ return decorator
217
+
218
+
219
+ class TemporaryRngState:
220
+ '''
221
+ Context manager for working with a temporary random number generator (RNG) state.
222
+ The constructor gets a random number from the Python RNG that is used as
223
+ (part of) the seed for the temporary RNG
224
+ and then stores the current RNG state to restore the it later on.
225
+ If add_rank_to_seed=True, the GPU rank is added to the seed.
226
+ This is useful to initialize MoE models
227
+ where the experts on different GPUs should be initialized independently.
228
+ Note that this feature requires torch.distributed to be initialized.
229
+ On enter, the context managers sets the RNG state to the random seed created in the constructor
230
+ to establish a temporary RNG state.
231
+ On exit, the context manager resets the RNG state to the previously remembered state.
232
+ Thereby, any RNG operations executed with this context manager
233
+ do not affect the global, non-temporary RNG state.
234
+ However, the usage of this context manager does advance the Python RNG
235
+ since it uses that RNG to generate the random seed in the constructor.
236
+ The context manager resets the Python RNG state and
237
+ the PyTorch RNG state for CPU and GPU (if cuda is initialized).
238
+ It does not currently reset the numpy RNG state.
239
+ '''
240
+ def __init__(self, add_rank_to_seed=False):
241
+ self.seed = random.randrange(2**32)
242
+ if add_rank_to_seed and dist.is_initialized():
243
+ self.seed += dist.get_rank()
244
+ self.python_rng_state = random.getstate()
245
+ self.torch_rng_state = torch.get_rng_state()
246
+ if torch.cuda.is_initialized():
247
+ self.torch_rng_state_cuda = torch.cuda.get_rng_state()
248
+
249
+ def __enter__(self):
250
+ # increment seed for different RNGs to avoid correlation
251
+ # in the (very unlikely) case that the different RNGs
252
+ # use the exact same algorithm
253
+ random.seed(self.seed)
254
+ # do not call torch.maunal_seed here, because that sets the seed of all GPUs
255
+ default_generator.manual_seed(self.seed + 1)
256
+ if torch.cuda.is_initialized():
257
+ torch.cuda.manual_seed(self.seed + 2) # only set seed of default cuda device
258
+
259
+ def __exit__(self, exc_type, exc_value, exc_traceback):
260
+ random.setstate(self.python_rng_state)
261
+ torch.set_rng_state(self.torch_rng_state)
262
+ if torch.cuda.is_initialized():
263
+ torch.cuda.set_rng_state(self.torch_rng_state_cuda)
MedImageInsight/Utils/GlobalExceptHook.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ _orig_except_hook = None
7
+
8
+
9
+ def _global_except_hook(exctype, value, traceback):
10
+ """Catches an unhandled exception and call MPI_Abort()."""
11
+ try:
12
+ if _orig_except_hook:
13
+ _orig_except_hook(exctype, value, traceback)
14
+ else:
15
+ sys.__excepthook__(exctype, value, traceback)
16
+
17
+ finally:
18
+ import mpi4py.MPI
19
+ rank = mpi4py.MPI.COMM_WORLD.Get_rank()
20
+ logger.warning("******************************************")
21
+ logger.warning("MainzTrainer:")
22
+ logger.warning(f" Uncaught exception on rank {rank}.")
23
+ logger.warning(" Calling MPI_Abort() to shut down MPI...")
24
+ logger.warning("******************************************")
25
+ logging.shutdown()
26
+
27
+ try:
28
+ import mpi4py.MPI
29
+ mpi4py.MPI.COMM_WORLD.Abort(1)
30
+ except Exception as e:
31
+ # Something is completely broken...
32
+ # There's nothing we can do any more
33
+ sys.stderr.write("Sorry, failed to stop MPI and the process may hang.\n")
34
+ sys.stderr.flush()
35
+ raise e
36
+
37
+
38
+ def add_hook():
39
+ """
40
+ Add a global hook function that captures all unhandled exceptions.
41
+ The function calls MPI_Abort() to force all processes abort.
42
+
43
+ An MPI runtime is expected to kill all of its child processes
44
+ if one of them exits abnormally or without calling `MPI_Finalize()`.
45
+ However, when a Python program run on `mpi4py`, the MPI runtime
46
+ often fails to detect a process failure, and the rest of the processes
47
+ hang infinitely.
48
+
49
+ See https://github.com/chainer/chainermn/issues/236 and
50
+ https://mpi4py.readthedocs.io/en/stable/mpi4py.run.html for more
51
+ information.
52
+ """
53
+ global _orig_except_hook
54
+
55
+ if _orig_except_hook is not None:
56
+ logger.warning("GlobalExceptHook.add_hook() seems to be called multiple times. Ignoring.")
57
+ return
58
+
59
+ logger.info("Adding global except hook for the distributed job to shutdown MPI if unhandled exception is raised on some of the ranks.")
60
+ _orig_except_hook = sys.excepthook
61
+ sys.excepthook = _global_except_hook
MedImageInsight/Utils/MPIAdapter.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from mpi4py import MPI
3
+ import os
4
+ import re
5
+ import subprocess
6
+ import torch
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class MPIAdapter:
12
+ """
13
+ MPIAdapter automatically detects and analyzes the training environment for distributed training
14
+ and offers methods to set up distributed training jobs.
15
+
16
+ For example, it determines whether training happens on AML, Philly, or locally.
17
+ It also determines variables such as the world size and the rank of each GPU.
18
+ """
19
+
20
+ def __init__(self, set_env_vars=True, master_address=None, port='55551'):
21
+ local_address = '127.0.0.1'
22
+ default_torch_distributed_port = str(port) # chosen arbitrarily
23
+
24
+ if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
25
+ # application was started without MPI
26
+ # default to single node with single process
27
+ self.env_info = 'no MPI'
28
+ self.world_size = 1
29
+ self.local_size = 1
30
+ self.rank = 0
31
+ self.local_rank = 0
32
+ self.master_address = local_address
33
+ self.master_port = default_torch_distributed_port
34
+ else:
35
+ # application was started with MPI
36
+ # get MPI parameters
37
+ self.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
38
+ self.local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
39
+ self.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
40
+ self.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
41
+
42
+ if master_address is not None:
43
+ self.master_address = master_address
44
+ self.master_port = default_torch_distributed_port
45
+ self.env_info = 'manually set master ip'
46
+ elif 'PHILLY_CONTAINER_IP' in os.environ:
47
+ # application is running on Philly
48
+ # read environment variables on master node and broadcast via MPI
49
+ self.env_info = 'philly'
50
+ if self.rank == 0:
51
+ self.master_address = os.environ['PHILLY_CONTAINER_IP']
52
+ self.master_port = os.environ['PHILLY_CONTAINER_PORT_RANGE_START']
53
+ else:
54
+ self.master_address = None
55
+ self.master_port = None
56
+ self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0)
57
+ self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0)
58
+ elif "AMLK8S_NUM_WORKER" in os.environ or "AZ_CMK8S_JOB_WORK_DIR" in os.environ:
59
+ # application is running on AMLK8S (ITP)
60
+ # read master address from a specific file.
61
+ self.env_info = 'AMLK8S (ITP)'
62
+ # from: https://k8s-wiki.azureml.com/faq.html
63
+ regexp = r"[\s\S]*export[\s]*DLTS_SD_worker0_IP=([0-9.]+)[\s|s]*"
64
+ with open("/dlts-runtime/env/init.env", 'r') as f:
65
+ line = f.read()
66
+ match = re.match(regexp, line)
67
+ if match:
68
+ self.master_address = str(match.group(1))
69
+ else:
70
+ # Did not find master node ip in file. It must be a single-node
71
+ # debugging job with custom "mpirun" command
72
+ assert self.world_size == self.local_size, \
73
+ "It's not a single-node debugging job on AMLK8S (ITP), but no master ip is found in file."
74
+ self.env_info = 'single-node AMLK8S (ITP) debugging job'
75
+ self.master_address = local_address
76
+ self.master_port = default_torch_distributed_port
77
+ elif 'AZ_BATCH_MASTER_NODE' in os.environ:
78
+ # application is running on multiple nodes on AML
79
+ self.env_info = 'multi-node AML'
80
+ master_node_params = os.environ['AZ_BATCH_MASTER_NODE'].split(':')
81
+ self.master_address = master_node_params[0]
82
+ self.master_port = default_torch_distributed_port
83
+ elif self.world_size == self.local_size:
84
+ # application is running with MPI on single node
85
+ self.env_info = 'single-node AML or other MPI environment'
86
+ self.master_address = local_address
87
+ self.master_port = default_torch_distributed_port
88
+ else:
89
+ # multi-node MPI environment, but not Philly or AML
90
+ # we use "hostname -I" command on rank 0 to get the master address
91
+ self.env_info = 'multi-node other MPI environment'
92
+ if self.rank == 0:
93
+ hostname_cmd = ["hostname -I"]
94
+ result = subprocess.check_output(hostname_cmd, shell=True)
95
+ self.master_address = result.decode('utf-8').split()[0]
96
+ self.master_port = default_torch_distributed_port
97
+ else:
98
+ self.master_address = None
99
+ self.master_port = None
100
+ self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0)
101
+ self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0)
102
+
103
+ self.init_method_url = f'tcp://{self.master_address}:{self.master_port}'
104
+ if set_env_vars:
105
+ self._set_env_vars()
106
+
107
+ def log_info(self):
108
+ """
109
+ Logs information about distributed training environment.
110
+ """
111
+ # use logger.warning because MainzTrain has a hidden convention
112
+ # of not printing logger.info messages on processes with rank > 0
113
+ logger.warning('----------------')
114
+ logger.warning('MPI Adapter data')
115
+ logger.warning('----------------')
116
+ logger.warning(f'environment info: {self.env_info}')
117
+ logger.warning(f'init method url: {self.init_method_url}')
118
+ logger.warning(f'world size: {self.world_size}')
119
+ logger.warning(f'local size: {self.local_size}')
120
+ logger.warning(f'rank: {self.rank}')
121
+ logger.warning(f'local rank: {self.local_rank}')
122
+ logger.warning(f'master address: {self.master_address}')
123
+ logger.warning(f'master port: {self.master_port}')
124
+ logger.warning('----------------')
125
+
126
+ def init_process_group(self, backend):
127
+ """
128
+ Initializes the default PyTorch distributed process group.
129
+ """
130
+ # use logger.warning because MainzTrain has a hidden convention
131
+ # of not printing logger.info messages on processes with rank > 0
132
+ logger.warning('trying to initialize process group ...')
133
+ torch.distributed.init_process_group(backend=backend,
134
+ init_method=self.init_method_url,
135
+ world_size=self.world_size,
136
+ rank=self.rank)
137
+ logger.warning('process group initialized')
138
+
139
+ def _set_env_vars(self):
140
+ """
141
+ Sets environment variables for world size, rank, local rank, master addr, and master port.
142
+ """
143
+ os.environ['WORLD_SIZE'] = str(self.world_size)
144
+ os.environ['RANK'] = str(self.rank)
145
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
146
+ os.environ['MASTER_ADDR'] = self.master_address
147
+ os.environ['MASTER_PORT'] = self.master_port
MedImageInsight/Utils/Utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import torch
4
+ import torch.distributed as dist
5
+ import yaml
6
+
7
+ from fvcore.nn import FlopCountAnalysis
8
+ from fvcore.nn import flop_count_table
9
+ from fvcore.nn import flop_count_str
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ NORM_MODULES = [
16
+ torch.nn.BatchNorm1d,
17
+ torch.nn.BatchNorm2d,
18
+ torch.nn.BatchNorm3d,
19
+ torch.nn.SyncBatchNorm,
20
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
21
+ torch.nn.GroupNorm,
22
+ torch.nn.InstanceNorm1d,
23
+ torch.nn.InstanceNorm2d,
24
+ torch.nn.InstanceNorm3d,
25
+ torch.nn.LayerNorm,
26
+ torch.nn.LocalResponseNorm,
27
+ ]
28
+
29
+ def register_norm_module(cls):
30
+ NORM_MODULES.append(cls)
31
+
32
+ return cls
33
+
34
+
35
+ def is_main_process():
36
+ rank = 0
37
+ if 'OMPI_COMM_WORLD_SIZE' in os.environ:
38
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
39
+
40
+ return rank == 0
41
+
42
+
43
+ @torch.no_grad()
44
+ def analysis_model(model, dump_input, verbose=False):
45
+ model.eval()
46
+ flops = FlopCountAnalysis(model, dump_input)
47
+ total = flops.total()
48
+ model.train()
49
+ params_total = sum(p.numel() for p in model.parameters())
50
+ params_learned = sum(
51
+ p.numel() for p in model.parameters() if p.requires_grad
52
+ )
53
+ logger.info(f"flop count table:\n {flop_count_table(flops)}")
54
+ if verbose:
55
+ logger.info(f"flop count str:\n {flop_count_str(flops)}")
56
+ logger.info(f" Total flops: {total/1000/1000:.3f}M,")
57
+ logger.info(f" Total params: {params_total/1000/1000:.3f}M,")
58
+ logger.info(f" Learned params: {params_learned/1000/1000:.3f}M")
59
+
60
+ return total, flop_count_table(flops), flop_count_str(flops)
61
+
62
+
63
+ def load_config_dict_to_opt(opt, config_dict, splitter='.'):
64
+ """
65
+ Load the key, value pairs from config_dict to opt, overriding existing values in opt
66
+ if there is any.
67
+ """
68
+ if not isinstance(config_dict, dict):
69
+ raise TypeError("Config must be a Python dictionary")
70
+ for k, v in config_dict.items():
71
+ k_parts = k.split(splitter)
72
+ pointer = opt
73
+ for k_part in k_parts[:-1]:
74
+ if k_part not in pointer:
75
+ pointer[k_part] = {}
76
+ pointer = pointer[k_part]
77
+ assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
78
+ ori_value = pointer.get(k_parts[-1])
79
+ pointer[k_parts[-1]] = v
80
+ if ori_value:
81
+ print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
82
+
83
+
84
+ def load_opt_from_config_file(conf_file):
85
+ """
86
+ Load opt from the config file.
87
+
88
+ Args:
89
+ conf_file: config file path
90
+
91
+ Returns:
92
+ dict: a dictionary of opt settings
93
+ """
94
+ opt = {}
95
+ with open(conf_file, encoding='utf-8') as f:
96
+ config_dict = yaml.safe_load(f)
97
+ load_config_dict_to_opt(opt, config_dict)
98
+
99
+ return opt
100
+
101
+ def cast_batch_to_dtype(batch, dtype):
102
+ """
103
+ Cast the float32 tensors in a batch to a specified torch dtype.
104
+ It should be called before feeding the batch to the FP16 DeepSpeed model.
105
+
106
+ Args:
107
+ batch (torch.tensor or container of torch.tensor): input batch
108
+ Returns:
109
+ return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype.
110
+ """
111
+ if torch.is_tensor(batch):
112
+ if torch.is_floating_point(batch):
113
+ return_batch = batch.to(dtype)
114
+ else:
115
+ return_batch = batch
116
+ elif isinstance(batch, list):
117
+ return_batch = [cast_batch_to_dtype(t, dtype) for t in batch]
118
+ elif isinstance(batch, tuple):
119
+ return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch)
120
+ elif isinstance(batch, dict):
121
+ return_batch = {}
122
+ for k in batch:
123
+ return_batch[k] = cast_batch_to_dtype(batch[k], dtype)
124
+ else:
125
+ logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.")
126
+ return_batch = batch
127
+
128
+ return return_batch
129
+
130
+
131
+ def cast_batch_to_half(batch):
132
+ """
133
+ Cast the float32 tensors in a batch to float16.
134
+ It should be called before feeding the batch to the FP16 DeepSpeed model.
135
+
136
+ Args:
137
+ batch (torch.tensor or container of torch.tensor): input batch
138
+ Returns:
139
+ return_batch: same type as the input batch with internal float32 tensors casted to float16
140
+ """
141
+ return cast_batch_to_dtype(batch, torch.float16)
MedImageInsight/Utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .Utils import analysis_model
2
+ from .Utils import is_main_process
3
+ from .Utils import register_norm_module
4
+ from .Utils import NORM_MODULES
5
+ from .Utils import load_config_dict_to_opt
6
+ from .Utils import load_opt_from_config_file
7
+ from .Utils import cast_batch_to_half
MedImageInsight/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .UniCLModel import build_unicl_model
2
+
3
+ __all__ = [
4
+ 'build_od_model',
5
+ 'build_unicl_model',
6
+ 'build_tokenizer_from_name',
7
+ 'get_image_preprocess',
8
+ 'build_unicl_matching_model'
9
+ ]