Commited model weights and demo code
Browse files- LICENSE.md +9 -0
- README.md +51 -0
- autoencoders_demo.ipynb +0 -0
- config.json +28 -0
- data_preprocessing_recipe.py +202 -0
- data_utils.py +24 -0
- diffusion_pytorch_model.safetensors +3 -0
- example_data/mri_complex_images.npz +3 -0
- inference.py +28 -0
- metrics.py +30 -0
LICENSE.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Microsoft Corporation
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
6 |
+
|
7 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
8 |
+
|
9 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
library_name: diffusers
|
6 |
+
tags:
|
7 |
+
- MRI
|
8 |
+
- medical-imaging
|
9 |
+
- VAE
|
10 |
+
- autoencoder
|
11 |
+
---
|
12 |
+
# MRI Autoencoder v0.1
|
13 |
+
|
14 |
+
## Model
|
15 |
+
MRI autoencoder is a Variational Autoencoder (VAE) trained on the fast MRI multi-coil brain and knee datasets. The model is trained from scratch and uses the same architecture as the Stable Diffusion SDXL VAE model.
|
16 |
+
|
17 |
+
Latent Diffusion Models (LDMs) have been extremely popular in synthesizing images and videos. However, they remain relatively under-explored in the field of medical imaging. One possible reason is the lack of domain specific autoencoders that can encode and decode higher dimensional medical imaging data to their lower dimensional latent representation. MRI images, for example, are different than general domain images in that they are complex valued with magnitude and phase information. To this end, we are publishing an autoencoder that can be used to encode and decode complex valued MRI images to and from their latent representation.
|
18 |
+
|
19 |
+
## Use
|
20 |
+
|
21 |
+
```
|
22 |
+
from diffusers.models import AutoencoderKL
|
23 |
+
autoencoder = AutoencoderKL.from_pretrained("microsoft/mri-autoencoder-v0.1")
|
24 |
+
```
|
25 |
+
|
26 |
+
For more details please refer to the provided autoencoders_demo notebook. For details on how the fastmri data was preprocessed, please refer to data_preprocessing_recipe.py.
|
27 |
+
|
28 |
+
## Intended Use
|
29 |
+
|
30 |
+
The model is intended to be used solely for future research in medical imaging. Stakeholders would benefit by treating this model as a building block towards exploring latent space generative models applied to complex valued MRI images.
|
31 |
+
|
32 |
+
## Out-of-Scope Use
|
33 |
+
|
34 |
+
Any deployed use case of the model, commercial or otherwise, is out of scope. The model weights and code are not intended for clinical use.
|
35 |
+
|
36 |
+
## Evaluation
|
37 |
+
|
38 |
+
The PSNR and SSIM scores on randomly chosen 8000 slices from the fastMRI multicoil validation dataset are as follows:
|
39 |
+
|
40 |
+
| Autoencoder | Median PSNR | Mean PSNR | PSNR 95% CI | Median SSIM | Mean SSIM | SSIM 95% CI |
|
41 |
+
| ----------- | ----------- | --------- | ----------- | ----------- | --------- | ----------- |
|
42 |
+
| MRI-AUTOENCODER-v0.1 | 34.31 | 33.98 | (28.55. 37.79) | 0.91 | 0.88 | (0.54, 0.97) |
|
43 |
+
| SDXL-VAE | 31.45 | 31.51 | (27.85, 35.63) | 0.89 | 0.86 | (0.58, 0.94) |
|
44 |
+
|
45 |
+
## Data
|
46 |
+
|
47 |
+
This model was trained, with permission, using the NYU fastMRI Dataset (https://fastmri.med.nyu.edu/), which is a deidentified imaging dataset provided by NYU Langone comprised of raw k-space data in several sub-dataset groups.
|
48 |
+
|
49 |
+
## Limitations
|
50 |
+
|
51 |
+
A model trained on this dataset might likely overfit and not generalize well to new data. This model has not been evaluated for clinical use or across a range of scanner types.
|
autoencoders_demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "AutoencoderKL",
|
3 |
+
"_diffusers_version": "0.24.0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"block_out_channels": [
|
6 |
+
128,
|
7 |
+
256,
|
8 |
+
512
|
9 |
+
],
|
10 |
+
"down_block_types": [
|
11 |
+
"DownEncoderBlock2D",
|
12 |
+
"DownEncoderBlock2D",
|
13 |
+
"DownEncoderBlock2D"
|
14 |
+
],
|
15 |
+
"force_upcast": true,
|
16 |
+
"in_channels": 2,
|
17 |
+
"latent_channels": 4,
|
18 |
+
"layers_per_block": 2,
|
19 |
+
"norm_num_groups": 32,
|
20 |
+
"out_channels": 2,
|
21 |
+
"sample_size": 256,
|
22 |
+
"scaling_factor": 0.18215,
|
23 |
+
"up_block_types": [
|
24 |
+
"UpDecoderBlock2D",
|
25 |
+
"UpDecoderBlock2D",
|
26 |
+
"UpDecoderBlock2D"
|
27 |
+
]
|
28 |
+
}
|
data_preprocessing_recipe.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' This file contains the recipe for data preprocessing used to generate the combined coil images from the fastmri multicoil brain and knee datasets.
|
2 |
+
These combined coil images were then used to train the autoencoder. The combined coil images are generated by combining the coil images using
|
3 |
+
the sensitivity maps calculated with bart. To run this recipe, the bart toolbox needs to be installed and then follow the steps outlined in
|
4 |
+
the preprocess_recipe function.'''
|
5 |
+
|
6 |
+
|
7 |
+
# bart toolbox installation instructions - https://mrirecon.github.io/bart/installation.html
|
8 |
+
_BART_TOOLBOX_PATH = ''
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import h5py
|
12 |
+
from tqdm import tqdm
|
13 |
+
import sys, os
|
14 |
+
|
15 |
+
os.environ["TOOLBOX_PATH"] = _BART_TOOLBOX_PATH
|
16 |
+
sys.path.append(os.path.join(_BART_TOOLBOX_PATH, 'python'))
|
17 |
+
from bart import bart
|
18 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
19 |
+
|
20 |
+
def fftc(input, axes=None, norm='ortho'):
|
21 |
+
"""
|
22 |
+
Perform a Fast Fourier Transform on the input array.
|
23 |
+
|
24 |
+
Parameters:
|
25 |
+
input (numpy.ndarray): The input array to transform.
|
26 |
+
axes (tuple, optional): Axes over which to compute the FFT. If not specified, compute over all axes.
|
27 |
+
norm (str, optional): Normalization mode. Default is 'ortho' for orthonormal transform.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
numpy.ndarray: The transformed output array.
|
31 |
+
"""
|
32 |
+
tmp = np.fft.ifftshift(input, axes=axes)
|
33 |
+
tmp = np.fft.fftn(tmp, axes=axes, norm=norm)
|
34 |
+
output = np.fft.fftshift(tmp, axes=axes)
|
35 |
+
return output
|
36 |
+
|
37 |
+
def ifftc(input, axes=None, norm='ortho'):
|
38 |
+
"""
|
39 |
+
Perform an Inverse Fast Fourier Transform on the input array.
|
40 |
+
|
41 |
+
Parameters:
|
42 |
+
input (numpy.ndarray): The input array to transform.
|
43 |
+
axes (tuple, optional): Axes over which to compute the inverse FFT. If not specified, compute over all axes.
|
44 |
+
norm (str, optional): Normalization mode. Default is 'ortho' for orthonormal transform.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
numpy.ndarray: The transformed output array.
|
48 |
+
"""
|
49 |
+
tmp = np.fft.ifftshift(input, axes=axes)
|
50 |
+
tmp = np.fft.ifftn(tmp, axes=axes, norm=norm)
|
51 |
+
output = np.fft.fftshift(tmp, axes=axes)
|
52 |
+
return output
|
53 |
+
|
54 |
+
def adjoint(ksp, maps, mask):
|
55 |
+
"""
|
56 |
+
Perform the adjoint operation on k-space data with coil sensitivity maps and a mask.
|
57 |
+
|
58 |
+
Parameters:
|
59 |
+
ksp (numpy.ndarray): The input k-space data, shape: [1, C, H, W].
|
60 |
+
maps (numpy.ndarray): The coil sensitivity maps, shape: [1, C, H, W].
|
61 |
+
mask (numpy.ndarray): The mask to apply on the k-space data, shape: [1, 1, H, W].
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
numpy.ndarray: The output image after applying the adjoint operation, shape: [1, 1, H, W].
|
65 |
+
"""
|
66 |
+
masked_ksp = ksp*mask
|
67 |
+
coil_imgs = ifftc(masked_ksp,axes=(-2,-1))
|
68 |
+
img_out = np.sum(coil_imgs*np.conj(maps),axis=1)[:,None,...]
|
69 |
+
return img_out
|
70 |
+
|
71 |
+
def _expand_shapes(*shapes):
|
72 |
+
"""
|
73 |
+
Expand the dimensions of the given shapes to match the maximum dimension.
|
74 |
+
|
75 |
+
This function prepends 1s to the shapes with fewer dimensions to match the maximum number of dimensions.
|
76 |
+
|
77 |
+
Parameters:
|
78 |
+
*shapes (tuple): A variable length tuple containing shapes (as lists or tuples of integers).
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
tuple: A tuple of expanded shapes, where each shape is a list of integers.
|
82 |
+
"""
|
83 |
+
|
84 |
+
shapes = [list(shape) for shape in shapes]
|
85 |
+
max_ndim = max(len(shape) for shape in shapes)
|
86 |
+
shapes_exp = [[1] * (max_ndim - len(shape)) + shape
|
87 |
+
for shape in shapes]
|
88 |
+
|
89 |
+
return tuple(shapes_exp)
|
90 |
+
|
91 |
+
def resize(input, oshape, ishift=None, oshift=None):
|
92 |
+
"""
|
93 |
+
Resize with zero-padding or cropping.
|
94 |
+
|
95 |
+
Parameters:
|
96 |
+
input (array): Input array.
|
97 |
+
oshape (tuple of ints): Output shape.
|
98 |
+
ishift (None or tuple of ints): Input shift.
|
99 |
+
oshift (None or tuple of ints): Output shift.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
array: Zero-padded or cropped result.
|
103 |
+
"""
|
104 |
+
|
105 |
+
ishape1, oshape1 = _expand_shapes(input.shape, oshape)
|
106 |
+
|
107 |
+
if ishape1 == oshape1:
|
108 |
+
return input.reshape(oshape)
|
109 |
+
|
110 |
+
if ishift is None:
|
111 |
+
ishift = [max(i // 2 - o // 2, 0) for i, o in zip(ishape1, oshape1)]
|
112 |
+
|
113 |
+
if oshift is None:
|
114 |
+
oshift = [max(o // 2 - i // 2, 0) for i, o in zip(ishape1, oshape1)]
|
115 |
+
|
116 |
+
copy_shape = [min(i - si, o - so)
|
117 |
+
for i, si, o, so in zip(ishape1, ishift, oshape1, oshift)]
|
118 |
+
islice = tuple([slice(si, si + c) for si, c in zip(ishift, copy_shape)])
|
119 |
+
oslice = tuple([slice(so, so + c) for so, c in zip(oshift, copy_shape)])
|
120 |
+
|
121 |
+
output = np.zeros(oshape1, dtype=input.dtype)
|
122 |
+
input = input.reshape(ishape1)
|
123 |
+
output[oslice] = input[islice]
|
124 |
+
|
125 |
+
return output.reshape(oshape)
|
126 |
+
|
127 |
+
def shape_data(ksp, final_res):
|
128 |
+
"""
|
129 |
+
Reshape coil k-space data to output coil images with isotropic pixels and correct FOV = origional image width and the correct square image size given by "final_res".
|
130 |
+
|
131 |
+
This function assumes that the k-space data has already been padded to make the corresponding images have isotropic pixels.
|
132 |
+
|
133 |
+
Parameters:
|
134 |
+
ksp (numpy.ndarray): The input coil k-space data, shape: [S, C, H, W].
|
135 |
+
final_res (int): The final resolution for the output image.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
numpy.ndarray: The output image after reshaping, shape: [S, C, final_res, final_res].
|
139 |
+
"""
|
140 |
+
H = ksp.shape[-2]
|
141 |
+
W = ksp.shape[-1]
|
142 |
+
S = ksp.shape[0]
|
143 |
+
C = ksp.shape[1]
|
144 |
+
# bring the coil ksp into coil image space
|
145 |
+
img1 = ifftc(ksp,axes=(-2,-1))
|
146 |
+
img1_cropped = resize(img1, oshape=(S,C,W,W))
|
147 |
+
# FOV is now the same in both directions without modifying the resolution
|
148 |
+
ksp1 = fftc(img1_cropped,axes=(-2,-1))
|
149 |
+
# crop or pad the ksp isotropically in fourier space to the correct image size while mainting the same field of view (in width direction) in the original image
|
150 |
+
ksp1_cropped = resize(ksp1, oshape=(S,C,final_res,final_res))
|
151 |
+
img_out = ifftc(ksp1_cropped,axes=(-2,-1))
|
152 |
+
|
153 |
+
return img_out
|
154 |
+
|
155 |
+
def read_fastmri_data(file_path):
|
156 |
+
"""
|
157 |
+
This function reads k-space data from a .h5 file.
|
158 |
+
|
159 |
+
Parameters:
|
160 |
+
file_path (str): The path to the .h5 file containing FastMRI data.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
numpy.ndarray: The k-space data as a numpy array.
|
164 |
+
"""
|
165 |
+
hf = h5py.File(file_path, 'r')
|
166 |
+
ksp = np.asarray(hf['kspace'])
|
167 |
+
return ksp
|
168 |
+
|
169 |
+
def combine_coils(ksp):
|
170 |
+
"""
|
171 |
+
Combine multi-coil k-space data into a single coil image.
|
172 |
+
|
173 |
+
This function reshapes the raw multi-coil k-space data, calculates sensitivity maps for the reshaped data using the BART tool's 'ecalib' command, and then uses these maps to create a single coil image via a fully sampled adjoint operation.
|
174 |
+
|
175 |
+
Parameters:
|
176 |
+
ksp (numpy.ndarray): The input multi-coil k-space data, shape: [B, C, H, W].
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
numpy.ndarray: The output single coil image, shape: [B, 1, H, W].
|
180 |
+
"""
|
181 |
+
# reshape raw multi-coil kspace to desired shape (ex [B,C,256,256])
|
182 |
+
coil_img_rs = shape_data(ksp, final_res=256)
|
183 |
+
coil_ksp_rs = fftc(coil_img_rs, axes=(-2,-1))
|
184 |
+
|
185 |
+
# calculate sensitivity maps for reshaped coil ksp
|
186 |
+
ksp_rs = coil_ksp_rs.transpose((2,3,0,1))
|
187 |
+
maps = np.array(ksp_rs)
|
188 |
+
#calculate Espirit maps with bart
|
189 |
+
for j in tqdm(range(ksp_rs.shape[2])):
|
190 |
+
sens = bart(1,'ecalib -m1 -W -c0', ksp_rs[:,:,j,None,:])#requires data of the form (Row,Column,None,Coil)<-output of ecalib too, this should then be saved (slice, coil, rows, columns)
|
191 |
+
maps[:,:,j,:] = sens[:,:,0,:]
|
192 |
+
|
193 |
+
maps_rs = maps.transpose((2,3,0,1))
|
194 |
+
# use new maps to create single coil image via fully sampled adjoint operation
|
195 |
+
single_coil_rs_img = adjoint(ksp=coil_ksp_rs, maps = maps_rs, mask = np.ones_like(coil_ksp_rs))
|
196 |
+
return single_coil_rs_img
|
197 |
+
|
198 |
+
def preprocess_data_recipe():
|
199 |
+
# for each file in the fastMRI dataset
|
200 |
+
# call read_fastmri_data to get the kspace data
|
201 |
+
# call combine_coils to create the combined coil image
|
202 |
+
pass
|
data_utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
def complex_to_two_channel_image(complex_img: np.ndarray) -> np.ndarray:
|
4 |
+
"""Converts a complex valued image to a 2 channel image (real and imaginary channels)"""
|
5 |
+
real, imag = np.real(complex_img), np.imag(complex_img)
|
6 |
+
return np.concatenate((real, imag), axis=0)
|
7 |
+
|
8 |
+
def two_channel_to_complex_image(two_ch_img: np.ndarray) -> np.ndarray:
|
9 |
+
"""Converts a 2 channel image (real and imaginary channels) to a complex valued image"""
|
10 |
+
two_ch_img = two_ch_img[0]
|
11 |
+
real = two_ch_img[0]
|
12 |
+
imag = two_ch_img[1]
|
13 |
+
complex_image = real + 1j*imag
|
14 |
+
return complex_image[None,...]
|
15 |
+
|
16 |
+
def normalize_complex_coil_image(complex_coil_img: np.ndarray) -> np.ndarray:
|
17 |
+
"""Scales the complex valued coil image """
|
18 |
+
max_val = np.percentile(np.abs(complex_coil_img), 99.5)
|
19 |
+
return complex_coil_img / max_val
|
20 |
+
|
21 |
+
def create_three_channel_image(complex_coil_img: np.ndarray) -> np.ndarray:
|
22 |
+
"""Converts a complex valued coil image to a 3 channel image (magnitude channels repated 3 times)"""
|
23 |
+
mag = np.abs(complex_coil_img)
|
24 |
+
return np.concatenate((mag, mag, mag), axis=0)
|
diffusion_pytorch_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b07cfaad692d5e60669b5bfe0432de71eb11ad6925bf9e0bd333b69d15c5e62
|
3 |
+
size 221317280
|
example_data/mri_complex_images.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69a6b217303b2147957ac2f1dbcee976b5ed509621acd94ca55110a1c8f02e5c
|
3 |
+
size 2022432
|
inference.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import data_utils as du
|
3 |
+
|
4 |
+
def run_inference_two_channels(coil_complex_image, autoencoder, device="cuda"):
|
5 |
+
coil_complex_image = du.normalize_complex_coil_image(coil_complex_image)
|
6 |
+
two_channel_image = du.complex_to_two_channel_image(coil_complex_image)
|
7 |
+
two_channel_tensor = torch.from_numpy(two_channel_image)[None,...].type(torch.FloatTensor).to(device)
|
8 |
+
autoencoder = autoencoder.to(device)
|
9 |
+
with torch.no_grad():
|
10 |
+
autoencoder_output = autoencoder.encode(two_channel_tensor)
|
11 |
+
latents = autoencoder_output.latent_dist.mean
|
12 |
+
decoded_image = autoencoder.decode(latents).sample
|
13 |
+
recon = du.two_channel_to_complex_image(decoded_image.detach().cpu().numpy())
|
14 |
+
input = coil_complex_image
|
15 |
+
return input, recon
|
16 |
+
|
17 |
+
def run_inference_three_channels(coil_complex_image, autoencoder, device="cuda"):
|
18 |
+
coil_complex_image = du.normalize_complex_coil_image(coil_complex_image)
|
19 |
+
three_channel_image = du.create_three_channel_image(coil_complex_image)
|
20 |
+
three_channel_tensor = torch.from_numpy(three_channel_image)[None,...].type(torch.FloatTensor).to(device)
|
21 |
+
autoencoder = autoencoder.to(device)
|
22 |
+
with torch.no_grad():
|
23 |
+
autoencoder_output = autoencoder.encode(three_channel_tensor)
|
24 |
+
latents = autoencoder_output.latent_dist.mean
|
25 |
+
decoded_image = autoencoder.decode(latents).sample
|
26 |
+
recon = decoded_image[0].detach().cpu().numpy()
|
27 |
+
input = three_channel_image
|
28 |
+
return input, recon
|
metrics.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
def ssim(
|
7 |
+
gt: np.ndarray, pred: np.ndarray, data_range: Optional[float] = None
|
8 |
+
) -> np.ndarray:
|
9 |
+
"""Compute Structural Similarity Index Metric (SSIM)"""
|
10 |
+
if not gt.ndim == 3:
|
11 |
+
raise ValueError("Unexpected number of dimensions in ground truth.")
|
12 |
+
if not gt.ndim == pred.ndim:
|
13 |
+
raise ValueError("Ground truth dimensions does not match pred.")
|
14 |
+
|
15 |
+
data_range = gt.max() if data_range is None else data_range
|
16 |
+
|
17 |
+
ssim = np.array([0])
|
18 |
+
for slice_num in range(gt.shape[0]):
|
19 |
+
ssim = ssim + structural_similarity(
|
20 |
+
gt[slice_num], pred[slice_num], data_range=data_range
|
21 |
+
)
|
22 |
+
|
23 |
+
return ssim / gt.shape[0]
|
24 |
+
|
25 |
+
def psnr(
|
26 |
+
gt: np.ndarray, pred: np.ndarray, data_range: Optional[float] = None
|
27 |
+
) -> np.ndarray:
|
28 |
+
"""Compute Peak Signal to Noise Ratio metric (PSNR)"""
|
29 |
+
data_range = gt.max() if data_range is None else data_range
|
30 |
+
return peak_signal_noise_ratio(gt, pred, data_range=data_range)
|