osbm commited on
Commit
4affc67
1 Parent(s): 7cdfd28

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +237 -1
main.py CHANGED
@@ -5,10 +5,246 @@ import gradio as gr
5
  # run nnunet
6
  # export
7
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
- def predict(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  return "cat"
13
 
14
 
 
5
  # run nnunet
6
  # export
7
 
8
+ import os
9
+ import pickle
10
+ import subprocess
11
+ from pathlib import Path
12
+ from typing import Union
13
 
14
+ import numpy as np
15
+ import SimpleITK as sitk
16
+ from evalutils import SegmentationAlgorithm
17
+ from evalutils.validators import (UniqueImagesValidator,
18
+ UniquePathIndicesValidator)
19
+ from picai_baseline.nnunet.softmax_export import \
20
+ save_softmax_nifti_from_softmax
21
+ from picai_prep import atomic_image_write
22
+ from picai_prep.preprocessing import (PreprocessingSettings, Sample,
23
+ resample_to_reference_scan)
24
 
25
 
26
+ class MissingSequenceError(Exception):
27
+ """Exception raised when a sequence is missing."""
28
+
29
+ def __init__(self, name, folder):
30
+ message = f"Could not find scan for {name} in {folder} (files: {os.listdir(folder)})"
31
+ super().__init__(message)
32
+
33
+
34
+ class MultipleScansSameSequencesError(Exception):
35
+ """Exception raised when multiple scans of the same sequences are provided."""
36
+
37
+ def __init__(self, name, folder):
38
+ message = f"Found multiple scans for {name} in {folder} (files: {os.listdir(folder)})"
39
+ super().__init__(message)
40
+
41
+
42
+ def convert_to_original_extent(pred: np.ndarray, pkl_path: Union[Path, str], dst_path: Union[Path, str]):
43
+ # convert to nnUNet's internal softmax format
44
+ pred = np.array([1-pred, pred])
45
+
46
+ # read physical properties of current case
47
+ with open(pkl_path, "rb") as fp:
48
+ properties = pickle.load(fp)
49
+
50
+ # let nnUNet resample to original physical space
51
+ save_softmax_nifti_from_softmax(
52
+ segmentation_softmax=pred,
53
+ out_fname=str(dst_path),
54
+ properties_dict=properties,
55
+ )
56
+
57
+
58
+ def strip_metadata(img: sitk.Image) -> None:
59
+ for key in img.GetMetaDataKeys():
60
+ img.EraseMetaData(key)
61
+
62
+
63
+ def overwrite_affine(fixed_img: sitk.Image, moving_img: sitk.Image) -> sitk.Image:
64
+ moving_img.SetOrigin(fixed_img.GetOrigin())
65
+ moving_img.SetDirection(fixed_img.GetDirection())
66
+ moving_img.SetSpacing(fixed_img.GetSpacing())
67
+ return moving_img
68
+
69
+
70
+ class ProstateSegmentationAlgorithm(SegmentationAlgorithm):
71
+ """
72
+ Wrapper to deploy trained prostate segmentation nnU-Net model from
73
+ https://github.com/DIAGNijmegen/picai_baseline as a
74
+ grand-challenge.org algorithm.
75
+ """
76
+
77
+ def __init__(self):
78
+ super().__init__(
79
+ validators=dict(
80
+ input_image=(
81
+ UniqueImagesValidator(),
82
+ UniquePathIndicesValidator(),
83
+ )
84
+ ),
85
+ )
86
+
87
+ # input / output paths for algorithm
88
+ self.input_dirs = [
89
+ "/input/images/transverse-t2-prostate-mri"
90
+ ]
91
+ self.scan_paths = []
92
+ self.prostate_segmentation_path_pz = Path("/output/images/softmax-prostate-peripheral-zone-segmentation/prostate_gland_sm_pz.mha")
93
+ self.prostate_segmentation_path_tz = Path("/output/images/softmax-prostate-central-gland-segmentation/prostate_gland_sm_tz.mha")
94
+ self.prostate_segmentation_path = Path("/output/images/prostate-zonal-segmentation/prostate_gland.mha")
95
+
96
+ # input / output paths for nnUNet
97
+ self.nnunet_inp_dir = Path("/opt/algorithm/nnunet/input")
98
+ self.nnunet_out_dir = Path("/opt/algorithm/nnunet/output")
99
+ self.nnunet_results = Path("/opt/algorithm/results")
100
+
101
+ # ensure required folders exist
102
+ self.nnunet_inp_dir.mkdir(exist_ok=True, parents=True)
103
+ self.nnunet_out_dir.mkdir(exist_ok=True, parents=True)
104
+ self.prostate_segmentation_path_pz.parent.mkdir(exist_ok=True, parents=True)
105
+
106
+ # input validation for multiple inputs
107
+ scan_glob_format = "*.mha"
108
+ for folder in self.input_dirs:
109
+ file_paths = list(Path(folder).glob(scan_glob_format))
110
+ if len(file_paths) == 0:
111
+ raise MissingSequenceError(name=folder.split("/")[-1], folder=folder)
112
+ elif len(file_paths) >= 2:
113
+ raise MultipleScansSameSequencesError(name=folder.split("/")[-1], folder=folder)
114
+ else:
115
+ # append scan path to algorithm input paths
116
+ self.scan_paths += [file_paths[0]]
117
+
118
+ def preprocess_input(self):
119
+ """Preprocess input images to nnUNet Raw Data Archive format"""
120
+ # set up Sample
121
+ sample = Sample(
122
+ scans=[
123
+ sitk.ReadImage(str(path))
124
+ for path in [self.scan_paths[0]]
125
+ ],
126
+ settings=PreprocessingSettings(
127
+ physical_size=[81.0, 192.0, 192.0],
128
+ crop_only=True
129
+ )
130
+ )
131
+
132
+ # perform preprocessing
133
+ sample.preprocess()
134
+
135
+ # write preprocessed scans to nnUNet input directory
136
+ for i, scan in enumerate(sample.scans):
137
+ path = self.nnunet_inp_dir / f"scan_{i:04d}.nii.gz"
138
+ atomic_image_write(scan, path)
139
+
140
+ # Note: need to overwrite process because of flexible inputs, which requires custom data loading
141
+ def process(self):
142
+ """
143
+ Load bpMRI scans and segment the prostate glands
144
+ """
145
+ # perform preprocessing
146
+ self.preprocess_input()
147
+
148
+ # perform inference using nnUNet
149
+ self.predict(
150
+ task="Task848_experiment48",
151
+ trainer="nnUNetTrainerV2_MMS",
152
+ checkpoint="model_best",
153
+ folds="0"
154
+ )
155
+
156
+ pred_path_prostate = str(self.nnunet_out_dir / "scan.npz")
157
+ sm_arr = np.load(pred_path_prostate)['softmax']
158
+ pz_arr = np.array(sm_arr[1, :, :, :]).astype('float32')
159
+ tz_arr = np.array(sm_arr[2, :, :, :]).astype('float32')
160
+
161
+ # read postprocessed prediction
162
+ pred_path = str(self.nnunet_out_dir / "scan.nii.gz")
163
+ pred_postprocessed: sitk.Image = sitk.ReadImage(pred_path)
164
+
165
+ # remove metadata to get rid of SimpleITK warning
166
+ strip_metadata(pred_postprocessed)
167
+
168
+ # save postprocessed prediction to output
169
+ atomic_image_write(pred_postprocessed, self.prostate_segmentation_path, mkdir=True)
170
+
171
+ for pred, save_path in [
172
+ (pz_arr, self.prostate_segmentation_path_pz),
173
+ (tz_arr, self.prostate_segmentation_path_tz),
174
+ ]:
175
+ # the prediction is currently at the size and location of the nnU-Net preprocessed
176
+ # scan, so we need to convert it to the original extent before we continue
177
+ convert_to_original_extent(
178
+ pred=pred,
179
+ pkl_path=self.nnunet_out_dir / "scan.pkl",
180
+ dst_path=self.nnunet_out_dir / "softmax.nii.gz",
181
+ )
182
+
183
+ # now each voxel in softmax.nii.gz corresponds to the same voxel in the reference scan
184
+ pred = sitk.ReadImage(str(self.nnunet_out_dir / "softmax.nii.gz"))
185
+
186
+ # convert prediction to a SimpleITK image and infuse the physical metadata of the reference scan
187
+ reference_scan_original_path = str(self.scan_paths[0])
188
+ reference_scan = sitk.ReadImage(reference_scan_original_path)
189
+ pred = resample_to_reference_scan(pred, reference_scan_original=reference_scan)
190
+
191
+ # clip small values to 0 to save disk space
192
+ arr = sitk.GetArrayFromImage(pred)
193
+ arr[arr < 1e-3] = 0
194
+ pred_clipped = sitk.GetImageFromArray(arr)
195
+ pred_clipped.CopyInformation(pred)
196
+
197
+ # remove metadata to get rid of SimpleITK warning
198
+ strip_metadata(pred_clipped)
199
+
200
+ # save prediction to output folder
201
+ atomic_image_write(pred_clipped, save_path, mkdir=True)
202
+
203
+ def predict(self, task, trainer="nnUNetTrainerV2", network="3d_fullres",
204
+ checkpoint="model_final_checkpoint", folds="0,1,2,3,4", store_probability_maps=True,
205
+ disable_augmentation=False, disable_patch_overlap=False):
206
+ """
207
+ Use trained nnUNet network to generate segmentation masks
208
+ """
209
+
210
+ # Set environment variables
211
+ os.environ['RESULTS_FOLDER'] = str(self.nnunet_results)
212
+
213
+ # Run prediction script
214
+ cmd = [
215
+ 'nnUNet_predict',
216
+ '-t', task,
217
+ '-i', str(self.nnunet_inp_dir),
218
+ '-o', str(self.nnunet_out_dir),
219
+ '-m', network,
220
+ '-tr', trainer,
221
+ '--num_threads_preprocessing', '2',
222
+ '--num_threads_nifti_save', '1'
223
+ ]
224
+
225
+ if folds:
226
+ cmd.append('-f')
227
+ cmd.extend(folds.split(','))
228
+
229
+ if checkpoint:
230
+ cmd.append('-chk')
231
+ cmd.append(checkpoint)
232
+
233
+ if store_probability_maps:
234
+ cmd.append('--save_npz')
235
+
236
+ if disable_augmentation:
237
+ cmd.append('--disable_tta')
238
+
239
+ if disable_patch_overlap:
240
+ cmd.extend(['--step_size', '1'])
241
+
242
+ subprocess.check_call(cmd)
243
+
244
+
245
+
246
+ def predict(input_file):
247
+ print(input_file)
248
  return "cat"
249
 
250