DiGuaQiu commited on
Commit
737e510
1 Parent(s): 7c8a431

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from pathlib import Path
3
+
4
+ import SimpleITK as sitk
5
+ import torch
6
+ from mrsegmentator import inference
7
+ from mrsegmentator.utils import add_postfix
8
+
9
+ import gradio as gr
10
+ import utils
11
+
12
+
13
+ description_markdown = """
14
+ - **GitHub: https://github.com/hhaentze/mrsegmentator
15
+ - **Paper: https://arxiv.org/abs/2405.06463"
16
+ - **Please Note:** This tool is intended for research purposes only.
17
+ """
18
+
19
+
20
+ css = """
21
+
22
+ h1 {
23
+ text-align: center;
24
+ display:block;
25
+ }
26
+ .markdown-block {
27
+ background-color: #0b0f1a; /* Light gray background */
28
+ color: white; /* Black text */
29
+ padding: 10px; /* Padding around the text */
30
+ border-radius: 5px; /* Rounded corners */
31
+ box-shadow: 0 0 10px rgba(11,15,26,1);
32
+ display: inline-flex; /* Use inline-flex to shrink to content size */
33
+ flex-direction: column;
34
+ justify-content: center; /* Vertically center content */
35
+ align-items: center; /* Horizontally center items within */
36
+ margin: auto; /* Center the block */
37
+ }
38
+
39
+ .markdown-block ul, .markdown-block ol {
40
+ background-color: #1e2936;
41
+ border-radius: 5px;
42
+ padding: 10px;
43
+ box-shadow: 0 0 10px rgba(0,0,0,0.3);
44
+ padding-left: 20px; /* Adjust padding for bullet alignment */
45
+ text-align: left; /* Ensure text within list is left-aligned */
46
+ list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
47
+ }
48
+
49
+ footer {
50
+ display:none !important
51
+ }
52
+ """
53
+
54
+ examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"]
55
+
56
+
57
+ def save_file(segmentation, path):
58
+ """If the segmentation comes from our sample files directly return the path.
59
+ Otherwise save it to the temporary file that was previously allocated by the input image"""
60
+
61
+ if Path(path).name in examples:
62
+ path = "segmentations/" + add_postfix(path, "seg")
63
+ else:
64
+ sitk.WriteImage(segmentation, path)
65
+
66
+ return path
67
+
68
+
69
+ def infer(image_path):
70
+ with tempfile.TemporaryDirectory() as tmpdirname:
71
+
72
+ inference.infer(
73
+ [image_path], tmpdirname, [0, 1, 2, 3, 4], cpu_only=False if torch.cuda.is_available() else True
74
+ )
75
+ filename = add_postfix(Path(image_path).name, "seg")
76
+ segmentation = sitk.ReadImage(tmpdirname + "/" + filename)
77
+
78
+ return segmentation
79
+
80
+
81
+ def infer_wrapper(input_file, image_state, seg_state, slider=50):
82
+
83
+ filename = Path(input_file).name
84
+
85
+ # inference
86
+ if filename in examples:
87
+ segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg"))
88
+ else:
89
+ segmentation = infer(input_file.name)
90
+
91
+ # save file
92
+ seg_path = save_file(segmentation, input_file.name)
93
+ seg_state.append(utils.sitk2numpy(segmentation))
94
+
95
+ return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path
96
+
97
+
98
+ with gr.Blocks(css=css, title="MRSegmentator") as iface:
99
+
100
+ gr.Markdown("# Robust Multi-Modality Segmentation of 40 Classes in MRI and CT Imaging")
101
+ gr.Markdown(description_markdown, elem_classes="markdown-block")
102
+
103
+ image_state = gr.State([])
104
+ seg_state = gr.State([])
105
+
106
+ with gr.Row():
107
+ with gr.Column():
108
+
109
+ input_file = gr.File(
110
+ type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
111
+ )
112
+ gr.Examples(["images/" + ex for ex in examples], input_file)
113
+
114
+ with gr.Row():
115
+ submit_button = gr.Button("Run", variant="primary")
116
+ clear_button = gr.ClearButton()
117
+
118
+ slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
119
+ download_file = gr.File(label="Download Segmentation", interactive=False)
120
+
121
+ with gr.Column():
122
+ overlay_image_np = gr.AnnotatedImage(label="Axial View")
123
+
124
+ # pred_dict = gr.Label(label="Prediction")
125
+ # explanation= gr.Textbox(label="Classification Decision")
126
+
127
+ # with gr.Accordion("Additional Information", open=False):
128
+ # gradcam = gr.Image(label="GradCAM")
129
+ # cropped_boxed_array_disp = gr.Image(label="Bounding Box")
130
+
131
+ input_file.change(
132
+ utils.read_and_display,
133
+ inputs=[input_file, image_state, seg_state],
134
+ outputs=[overlay_image_np, image_state, seg_state],
135
+ )
136
+ slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np])
137
+
138
+ submit_button.click(
139
+ infer_wrapper,
140
+ inputs=[input_file, image_state, seg_state, slider],
141
+ outputs=[overlay_image_np, seg_state, download_file],
142
+ )
143
+
144
+ clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file])
145
+
146
+
147
+ if __name__ == "__main__":
148
+ iface.queue()
149
+ # iface.launch(server_name='0.0.0.0', server_port=8080)
150
+ iface.launch()