File size: 5,256 Bytes
65a733b
14167ad
65a733b
 
 
 
 
 
 
 
 
 
 
 
14167ad
b5b43af
65a733b
 
 
b5b43af
4e75bca
c7af872
afcd79b
3b059ad
2aa8f6b
afcd79b
b5b43af
 
3b059ad
afcd79b
3b059ad
 
b5b43af
3b059ad
b5b43af
afcd79b
3b059ad
a2cbc95
 
 
c7af872
 
b5b43af
2aa8f6b
afcd79b
b5b43af
65a733b
b5b43af
65a733b
 
44fed4a
65a733b
 
 
 
b5b43af
65a733b
 
b5b43af
65a733b
 
b5b43af
65a733b
 
 
 
 
 
 
 
b5b43af
 
 
 
 
 
 
 
65a733b
 
 
 
 
 
 
 
73400f9
b5b43af
65a733b
b5b43af
65a733b
 
 
b5b43af
65a733b
 
 
 
 
 
 
b5b43af
 
 
65a733b
b5b43af
 
 
 
 
65a733b
b5b43af
 
 
 
65a733b
b5b43af
 
73400f9
b5b43af
 
 
 
 
 
 
5e59526
2aa8f6b
 
f80e05d
2aa8f6b
 
 
 
 
 
5e59526
 
 
2aa8f6b
 
 
 
 
4e75bca
b5b43af
 
 
 
 
 
73400f9
 
 
 
 
 
 
7d7f987
73400f9
b5b43af
 
 
14167ad
73400f9
b5b43af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import gradio as gr
import torch
from monai import bundle
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Orientationd,
    NormalizeIntensityd,
    Activationsd,
    AsDiscreted,
    ScaleIntensityd,
)

# Define the bundle name and path for downloading
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0'
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME)

# Title and description
title = '<h1 style="text-align: center;">Segment Brain Tumors with MONAI! 🧠 </h1>'
description = """
## 🚀 To run

Upload a brain MRI image file, or try out one of the examples below! 
If you want to see a different slice, update the slider.

More details on the model can be found [here!](https://huggingface.co/katielink/brats_mri_segmentation_v0.1.0)

## ⚠️ Disclaimer

This is an example, not to be used for diagnostic purposes.
"""

references = """
## 👀 References

1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654.
2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694
3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117
"""

examples = [
    ['examples/BRATS_485.nii.gz', 65],
    ['examples/BRATS_486.nii.gz', 80]
]

# Load the MONAI pretrained model from Hugging Face Hub
model, _, _ = bundle.load(
    name = BUNDLE_NAME,
    source = 'huggingface_hub',
    repo = 'katielink/brats_mri_segmentation_v0.1.0',
    load_ts_module=True,
)

# Use GPU if available
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load the parser from the MONAI bundle's inference config
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json')

# Compose the preprocessing transforms
preproc_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

# Get the inferer from the bundle's inference config
inferer = parser.get_parsed_content(
    'inferer',
    lazy=True, eval_expr=True, instantiate=True
)

# Compose the postprocessing transforms
post_transforms = Compose(
    [
        Activationsd(keys='pred', sigmoid=True),
        AsDiscreted(keys='pred', threshold=0.5),
        ScaleIntensityd(keys='image', minv=0., maxv=1.)
    ]
)


# Define the predict function for the demo
def predict(input_file, z_axis, model=model, device=device):
    # Load and process data in MONAI format
    data = {'image': [input_file.name]}
    data = preproc_transforms(data)
    
    # Run inference and post-process predicted labels
    model.to(device)
    model.eval()
    with torch.no_grad():
        inputs = data['image'].to(device)
        data['pred'] = inferer(inputs=inputs[None,...], network=model)
        data = post_transforms(data)
    
    # Convert tensors back to numpy arrays
    data['image'] = data['image'].numpy()
    data['pred'] = data['pred'].cpu().detach().numpy()
    
    # Magnetic resonance imaging sequences
    t1c = data['image'][0, :, :, z_axis]  # T1-weighted, post contrast
    t1 = data['image'][1, :, :, z_axis]  # T1-weighted, pre contrast
    t2 = data['image'][2, :, :, z_axis]  # T2-weighted
    flair = data['image'][3, :, :, z_axis]  # FLAIR
    
    # BraTS labels
    tc = data['pred'][0, 0, :, :, z_axis]  # Tumor core
    wt = data['pred'][0, 1, :, :, z_axis]  # Whole tumor
    et = data['pred'][0, 2, :, :, z_axis]  # Enhancing tumor
    
    return [t1c, t1, t2, flair], [tc, wt, et] 


# Use blocks to set up a more complex demo
with gr.Blocks() as demo:

    # Show title and description
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        # Get the input file and slice slider as inputs
        input_file = gr.File(label='input file')
        z_axis = gr.Slider(0, 200, label='slice', value=50)
        
    with gr.Row():
        # Show the button with custom label
        button = gr.Button("Segment Tumor!")

    with gr.Row():
        with gr.Column():
            # Show the input image with different MR sequences
            input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)')
        
        with gr.Column():
            # Show the segmentation labels
            output_segmentation = gr.Gallery(label='output segmentations (TC, WT, ET)')

    
    # Run prediction on button click
    button.click(
        predict, 
        inputs=[input_file, z_axis],
        outputs=[input_image, output_segmentation]
    )

    # Have some example for the user to try out
    examples = gr.Examples(
        examples=examples,
        inputs=[input_file, z_axis],
        outputs=[input_image, output_segmentation],
        fn=predict,
        cache_examples=False
    )
    
    # Show references at the bottom of the demo
    gr.Markdown(references)


# Launch the demo 
demo.launch()