MRSegmentator / app.py
DiGuaQiu's picture
Decrease workload on CPUs
243d214 verified
raw
history blame contribute delete
No virus
4.57 kB
import spaces
import tempfile
from pathlib import Path
import SimpleITK as sitk
import torch
from mrsegmentator import inference
from mrsegmentator.utils import add_postfix
import gradio as gr
import utils
description_markdown = """
- **GitHub: https://github.com/hhaentze/mrsegmentator
- **Paper: https://arxiv.org/abs/2405.06463"
- **Please Note:** This tool is intended for research purposes only.
"""
css = """
h1 {
text-align: center;
display:block;
}
.markdown-block {
background-color: #0b0f1a; /* Light gray background */
color: white; /* Black text */
padding: 10px; /* Padding around the text */
border-radius: 5px; /* Rounded corners */
box-shadow: 0 0 10px rgba(11,15,26,1);
display: inline-flex; /* Use inline-flex to shrink to content size */
flex-direction: column;
justify-content: center; /* Vertically center content */
align-items: center; /* Horizontally center items within */
margin: auto; /* Center the block */
}
.markdown-block ul, .markdown-block ol {
background-color: #1e2936;
border-radius: 5px;
padding: 10px;
box-shadow: 0 0 10px rgba(0,0,0,0.3);
padding-left: 20px; /* Adjust padding for bullet alignment */
text-align: left; /* Ensure text within list is left-aligned */
list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
}
footer {
display:none !important
}
"""
examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"]
def save_file(segmentation, path):
"""If the segmentation comes from our sample files directly return the path.
Otherwise save it to the temporary file that was previously allocated by the input image"""
if Path(path).name in examples:
path = "segmentations/" + add_postfix(path, "seg")
else:
sitk.WriteImage(segmentation, path)
return path
@spaces.GPU(duration=150)
def infer(image_path):
with tempfile.TemporaryDirectory() as tmpdirname:
if torch.cuda.is_available():
inference.infer([image_path], tmpdirname, [0, 1, 2, 3, 4], cpu_only=False, split_level=1)
else:
inference.infer([image_path], tmpdirname, [0], cpu_only=True, split_level=1)
filename = add_postfix(Path(image_path).name, "seg")
segmentation = sitk.ReadImage(tmpdirname + "/" + filename)
return segmentation
def infer_wrapper(input_file, image_state, seg_state, slider=50):
filename = Path(input_file).name
# inference
if filename in examples:
segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg"))
else:
segmentation = infer(input_file.name)
# save file
seg_path = save_file(segmentation, input_file.name)
seg_state.append(utils.sitk2numpy(segmentation))
return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path
with gr.Blocks(css=css, title="MRSegmentator") as iface:
gr.Markdown("# Robust Multi-Modality Segmentation of 40 Classes in MRI and CT Imaging")
gr.Markdown(description_markdown, elem_classes="markdown-block")
image_state = gr.State([])
seg_state = gr.State([])
with gr.Row():
with gr.Column():
input_file = gr.File(
type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
)
gr.Examples(["images/" + ex for ex in examples], input_file)
with gr.Row():
submit_button = gr.Button("Run", variant="primary")
clear_button = gr.ClearButton()
slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
download_file = gr.File(label="Download Segmentation", interactive=False)
with gr.Column():
overlay_image_np = gr.AnnotatedImage(label="Axial View")
input_file.change(
utils.read_and_display,
inputs=[input_file, image_state, seg_state],
outputs=[overlay_image_np, image_state, seg_state],
)
slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np])
submit_button.click(
infer_wrapper,
inputs=[input_file, image_state, seg_state, slider],
outputs=[overlay_image_np, seg_state, download_file],
)
clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file])
if __name__ == "__main__":
iface.queue()
iface.launch()