Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import SimpleITK as sitk
|
5 |
+
import torch
|
6 |
+
from mrsegmentator import inference
|
7 |
+
from mrsegmentator.utils import add_postfix
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import utils
|
11 |
+
|
12 |
+
|
13 |
+
description_markdown = """
|
14 |
+
- **GitHub: https://github.com/hhaentze/mrsegmentator
|
15 |
+
- **Paper: https://arxiv.org/abs/2405.06463"
|
16 |
+
- **Please Note:** This tool is intended for research purposes only.
|
17 |
+
"""
|
18 |
+
|
19 |
+
|
20 |
+
css = """
|
21 |
+
|
22 |
+
h1 {
|
23 |
+
text-align: center;
|
24 |
+
display:block;
|
25 |
+
}
|
26 |
+
.markdown-block {
|
27 |
+
background-color: #0b0f1a; /* Light gray background */
|
28 |
+
color: white; /* Black text */
|
29 |
+
padding: 10px; /* Padding around the text */
|
30 |
+
border-radius: 5px; /* Rounded corners */
|
31 |
+
box-shadow: 0 0 10px rgba(11,15,26,1);
|
32 |
+
display: inline-flex; /* Use inline-flex to shrink to content size */
|
33 |
+
flex-direction: column;
|
34 |
+
justify-content: center; /* Vertically center content */
|
35 |
+
align-items: center; /* Horizontally center items within */
|
36 |
+
margin: auto; /* Center the block */
|
37 |
+
}
|
38 |
+
|
39 |
+
.markdown-block ul, .markdown-block ol {
|
40 |
+
background-color: #1e2936;
|
41 |
+
border-radius: 5px;
|
42 |
+
padding: 10px;
|
43 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.3);
|
44 |
+
padding-left: 20px; /* Adjust padding for bullet alignment */
|
45 |
+
text-align: left; /* Ensure text within list is left-aligned */
|
46 |
+
list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
|
47 |
+
}
|
48 |
+
|
49 |
+
footer {
|
50 |
+
display:none !important
|
51 |
+
}
|
52 |
+
"""
|
53 |
+
|
54 |
+
examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"]
|
55 |
+
|
56 |
+
|
57 |
+
def save_file(segmentation, path):
|
58 |
+
"""If the segmentation comes from our sample files directly return the path.
|
59 |
+
Otherwise save it to the temporary file that was previously allocated by the input image"""
|
60 |
+
|
61 |
+
if Path(path).name in examples:
|
62 |
+
path = "segmentations/" + add_postfix(path, "seg")
|
63 |
+
else:
|
64 |
+
sitk.WriteImage(segmentation, path)
|
65 |
+
|
66 |
+
return path
|
67 |
+
|
68 |
+
|
69 |
+
def infer(image_path):
|
70 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
71 |
+
|
72 |
+
inference.infer(
|
73 |
+
[image_path], tmpdirname, [0, 1, 2, 3, 4], cpu_only=False if torch.cuda.is_available() else True
|
74 |
+
)
|
75 |
+
filename = add_postfix(Path(image_path).name, "seg")
|
76 |
+
segmentation = sitk.ReadImage(tmpdirname + "/" + filename)
|
77 |
+
|
78 |
+
return segmentation
|
79 |
+
|
80 |
+
|
81 |
+
def infer_wrapper(input_file, image_state, seg_state, slider=50):
|
82 |
+
|
83 |
+
filename = Path(input_file).name
|
84 |
+
|
85 |
+
# inference
|
86 |
+
if filename in examples:
|
87 |
+
segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg"))
|
88 |
+
else:
|
89 |
+
segmentation = infer(input_file.name)
|
90 |
+
|
91 |
+
# save file
|
92 |
+
seg_path = save_file(segmentation, input_file.name)
|
93 |
+
seg_state.append(utils.sitk2numpy(segmentation))
|
94 |
+
|
95 |
+
return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path
|
96 |
+
|
97 |
+
|
98 |
+
with gr.Blocks(css=css, title="MRSegmentator") as iface:
|
99 |
+
|
100 |
+
gr.Markdown("# Robust Multi-Modality Segmentation of 40 Classes in MRI and CT Imaging")
|
101 |
+
gr.Markdown(description_markdown, elem_classes="markdown-block")
|
102 |
+
|
103 |
+
image_state = gr.State([])
|
104 |
+
seg_state = gr.State([])
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column():
|
108 |
+
|
109 |
+
input_file = gr.File(
|
110 |
+
type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
|
111 |
+
)
|
112 |
+
gr.Examples(["images/" + ex for ex in examples], input_file)
|
113 |
+
|
114 |
+
with gr.Row():
|
115 |
+
submit_button = gr.Button("Run", variant="primary")
|
116 |
+
clear_button = gr.ClearButton()
|
117 |
+
|
118 |
+
slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
|
119 |
+
download_file = gr.File(label="Download Segmentation", interactive=False)
|
120 |
+
|
121 |
+
with gr.Column():
|
122 |
+
overlay_image_np = gr.AnnotatedImage(label="Axial View")
|
123 |
+
|
124 |
+
# pred_dict = gr.Label(label="Prediction")
|
125 |
+
# explanation= gr.Textbox(label="Classification Decision")
|
126 |
+
|
127 |
+
# with gr.Accordion("Additional Information", open=False):
|
128 |
+
# gradcam = gr.Image(label="GradCAM")
|
129 |
+
# cropped_boxed_array_disp = gr.Image(label="Bounding Box")
|
130 |
+
|
131 |
+
input_file.change(
|
132 |
+
utils.read_and_display,
|
133 |
+
inputs=[input_file, image_state, seg_state],
|
134 |
+
outputs=[overlay_image_np, image_state, seg_state],
|
135 |
+
)
|
136 |
+
slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np])
|
137 |
+
|
138 |
+
submit_button.click(
|
139 |
+
infer_wrapper,
|
140 |
+
inputs=[input_file, image_state, seg_state, slider],
|
141 |
+
outputs=[overlay_image_np, seg_state, download_file],
|
142 |
+
)
|
143 |
+
|
144 |
+
clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file])
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
iface.queue()
|
149 |
+
# iface.launch(server_name='0.0.0.0', server_port=8080)
|
150 |
+
iface.launch()
|