Spaces:
Runtime error
Runtime error
File size: 4,569 Bytes
e1e89c3 737e510 e1e89c3 737e510 243d214 737e510 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import spaces
import tempfile
from pathlib import Path
import SimpleITK as sitk
import torch
from mrsegmentator import inference
from mrsegmentator.utils import add_postfix
import gradio as gr
import utils
description_markdown = """
- **GitHub: https://github.com/hhaentze/mrsegmentator
- **Paper: https://arxiv.org/abs/2405.06463"
- **Please Note:** This tool is intended for research purposes only.
"""
css = """
h1 {
text-align: center;
display:block;
}
.markdown-block {
background-color: #0b0f1a; /* Light gray background */
color: white; /* Black text */
padding: 10px; /* Padding around the text */
border-radius: 5px; /* Rounded corners */
box-shadow: 0 0 10px rgba(11,15,26,1);
display: inline-flex; /* Use inline-flex to shrink to content size */
flex-direction: column;
justify-content: center; /* Vertically center content */
align-items: center; /* Horizontally center items within */
margin: auto; /* Center the block */
}
.markdown-block ul, .markdown-block ol {
background-color: #1e2936;
border-radius: 5px;
padding: 10px;
box-shadow: 0 0 10px rgba(0,0,0,0.3);
padding-left: 20px; /* Adjust padding for bullet alignment */
text-align: left; /* Ensure text within list is left-aligned */
list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
}
footer {
display:none !important
}
"""
examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"]
def save_file(segmentation, path):
"""If the segmentation comes from our sample files directly return the path.
Otherwise save it to the temporary file that was previously allocated by the input image"""
if Path(path).name in examples:
path = "segmentations/" + add_postfix(path, "seg")
else:
sitk.WriteImage(segmentation, path)
return path
@spaces.GPU(duration=150)
def infer(image_path):
with tempfile.TemporaryDirectory() as tmpdirname:
if torch.cuda.is_available():
inference.infer([image_path], tmpdirname, [0, 1, 2, 3, 4], cpu_only=False, split_level=1)
else:
inference.infer([image_path], tmpdirname, [0], cpu_only=True, split_level=1)
filename = add_postfix(Path(image_path).name, "seg")
segmentation = sitk.ReadImage(tmpdirname + "/" + filename)
return segmentation
def infer_wrapper(input_file, image_state, seg_state, slider=50):
filename = Path(input_file).name
# inference
if filename in examples:
segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg"))
else:
segmentation = infer(input_file.name)
# save file
seg_path = save_file(segmentation, input_file.name)
seg_state.append(utils.sitk2numpy(segmentation))
return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path
with gr.Blocks(css=css, title="MRSegmentator") as iface:
gr.Markdown("# Robust Multi-Modality Segmentation of 40 Classes in MRI and CT Imaging")
gr.Markdown(description_markdown, elem_classes="markdown-block")
image_state = gr.State([])
seg_state = gr.State([])
with gr.Row():
with gr.Column():
input_file = gr.File(
type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
)
gr.Examples(["images/" + ex for ex in examples], input_file)
with gr.Row():
submit_button = gr.Button("Run", variant="primary")
clear_button = gr.ClearButton()
slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
download_file = gr.File(label="Download Segmentation", interactive=False)
with gr.Column():
overlay_image_np = gr.AnnotatedImage(label="Axial View")
input_file.change(
utils.read_and_display,
inputs=[input_file, image_state, seg_state],
outputs=[overlay_image_np, image_state, seg_state],
)
slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np])
submit_button.click(
infer_wrapper,
inputs=[input_file, image_state, seg_state, slider],
outputs=[overlay_image_np, seg_state, download_file],
)
clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file])
if __name__ == "__main__":
iface.queue()
iface.launch()
|