Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import torch | |
from monai import bundle | |
from monai.transforms import ( | |
Compose, | |
LoadImaged, | |
EnsureChannelFirstd, | |
Orientationd, | |
NormalizeIntensityd, | |
Activationsd, | |
AsDiscreted, | |
ScaleIntensityd, | |
) | |
# Define the bundle name and path for downloading | |
BUNDLE_NAME = 'spleen_ct_segmentation_v0.1.0' | |
BUNDLE_PATH = os.path.join(torch.hub.get_dir(), 'bundle', BUNDLE_NAME) | |
# Title and description | |
title = '<h1 style="text-align: center;">Segment Brain Tumors with MONAI! π§ </h1>' | |
description = """ | |
## π To run | |
Upload a brain MRI image file, or try out one of the examples below! | |
If you want to see a different slice, update the slider. | |
More details on the model can be found [here!](https://huggingface.co/katielink/brats_mri_segmentation_v0.1.0) | |
## β οΈ Disclaimer | |
This is an example, not to be used for diagnostic purposes. | |
""" | |
references = """ | |
## π References | |
1. Myronenko, Andriy. "3D MRI brain tumor segmentation using autoencoder regularization." International MICCAI Brainlesion Workshop. Springer, Cham, 2018. https://arxiv.org/abs/1810.11654. | |
2. Menze BH, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694 | |
3. Bakas S, et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI:10.1038/sdata.2017.117 | |
""" | |
examples = [ | |
['examples/BRATS_485.nii.gz', 65], | |
['examples/BRATS_486.nii.gz', 80] | |
] | |
# Load the MONAI pretrained model from Hugging Face Hub | |
model, _, _ = bundle.load( | |
name = BUNDLE_NAME, | |
source = 'huggingface_hub', | |
repo = 'katielink/brats_mri_segmentation_v0.1.0', | |
load_ts_module=True, | |
) | |
# Use GPU if available | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Load the parser from the MONAI bundle's inference config | |
parser = bundle.load_bundle_config(BUNDLE_PATH, 'inference.json') | |
# Compose the preprocessing transforms | |
preproc_transforms = Compose( | |
[ | |
LoadImaged(keys=["image"]), | |
EnsureChannelFirstd(keys="image"), | |
Orientationd(keys=["image"], axcodes="RAS"), | |
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), | |
] | |
) | |
# Get the inferer from the bundle's inference config | |
inferer = parser.get_parsed_content( | |
'inferer', | |
lazy=True, eval_expr=True, instantiate=True | |
) | |
# Compose the postprocessing transforms | |
post_transforms = Compose( | |
[ | |
Activationsd(keys='pred', sigmoid=True), | |
AsDiscreted(keys='pred', threshold=0.5), | |
ScaleIntensityd(keys='image', minv=0., maxv=1.) | |
] | |
) | |
# Define the predict function for the demo | |
def predict(input_file, z_axis, model=model, device=device): | |
# Load and process data in MONAI format | |
data = {'image': [input_file.name]} | |
data = preproc_transforms(data) | |
# Run inference and post-process predicted labels | |
model.to(device) | |
model.eval() | |
with torch.no_grad(): | |
inputs = data['image'].to(device) | |
data['pred'] = inferer(inputs=inputs[None,...], network=model) | |
data = post_transforms(data) | |
# Convert tensors back to numpy arrays | |
data['image'] = data['image'].numpy() | |
data['pred'] = data['pred'].cpu().detach().numpy() | |
# Magnetic resonance imaging sequences | |
t1c = data['image'][0, :, :, z_axis] # T1-weighted, post contrast | |
t1 = data['image'][1, :, :, z_axis] # T1-weighted, pre contrast | |
t2 = data['image'][2, :, :, z_axis] # T2-weighted | |
flair = data['image'][3, :, :, z_axis] # FLAIR | |
# BraTS labels | |
tc = data['pred'][0, 0, :, :, z_axis] # Tumor core | |
wt = data['pred'][0, 1, :, :, z_axis] # Whole tumor | |
et = data['pred'][0, 2, :, :, z_axis] # Enhancing tumor | |
return [t1c, t1, t2, flair], [tc, wt, et] | |
# Use blocks to set up a more complex demo | |
with gr.Blocks() as demo: | |
# Show title and description | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
# Get the input file and slice slider as inputs | |
input_file = gr.File(label='input file') | |
z_axis = gr.Slider(0, 200, label='slice', value=50) | |
with gr.Row(): | |
# Show the button with custom label | |
button = gr.Button("Segment Tumor!") | |
with gr.Row(): | |
with gr.Column(): | |
# Show the input image with different MR sequences | |
input_image = gr.Gallery(label='input MRI sequences (T1+, T1, T2, FLAIR)') | |
with gr.Column(): | |
# Show the segmentation labels | |
output_segmentation = gr.Gallery(label='output segmentations (TC, WT, ET)') | |
# Run prediction on button click | |
button.click( | |
predict, | |
inputs=[input_file, z_axis], | |
outputs=[input_image, output_segmentation] | |
) | |
# Have some example for the user to try out | |
examples = gr.Examples( | |
examples=examples, | |
inputs=[input_file, z_axis], | |
outputs=[input_image, output_segmentation], | |
fn=predict, | |
cache_examples=False | |
) | |
# Show references at the bottom of the demo | |
gr.Markdown(references) | |
# Launch the demo | |
demo.launch() | |