This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. README.md +4 -5
  2. SegVol_v1.pth +0 -3
  3. __pycache__/utils.cpython-39.pyc +0 -0
  4. app.py +0 -339
  5. model/LICENSE +0 -21
  6. model/README.md +0 -74
  7. model/__pycache__/inference_cpu.cpython-39.pyc +0 -0
  8. model/asset/FLARE22_Tr_0002_0000.nii.gz +0 -3
  9. model/asset/FLARE22_Tr_0005_0000.nii.gz +0 -3
  10. model/asset/FLARE22_Tr_0034_0000.nii.gz +0 -3
  11. model/asset/FLARE22_Tr_0045_0000.nii.gz +0 -3
  12. model/asset/model.png +0 -0
  13. model/asset/overview back.png +0 -0
  14. model/asset/overview.png +0 -0
  15. model/config/clip/config.json +0 -157
  16. model/config/clip/special_tokens_map.json +0 -1
  17. model/config/clip/tokenizer.json +0 -0
  18. model/config/clip/tokenizer_config.json +0 -1
  19. model/config/clip/vocab.json +0 -0
  20. model/config/config_demo.json +0 -8
  21. model/data_process/__pycache__/demo_data_process.cpython-39.pyc +0 -0
  22. model/data_process/demo_data_process.py +0 -95
  23. model/inference_cpu.py +0 -172
  24. model/inference_demo.py +0 -219
  25. model/network/__pycache__/model.cpython-39.pyc +0 -0
  26. model/network/model.py +0 -91
  27. model/script/inference_demo.sh +0 -8
  28. model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py +0 -172
  29. model/segment_anything_volumetric/__init__.py +0 -12
  30. model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc +0 -0
  31. model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc +0 -0
  32. model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc +0 -0
  33. model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc +0 -0
  34. model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc +0 -0
  35. model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc +0 -0
  36. model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc +0 -0
  37. model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc +0 -0
  38. model/segment_anything_volumetric/automatic_mask_generator.py +0 -372
  39. model/segment_anything_volumetric/build_sam.py +0 -111
  40. model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py +0 -709
  41. model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py +0 -232
  42. model/segment_anything_volumetric/modeling/__init__.py +0 -11
  43. model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc +0 -0
  44. model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc +0 -0
  45. model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc +0 -0
  46. model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc +0 -0
  47. model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc +0 -0
  48. model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc +0 -0
  49. model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc +0 -0
  50. model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: SegVol
3
- emoji: 🏢
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: streamlit
7
- sdk_version: 1.28.2
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: SegVol
3
+ emoji: 📈
4
+ colorFrom: gray
5
+ colorTo: red
6
  sdk: streamlit
7
+ sdk_version: 1.29.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
SegVol_v1.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b751dc95f1a0c0c6086c1e6fa7f8a17bbb87635e5226e15f5d156fbd364dbb85
3
- size 1660308695
 
 
 
 
__pycache__/utils.cpython-39.pyc DELETED
Binary file (3.88 kB)
 
app.py DELETED
@@ -1,339 +0,0 @@
1
- import streamlit as st
2
- from streamlit_drawable_canvas import st_canvas
3
- from streamlit_image_coordinates import streamlit_image_coordinates
4
-
5
-
6
- from model.data_process.demo_data_process import process_ct_gt
7
- import numpy as np
8
- import matplotlib.pyplot as plt
9
- from PIL import Image, ImageDraw
10
- import monai.transforms as transforms
11
- from utils import show_points, make_fig, reflect_points_into_model, initial_rectangle, reflect_json_data_to_3D_box, reflect_box_into_model, run
12
- import nibabel as nib
13
- import tempfile
14
-
15
- print('script run')
16
-
17
- #############################################
18
- # init session_state
19
- if 'option' not in st.session_state:
20
- st.session_state.option = None
21
- if 'text_prompt' not in st.session_state:
22
- st.session_state.text_prompt = None
23
-
24
- if 'reset_demo_case' not in st.session_state:
25
- st.session_state.reset_demo_case = False
26
-
27
- if 'preds_3D' not in st.session_state:
28
- st.session_state.preds_3D = None
29
- st.session_state.preds_3D_ori = None
30
-
31
- if 'data_item' not in st.session_state:
32
- st.session_state.data_item = None
33
-
34
- if 'points' not in st.session_state:
35
- st.session_state.points = []
36
-
37
- if 'use_text_prompt' not in st.session_state:
38
- st.session_state.use_text_prompt = False
39
-
40
- if 'use_point_prompt' not in st.session_state:
41
- st.session_state.use_point_prompt = False
42
-
43
- if 'use_box_prompt' not in st.session_state:
44
- st.session_state.use_box_prompt = False
45
-
46
- if 'rectangle_3Dbox' not in st.session_state:
47
- st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
48
-
49
- if 'irregular_box' not in st.session_state:
50
- st.session_state.irregular_box = False
51
-
52
- if 'running' not in st.session_state:
53
- st.session_state.running = False
54
-
55
- if 'transparency' not in st.session_state:
56
- st.session_state.transparency = 0.25
57
-
58
- case_list = [
59
- 'model/asset/FLARE22_Tr_0002_0000.nii.gz',
60
- 'model/asset/FLARE22_Tr_0005_0000.nii.gz',
61
- 'model/asset/FLARE22_Tr_0034_0000.nii.gz',
62
- 'model/asset/FLARE22_Tr_0045_0000.nii.gz'
63
- ]
64
-
65
- #############################################
66
-
67
- #############################################
68
- # reset functions
69
- def clear_prompts():
70
- st.session_state.points = []
71
- st.session_state.rectangle_3Dbox = [0,0,0,0,0,0]
72
-
73
- def reset_demo_case():
74
- st.session_state.data_item = None
75
- st.session_state.reset_demo_case = True
76
- clear_prompts()
77
-
78
- def clear_file():
79
- st.session_state.option = None
80
- process_ct_gt.clear()
81
- reset_demo_case()
82
- clear_prompts()
83
-
84
- #############################################
85
-
86
- st.image(Image.open('model/asset/overview back.png'), use_column_width=True)
87
-
88
- github_col, arxive_col = st.columns(2)
89
-
90
- with github_col:
91
- st.write('GitHub repo:https://github.com/BAAI-DCAI/SegVol')
92
-
93
- with arxive_col:
94
- st.write('Paper:https://arxiv.org/abs/2311.13385')
95
-
96
-
97
- # modify demo case here
98
- demo_type = st.radio(
99
- "Demo case source",
100
- ["Select", "Upload"],
101
- on_change=clear_file
102
- )
103
-
104
- if demo_type=="Select":
105
- uploaded_file = st.selectbox(
106
- "Select a demo case",
107
- case_list,
108
- index=None,
109
- placeholder="Select a demo case...",
110
- on_change=reset_demo_case
111
- )
112
- else:
113
- uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type='nii.gz', on_change=reset_demo_case)
114
-
115
- st.session_state.option = uploaded_file
116
-
117
- if st.session_state.option is not None and \
118
- st.session_state.reset_demo_case or (st.session_state.data_item is None and st.session_state.option is not None):
119
-
120
- st.session_state.data_item = process_ct_gt(st.session_state.option)
121
- st.session_state.reset_demo_case = False
122
- st.session_state.preds_3D = None
123
- st.session_state.preds_3D_ori = None
124
-
125
- prompt_col1, prompt_col2 = st.columns(2)
126
-
127
- with prompt_col1:
128
- st.session_state.use_text_prompt = st.toggle('Sematic prompt')
129
- text_prompt_type = st.radio(
130
- "Sematic prompt type",
131
- ["Predefined", "Custom"],
132
- disabled=(not st.session_state.use_text_prompt)
133
- )
134
- if text_prompt_type == "Predefined":
135
- pre_text = st.selectbox(
136
- "Predefined anatomical category:",
137
- ['liver', 'right kidney', 'spleen', 'pancreas', 'aorta', 'inferior vena cava', 'right adrenal gland', 'left adrenal gland', 'gallbladder', 'esophagus', 'stomach', 'duodenum', 'left kidney'],
138
- index=None,
139
- disabled=(not st.session_state.use_text_prompt)
140
- )
141
- else:
142
- pre_text = st.text_input('Enter an Anatomical word or phrase:', None, max_chars=20,
143
- disabled=(not st.session_state.use_text_prompt))
144
- if pre_text is None or len(pre_text) > 0:
145
- st.session_state.text_prompt = pre_text
146
- else:
147
- st.session_state.text_prompt = None
148
-
149
-
150
- with prompt_col2:
151
- spatial_prompt_on = st.toggle('Spatial prompt', on_change=clear_prompts)
152
- spatial_prompt = st.radio(
153
- "Spatial prompt type",
154
- ["Point prompt", "Box prompt"],
155
- on_change=clear_prompts,
156
- disabled=(not spatial_prompt_on))
157
- st.session_state.enforce_zoom = st.checkbox('Enforce zoom-out-zoom-in')
158
-
159
- if spatial_prompt == "Point prompt":
160
- st.session_state.use_point_prompt = True
161
- st.session_state.use_box_prompt = False
162
- elif spatial_prompt == "Box prompt":
163
- st.session_state.use_box_prompt = True
164
- st.session_state.use_point_prompt = False
165
- else:
166
- st.session_state.use_point_prompt = False
167
- st.session_state.use_box_prompt = False
168
-
169
- if not spatial_prompt_on:
170
- st.session_state.use_point_prompt = False
171
- st.session_state.use_box_prompt = False
172
-
173
- if not st.session_state.use_text_prompt:
174
- st.session_state.text_prompt = None
175
-
176
- if st.session_state.option is None:
177
- st.write('please select demo case first')
178
- else:
179
- image_3D = st.session_state.data_item['z_image'][0].numpy()
180
- col_control1, col_control2 = st.columns(2)
181
-
182
- with col_control1:
183
- selected_index_z = st.slider('X-Y view', 0, image_3D.shape[0] - 1, 162, key='xy', disabled=st.session_state.running)
184
-
185
- with col_control2:
186
- selected_index_y = st.slider('X-Z view', 0, image_3D.shape[1] - 1, 162, key='xz', disabled=st.session_state.running)
187
- if st.session_state.use_box_prompt:
188
- top, bottom = st.select_slider(
189
- 'Top and bottom of box',
190
- options=range(0, 325),
191
- value=(0, 324),
192
- disabled=st.session_state.running
193
- )
194
- st.session_state.rectangle_3Dbox[0] = top
195
- st.session_state.rectangle_3Dbox[3] = bottom
196
- col_image1, col_image2 = st.columns(2)
197
-
198
- if st.session_state.preds_3D is not None:
199
- st.session_state.transparency = st.slider('Mask opacity', 0.0, 1.0, 0.25, disabled=st.session_state.running)
200
-
201
- with col_image1:
202
-
203
- image_z_array = image_3D[selected_index_z]
204
-
205
- preds_z_array = None
206
- if st.session_state.preds_3D is not None:
207
- preds_z_array = st.session_state.preds_3D[selected_index_z]
208
-
209
- image_z = make_fig(image_z_array, preds_z_array, st.session_state.points, selected_index_z, 'xy')
210
-
211
-
212
- if st.session_state.use_point_prompt:
213
- value_xy = streamlit_image_coordinates(image_z, width=325)
214
-
215
- if value_xy is not None:
216
- point_ax_xy = (selected_index_z, value_xy['y'], value_xy['x'])
217
- if len(st.session_state.points) >= 3:
218
- st.warning('Max point num is 3', icon="⚠️")
219
- elif point_ax_xy not in st.session_state.points:
220
- st.session_state.points.append(point_ax_xy)
221
- print('point_ax_xy add rerun')
222
- st.rerun()
223
- elif st.session_state.use_box_prompt:
224
- canvas_result_xy = st_canvas(
225
- fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
226
- stroke_width=3,
227
- stroke_color='#2909F1',
228
- background_image=image_z,
229
- update_streamlit=True,
230
- height=325,
231
- width=325,
232
- drawing_mode='transform',
233
- point_display_radius=0,
234
- key="canvas_xy",
235
- initial_drawing=initial_rectangle,
236
- display_toolbar=True
237
- )
238
- try:
239
- print(canvas_result_xy.json_data['objects'][0]['angle'])
240
- if canvas_result_xy.json_data['objects'][0]['angle'] != 0:
241
- st.warning('Rotating is undefined behavior', icon="⚠️")
242
- st.session_state.irregular_box = True
243
- else:
244
- st.session_state.irregular_box = False
245
- reflect_json_data_to_3D_box(canvas_result_xy.json_data, view='xy')
246
- except:
247
- print('exception')
248
- pass
249
- else:
250
- st.image(image_z, use_column_width=False)
251
-
252
- with col_image2:
253
- image_y_array = image_3D[:, selected_index_y, :]
254
-
255
- preds_y_array = None
256
- if st.session_state.preds_3D is not None:
257
- preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
258
-
259
- image_y = make_fig(image_y_array, preds_y_array, st.session_state.points, selected_index_y, 'xz')
260
-
261
- if st.session_state.use_point_prompt:
262
- value_yz = streamlit_image_coordinates(image_y, width=325)
263
-
264
- if value_yz is not None:
265
- point_ax_xz = (value_yz['y'], selected_index_y, value_yz['x'])
266
- if len(st.session_state.points) >= 3:
267
- st.warning('Max point num is 3', icon="⚠️")
268
- elif point_ax_xz not in st.session_state.points:
269
- st.session_state.points.append(point_ax_xz)
270
- print('point_ax_xz add rerun')
271
- st.rerun()
272
- elif st.session_state.use_box_prompt:
273
- if st.session_state.rectangle_3Dbox[1] <= selected_index_y and selected_index_y <= st.session_state.rectangle_3Dbox[4]:
274
- draw = ImageDraw.Draw(image_y)
275
- #rectangle xz view (upper-left and lower-right)
276
- rectangle_coords = [(st.session_state.rectangle_3Dbox[2], st.session_state.rectangle_3Dbox[0]),
277
- (st.session_state.rectangle_3Dbox[5], st.session_state.rectangle_3Dbox[3])]
278
- # Draw the rectangle on the image
279
- draw.rectangle(rectangle_coords, outline='#2909F1', width=3)
280
- st.image(image_y, use_column_width=False)
281
- else:
282
- st.image(image_y, use_column_width=False)
283
-
284
-
285
- col1, col2, col3 = st.columns(3)
286
-
287
- with col1:
288
- if st.button("Clear", use_container_width=True,
289
- disabled=(st.session_state.option is None or (len(st.session_state.points)==0 and not st.session_state.use_box_prompt and st.session_state.preds_3D is None))):
290
- clear_prompts()
291
- st.session_state.preds_3D = None
292
- st.session_state.preds_3D_ori = None
293
- st.rerun()
294
-
295
- with col2:
296
- img_nii = None
297
- if st.session_state.preds_3D_ori is not None and st.session_state.data_item is not None:
298
- meta_dict = st.session_state.data_item['meta']
299
- foreground_start_coord = st.session_state.data_item['foreground_start_coord']
300
- foreground_end_coord = st.session_state.data_item['foreground_end_coord']
301
- original_shape = st.session_state.data_item['ori_shape']
302
- pred_array = st.session_state.preds_3D_ori
303
- original_array = np.zeros(original_shape)
304
- original_array[foreground_start_coord[0]:foreground_end_coord[0],
305
- foreground_start_coord[1]:foreground_end_coord[1],
306
- foreground_start_coord[2]:foreground_end_coord[2]] = pred_array
307
-
308
- original_array = original_array.transpose(2, 1, 0)
309
- img_nii = nib.Nifti1Image(original_array, affine=meta_dict['affine'])
310
-
311
- with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
312
- nib.save(img_nii, tmpfile.name)
313
- with open(tmpfile.name, "rb") as f:
314
- bytes_data = f.read()
315
- st.download_button(
316
- label="Download result(.nii.gz)",
317
- data=bytes_data,
318
- file_name="segvol_preds.nii.gz",
319
- mime="application/octet-stream",
320
- disabled=img_nii is None
321
- )
322
-
323
- with col3:
324
- run_button_name = 'Run'if not st.session_state.running else 'Running'
325
- if st.button(run_button_name, type="primary", use_container_width=True,
326
- disabled=(
327
- st.session_state.data_item is None or
328
- (st.session_state.text_prompt is None and len(st.session_state.points) == 0 and st.session_state.use_box_prompt is False) or
329
- st.session_state.irregular_box or
330
- st.session_state.running
331
- )):
332
- st.session_state.running = True
333
- st.rerun()
334
-
335
- if st.session_state.running:
336
- st.session_state.running = False
337
- with st.status("Running...", expanded=False) as status:
338
- run()
339
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2023 BAAI-DCAI
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/README.md DELETED
@@ -1,74 +0,0 @@
1
- # SegVol: Universal and Interactive Volumetric Medical Image Segmentation
2
- This repo is the official implementation of [SegVol: Universal and Interactive Volumetric Medical Image Segmentation](https://arxiv.org/abs/2311.13385).
3
-
4
- ## News🚀
5
- (2023.11.24) *You can download weight files of SegVol and ViT(CTs pre-train) [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link).* 🔥
6
-
7
- (2023.11.23) *The brief introduction and instruction have been uploaded.*
8
-
9
- (2023.11.23) *The inference demo code has been uploaded.*
10
-
11
- (2023.11.22) *The first edition of our paper has been uploaded to arXiv.* 📃
12
-
13
- ## Introduction
14
- <img src="https://github.com/BAAI-DCAI/SegVol/blob/main/asset/overview.png" width="60%" height="60%">
15
-
16
- The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts **point**, **box** and **text** prompt while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.
17
-
18
- We will release SegVol's **inference code**, **training code**, **model params** and **ViT pre-training params** (pre-training is performed over 2,000 epochs on 96k CTs).
19
-
20
- ## Usage
21
- ### Requirements
22
- The [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) (or higher virsion) is needed first. Following install key requirements using commands:
23
-
24
- ```
25
- pip install 'monai[all]==0.9.0'
26
- pip install einops==0.6.1
27
- pip install transformers==4.18.0
28
- pip install matplotlib
29
- ```
30
- ### Config and run demo script
31
- 1. You can download the demo case [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link), or download the whole demo dataset [AbdomenCT-1K](https://github.com/JunMa11/AbdomenCT-1K) and choose any demo case you want.
32
- 2. Please set CT path and Ground Truth path of the case in the [config_demo.json](https://github.com/BAAI-DCAI/SegVol/blob/main/config/config_demo.json).
33
- 3. After that, config the [inference_demo.sh](https://github.com/BAAI-DCAI/SegVol/blob/main/script/inference_demo.sh) file for execution:
34
-
35
- - `$segvol_ckpt`: the path of SegVol's checkpoint (Download from [here](https://drive.google.com/drive/folders/1TEJtgctH534Ko5r4i79usJvqmXVuLf54?usp=drive_link)).
36
-
37
- - `$work_dir`: any path of folder you want to save the log files and visualizaion results.
38
-
39
- 4. Finally, you can control the **prompt type**, **zoom-in-zoom-out mechanism** and **visualizaion switch** [here](https://github.com/BAAI-DCAI/SegVol/blob/35f3ff9c943a74f630e6948051a1fe21aaba91bc/inference_demo.py#L208C11-L208C11).
40
- 5. Now, just run `bash script/inference_demo.sh` to infer your demo case.
41
-
42
- ## Citation
43
- If you find this repository helpful, please consider citing:
44
- ```
45
- @misc{du2023segvol,
46
- title={SegVol: Universal and Interactive Volumetric Medical Image Segmentation},
47
- author={Yuxin Du and Fan Bai and Tiejun Huang and Bo Zhao},
48
- year={2023},
49
- eprint={2311.13385},
50
- archivePrefix={arXiv},
51
- primaryClass={cs.CV}
52
- }
53
- ```
54
-
55
- ## Acknowledgement
56
- Thanks for the following amazing works:
57
-
58
- [HuggingFace](https://huggingface.co/).
59
-
60
- [CLIP](https://github.com/openai/CLIP).
61
-
62
- [MONAI](https://github.com/Project-MONAI/MONAI).
63
-
64
- [Image by brgfx](https://www.freepik.com/free-vector/anatomical-structure-human-bodies_26353260.htm) on Freepik.
65
-
66
- [Image by muammark](https://www.freepik.com/free-vector/people-icon-collection_1157380.htm#query=user&position=2&from_view=search&track=sph) on Freepik.
67
-
68
- [Image by pch.vector](https://www.freepik.com/free-vector/different-phone-hand-gestures-set_9649376.htm#query=Vector%20touch%20screen%20hand%20gestures&position=4&from_view=search&track=ais) on Freepik.
69
-
70
- [Image by starline](https://www.freepik.com/free-vector/set-three-light-bulb-represent-effective-business-idea-concept_37588597.htm#query=idea&position=0&from_view=search&track=sph) on Freepik.
71
-
72
-
73
-
74
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/__pycache__/inference_cpu.cpython-39.pyc DELETED
Binary file (4.77 kB)
 
model/asset/FLARE22_Tr_0002_0000.nii.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:eb16eced003524fa005e28b2822c0b53503f1223d758cdf72528fad359aa10ba
3
- size 30611274
 
 
 
 
model/asset/FLARE22_Tr_0005_0000.nii.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2be5019bfc7e805d5e24785bcd44ffe7720e13e38b2a3124ad25b454811b221c
3
- size 26615527
 
 
 
 
model/asset/FLARE22_Tr_0034_0000.nii.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:023c5d06ea2a6c8866c1e214ecee06a4447a8d0c50225142cdfdbbccc2bf8c66
3
- size 28821917
 
 
 
 
model/asset/FLARE22_Tr_0045_0000.nii.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:336b3719af673fd6fafe89d7d5d95d5f18239a9faccde9753703fc1465f43736
3
- size 32885093
 
 
 
 
model/asset/model.png DELETED
Binary file (888 kB)
 
model/asset/overview back.png DELETED
Binary file (237 kB)
 
model/asset/overview.png DELETED
Binary file (226 kB)
 
model/config/clip/config.json DELETED
@@ -1,157 +0,0 @@
1
- {
2
- "_name_or_path": "openai/clip-vit-base-patch32",
3
- "architectures": [
4
- "CLIPModel"
5
- ],
6
- "initializer_factor": 1.0,
7
- "logit_scale_init_value": 2.6592,
8
- "model_type": "clip",
9
- "projection_dim": 512,
10
- "text_config": {
11
- "_name_or_path": "",
12
- "add_cross_attention": false,
13
- "architectures": null,
14
- "attention_dropout": 0.0,
15
- "bad_words_ids": null,
16
- "bos_token_id": 0,
17
- "chunk_size_feed_forward": 0,
18
- "cross_attention_hidden_size": null,
19
- "decoder_start_token_id": null,
20
- "diversity_penalty": 0.0,
21
- "do_sample": false,
22
- "dropout": 0.0,
23
- "early_stopping": false,
24
- "encoder_no_repeat_ngram_size": 0,
25
- "eos_token_id": 2,
26
- "finetuning_task": null,
27
- "forced_bos_token_id": null,
28
- "forced_eos_token_id": null,
29
- "hidden_act": "quick_gelu",
30
- "hidden_size": 512,
31
- "id2label": {
32
- "0": "LABEL_0",
33
- "1": "LABEL_1"
34
- },
35
- "initializer_factor": 1.0,
36
- "initializer_range": 0.02,
37
- "intermediate_size": 2048,
38
- "is_decoder": false,
39
- "is_encoder_decoder": false,
40
- "label2id": {
41
- "LABEL_0": 0,
42
- "LABEL_1": 1
43
- },
44
- "layer_norm_eps": 1e-05,
45
- "length_penalty": 1.0,
46
- "max_length": 20,
47
- "max_position_embeddings": 77,
48
- "min_length": 0,
49
- "model_type": "clip_text_model",
50
- "no_repeat_ngram_size": 0,
51
- "num_attention_heads": 8,
52
- "num_beam_groups": 1,
53
- "num_beams": 1,
54
- "num_hidden_layers": 12,
55
- "num_return_sequences": 1,
56
- "output_attentions": false,
57
- "output_hidden_states": false,
58
- "output_scores": false,
59
- "pad_token_id": 1,
60
- "prefix": null,
61
- "projection_dim": 512,
62
- "problem_type": null,
63
- "pruned_heads": {},
64
- "remove_invalid_values": false,
65
- "repetition_penalty": 1.0,
66
- "return_dict": true,
67
- "return_dict_in_generate": false,
68
- "sep_token_id": null,
69
- "task_specific_params": null,
70
- "temperature": 1.0,
71
- "tie_encoder_decoder": false,
72
- "tie_word_embeddings": true,
73
- "tokenizer_class": null,
74
- "top_k": 50,
75
- "top_p": 1.0,
76
- "torch_dtype": null,
77
- "torchscript": false,
78
- "transformers_version": "4.16.0.dev0",
79
- "use_bfloat16": false,
80
- "vocab_size": 49408
81
- },
82
- "text_config_dict": null,
83
- "transformers_version": null,
84
- "vision_config": {
85
- "_name_or_path": "",
86
- "add_cross_attention": false,
87
- "architectures": null,
88
- "attention_dropout": 0.0,
89
- "bad_words_ids": null,
90
- "bos_token_id": null,
91
- "chunk_size_feed_forward": 0,
92
- "cross_attention_hidden_size": null,
93
- "decoder_start_token_id": null,
94
- "diversity_penalty": 0.0,
95
- "do_sample": false,
96
- "dropout": 0.0,
97
- "early_stopping": false,
98
- "encoder_no_repeat_ngram_size": 0,
99
- "eos_token_id": null,
100
- "finetuning_task": null,
101
- "forced_bos_token_id": null,
102
- "forced_eos_token_id": null,
103
- "hidden_act": "quick_gelu",
104
- "hidden_size": 768,
105
- "id2label": {
106
- "0": "LABEL_0",
107
- "1": "LABEL_1"
108
- },
109
- "image_size": 224,
110
- "initializer_factor": 1.0,
111
- "initializer_range": 0.02,
112
- "intermediate_size": 3072,
113
- "is_decoder": false,
114
- "is_encoder_decoder": false,
115
- "label2id": {
116
- "LABEL_0": 0,
117
- "LABEL_1": 1
118
- },
119
- "layer_norm_eps": 1e-05,
120
- "length_penalty": 1.0,
121
- "max_length": 20,
122
- "min_length": 0,
123
- "model_type": "clip_vision_model",
124
- "no_repeat_ngram_size": 0,
125
- "num_attention_heads": 12,
126
- "num_beam_groups": 1,
127
- "num_beams": 1,
128
- "num_hidden_layers": 12,
129
- "num_return_sequences": 1,
130
- "output_attentions": false,
131
- "output_hidden_states": false,
132
- "output_scores": false,
133
- "pad_token_id": null,
134
- "patch_size": 32,
135
- "prefix": null,
136
- "projection_dim" : 512,
137
- "problem_type": null,
138
- "pruned_heads": {},
139
- "remove_invalid_values": false,
140
- "repetition_penalty": 1.0,
141
- "return_dict": true,
142
- "return_dict_in_generate": false,
143
- "sep_token_id": null,
144
- "task_specific_params": null,
145
- "temperature": 1.0,
146
- "tie_encoder_decoder": false,
147
- "tie_word_embeddings": true,
148
- "tokenizer_class": null,
149
- "top_k": 50,
150
- "top_p": 1.0,
151
- "torch_dtype": null,
152
- "torchscript": false,
153
- "transformers_version": "4.16.0.dev0",
154
- "use_bfloat16": false
155
- },
156
- "vision_config_dict": null
157
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/config/clip/special_tokens_map.json DELETED
@@ -1 +0,0 @@
1
- {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
 
 
model/config/clip/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
model/config/clip/tokenizer_config.json DELETED
@@ -1 +0,0 @@
1
- {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "./clip_ViT_B_32/"}
 
 
model/config/clip/vocab.json DELETED
The diff for this file is too large to render. See raw diff
 
model/config/config_demo.json DELETED
@@ -1,8 +0,0 @@
1
- {
2
- "dataset_name": "AbdomenCT-1k",
3
- "categories": ["liver", "kidney", "spleen", "pancreas"],
4
- "demo_case": {
5
- "ct_path": "path/to/Case_image",
6
- "gt_path": "path/to/Case_label"
7
- }
8
- }
 
 
 
 
 
 
 
 
 
model/data_process/__pycache__/demo_data_process.cpython-39.pyc DELETED
Binary file (3.4 kB)
 
model/data_process/demo_data_process.py DELETED
@@ -1,95 +0,0 @@
1
- import numpy as np
2
- import monai.transforms as transforms
3
- import streamlit as st
4
- import tempfile
5
-
6
- class MinMaxNormalization(transforms.Transform):
7
- def __call__(self, data):
8
- d = dict(data)
9
- k = "image"
10
- d[k] = d[k] - d[k].min()
11
- d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
12
- return d
13
-
14
- class DimTranspose(transforms.Transform):
15
- def __init__(self, keys):
16
- self.keys = keys
17
-
18
- def __call__(self, data):
19
- d = dict(data)
20
- for key in self.keys:
21
- d[key] = np.swapaxes(d[key], -1, -3)
22
- return d
23
-
24
- class ForegroundNormalization(transforms.Transform):
25
- def __init__(self, keys):
26
- self.keys = keys
27
-
28
- def __call__(self, data):
29
- d = dict(data)
30
-
31
- for key in self.keys:
32
- d[key] = self.normalize(d[key])
33
- return d
34
-
35
- def normalize(self, ct_narray):
36
- ct_voxel_ndarray = ct_narray.copy()
37
- ct_voxel_ndarray = ct_voxel_ndarray.flatten()
38
- thred = np.mean(ct_voxel_ndarray)
39
- voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
40
- upper_bound = np.percentile(voxel_filtered, 99.95)
41
- lower_bound = np.percentile(voxel_filtered, 00.05)
42
- mean = np.mean(voxel_filtered)
43
- std = np.std(voxel_filtered)
44
- ### transform ###
45
- ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
46
- ct_narray = (ct_narray - mean) / max(std, 1e-8)
47
- return ct_narray
48
-
49
- @st.cache_data
50
- def process_ct_gt(case_path, spatial_size=(32,256,256)):
51
- if case_path is None:
52
- return None
53
- print('Data preprocessing...')
54
- # transform
55
- img_loader = transforms.LoadImage(dtype=np.float32)
56
- transform = transforms.Compose(
57
- [
58
- transforms.Orientationd(keys=["image"], axcodes="RAS"),
59
- ForegroundNormalization(keys=["image"]),
60
- DimTranspose(keys=["image"]),
61
- MinMaxNormalization(),
62
- transforms.SpatialPadd(keys=["image"], spatial_size=spatial_size, mode='constant'),
63
- transforms.CropForegroundd(keys=["image"], source_key="image"),
64
- transforms.ToTensord(keys=["image"]),
65
- ]
66
- )
67
- zoom_out_transform = transforms.Resized(keys=["image"], spatial_size=spatial_size, mode='nearest-exact')
68
- z_transform = transforms.Resized(keys=["image"], spatial_size=(325,325,325), mode='nearest-exact')
69
- ###
70
- item = {}
71
- # generate ct_voxel_ndarray
72
- if type(case_path) is str:
73
- ct_voxel_ndarray, meta_tensor_dict = img_loader(case_path)
74
- else:
75
- bytes_data = case_path.read()
76
- with tempfile.NamedTemporaryFile(suffix='.nii.gz') as tmp:
77
- tmp.write(bytes_data)
78
- tmp.seek(0)
79
- ct_voxel_ndarray, meta_tensor_dict = img_loader(tmp.name)
80
-
81
- ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
82
- ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
83
- item['image'] = ct_voxel_ndarray
84
- ori_shape = np.swapaxes(ct_voxel_ndarray, -1, -3).shape[1:]
85
-
86
- # transform
87
- item = transform(item)
88
- item_zoom_out = zoom_out_transform(item)
89
- item['zoom_out_image'] = item_zoom_out['image']
90
- item['ori_shape'] = ori_shape
91
-
92
- item_z = z_transform(item)
93
- item['z_image'] = item_z['image']
94
- item['meta'] = meta_tensor_dict
95
- return item
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/inference_cpu.py DELETED
@@ -1,172 +0,0 @@
1
- import argparse
2
- import os
3
- import torch
4
- import torch.nn.functional as F
5
- import json
6
- import monai.transforms as transforms
7
-
8
- from model.segment_anything_volumetric import sam_model_registry
9
- from model.network.model import SegVol
10
- from model.data_process.demo_data_process import process_ct_gt
11
- from model.utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
12
- from model.utils.visualize import draw_result
13
- import streamlit as st
14
-
15
- def set_parse():
16
- # %% set up parser
17
- parser = argparse.ArgumentParser()
18
- parser.add_argument("--test_mode", default=True, type=bool)
19
- parser.add_argument("--resume", type = str, default = 'SegVol_v1.pth')
20
- parser.add_argument("-infer_overlap", default=0.0, type=float, help="sliding window inference overlap")
21
- parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
22
- parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
23
- parser.add_argument('-work_dir', type=str, default='./work_dir')
24
- ### demo
25
- parser.add_argument("--clip_ckpt", type = str, default = 'model/config/clip')
26
- args = parser.parse_args()
27
- return args
28
-
29
- def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point_prompt, box_prompt):
30
- image_single_resize = image_resize
31
- image_single = image[0,0]
32
- ori_shape = image_single.shape
33
- resize_shape = image_single_resize.shape[2:]
34
-
35
- # generate prompts
36
- text_single = None if text_prompt is None else [text_prompt]
37
- points_single = None
38
- box_single = None
39
-
40
- if args.use_point_prompt:
41
- point, point_label = point_prompt
42
- points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float())
43
- binary_points_resize = build_binary_points(point, point_label, resize_shape)
44
- if args.use_box_prompt:
45
- box_single = box_prompt.unsqueeze(0).float()
46
- binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=resize_shape)
47
-
48
- ####################
49
- # zoom-out inference:
50
- print('--- zoom out inference ---')
51
- print(text_single)
52
- print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
53
- with torch.no_grad():
54
- logits_global_single = segvol_model(image_single_resize,
55
- text=text_single,
56
- boxes=box_single,
57
- points=points_single)
58
-
59
- # resize back global logits
60
- logits_global_single = F.interpolate(
61
- logits_global_single.cpu(),
62
- size=ori_shape, mode='nearest')[0][0]
63
-
64
- # build prompt reflection for zoom-in
65
- if args.use_point_prompt:
66
- binary_points = F.interpolate(
67
- binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
68
- size=ori_shape, mode='nearest')[0][0]
69
- if args.use_box_prompt:
70
- binary_cube = F.interpolate(
71
- binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
72
- size=ori_shape, mode='nearest')[0][0]
73
- # draw_result('unknow', image_single_resize, None, point_prompt, logits_global_single, logits_global_single)
74
- if not args.use_zoom_in:
75
- return logits_global_single
76
-
77
- ####################
78
- # zoom-in inference:
79
- min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
80
- if min_d is None:
81
- print('Fail to detect foreground!')
82
- return logits_global_single
83
-
84
- # Crop roi
85
- image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
86
- global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
87
-
88
- assert not (args.use_box_prompt and args.use_point_prompt)
89
- # label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
90
- prompt_reflection = None
91
- if args.use_box_prompt:
92
- binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
93
- prompt_reflection = (
94
- binary_cube_cropped.unsqueeze(0).unsqueeze(0),
95
- global_preds.unsqueeze(0).unsqueeze(0)
96
- )
97
- if args.use_point_prompt:
98
- binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
99
- prompt_reflection = (
100
- binary_points_cropped.unsqueeze(0).unsqueeze(0),
101
- global_preds.unsqueeze(0).unsqueeze(0)
102
- )
103
-
104
- ## inference
105
- with torch.no_grad():
106
- logits_single_cropped = sliding_window_inference(
107
- image_single_cropped, prompt_reflection,
108
- args.spatial_size, 1, segvol_model, args.infer_overlap,
109
- text=text_single,
110
- use_box=args.use_box_prompt,
111
- use_point=args.use_point_prompt,
112
- logits_global_single=logits_global_single,
113
- )
114
- logits_single_cropped = logits_single_cropped.cpu().squeeze()
115
- if logits_single_cropped.shape != logits_global_single.shape:
116
- logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
117
-
118
- return logits_global_single
119
-
120
- @st.cache_resource
121
- def build_model():
122
- # build model
123
- st.write('building model')
124
- clip_ckpt = 'model/config/clip'
125
- resume = 'SegVol_v1.pth'
126
- sam_model = sam_model_registry['vit']()
127
- segvol_model = SegVol(
128
- image_encoder=sam_model.image_encoder,
129
- mask_decoder=sam_model.mask_decoder,
130
- prompt_encoder=sam_model.prompt_encoder,
131
- clip_ckpt=clip_ckpt,
132
- roi_size=(32,256,256),
133
- patch_size=(4,16,16),
134
- test_mode=True,
135
- )
136
- segvol_model = torch.nn.DataParallel(segvol_model)
137
- segvol_model.eval()
138
- # load param
139
- if os.path.isfile(resume):
140
- ## Map model to be loaded to specified single GPU
141
- loc = 'cpu'
142
- checkpoint = torch.load(resume, map_location=loc)
143
- segvol_model.load_state_dict(checkpoint['model'], strict=True)
144
- print("loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch']))
145
- print('model build done!')
146
- return segvol_model
147
-
148
- @st.cache_data
149
- def inference_case(_image, _image_zoom_out, _point_prompt, text_prompt, _box_prompt):
150
- # seg config
151
- args = set_parse()
152
- args.use_zoom_in = True
153
- args.use_text_prompt = text_prompt is not None
154
- args.use_box_prompt = _box_prompt is not None
155
- args.use_point_prompt = _point_prompt is not None
156
-
157
- segvol_model = build_model()
158
-
159
- # run inference
160
- logits = zoom_in_zoom_out(
161
- args, segvol_model,
162
- _image.unsqueeze(0), _image_zoom_out.unsqueeze(0),
163
- text_prompt, _point_prompt, _box_prompt)
164
- print(logits.shape)
165
- resize_transform = transforms.Compose([
166
- transforms.AddChannel(),
167
- transforms.Resize((325,325,325), mode='trilinear')
168
- ]
169
- )
170
- logits_resize = resize_transform(logits)[0]
171
- return (torch.sigmoid(logits_resize) > 0.5).int().numpy(), (torch.sigmoid(logits) > 0.5).int().numpy()
172
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/inference_demo.py DELETED
@@ -1,219 +0,0 @@
1
- import argparse
2
- import os
3
- import torch
4
- import torch.nn.functional as F
5
- import json
6
- from segment_anything_volumetric import sam_model_registry
7
- from network.model import SegVol
8
- from data_process.demo_data_process import process_ct_gt
9
- import monai.transforms as transforms
10
- from utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
11
- from utils.visualize import draw_result
12
-
13
- def set_parse():
14
- # %% set up parser
15
- parser = argparse.ArgumentParser()
16
- parser.add_argument("--test_mode", default=True, type=bool)
17
- parser.add_argument("--resume", type = str, default = '')
18
- parser.add_argument("-infer_overlap", default=0.5, type=float, help="sliding window inference overlap")
19
- parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
20
- parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
21
- parser.add_argument('-work_dir', type=str, default='./work_dir')
22
- ### demo
23
- parser.add_argument('--demo_config', type=str, required=True)
24
- parser.add_argument("--clip_ckpt", type = str, default = './config/clip')
25
- args = parser.parse_args()
26
- return args
27
-
28
- def dice_score(preds, labels): # on GPU
29
- assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
30
- predict = preds.view(1, -1)
31
- target = labels.view(1, -1)
32
- if target.shape[1] < 1e8:
33
- predict = predict.cuda()
34
- target = target.cuda()
35
- predict = torch.sigmoid(predict)
36
- predict = torch.where(predict > 0.5, 1., 0.)
37
-
38
- tp = torch.sum(torch.mul(predict, target))
39
- den = torch.sum(predict) + torch.sum(target) + 1
40
- dice = 2 * tp / den
41
-
42
- if target.shape[1] < 1e8:
43
- predict = predict.cpu()
44
- target = target.cpu()
45
- return dice
46
-
47
- def zoom_in_zoom_out(args, segvol_model, image, image_resize, gt3D, gt3D_resize, categories=None):
48
- logits_labels_record = {}
49
- image_single_resize = image_resize
50
- image_single = image[0,0]
51
- ori_shape = image_single.shape
52
- for item_idx in range(len(categories)):
53
- # get label to generate prompts
54
- label_single = gt3D[0][item_idx]
55
- label_single_resize = gt3D_resize[0][item_idx]
56
- # skip meaningless categories
57
- if torch.sum(label_single) == 0:
58
- print('No object, skip')
59
- continue
60
- # generate prompts
61
- text_single = categories[item_idx] if args.use_text_prompt else None
62
- if categories is not None: print(f'inference |{categories[item_idx]}| target...')
63
- points_single = None
64
- box_single = None
65
- if args.use_point_prompt:
66
- point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
67
- points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
68
- binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
69
- if args.use_box_prompt:
70
- box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
71
- binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
72
-
73
- ####################
74
- # zoom-out inference:
75
- print('--- zoom out inference ---')
76
- print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
77
- with torch.no_grad():
78
- logits_global_single = segvol_model(image_single_resize.cuda(),
79
- text=text_single,
80
- boxes=box_single,
81
- points=points_single)
82
-
83
- # resize back global logits
84
- logits_global_single = F.interpolate(
85
- logits_global_single.cpu(),
86
- size=ori_shape, mode='nearest')[0][0]
87
-
88
- # build prompt reflection for zoom-in
89
- if args.use_point_prompt:
90
- binary_points = F.interpolate(
91
- binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
92
- size=ori_shape, mode='nearest')[0][0]
93
- if args.use_box_prompt:
94
- binary_cube = F.interpolate(
95
- binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
96
- size=ori_shape, mode='nearest')[0][0]
97
- zoom_out_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
98
- logits_labels_record[categories[item_idx]] = (
99
- zoom_out_dice,
100
- image_single,
101
- points_single,
102
- box_single,
103
- logits_global_single,
104
- label_single)
105
- print(f'zoom out inference done with zoom_out_dice: {zoom_out_dice:.4f}')
106
- if not args.use_zoom_in:
107
- continue
108
-
109
- ####################
110
- # zoom-in inference:
111
- min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
112
- if min_d is None:
113
- print('Fail to detect foreground!')
114
- continue
115
-
116
- # Crop roi
117
- image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
118
- global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
119
-
120
- assert not (args.use_box_prompt and args.use_point_prompt)
121
- # label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
122
- prompt_reflection = None
123
- if args.use_box_prompt:
124
- binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
125
- prompt_reflection = (
126
- binary_cube_cropped.unsqueeze(0).unsqueeze(0),
127
- global_preds.unsqueeze(0).unsqueeze(0)
128
- )
129
- if args.use_point_prompt:
130
- binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
131
- prompt_reflection = (
132
- binary_points_cropped.unsqueeze(0).unsqueeze(0),
133
- global_preds.unsqueeze(0).unsqueeze(0)
134
- )
135
-
136
- ## inference
137
- with torch.no_grad():
138
- logits_single_cropped = sliding_window_inference(
139
- image_single_cropped.cuda(), prompt_reflection,
140
- args.spatial_size, 1, segvol_model, args.infer_overlap,
141
- text=text_single,
142
- use_box=args.use_box_prompt,
143
- use_point=args.use_point_prompt,
144
- )
145
- logits_single_cropped = logits_single_cropped.cpu().squeeze()
146
- logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
147
- zoom_in_dice = dice_score(logits_global_single.squeeze(), label_single.squeeze())
148
- logits_labels_record[categories[item_idx]] = (
149
- zoom_in_dice,
150
- image_single,
151
- points_single,
152
- box_single,
153
- logits_global_single,
154
- label_single)
155
- print(f'===> zoom out dice {zoom_out_dice:.4f} -> zoom-out-zoom-in dice {zoom_in_dice:.4f} <===')
156
- return logits_labels_record
157
-
158
- def inference_single_ct(args, segvol_model, data_item, categories):
159
- segvol_model.eval()
160
- image, gt3D = data_item["image"].float(), data_item["label"]
161
- image_zoom_out, gt3D__zoom_out = data_item["zoom_out_image"].float(), data_item['zoom_out_label']
162
-
163
- logits_labels_record = zoom_in_zoom_out(
164
- args, segvol_model,
165
- image.unsqueeze(0), image_zoom_out.unsqueeze(0),
166
- gt3D.unsqueeze(0), gt3D__zoom_out.unsqueeze(0), # add batch dim
167
- categories=categories)
168
-
169
- # visualize
170
- if args.visualize:
171
- for target, values in logits_labels_record.items():
172
- dice_score, image, point_prompt, box_prompt, logits, labels = values
173
- print(f'{target} result with Dice score {dice_score:.4f} visualizing')
174
- draw_result(target + f"-Dice {dice_score:.4f}", image, box_prompt, point_prompt, logits, labels, args.spatial_size, args.work_dir)
175
-
176
- def main(args):
177
- gpu = 0
178
- torch.cuda.set_device(gpu)
179
- # build model
180
- sam_model = sam_model_registry['vit'](args=args)
181
- segvol_model = SegVol(
182
- image_encoder=sam_model.image_encoder,
183
- mask_decoder=sam_model.mask_decoder,
184
- prompt_encoder=sam_model.prompt_encoder,
185
- clip_ckpt=args.clip_ckpt,
186
- roi_size=args.spatial_size,
187
- patch_size=args.patch_size,
188
- test_mode=args.test_mode,
189
- ).cuda()
190
- segvol_model = torch.nn.DataParallel(segvol_model, device_ids=[gpu])
191
-
192
- # load param
193
- if os.path.isfile(args.resume):
194
- ## Map model to be loaded to specified single GPU
195
- loc = 'cuda:{}'.format(gpu)
196
- checkpoint = torch.load(args.resume, map_location=loc)
197
- segvol_model.load_state_dict(checkpoint['model'], strict=True)
198
- print("loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
199
-
200
- # load demo config
201
- with open(args.demo_config, 'r') as file:
202
- config_dict = json.load(file)
203
- ct_path, gt_path, categories = config_dict['demo_case']['ct_path'], config_dict['demo_case']['gt_path'], config_dict['categories']
204
-
205
- # preprocess for data
206
- data_item = process_ct_gt(ct_path, gt_path, categories, args.spatial_size) # keys: image, label
207
-
208
- # seg config for prompt & zoom-in-zoom-out
209
- args.use_zoom_in = True
210
- args.use_text_prompt = True
211
- args.use_box_prompt = True
212
- args.use_point_prompt = False
213
- args.visualize = False
214
-
215
- inference_single_ct(args, segvol_model, data_item, categories)
216
-
217
- if __name__ == "__main__":
218
- args = set_parse()
219
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/network/__pycache__/model.cpython-39.pyc DELETED
Binary file (3.29 kB)
 
model/network/model.py DELETED
@@ -1,91 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig
6
-
7
- #%% set up model
8
- class SegVol(nn.Module):
9
- def __init__(self,
10
- image_encoder,
11
- mask_decoder,
12
- prompt_encoder,
13
- clip_ckpt,
14
- roi_size,
15
- patch_size,
16
- test_mode=False,
17
- ):
18
- super().__init__()
19
- self.image_encoder = image_encoder
20
- self.mask_decoder = mask_decoder
21
- self.prompt_encoder = prompt_encoder
22
- self.text_encoder = TextEncoder(clip_ckpt)
23
- self.feat_shape = np.array(roi_size)/np.array(patch_size)
24
- self.test_mode = test_mode
25
-
26
- def forward(self, image, text=None, boxes=None, points=None, **kwargs):
27
- bs = image.shape[0]
28
- img_shape = (image.shape[2], image.shape[3], image.shape[4])
29
- image_embedding, _ = self.image_encoder(image)
30
- image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
31
- int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
32
- # test mode
33
- if self.test_mode:
34
- return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
35
- # train mode
36
- # future release
37
-
38
- def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
39
- with torch.no_grad():
40
- if boxes is not None:
41
- if len(boxes.shape) == 2:
42
- boxes = boxes[:, None, :] # (B, 1, 6)
43
- if text is not None:
44
- text_embedding = self.text_encoder(text) # (B, 768)
45
- else:
46
- text_embedding = None
47
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
48
- points=points,
49
- boxes=boxes,
50
- masks=None,
51
- text_embedding=text_embedding,
52
- )
53
-
54
- dense_pe = self.prompt_encoder.get_dense_pe()
55
- low_res_masks, _ = self.mask_decoder(
56
- image_embeddings=image_embedding,
57
- text_embedding = text_embedding,
58
- image_pe=dense_pe,
59
- sparse_prompt_embeddings=sparse_embeddings,
60
- dense_prompt_embeddings=dense_embeddings,
61
- multimask_output=False,
62
- )
63
- logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
64
- return logits
65
-
66
- class TextEncoder(nn.Module):
67
- def __init__(self, clip_ckpt):
68
- super().__init__()
69
- config = CLIPTextConfig()
70
- self.clip_text_model = CLIPTextModel(config)
71
- self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt)
72
- self.dim_align = nn.Linear(512, 768)
73
- # freeze text encoder
74
- for param in self.clip_text_model.parameters():
75
- param.requires_grad = False
76
-
77
- def organ2tokens(self, organ_names):
78
- text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
79
- tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
80
- return tokens
81
-
82
- def forward(self, text):
83
- if text is None:
84
- return None
85
- if type(text) is str:
86
- text = [text]
87
- tokens = self.organ2tokens(text)
88
- clip_outputs = self.clip_text_model(**tokens)
89
- text_embedding = clip_outputs.pooler_output
90
- text_embedding = self.dim_align(text_embedding)
91
- return text_embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/script/inference_demo.sh DELETED
@@ -1,8 +0,0 @@
1
- export segvol_ckpt="path/to/SegVol_v1.pth"
2
- export work_dir="path/to/work_dir"
3
- export demo_config_path="./config/config_demo.json"
4
-
5
- CUDA_VISIBLE_DEVICES=0 python inference_demo.py \
6
- --resume $segvol_ckpt \
7
- -work_dir $work_dir \
8
- --demo_config $demo_config_path
 
 
 
 
 
 
 
 
 
model/segment_anything_volumetric/.ipynb_checkpoints/build_sam-checkpoint.py DELETED
@@ -1,172 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- from functools import partial
7
- from pathlib import Path
8
- import urllib.request
9
- import torch
10
-
11
- from .modeling import (
12
- ImageEncoderViT,
13
- MaskDecoder,
14
- PromptEncoder,
15
- Sam,
16
- TwoWayTransformer,
17
- )
18
-
19
- from .modeling.image_encoder_swin import SwinTransformer
20
-
21
- from monai.utils import ensure_tuple_rep, optional_import
22
-
23
- def build_sam_vit_h(checkpoint=None, image_size=1024):
24
- return _build_sam(
25
- encoder_embed_dim=1280,
26
- encoder_depth=32,
27
- encoder_num_heads=16,
28
- encoder_global_attn_indexes=[7, 15, 23, 31],
29
- checkpoint=checkpoint,
30
- image_size=image_size,
31
- )
32
-
33
-
34
- build_sam = build_sam_vit_h
35
-
36
-
37
- def build_sam_vit_l(checkpoint=None, image_size=1024):
38
- return _build_sam(
39
- encoder_embed_dim=1024,
40
- encoder_depth=24,
41
- encoder_num_heads=16,
42
- encoder_global_attn_indexes=[5, 11, 17, 23],
43
- checkpoint=checkpoint,
44
- image_size=image_size,
45
- )
46
-
47
-
48
- def build_sam_vit_b(checkpoint=None, image_size=1024):
49
- return _build_sam(
50
- encoder_embed_dim=768,
51
- encoder_depth=12,
52
- encoder_num_heads=12,
53
- encoder_global_attn_indexes=[2, 5, 8, 11],
54
- checkpoint=checkpoint,
55
- image_size=image_size,
56
- )
57
- """
58
- Examples::
59
- # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
60
- >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
61
- # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
62
- >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
63
- # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
64
- >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
65
- """
66
-
67
- def build_sam_vit_swin(checkpoint=None, image_size=96):
68
- print('==> build_sam_vit_swin')
69
- return _build_sam(
70
- encoder_embed_dim=48,
71
- encoder_depth=12,
72
- encoder_num_heads=12,
73
- encoder_global_attn_indexes=[2, 5, 8, 11],
74
- checkpoint=checkpoint,
75
- image_size=image_size,
76
- )
77
-
78
- sam_model_registry = {
79
- "default": build_sam_vit_h,
80
- "vit_h": build_sam_vit_h,
81
- "vit_l": build_sam_vit_l,
82
- "vit_b": build_sam_vit_b,
83
- "swin_vit": build_sam_vit_swin,
84
- }
85
-
86
-
87
- def _build_sam(
88
- encoder_embed_dim,
89
- encoder_depth,
90
- encoder_num_heads,
91
- encoder_global_attn_indexes,
92
- checkpoint=None,
93
- image_size=None,
94
- spatial_dims=3,
95
- ):
96
- prompt_embed_dim = 768
97
- patch_size = ensure_tuple_rep(2, spatial_dims)
98
- window_size = ensure_tuple_rep(7, spatial_dims)
99
- image_embedding_size = [size // 32 for size in image_size]
100
- sam = Sam(
101
- image_encoder=SwinTransformer(
102
- in_chans=1,
103
- embed_dim=encoder_embed_dim,
104
- window_size=window_size,
105
- patch_size=patch_size,
106
- depths=(2, 2, 6, 2), #(2, 2, 6, 2),
107
- num_heads=(3, 6, 12, 24),
108
- mlp_ratio=4.0,
109
- qkv_bias=True,
110
- spatial_dims=spatial_dims,
111
- ),
112
- prompt_encoder=PromptEncoder(
113
- embed_dim=prompt_embed_dim,
114
- image_embedding_size=image_embedding_size,
115
- input_image_size=image_size,
116
- mask_in_chans=16,
117
- ),
118
- mask_decoder=MaskDecoder(
119
- num_multimask_outputs=3,
120
- transformer=TwoWayTransformer(
121
- depth=2,
122
- embedding_dim=prompt_embed_dim,
123
- mlp_dim=2048,
124
- num_heads=8,
125
- ),
126
- transformer_dim=prompt_embed_dim,
127
- iou_head_depth=3,
128
- iou_head_hidden_dim=256,
129
- ),
130
- pixel_mean=[123.675, 116.28, 103.53],
131
- pixel_std=[58.395, 57.12, 57.375],
132
- )
133
- sam.eval()
134
- if checkpoint is not None:
135
- checkpoint = Path(checkpoint)
136
- if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists():
137
- cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ")
138
- if len(cmd) == 0 or cmd.lower() == 'y':
139
- checkpoint.parent.mkdir(parents=True, exist_ok=True)
140
- print("Downloading SAM ViT-B checkpoint...")
141
- urllib.request.urlretrieve(
142
- "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
143
- checkpoint,
144
- )
145
- print(checkpoint.name, " is downloaded!")
146
- elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists():
147
- cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ")
148
- if len(cmd) == 0 or cmd.lower() == 'y':
149
- checkpoint.parent.mkdir(parents=True, exist_ok=True)
150
- print("Downloading SAM ViT-H checkpoint...")
151
- urllib.request.urlretrieve(
152
- "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
153
- checkpoint,
154
- )
155
- print(checkpoint.name, " is downloaded!")
156
- elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists():
157
- cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ")
158
- if len(cmd) == 0 or cmd.lower() == 'y':
159
- checkpoint.parent.mkdir(parents=True, exist_ok=True)
160
- print("Downloading SAM ViT-L checkpoint...")
161
- urllib.request.urlretrieve(
162
- "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
163
- checkpoint,
164
- )
165
- print(checkpoint.name, " is downloaded!")
166
-
167
-
168
- if checkpoint is not None:
169
- with open(checkpoint, "rb") as f:
170
- state_dict = torch.load(f)
171
- sam.load_state_dict(state_dict)
172
- return sam
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/segment_anything_volumetric/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from .build_sam import (
8
- build_sam_vit_3d,
9
- sam_model_registry,
10
- )
11
- from .predictor import SamPredictor
12
- from .automatic_mask_generator import SamAutomaticMaskGenerator
 
 
 
 
 
 
 
 
 
 
 
 
 
model/segment_anything_volumetric/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (407 Bytes)
 
model/segment_anything_volumetric/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (386 Bytes)
 
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-310.pyc DELETED
Binary file (11.4 kB)
 
model/segment_anything_volumetric/__pycache__/automatic_mask_generator.cpython-39.pyc DELETED
Binary file (11.4 kB)
 
model/segment_anything_volumetric/__pycache__/build_sam.cpython-310.pyc DELETED
Binary file (3.3 kB)
 
model/segment_anything_volumetric/__pycache__/build_sam.cpython-39.pyc DELETED
Binary file (2.63 kB)
 
model/segment_anything_volumetric/__pycache__/predictor.cpython-310.pyc DELETED
Binary file (9.96 kB)
 
model/segment_anything_volumetric/__pycache__/predictor.cpython-39.pyc DELETED
Binary file (9.99 kB)
 
model/segment_anything_volumetric/automatic_mask_generator.py DELETED
@@ -1,372 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import numpy as np
8
- import torch
9
- from torchvision.ops.boxes import batched_nms, box_area # type: ignore
10
-
11
- from typing import Any, Dict, List, Optional, Tuple
12
-
13
- from .modeling import Sam
14
- from .predictor import SamPredictor
15
- from .utils.amg import (
16
- MaskData,
17
- area_from_rle,
18
- batch_iterator,
19
- batched_mask_to_box,
20
- box_xyxy_to_xywh,
21
- build_all_layer_point_grids,
22
- calculate_stability_score,
23
- coco_encode_rle,
24
- generate_crop_boxes,
25
- is_box_near_crop_edge,
26
- mask_to_rle_pytorch,
27
- remove_small_regions,
28
- rle_to_mask,
29
- uncrop_boxes_xyxy,
30
- uncrop_masks,
31
- uncrop_points,
32
- )
33
-
34
-
35
- class SamAutomaticMaskGenerator:
36
- def __init__(
37
- self,
38
- model: Sam,
39
- points_per_side: Optional[int] = 32,
40
- points_per_batch: int = 64,
41
- pred_iou_thresh: float = 0.88,
42
- stability_score_thresh: float = 0.95,
43
- stability_score_offset: float = 1.0,
44
- box_nms_thresh: float = 0.7,
45
- crop_n_layers: int = 0,
46
- crop_nms_thresh: float = 0.7,
47
- crop_overlap_ratio: float = 512 / 1500,
48
- crop_n_points_downscale_factor: int = 1,
49
- point_grids: Optional[List[np.ndarray]] = None,
50
- min_mask_region_area: int = 0,
51
- output_mode: str = "binary_mask",
52
- ) -> None:
53
- """
54
- Using a SAM model, generates masks for the entire image.
55
- Generates a grid of point prompts over the image, then filters
56
- low quality and duplicate masks. The default settings are chosen
57
- for SAM with a ViT-H backbone.
58
-
59
- Arguments:
60
- model (Sam): The SAM model to use for mask prediction.
61
- points_per_side (int or None): The number of points to be sampled
62
- along one side of the image. The total number of points is
63
- points_per_side**2. If None, 'point_grids' must provide explicit
64
- point sampling.
65
- points_per_batch (int): Sets the number of points run simultaneously
66
- by the model. Higher numbers may be faster but use more GPU memory.
67
- pred_iou_thresh (float): A filtering threshold in [0,1], using the
68
- model's predicted mask quality.
69
- stability_score_thresh (float): A filtering threshold in [0,1], using
70
- the stability of the mask under changes to the cutoff used to binarize
71
- the model's mask predictions.
72
- stability_score_offset (float): The amount to shift the cutoff when
73
- calculated the stability score.
74
- box_nms_thresh (float): The box IoU cutoff used by non-maximal
75
- suppression to filter duplicate masks.
76
- crop_n_layers (int): If >0, mask prediction will be run again on
77
- crops of the image. Sets the number of layers to run, where each
78
- layer has 2**i_layer number of image crops.
79
- crop_nms_thresh (float): The box IoU cutoff used by non-maximal
80
- suppression to filter duplicate masks between different crops.
81
- crop_overlap_ratio (float): Sets the degree to which crops overlap.
82
- In the first crop layer, crops will overlap by this fraction of
83
- the image length. Later layers with more crops scale down this overlap.
84
- crop_n_points_downscale_factor (int): The number of points-per-side
85
- sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86
- point_grids (list(np.ndarray) or None): A list over explicit grids
87
- of points used for sampling, normalized to [0,1]. The nth grid in the
88
- list is used in the nth crop layer. Exclusive with points_per_side.
89
- min_mask_region_area (int): If >0, postprocessing will be applied
90
- to remove disconnected regions and holes in masks with area smaller
91
- than min_mask_region_area. Requires opencv.
92
- output_mode (str): The form masks are returned in. Can be 'binary_mask',
93
- 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94
- For large resolutions, 'binary_mask' may consume large amounts of
95
- memory.
96
- """
97
-
98
- assert (points_per_side is None) != (
99
- point_grids is None
100
- ), "Exactly one of points_per_side or point_grid must be provided."
101
- if points_per_side is not None:
102
- self.point_grids = build_all_layer_point_grids(
103
- points_per_side,
104
- crop_n_layers,
105
- crop_n_points_downscale_factor,
106
- )
107
- elif point_grids is not None:
108
- self.point_grids = point_grids
109
- else:
110
- raise ValueError("Can't have both points_per_side and point_grid be None.")
111
-
112
- assert output_mode in [
113
- "binary_mask",
114
- "uncompressed_rle",
115
- "coco_rle",
116
- ], f"Unknown output_mode {output_mode}."
117
- if output_mode == "coco_rle":
118
- from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119
-
120
- if min_mask_region_area > 0:
121
- import cv2 # type: ignore # noqa: F401
122
-
123
- self.predictor = SamPredictor(model)
124
- self.points_per_batch = points_per_batch
125
- self.pred_iou_thresh = pred_iou_thresh
126
- self.stability_score_thresh = stability_score_thresh
127
- self.stability_score_offset = stability_score_offset
128
- self.box_nms_thresh = box_nms_thresh
129
- self.crop_n_layers = crop_n_layers
130
- self.crop_nms_thresh = crop_nms_thresh
131
- self.crop_overlap_ratio = crop_overlap_ratio
132
- self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133
- self.min_mask_region_area = min_mask_region_area
134
- self.output_mode = output_mode
135
-
136
- @torch.no_grad()
137
- def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138
- """
139
- Generates masks for the given image.
140
-
141
- Arguments:
142
- image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143
-
144
- Returns:
145
- list(dict(str, any)): A list over records for masks. Each record is
146
- a dict containing the following keys:
147
- segmentation (dict(str, any) or np.ndarray): The mask. If
148
- output_mode='binary_mask', is an array of shape HW. Otherwise,
149
- is a dictionary containing the RLE.
150
- bbox (list(float)): The box around the mask, in XYWH format.
151
- area (int): The area in pixels of the mask.
152
- predicted_iou (float): The model's own prediction of the mask's
153
- quality. This is filtered by the pred_iou_thresh parameter.
154
- point_coords (list(list(float))): The point coordinates input
155
- to the model to generate this mask.
156
- stability_score (float): A measure of the mask's quality. This
157
- is filtered on using the stability_score_thresh parameter.
158
- crop_box (list(float)): The crop of the image used to generate
159
- the mask, given in XYWH format.
160
- """
161
-
162
- # Generate masks
163
- mask_data = self._generate_masks(image)
164
-
165
- # Filter small disconnected regions and holes in masks
166
- if self.min_mask_region_area > 0:
167
- mask_data = self.postprocess_small_regions(
168
- mask_data,
169
- self.min_mask_region_area,
170
- max(self.box_nms_thresh, self.crop_nms_thresh),
171
- )
172
-
173
- # Encode masks
174
- if self.output_mode == "coco_rle":
175
- mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176
- elif self.output_mode == "binary_mask":
177
- mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178
- else:
179
- mask_data["segmentations"] = mask_data["rles"]
180
-
181
- # Write mask records
182
- curr_anns = []
183
- for idx in range(len(mask_data["segmentations"])):
184
- ann = {
185
- "segmentation": mask_data["segmentations"][idx],
186
- "area": area_from_rle(mask_data["rles"][idx]),
187
- "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188
- "predicted_iou": mask_data["iou_preds"][idx].item(),
189
- "point_coords": [mask_data["points"][idx].tolist()],
190
- "stability_score": mask_data["stability_score"][idx].item(),
191
- "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192
- }
193
- curr_anns.append(ann)
194
-
195
- return curr_anns
196
-
197
- def _generate_masks(self, image: np.ndarray) -> MaskData:
198
- orig_size = image.shape[:2]
199
- crop_boxes, layer_idxs = generate_crop_boxes(
200
- orig_size, self.crop_n_layers, self.crop_overlap_ratio
201
- )
202
-
203
- # Iterate over image crops
204
- data = MaskData()
205
- for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
206
- crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
207
- data.cat(crop_data)
208
-
209
- # Remove duplicate masks between crops
210
- if len(crop_boxes) > 1:
211
- # Prefer masks from smaller crops
212
- scores = 1 / box_area(data["crop_boxes"])
213
- scores = scores.to(data["boxes"].device)
214
- keep_by_nms = batched_nms(
215
- data["boxes"].float(),
216
- scores,
217
- torch.zeros_like(data["boxes"][:, 0]), # categories
218
- iou_threshold=self.crop_nms_thresh,
219
- )
220
- data.filter(keep_by_nms)
221
-
222
- data.to_numpy()
223
- return data
224
-
225
- def _process_crop(
226
- self,
227
- image: np.ndarray,
228
- crop_box: List[int],
229
- crop_layer_idx: int,
230
- orig_size: Tuple[int, ...],
231
- ) -> MaskData:
232
- # Crop the image and calculate embeddings
233
- x0, y0, x1, y1 = crop_box
234
- cropped_im = image[y0:y1, x0:x1, :]
235
- cropped_im_size = cropped_im.shape[:2]
236
- self.predictor.set_image(cropped_im)
237
-
238
- # Get points for this crop
239
- points_scale = np.array(cropped_im_size)[None, ::-1]
240
- points_for_image = self.point_grids[crop_layer_idx] * points_scale
241
-
242
- # Generate masks for this crop in batches
243
- data = MaskData()
244
- for (points,) in batch_iterator(self.points_per_batch, points_for_image):
245
- batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
246
- data.cat(batch_data)
247
- del batch_data
248
- self.predictor.reset_image()
249
-
250
- # Remove duplicates within this crop.
251
- keep_by_nms = batched_nms(
252
- data["boxes"].float(),
253
- data["iou_preds"],
254
- torch.zeros_like(data["boxes"][:, 0]), # categories
255
- iou_threshold=self.box_nms_thresh,
256
- )
257
- data.filter(keep_by_nms)
258
-
259
- # Return to the original image frame
260
- data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
261
- data["points"] = uncrop_points(data["points"], crop_box)
262
- data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
263
-
264
- return data
265
-
266
- def _process_batch(
267
- self,
268
- points: np.ndarray,
269
- im_size: Tuple[int, ...],
270
- crop_box: List[int],
271
- orig_size: Tuple[int, ...],
272
- ) -> MaskData:
273
- orig_h, orig_w = orig_size
274
-
275
- # Run model on this batch
276
- transformed_points = self.predictor.transform.apply_coords(points, im_size)
277
- in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
278
- in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
279
- masks, iou_preds, _ = self.predictor.predict_torch(
280
- in_points[:, None, :],
281
- in_labels[:, None],
282
- multimask_output=True,
283
- return_logits=True,
284
- )
285
-
286
- # Serialize predictions and store in MaskData
287
- data = MaskData(
288
- masks=masks.flatten(0, 1),
289
- iou_preds=iou_preds.flatten(0, 1),
290
- points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
291
- )
292
- del masks
293
-
294
- # Filter by predicted IoU
295
- if self.pred_iou_thresh > 0.0:
296
- keep_mask = data["iou_preds"] > self.pred_iou_thresh
297
- data.filter(keep_mask)
298
-
299
- # Calculate stability score
300
- data["stability_score"] = calculate_stability_score(
301
- data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
302
- )
303
- if self.stability_score_thresh > 0.0:
304
- keep_mask = data["stability_score"] >= self.stability_score_thresh
305
- data.filter(keep_mask)
306
-
307
- # Threshold masks and calculate boxes
308
- data["masks"] = data["masks"] > self.predictor.model.mask_threshold
309
- data["boxes"] = batched_mask_to_box(data["masks"])
310
-
311
- # Filter boxes that touch crop boundaries
312
- keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
313
- if not torch.all(keep_mask):
314
- data.filter(keep_mask)
315
-
316
- # Compress to RLE
317
- data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318
- data["rles"] = mask_to_rle_pytorch(data["masks"])
319
- del data["masks"]
320
-
321
- return data
322
-
323
- @staticmethod
324
- def postprocess_small_regions(
325
- mask_data: MaskData, min_area: int, nms_thresh: float
326
- ) -> MaskData:
327
- """
328
- Removes small disconnected regions and holes in masks, then reruns
329
- box NMS to remove any new duplicates.
330
-
331
- Edits mask_data in place.
332
-
333
- Requires open-cv as a dependency.
334
- """
335
- if len(mask_data["rles"]) == 0:
336
- return mask_data
337
-
338
- # Filter small disconnected regions and holes
339
- new_masks = []
340
- scores = []
341
- for rle in mask_data["rles"]:
342
- mask = rle_to_mask(rle)
343
-
344
- mask, changed = remove_small_regions(mask, min_area, mode="holes")
345
- unchanged = not changed
346
- mask, changed = remove_small_regions(mask, min_area, mode="islands")
347
- unchanged = unchanged and not changed
348
-
349
- new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350
- # Give score=0 to changed masks and score=1 to unchanged masks
351
- # so NMS will prefer ones that didn't need postprocessing
352
- scores.append(float(unchanged))
353
-
354
- # Recalculate boxes and remove any new duplicates
355
- masks = torch.cat(new_masks, dim=0)
356
- boxes = batched_mask_to_box(masks)
357
- keep_by_nms = batched_nms(
358
- boxes.float(),
359
- torch.as_tensor(scores),
360
- torch.zeros_like(boxes[:, 0]), # categories
361
- iou_threshold=nms_thresh,
362
- )
363
-
364
- # Only recalculate RLEs for masks that have changed
365
- for i_mask in keep_by_nms:
366
- if scores[i_mask] == 0.0:
367
- mask_torch = masks[i_mask].unsqueeze(0)
368
- mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369
- mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370
- mask_data.filter(keep_by_nms)
371
-
372
- return mask_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/segment_anything_volumetric/build_sam.py DELETED
@@ -1,111 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- from functools import partial
7
- from pathlib import Path
8
- import urllib.request
9
- import torch
10
-
11
- from .modeling import (
12
- ImageEncoderViT,
13
- MaskDecoder,
14
- PromptEncoder,
15
- Sam,
16
- TwoWayTransformer,
17
- )
18
- import numpy as np
19
- from .modeling.image_encoder_swin import SwinTransformer
20
- from monai.networks.nets import ViT
21
- from monai.networks.nets.swin_unetr import SwinTransformer as SwinViT
22
-
23
- from monai.utils import ensure_tuple_rep, optional_import
24
-
25
-
26
- """
27
- Examples::
28
- # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
29
- >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
30
- # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
31
- >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
32
- # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
33
- >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
34
- """
35
-
36
- def build_sam_vit_3d(checkpoint=None):
37
- print('build_sam_vit_3d...')
38
- return _build_sam(
39
- image_encoder_type='vit',
40
- embed_dim = 768,
41
- patch_size=[4,16,16],
42
- checkpoint=checkpoint,
43
- image_size=[32,256,256],
44
- )
45
-
46
- sam_model_registry = {
47
- "vit": build_sam_vit_3d,
48
- }
49
-
50
-
51
- def _build_sam(
52
- image_encoder_type,
53
- embed_dim,
54
- patch_size,
55
- checkpoint,
56
- image_size,
57
- ):
58
- mlp_dim = 3072
59
- num_layers = 12
60
- num_heads = 12
61
- pos_embed = 'perceptron'
62
- dropout_rate = 0.0
63
-
64
- image_encoder=ViT(
65
- in_channels=1,
66
- img_size=image_size,
67
- patch_size=patch_size,
68
- hidden_size=embed_dim,
69
- mlp_dim=mlp_dim,
70
- num_layers=num_layers,
71
- num_heads=num_heads,
72
- pos_embed=pos_embed,
73
- classification=False,
74
- dropout_rate=dropout_rate,
75
- )
76
- image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
77
-
78
- if checkpoint is not None:
79
- with open(checkpoint, "rb") as f:
80
- state_dict = torch.load(f, map_location='cpu')['state_dict']
81
- encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
82
- image_encoder.load_state_dict(encoder_dict)
83
- print(f'===> image_encoder.load_param: {checkpoint}')
84
- sam = Sam(
85
- image_encoder=image_encoder,
86
- prompt_encoder=PromptEncoder(
87
- embed_dim=embed_dim,
88
- image_embedding_size=image_embedding_size,
89
- input_image_size=image_size,
90
- mask_in_chans=16,
91
- ),
92
- mask_decoder=MaskDecoder(
93
- image_encoder_type=image_encoder_type,
94
- num_multimask_outputs=3,
95
- transformer=TwoWayTransformer(
96
- depth=2,
97
- embedding_dim=embed_dim,
98
- mlp_dim=2048,
99
- num_heads=8,
100
- ),
101
- transformer_dim=embed_dim,
102
- iou_head_depth=3,
103
- iou_head_hidden_dim=256,
104
- image_size=np.array(image_size),
105
- patch_size=np.array(patch_size),
106
- ),
107
- pixel_mean=[123.675, 116.28, 103.53],
108
- pixel_std=[58.395, 57.12, 57.375],
109
- )
110
- sam.eval()
111
- return sam
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/image_encoder_swin-checkpoint.py DELETED
@@ -1,709 +0,0 @@
1
- from typing import Sequence, Tuple, Type, Union
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import torch.utils.checkpoint as checkpoint
8
- from torch.nn import LayerNorm
9
-
10
- from monai.networks.blocks import MLPBlock as Mlp
11
- from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
12
- from monai.networks.layers import DropPath, trunc_normal_
13
- from monai.utils import ensure_tuple_rep, optional_import
14
-
15
- rearrange, _ = optional_import("einops", name="rearrange")
16
-
17
- def window_partition(x, window_size):
18
- """window partition operation based on: "Liu et al.,
19
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
20
- <https://arxiv.org/abs/2103.14030>"
21
- https://github.com/microsoft/Swin-Transformer
22
- Args:
23
- x: input tensor.
24
- window_size: local window size.
25
- """
26
- x_shape = x.size()
27
- if len(x_shape) == 5:
28
- b, d, h, w, c = x_shape
29
- x = x.view(
30
- b,
31
- d // window_size[0],
32
- window_size[0],
33
- h // window_size[1],
34
- window_size[1],
35
- w // window_size[2],
36
- window_size[2],
37
- c,
38
- )
39
- windows = (
40
- x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
41
- )
42
- elif len(x_shape) == 4:
43
- b, h, w, c = x.shape
44
- x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
45
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
46
- return windows
47
-
48
-
49
- def window_reverse(windows, window_size, dims):
50
- """window reverse operation based on: "Liu et al.,
51
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
52
- <https://arxiv.org/abs/2103.14030>"
53
- https://github.com/microsoft/Swin-Transformer
54
- Args:
55
- windows: windows tensor.
56
- window_size: local window size.
57
- dims: dimension values.
58
- """
59
- if len(dims) == 4:
60
- b, d, h, w = dims
61
- x = windows.view(
62
- b,
63
- d // window_size[0],
64
- h // window_size[1],
65
- w // window_size[2],
66
- window_size[0],
67
- window_size[1],
68
- window_size[2],
69
- -1,
70
- )
71
- x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
72
-
73
- elif len(dims) == 3:
74
- b, h, w = dims
75
- x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1)
76
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
77
- return x
78
-
79
-
80
- def get_window_size(x_size, window_size, shift_size=None):
81
- """Computing window size based on: "Liu et al.,
82
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
83
- <https://arxiv.org/abs/2103.14030>"
84
- https://github.com/microsoft/Swin-Transformer
85
- Args:
86
- x_size: input size.
87
- window_size: local window size.
88
- shift_size: window shifting size.
89
- """
90
-
91
- use_window_size = list(window_size)
92
- if shift_size is not None:
93
- use_shift_size = list(shift_size)
94
- for i in range(len(x_size)):
95
- if x_size[i] <= window_size[i]:
96
- use_window_size[i] = x_size[i]
97
- if shift_size is not None:
98
- use_shift_size[i] = 0
99
-
100
- if shift_size is None:
101
- return tuple(use_window_size)
102
- else:
103
- return tuple(use_window_size), tuple(use_shift_size)
104
-
105
-
106
- class WindowAttention(nn.Module):
107
- """
108
- Window based multi-head self attention module with relative position bias based on: "Liu et al.,
109
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
110
- <https://arxiv.org/abs/2103.14030>"
111
- https://github.com/microsoft/Swin-Transformer
112
- """
113
-
114
- def __init__(
115
- self,
116
- dim: int,
117
- num_heads: int,
118
- window_size: Sequence[int],
119
- qkv_bias: bool = False,
120
- attn_drop: float = 0.0,
121
- proj_drop: float = 0.0,
122
- ) -> None:
123
- """
124
- Args:
125
- dim: number of feature channels.
126
- num_heads: number of attention heads.
127
- window_size: local window size.
128
- qkv_bias: add a learnable bias to query, key, value.
129
- attn_drop: attention dropout rate.
130
- proj_drop: dropout rate of output.
131
- """
132
-
133
- super().__init__()
134
- self.dim = dim
135
- self.window_size = window_size
136
- self.num_heads = num_heads
137
- head_dim = dim // num_heads
138
- self.scale = head_dim**-0.5
139
- mesh_args = torch.meshgrid.__kwdefaults__
140
-
141
- if len(self.window_size) == 3:
142
- self.relative_position_bias_table = nn.Parameter(
143
- torch.zeros(
144
- (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
145
- num_heads,
146
- )
147
- )
148
- coords_d = torch.arange(self.window_size[0])
149
- coords_h = torch.arange(self.window_size[1])
150
- coords_w = torch.arange(self.window_size[2])
151
- if mesh_args is not None:
152
- coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
153
- else:
154
- coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
155
- coords_flatten = torch.flatten(coords, 1)
156
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
157
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
158
- relative_coords[:, :, 0] += self.window_size[0] - 1
159
- relative_coords[:, :, 1] += self.window_size[1] - 1
160
- relative_coords[:, :, 2] += self.window_size[2] - 1
161
- relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
162
- relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
163
- elif len(self.window_size) == 2:
164
- self.relative_position_bias_table = nn.Parameter(
165
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
166
- )
167
- coords_h = torch.arange(self.window_size[0])
168
- coords_w = torch.arange(self.window_size[1])
169
- if mesh_args is not None:
170
- coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
171
- else:
172
- coords = torch.stack(torch.meshgrid(coords_h, coords_w))
173
- coords_flatten = torch.flatten(coords, 1)
174
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
175
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
176
- relative_coords[:, :, 0] += self.window_size[0] - 1
177
- relative_coords[:, :, 1] += self.window_size[1] - 1
178
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
179
-
180
- relative_position_index = relative_coords.sum(-1)
181
- self.register_buffer("relative_position_index", relative_position_index)
182
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
183
- self.attn_drop = nn.Dropout(attn_drop)
184
- self.proj = nn.Linear(dim, dim)
185
- self.proj_drop = nn.Dropout(proj_drop)
186
- trunc_normal_(self.relative_position_bias_table, std=0.02)
187
- self.softmax = nn.Softmax(dim=-1)
188
-
189
- def forward(self, x, mask):
190
- b, n, c = x.shape
191
- qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
192
- q, k, v = qkv[0], qkv[1], qkv[2]
193
- q = q * self.scale
194
- attn = q @ k.transpose(-2, -1)
195
- relative_position_bias = self.relative_position_bias_table[
196
- self.relative_position_index.clone()[:n, :n].reshape(-1)
197
- ].reshape(n, n, -1)
198
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
199
- attn = attn + relative_position_bias.unsqueeze(0)
200
- if mask is not None:
201
- nw = mask.shape[0]
202
- attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
203
- attn = attn.view(-1, self.num_heads, n, n)
204
- attn = self.softmax(attn)
205
- else:
206
- attn = self.softmax(attn)
207
-
208
- attn = self.attn_drop(attn)
209
- x = (attn @ v).transpose(1, 2).reshape(b, n, c)
210
- x = self.proj(x)
211
- x = self.proj_drop(x)
212
- return x
213
-
214
-
215
- class SwinTransformerBlock(nn.Module):
216
- """
217
- Swin Transformer block based on: "Liu et al.,
218
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
219
- <https://arxiv.org/abs/2103.14030>"
220
- https://github.com/microsoft/Swin-Transformer
221
- """
222
-
223
- def __init__(
224
- self,
225
- dim: int,
226
- num_heads: int,
227
- window_size: Sequence[int],
228
- shift_size: Sequence[int],
229
- mlp_ratio: float = 4.0,
230
- qkv_bias: bool = True,
231
- drop: float = 0.0,
232
- attn_drop: float = 0.0,
233
- drop_path: float = 0.0,
234
- act_layer: str = "GELU",
235
- norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
236
- use_checkpoint: bool = False,
237
- ) -> None:
238
- """
239
- Args:
240
- dim: number of feature channels.
241
- num_heads: number of attention heads.
242
- window_size: local window size.
243
- shift_size: window shift size.
244
- mlp_ratio: ratio of mlp hidden dim to embedding dim.
245
- qkv_bias: add a learnable bias to query, key, value.
246
- drop: dropout rate.
247
- attn_drop: attention dropout rate.
248
- drop_path: stochastic depth rate.
249
- act_layer: activation layer.
250
- norm_layer: normalization layer.
251
- use_checkpoint: use gradient checkpointing for reduced memory usage.
252
- """
253
-
254
- super().__init__()
255
- self.dim = dim
256
- self.num_heads = num_heads
257
- self.window_size = window_size
258
- self.shift_size = shift_size
259
- self.mlp_ratio = mlp_ratio
260
- self.use_checkpoint = use_checkpoint
261
- self.norm1 = norm_layer(dim)
262
- self.attn = WindowAttention(
263
- dim,
264
- window_size=self.window_size,
265
- num_heads=num_heads,
266
- qkv_bias=qkv_bias,
267
- attn_drop=attn_drop,
268
- proj_drop=drop,
269
- )
270
-
271
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
272
- self.norm2 = norm_layer(dim)
273
- mlp_hidden_dim = int(dim * mlp_ratio)
274
- self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
275
-
276
- def forward_part1(self, x, mask_matrix):
277
- x_shape = x.size()
278
- x = self.norm1(x)
279
- if len(x_shape) == 5:
280
- b, d, h, w, c = x.shape
281
- window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
282
- pad_l = pad_t = pad_d0 = 0
283
- pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
284
- pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
285
- pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
286
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
287
- _, dp, hp, wp, _ = x.shape
288
- dims = [b, dp, hp, wp]
289
-
290
- elif len(x_shape) == 4:
291
- b, h, w, c = x.shape
292
- window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
293
- pad_l = pad_t = 0
294
- pad_r = (window_size[0] - h % window_size[0]) % window_size[0]
295
- pad_b = (window_size[1] - w % window_size[1]) % window_size[1]
296
- x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
297
- _, hp, wp, _ = x.shape
298
- dims = [b, hp, wp]
299
-
300
- if any(i > 0 for i in shift_size):
301
- if len(x_shape) == 5:
302
- shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
303
- elif len(x_shape) == 4:
304
- shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
305
- attn_mask = mask_matrix
306
- else:
307
- shifted_x = x
308
- attn_mask = None
309
- x_windows = window_partition(shifted_x, window_size)
310
- attn_windows = self.attn(x_windows, mask=attn_mask)
311
- attn_windows = attn_windows.view(-1, *(window_size + (c,)))
312
- shifted_x = window_reverse(attn_windows, window_size, dims)
313
- if any(i > 0 for i in shift_size):
314
- if len(x_shape) == 5:
315
- x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
316
- elif len(x_shape) == 4:
317
- x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
318
- else:
319
- x = shifted_x
320
-
321
- if len(x_shape) == 5:
322
- if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
323
- x = x[:, :d, :h, :w, :].contiguous()
324
- elif len(x_shape) == 4:
325
- if pad_r > 0 or pad_b > 0:
326
- x = x[:, :h, :w, :].contiguous()
327
-
328
- return x
329
-
330
- def forward_part2(self, x):
331
- return self.drop_path(self.mlp(self.norm2(x)))
332
-
333
- def load_from(self, weights, n_block, layer):
334
- root = f"module.{layer}.0.blocks.{n_block}."
335
- block_names = [
336
- "norm1.weight",
337
- "norm1.bias",
338
- "attn.relative_position_bias_table",
339
- "attn.relative_position_index",
340
- "attn.qkv.weight",
341
- "attn.qkv.bias",
342
- "attn.proj.weight",
343
- "attn.proj.bias",
344
- "norm2.weight",
345
- "norm2.bias",
346
- "mlp.fc1.weight",
347
- "mlp.fc1.bias",
348
- "mlp.fc2.weight",
349
- "mlp.fc2.bias",
350
- ]
351
- with torch.no_grad():
352
- self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
353
- self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
354
- self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
355
- self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]])
356
- self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
357
- self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
358
- self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
359
- self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
360
- self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
361
- self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
362
- self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
363
- self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
364
- self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
365
- self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
366
-
367
- def forward(self, x, mask_matrix):
368
- shortcut = x
369
- if self.use_checkpoint:
370
- x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
371
- else:
372
- x = self.forward_part1(x, mask_matrix)
373
- x = shortcut + self.drop_path(x)
374
- if self.use_checkpoint:
375
- x = x + checkpoint.checkpoint(self.forward_part2, x)
376
- else:
377
- x = x + self.forward_part2(x)
378
- return x
379
-
380
-
381
- class PatchMerging(nn.Module):
382
- """
383
- Patch merging layer based on: "Liu et al.,
384
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
385
- <https://arxiv.org/abs/2103.14030>"
386
- https://github.com/microsoft/Swin-Transformer
387
- """
388
-
389
- def __init__(
390
- self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3
391
- ) -> None: # type: ignore
392
- """
393
- Args:
394
- dim: number of feature channels.
395
- norm_layer: normalization layer.
396
- spatial_dims: number of spatial dims.
397
- """
398
-
399
- super().__init__()
400
- self.dim = dim
401
- if spatial_dims == 3:
402
- self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
403
- self.norm = norm_layer(8 * dim)
404
- elif spatial_dims == 2:
405
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
406
- self.norm = norm_layer(4 * dim)
407
-
408
- def forward(self, x):
409
-
410
- x_shape = x.size()
411
- if len(x_shape) == 5:
412
- b, d, h, w, c = x_shape
413
- pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
414
- if pad_input:
415
- x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2))
416
- x0 = x[:, 0::2, 0::2, 0::2, :]
417
- x1 = x[:, 1::2, 0::2, 0::2, :]
418
- x2 = x[:, 0::2, 1::2, 0::2, :]
419
- x3 = x[:, 0::2, 0::2, 1::2, :]
420
- x4 = x[:, 1::2, 0::2, 1::2, :]
421
- x5 = x[:, 0::2, 1::2, 0::2, :]
422
- x6 = x[:, 0::2, 0::2, 1::2, :]
423
- x7 = x[:, 1::2, 1::2, 1::2, :]
424
- x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
425
-
426
- elif len(x_shape) == 4:
427
- b, h, w, c = x_shape
428
- pad_input = (h % 2 == 1) or (w % 2 == 1)
429
- if pad_input:
430
- x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
431
- x0 = x[:, 0::2, 0::2, :]
432
- x1 = x[:, 1::2, 0::2, :]
433
- x2 = x[:, 0::2, 1::2, :]
434
- x3 = x[:, 1::2, 1::2, :]
435
- x = torch.cat([x0, x1, x2, x3], -1)
436
-
437
- x = self.norm(x)
438
- x = self.reduction(x)
439
- return x
440
-
441
-
442
- def compute_mask(dims, window_size, shift_size, device):
443
- """Computing region masks based on: "Liu et al.,
444
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
445
- <https://arxiv.org/abs/2103.14030>"
446
- https://github.com/microsoft/Swin-Transformer
447
- Args:
448
- dims: dimension values.
449
- window_size: local window size.
450
- shift_size: shift size.
451
- device: device.
452
- """
453
-
454
- cnt = 0
455
-
456
- if len(dims) == 3:
457
- d, h, w = dims
458
- img_mask = torch.zeros((1, d, h, w, 1), device=device)
459
- for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
460
- for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
461
- for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
462
- img_mask[:, d, h, w, :] = cnt
463
- cnt += 1
464
-
465
- elif len(dims) == 2:
466
- h, w = dims
467
- img_mask = torch.zeros((1, h, w, 1), device=device)
468
- for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
469
- for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
470
- img_mask[:, h, w, :] = cnt
471
- cnt += 1
472
-
473
- mask_windows = window_partition(img_mask, window_size)
474
- mask_windows = mask_windows.squeeze(-1)
475
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
476
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
477
-
478
- return attn_mask
479
-
480
-
481
- class BasicLayer(nn.Module):
482
- """
483
- Basic Swin Transformer layer in one stage based on: "Liu et al.,
484
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
485
- <https://arxiv.org/abs/2103.14030>"
486
- https://github.com/microsoft/Swin-Transformer
487
- """
488
-
489
- def __init__(
490
- self,
491
- dim: int,
492
- depth: int,
493
- num_heads: int,
494
- window_size: Sequence[int],
495
- drop_path: list,
496
- mlp_ratio: float = 4.0,
497
- qkv_bias: bool = False,
498
- drop: float = 0.0,
499
- attn_drop: float = 0.0,
500
- norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
501
- downsample: isinstance = None, # type: ignore
502
- use_checkpoint: bool = False,
503
- ) -> None:
504
- """
505
- Args:
506
- dim: number of feature channels.
507
- depths: number of layers in each stage.
508
- num_heads: number of attention heads.
509
- window_size: local window size.
510
- drop_path: stochastic depth rate.
511
- mlp_ratio: ratio of mlp hidden dim to embedding dim.
512
- qkv_bias: add a learnable bias to query, key, value.
513
- drop: dropout rate.
514
- attn_drop: attention dropout rate.
515
- norm_layer: normalization layer.
516
- downsample: downsample layer at the end of the layer.
517
- use_checkpoint: use gradient checkpointing for reduced memory usage.
518
- """
519
-
520
- super().__init__()
521
- self.window_size = window_size
522
- self.shift_size = tuple(i // 2 for i in window_size)
523
- self.no_shift = tuple(0 for i in window_size)
524
- self.depth = depth
525
- self.use_checkpoint = use_checkpoint
526
- self.blocks = nn.ModuleList(
527
- [
528
- SwinTransformerBlock(
529
- dim=dim,
530
- num_heads=num_heads,
531
- window_size=self.window_size,
532
- shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
533
- mlp_ratio=mlp_ratio,
534
- qkv_bias=qkv_bias,
535
- drop=drop,
536
- attn_drop=attn_drop,
537
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
538
- norm_layer=norm_layer,
539
- use_checkpoint=use_checkpoint,
540
- )
541
- for i in range(depth)
542
- ]
543
- )
544
- self.downsample = downsample
545
- if self.downsample is not None:
546
- self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
547
-
548
- def forward(self, x):
549
- x_shape = x.size()
550
- if len(x_shape) == 5:
551
- b, c, d, h, w = x_shape
552
- window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
553
- x = rearrange(x, "b c d h w -> b d h w c")
554
- dp = int(np.ceil(d / window_size[0])) * window_size[0]
555
- hp = int(np.ceil(h / window_size[1])) * window_size[1]
556
- wp = int(np.ceil(w / window_size[2])) * window_size[2]
557
- attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
558
- for blk in self.blocks:
559
- x = blk(x, attn_mask)
560
- x = x.view(b, d, h, w, -1)
561
- if self.downsample is not None:
562
- x = self.downsample(x)
563
- x = rearrange(x, "b d h w c -> b c d h w")
564
-
565
- elif len(x_shape) == 4:
566
- b, c, h, w = x_shape
567
- window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
568
- x = rearrange(x, "b c h w -> b h w c")
569
- hp = int(np.ceil(h / window_size[0])) * window_size[0]
570
- wp = int(np.ceil(w / window_size[1])) * window_size[1]
571
- attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
572
- for blk in self.blocks:
573
- x = blk(x, attn_mask)
574
- x = x.view(b, h, w, -1)
575
- if self.downsample is not None:
576
- x = self.downsample(x)
577
- x = rearrange(x, "b h w c -> b c h w")
578
- return x
579
-
580
-
581
- class SwinTransformer(nn.Module):
582
- """
583
- Swin Transformer based on: "Liu et al.,
584
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
585
- <https://arxiv.org/abs/2103.14030>"
586
- https://github.com/microsoft/Swin-Transformer
587
- """
588
-
589
- def __init__(
590
- self,
591
- in_chans: int,
592
- embed_dim: int,
593
- window_size: Sequence[int],
594
- patch_size: Sequence[int],
595
- depths: Sequence[int],
596
- num_heads: Sequence[int],
597
- mlp_ratio: float = 4.0,
598
- qkv_bias: bool = True,
599
- drop_rate: float = 0.0,
600
- attn_drop_rate: float = 0.0,
601
- drop_path_rate: float = 0.0,
602
- norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore
603
- patch_norm: bool = False,
604
- use_checkpoint: bool = False,
605
- spatial_dims: int = 3,
606
- ) -> None:
607
- """
608
- Args:
609
- in_chans: dimension of input channels.
610
- embed_dim: number of linear projection output channels.
611
- window_size: local window size.
612
- patch_size: patch size.
613
- depths: number of layers in each stage.
614
- num_heads: number of attention heads.
615
- mlp_ratio: ratio of mlp hidden dim to embedding dim.
616
- qkv_bias: add a learnable bias to query, key, value.
617
- drop_rate: dropout rate.
618
- attn_drop_rate: attention dropout rate.
619
- drop_path_rate: stochastic depth rate.
620
- norm_layer: normalization layer.
621
- patch_norm: add normalization after patch embedding.
622
- use_checkpoint: use gradient checkpointing for reduced memory usage.
623
- spatial_dims: spatial dimension.
624
- """
625
-
626
- super().__init__()
627
- self.num_layers = len(depths)
628
- self.embed_dim = embed_dim
629
- self.patch_norm = patch_norm
630
- self.window_size = window_size
631
- self.patch_size = patch_size
632
- self.patch_embed = PatchEmbed(
633
- patch_size=self.patch_size,
634
- in_chans=in_chans,
635
- embed_dim=embed_dim,
636
- norm_layer=norm_layer if self.patch_norm else None, # type: ignore
637
- spatial_dims=spatial_dims,
638
- )
639
- self.pos_drop = nn.Dropout(p=drop_rate)
640
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
641
- # self.layers1 = nn.ModuleList()
642
- # self.layers2 = nn.ModuleList()
643
- # self.layers3 = nn.ModuleList()
644
- # self.layers4 = nn.ModuleList()
645
- self.layers = nn.ModuleList()
646
- for i_layer in range(self.num_layers):
647
- layer = BasicLayer(
648
- dim=int(embed_dim * 2**i_layer),
649
- depth=depths[i_layer],
650
- num_heads=num_heads[i_layer],
651
- window_size=self.window_size,
652
- drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
653
- mlp_ratio=mlp_ratio,
654
- qkv_bias=qkv_bias,
655
- drop=drop_rate,
656
- attn_drop=attn_drop_rate,
657
- norm_layer=norm_layer,
658
- downsample=PatchMerging,
659
- use_checkpoint=use_checkpoint,
660
- )
661
- self.layers.append(layer)
662
- # if i_layer == 0:
663
- # self.layers1.append(layer)
664
- # elif i_layer == 1:
665
- # self.layers2.append(layer)
666
- # elif i_layer == 2:
667
- # self.layers3.append(layer)
668
- # elif i_layer == 3:
669
- # self.layers4.append(layer)
670
- self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
671
-
672
- def proj_out(self, x, normalize=False):
673
- if normalize:
674
- x_shape = x.size()
675
- if len(x_shape) == 5:
676
- n, ch, d, h, w = x_shape
677
- x = rearrange(x, "n c d h w -> n d h w c")
678
- x = F.layer_norm(x, [ch])
679
- x = rearrange(x, "n d h w c -> n c d h w")
680
- elif len(x_shape) == 4:
681
- n, ch, h, w = x_shape
682
- x = rearrange(x, "n c h w -> n h w c")
683
- x = F.layer_norm(x, [ch])
684
- x = rearrange(x, "n h w c -> n c h w")
685
- return x
686
-
687
- def forward(self, x, normalize=True):
688
- # x input: [B*sample, C(1), H, W, D]
689
- # x = rearrange(x, "b c h w d -> b c d h w")
690
- # print('>> input: ', x.shape)
691
- x = self.patch_embed(x)
692
- # print('>> patch_embed: ', x.shape)
693
- x = self.pos_drop(x)
694
- for layer in self.layers:
695
- x = layer(x.contiguous())
696
- # print('>> layer: ', x.shape)
697
- return x
698
- # # x0_out = self.proj_out(x0, normalize)
699
- # x1 = self.layers1[0](x0.contiguous())
700
- # # x1_out = self.proj_out(x1, normalize)
701
- # x2 = self.layers2[0](x1.contiguous())
702
- # # x2_out = self.proj_out(x2, normalize)
703
- # x3 = self.layers3[0](x2.contiguous())
704
- # # x3_out = self.proj_out(x3, normalize)
705
- # x4 = self.layers4[0](x3.contiguous())
706
- # # x4_out = self.proj_out(x4, normalize)
707
- # # return [x0_out, x1_out, x2_out, x3_out, x4_out]
708
-
709
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/segment_anything_volumetric/modeling/.ipynb_checkpoints/prompt_encoder-checkpoint.py DELETED
@@ -1,232 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import numpy as np
8
- import torch
9
- from torch import nn
10
-
11
- from typing import Any, Optional, Tuple, Type
12
-
13
- from .common import LayerNorm2d
14
- import os
15
-
16
- class PromptEncoder(nn.Module):
17
- def __init__(
18
- self,
19
- embed_dim: int,
20
- image_embedding_size: Tuple[int, int, int],
21
- input_image_size: Tuple[int, int, int],
22
- mask_in_chans: int,
23
- activation: Type[nn.Module] = nn.GELU,
24
- ) -> None:
25
- """
26
- Encodes prompts for input to SAM's mask decoder.
27
-
28
- Arguments:
29
- embed_dim (int): The prompts' embedding dimension
30
- image_embedding_size (tuple(int, int)): The spatial size of the
31
- image embedding, as (H, W).
32
- input_image_size (int): The padded size of the image as input
33
- to the image encoder, as (H, W).
34
- mask_in_chans (int): The number of hidden channels used for
35
- encoding input masks.
36
- activation (nn.Module): The activation to use when encoding
37
- input masks.
38
- """
39
- super().__init__()
40
- self.embed_dim = embed_dim
41
- self.input_image_size = input_image_size
42
- self.image_embedding_size = image_embedding_size
43
- self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
-
45
- self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46
- point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47
- self.point_embeddings = nn.ModuleList(point_embeddings)
48
- self.not_a_point_embed = nn.Embedding(1, embed_dim)
49
-
50
- self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
51
- self.mask_downscaling = nn.Sequential(
52
- nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53
- LayerNorm2d(mask_in_chans // 4),
54
- activation(),
55
- nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56
- LayerNorm2d(mask_in_chans),
57
- activation(),
58
- nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59
- )
60
- self.no_mask_embed = nn.Embedding(1, embed_dim)
61
-
62
- def get_dense_pe(self) -> torch.Tensor:
63
- """
64
- Returns the positional encoding used to encode point prompts,
65
- applied to a dense set of points the shape of the image encoding.
66
-
67
- Returns:
68
- torch.Tensor: Positional encoding with shape
69
- 1x(embed_dim)x(embedding_h)x(embedding_w)
70
- """
71
- return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72
-
73
- def _embed_points(
74
- self,
75
- points: torch.Tensor,
76
- labels: torch.Tensor,
77
- pad: bool,
78
- ) -> torch.Tensor:
79
- """Embeds point prompts."""
80
- points = points + 0.5 # Shift to center of pixel
81
- if pad:
82
- padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
83
- padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84
- points = torch.cat([points, padding_point], dim=1)
85
- labels = torch.cat([labels, padding_label], dim=1)
86
- point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87
- point_embedding[labels == -1] = 0.0
88
- point_embedding[labels == -1] += self.not_a_point_embed.weight
89
- point_embedding[labels == 0] += self.point_embeddings[0].weight
90
- point_embedding[labels == 1] += self.point_embeddings[1].weight
91
- return point_embedding
92
-
93
- def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94
- """Embeds box prompts."""
95
- boxes = boxes + 0.5 # Shift to center of pixel
96
- coords = boxes.reshape(-1, 2, 3)
97
- corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98
- corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99
- corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100
- return corner_embedding
101
-
102
- def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103
- """Embeds mask inputs."""
104
- mask_embedding = self.mask_downscaling(masks)
105
- return mask_embedding
106
-
107
- def _get_batch_size(
108
- self,
109
- points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110
- boxes: Optional[torch.Tensor],
111
- masks: Optional[torch.Tensor],
112
- text_embedding: Optional[torch.Tensor],
113
- ) -> int:
114
- """
115
- Gets the batch size of the output given the batch size of the input prompts.
116
- """
117
- if points is not None:
118
- return points[0].shape[0]
119
- elif boxes is not None:
120
- return boxes.shape[0]
121
- elif masks is not None:
122
- return masks.shape[0]
123
- elif text_embedding is not None:
124
- return text_embedding.shape[0]
125
- else:
126
- return 1
127
-
128
- def _get_device(self) -> torch.device:
129
- return self.point_embeddings[0].weight.device
130
-
131
- def forward(
132
- self,
133
- points: Optional[Tuple[torch.Tensor, torch.Tensor]],
134
- boxes: Optional[torch.Tensor],
135
- masks: Optional[torch.Tensor],
136
- text_embedding: Optional[torch.Tensor],
137
- ) -> Tuple[torch.Tensor, torch.Tensor]:
138
- """
139
- Embeds different types of prompts, returning both sparse and dense
140
- embeddings.
141
-
142
- Arguments:
143
- points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
144
- and labels to embed.
145
- boxes (torch.Tensor or none): boxes to embed
146
- masks (torch.Tensor or none): masks to embed
147
- text: test prompt (B, 768)
148
-
149
- Returns:
150
- torch.Tensor: sparse embeddings for the points and boxes, with shape
151
- BxNx(embed_dim), where N is determined by the number of input points
152
- and boxes.
153
- torch.Tensor: dense embeddings for the masks, in the shape
154
- Bx(embed_dim)x(embed_H)x(embed_W)
155
- """
156
- # print('prompt encoder here...')
157
-
158
- bs = self._get_batch_size(points, boxes, masks, text_embedding)
159
- sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
160
- # print('sparse_embeddings ', sparse_embeddings.shape)
161
- if points is not None:
162
- coords, labels = points
163
- point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
164
- sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
165
-
166
- if boxes is not None:
167
- box_embeddings = self._embed_boxes(boxes)
168
- sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
169
-
170
- if text_embedding is not None:
171
- sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
172
-
173
- # print('box_embeddings ', box_embeddings.shape)
174
- # print('sparse_embeddings after box/point/text', sparse_embeddings.shape)
175
-
176
- if masks is not None:
177
- dense_embeddings = self._embed_masks(masks)
178
- else:
179
- dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
180
- bs, -1, self.image_embedding_size[0], self.image_embedding_size[1], self.image_embedding_size[2]
181
- )
182
- # print('dense_embeddings ', dense_embeddings.shape)
183
- return sparse_embeddings, dense_embeddings
184
-
185
-
186
- class PositionEmbeddingRandom(nn.Module):
187
- """
188
- Positional encoding using random spatial frequencies.
189
- """
190
-
191
- def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
192
- super().__init__()
193
- if scale is None or scale <= 0.0:
194
- scale = 1.0
195
- self.register_buffer(
196
- "positional_encoding_gaussian_matrix",
197
- scale * torch.randn((3, num_pos_feats)),
198
- )
199
-
200
- def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
201
- """Positionally encode points that are normalized to [0,1]."""
202
- # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
203
- coords = 2 * coords - 1
204
- coords = coords @ self.positional_encoding_gaussian_matrix
205
- coords = 2 * np.pi * coords
206
- # outputs d_1 x ... x d_n x C shape
207
- return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
208
-
209
- def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
210
- """Generate positional encoding for a grid of the specified size."""
211
- h, w, d = size
212
- device: Any = self.positional_encoding_gaussian_matrix.device
213
- grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
214
- y_embed = grid.cumsum(dim=0) - 0.5
215
- x_embed = grid.cumsum(dim=1) - 0.5
216
- z_embed = grid.cumsum(dim=2) - 0.5
217
- y_embed = y_embed / h
218
- x_embed = x_embed / w
219
- z_embed = z_embed / d
220
-
221
- pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
222
- return pe.permute(3, 0, 1, 2) # C x H x W x D
223
-
224
- def forward_with_coords(
225
- self, coords_input: torch.Tensor, image_size: Tuple[int, int]
226
- ) -> torch.Tensor:
227
- """Positionally encode points that are not normalized to [0,1]."""
228
- coords = coords_input.clone()
229
- coords[:, :, 0] = coords[:, :, 0] / image_size[1]
230
- coords[:, :, 1] = coords[:, :, 1] / image_size[0]
231
- coords[:, :, 2] = coords[:, :, 2] / image_size[2]
232
- return self._pe_encoding(coords.to(torch.float)) # B x N x C
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/segment_anything_volumetric/modeling/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from .sam import Sam
8
- from .image_encoder import ImageEncoderViT
9
- from .mask_decoder import MaskDecoder
10
- from .prompt_encoder import PromptEncoder
11
- from .transformer import TwoWayTransformer
 
 
 
 
 
 
 
 
 
 
 
 
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (394 Bytes)
 
model/segment_anything_volumetric/modeling/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (433 Bytes)
 
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-310.pyc DELETED
Binary file (1.75 kB)
 
model/segment_anything_volumetric/modeling/__pycache__/common.cpython-39.pyc DELETED
Binary file (1.78 kB)
 
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-310.pyc DELETED
Binary file (12.6 kB)
 
model/segment_anything_volumetric/modeling/__pycache__/image_encoder.cpython-39.pyc DELETED
Binary file (11.5 kB)
 
model/segment_anything_volumetric/modeling/__pycache__/image_encoder_swin.cpython-39.pyc DELETED
Binary file (21.5 kB)
 
model/segment_anything_volumetric/modeling/__pycache__/mask_decoder.cpython-310.pyc DELETED
Binary file (5.5 kB)