Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,27 +3,15 @@
|
|
3 |
import tempfile
|
4 |
from pathlib import Path
|
5 |
|
6 |
-
import nibabel as nib
|
7 |
-
import numpy as np
|
8 |
-
from PIL import ImageDraw
|
9 |
-
from streamlit_drawable_canvas import st_canvas
|
10 |
-
from streamlit_image_coordinates import streamlit_image_coordinates
|
11 |
-
import nibabel as nib
|
12 |
import SimpleITK as sitk
|
|
|
|
|
13 |
import streamlit as st
|
14 |
import utils
|
15 |
-
from utils import (
|
16 |
-
initial_rectangle,
|
17 |
-
make_fig,
|
18 |
-
reflect_box_into_model,
|
19 |
-
reflect_json_data_to_3D_box,
|
20 |
-
run,
|
21 |
-
)
|
22 |
-
|
23 |
-
# from viewer import BasicViewer
|
24 |
|
25 |
print("script run")
|
26 |
st.title("MRSegmentator")
|
|
|
27 |
|
28 |
#############################################
|
29 |
# init session_state
|
@@ -50,8 +38,9 @@ if "transparency" not in st.session_state:
|
|
50 |
st.session_state.transparency = 0.25
|
51 |
|
52 |
case_list = [
|
53 |
-
"
|
54 |
-
"
|
|
|
55 |
]
|
56 |
|
57 |
#############################################
|
@@ -66,8 +55,11 @@ def clear_prompts():
|
|
66 |
def reset_demo_case():
|
67 |
st.session_state.data_item = None
|
68 |
st.session_state.reset_demo_case = True
|
|
|
|
|
69 |
clear_prompts()
|
70 |
|
|
|
71 |
def clear_file():
|
72 |
st.session_state.option = None
|
73 |
reset_demo_case()
|
@@ -85,26 +77,33 @@ with arxive_col:
|
|
85 |
st.write("Paper: https://arxiv.org/abs/2405.06463")
|
86 |
|
87 |
# modify demo case here
|
88 |
-
demo_type = st.radio("Demo case source", ["Select", "Upload"], on_change=clear_file)
|
89 |
|
90 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
91 |
|
92 |
# modify demo case here
|
93 |
-
if demo_type == "Select":
|
94 |
-
|
95 |
"Select a demo case",
|
96 |
case_list,
|
97 |
index=None,
|
98 |
placeholder="Select a demo case...",
|
99 |
on_change=reset_demo_case,
|
100 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
else:
|
102 |
-
uploaded_file = st.file_uploader(
|
103 |
-
"Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case
|
104 |
-
)
|
105 |
|
106 |
-
if
|
107 |
-
with open(tmpdirname + "/" + uploaded_file.name,
|
108 |
f.write(uploaded_file.getvalue())
|
109 |
uploaded_file = tmpdirname + "/" + uploaded_file.name
|
110 |
|
@@ -117,39 +116,46 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|
117 |
):
|
118 |
|
119 |
st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file))
|
120 |
-
st.session_state.
|
|
|
121 |
st.session_state.reset_demo_case = False
|
122 |
-
st.session_state.preds_3D = None
|
123 |
-
st.session_state.preds_path = None
|
124 |
-
|
125 |
|
126 |
if st.session_state.option is None:
|
127 |
st.write("please select demo case first")
|
128 |
else:
|
129 |
image_3D = st.session_state.data_item
|
130 |
-
px_range = st.slider(
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
135 |
col_control1, col_control2 = st.columns(2)
|
136 |
|
137 |
with col_control1:
|
138 |
selected_index_z = st.slider(
|
139 |
-
"Axial view",
|
|
|
|
|
|
|
|
|
|
|
140 |
)
|
141 |
|
142 |
with col_control2:
|
143 |
selected_index_y = st.slider(
|
144 |
-
"Coronal view",
|
|
|
|
|
|
|
|
|
|
|
145 |
)
|
146 |
|
147 |
col_image1, col_image2 = st.columns(2)
|
148 |
|
149 |
if st.session_state.preds_3D is not None:
|
150 |
-
st.session_state.transparency = st.slider(
|
151 |
-
"Mask opacity", 0.0, 1.0, 0.5, disabled=st.session_state.running
|
152 |
-
)
|
153 |
|
154 |
with col_image1:
|
155 |
|
@@ -159,7 +165,7 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|
159 |
if st.session_state.preds_3D is not None:
|
160 |
preds_z_array = st.session_state.preds_3D[selected_index_z]
|
161 |
|
162 |
-
image_z = make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency)
|
163 |
st.image(image_z, use_column_width=False)
|
164 |
|
165 |
with col_image2:
|
@@ -169,7 +175,7 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|
169 |
if st.session_state.preds_3D is not None:
|
170 |
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
|
171 |
|
172 |
-
image_y = make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency)
|
173 |
st.image(image_y, use_column_width=False)
|
174 |
|
175 |
######################################################
|
@@ -177,6 +183,9 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|
177 |
col1, col2, col3 = st.columns(3)
|
178 |
|
179 |
with col1:
|
|
|
|
|
|
|
180 |
if st.button(
|
181 |
"Clear",
|
182 |
use_container_width=True,
|
@@ -188,19 +197,21 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|
188 |
st.rerun()
|
189 |
|
190 |
with col2:
|
|
|
|
|
|
|
191 |
|
192 |
if st.session_state.preds_3D is not None and st.session_state.data_item is not None:
|
193 |
|
194 |
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
|
195 |
|
196 |
preds = st.session_state.preds_3D_ori
|
197 |
-
#result_image.CopyInformation(inputImage)
|
198 |
sitk.WriteImage(preds, tmpfile.name)
|
199 |
-
|
200 |
with open(tmpfile.name, "rb") as f:
|
201 |
bytes_data = f.read()
|
202 |
st.download_button(
|
203 |
-
label="Download result(.nii.gz)",
|
204 |
data=bytes_data,
|
205 |
file_name="segmentation.nii.gz",
|
206 |
mime="application/octet-stream",
|
@@ -208,12 +219,25 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|
208 |
)
|
209 |
|
210 |
with col3:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
run_button_name = "Run" if not st.session_state.running else "Running"
|
212 |
if st.button(
|
213 |
run_button_name,
|
214 |
type="primary",
|
215 |
use_container_width=True,
|
216 |
-
disabled=
|
|
|
217 |
):
|
218 |
st.session_state.running = True
|
219 |
st.rerun()
|
@@ -221,5 +245,5 @@ with tempfile.TemporaryDirectory() as tmpdirname:
|
|
221 |
if st.session_state.running:
|
222 |
st.session_state.running = False
|
223 |
with st.status("Running...", expanded=False) as status:
|
224 |
-
run(tmpdirname)
|
225 |
st.rerun()
|
|
|
3 |
import tempfile
|
4 |
from pathlib import Path
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import SimpleITK as sitk
|
7 |
+
from mrsegmentator.utils import add_postfix
|
8 |
+
|
9 |
import streamlit as st
|
10 |
import utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
print("script run")
|
13 |
st.title("MRSegmentator")
|
14 |
+
st.write("(On-site segmentation is currently disabled, because we lack access to GPUs)")
|
15 |
|
16 |
#############################################
|
17 |
# init session_state
|
|
|
38 |
st.session_state.transparency = 0.25
|
39 |
|
40 |
case_list = [
|
41 |
+
"amos_0517_MRI.nii.gz",
|
42 |
+
"amos_0541_MRI.nii.gz",
|
43 |
+
"amos_0571_MRI.nii.gz",
|
44 |
]
|
45 |
|
46 |
#############################################
|
|
|
55 |
def reset_demo_case():
|
56 |
st.session_state.data_item = None
|
57 |
st.session_state.reset_demo_case = True
|
58 |
+
st.session_state.preds_3D = None
|
59 |
+
st.session_state.preds_3D_ori = None
|
60 |
clear_prompts()
|
61 |
|
62 |
+
|
63 |
def clear_file():
|
64 |
st.session_state.option = None
|
65 |
reset_demo_case()
|
|
|
77 |
st.write("Paper: https://arxiv.org/abs/2405.06463")
|
78 |
|
79 |
# modify demo case here
|
80 |
+
demo_type = st.radio("Demo case source", ["Select (presegmented)", "Upload"], on_change=clear_file)
|
81 |
|
82 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
83 |
|
84 |
# modify demo case here
|
85 |
+
if demo_type == "Select (presegmented)":
|
86 |
+
selection = st.selectbox(
|
87 |
"Select a demo case",
|
88 |
case_list,
|
89 |
index=None,
|
90 |
placeholder="Select a demo case...",
|
91 |
on_change=reset_demo_case,
|
92 |
)
|
93 |
+
|
94 |
+
if selection:
|
95 |
+
uploaded_file = "images/" + selection
|
96 |
+
seg_path = Path(__file__).parent / ("segmentations/" + add_postfix(selection, "seg"))
|
97 |
+
st.session_state.preds_3D = utils.read_image(seg_path)
|
98 |
+
st.session_state.preds_3D_ori = sitk.ReadImage(seg_path)
|
99 |
+
else:
|
100 |
+
uploaded_file = None
|
101 |
+
|
102 |
else:
|
103 |
+
uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case)
|
|
|
|
|
104 |
|
105 |
+
if uploaded_file is not None:
|
106 |
+
with open(tmpdirname + "/" + uploaded_file.name, "wb") as f:
|
107 |
f.write(uploaded_file.getvalue())
|
108 |
uploaded_file = tmpdirname + "/" + uploaded_file.name
|
109 |
|
|
|
116 |
):
|
117 |
|
118 |
st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file))
|
119 |
+
# st.session_state.preds_3D = None
|
120 |
+
# st.session_state.preds_3D_ori = None
|
121 |
st.session_state.reset_demo_case = False
|
|
|
|
|
|
|
122 |
|
123 |
if st.session_state.option is None:
|
124 |
st.write("please select demo case first")
|
125 |
else:
|
126 |
image_3D = st.session_state.data_item
|
127 |
+
px_range = st.slider(
|
128 |
+
"Select intensity range",
|
129 |
+
int(image_3D.min()),
|
130 |
+
int(image_3D.max()),
|
131 |
+
(int(image_3D.min()), int(image_3D.max())),
|
132 |
+
)
|
133 |
col_control1, col_control2 = st.columns(2)
|
134 |
|
135 |
with col_control1:
|
136 |
selected_index_z = st.slider(
|
137 |
+
"Axial view",
|
138 |
+
0,
|
139 |
+
image_3D.shape[0] - 1,
|
140 |
+
image_3D.shape[0] // 2,
|
141 |
+
key="xy",
|
142 |
+
disabled=st.session_state.running,
|
143 |
)
|
144 |
|
145 |
with col_control2:
|
146 |
selected_index_y = st.slider(
|
147 |
+
"Coronal view",
|
148 |
+
0,
|
149 |
+
image_3D.shape[1] - 1,
|
150 |
+
image_3D.shape[1] // 2,
|
151 |
+
key="xz",
|
152 |
+
disabled=st.session_state.running,
|
153 |
)
|
154 |
|
155 |
col_image1, col_image2 = st.columns(2)
|
156 |
|
157 |
if st.session_state.preds_3D is not None:
|
158 |
+
st.session_state.transparency = st.slider("Mask opacity", 0.0, 1.0, 0.35, disabled=st.session_state.running)
|
|
|
|
|
159 |
|
160 |
with col_image1:
|
161 |
|
|
|
165 |
if st.session_state.preds_3D is not None:
|
166 |
preds_z_array = st.session_state.preds_3D[selected_index_z]
|
167 |
|
168 |
+
image_z = utils.make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency)
|
169 |
st.image(image_z, use_column_width=False)
|
170 |
|
171 |
with col_image2:
|
|
|
175 |
if st.session_state.preds_3D is not None:
|
176 |
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
|
177 |
|
178 |
+
image_y = utils.make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency)
|
179 |
st.image(image_y, use_column_width=False)
|
180 |
|
181 |
######################################################
|
|
|
183 |
col1, col2, col3 = st.columns(3)
|
184 |
|
185 |
with col1:
|
186 |
+
st.markdown("#")
|
187 |
+
st.markdown("####")
|
188 |
+
st.markdown("####")
|
189 |
if st.button(
|
190 |
"Clear",
|
191 |
use_container_width=True,
|
|
|
197 |
st.rerun()
|
198 |
|
199 |
with col2:
|
200 |
+
st.markdown("#")
|
201 |
+
st.markdown("####")
|
202 |
+
st.markdown("####")
|
203 |
|
204 |
if st.session_state.preds_3D is not None and st.session_state.data_item is not None:
|
205 |
|
206 |
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
|
207 |
|
208 |
preds = st.session_state.preds_3D_ori
|
|
|
209 |
sitk.WriteImage(preds, tmpfile.name)
|
210 |
+
|
211 |
with open(tmpfile.name, "rb") as f:
|
212 |
bytes_data = f.read()
|
213 |
st.download_button(
|
214 |
+
label="Download result (.nii.gz)",
|
215 |
data=bytes_data,
|
216 |
file_name="segmentation.nii.gz",
|
217 |
mime="application/octet-stream",
|
|
|
219 |
)
|
220 |
|
221 |
with col3:
|
222 |
+
folds = st.radio("", ["Model of Fold 1 (fast)", "Ensemble Segmentation"])
|
223 |
+
if folds == "Model of Fold 1":
|
224 |
+
st.session_state.folds = (0,)
|
225 |
+
else:
|
226 |
+
st.session_state.folds = (
|
227 |
+
0,
|
228 |
+
1,
|
229 |
+
2,
|
230 |
+
3,
|
231 |
+
4,
|
232 |
+
)
|
233 |
+
|
234 |
run_button_name = "Run" if not st.session_state.running else "Running"
|
235 |
if st.button(
|
236 |
run_button_name,
|
237 |
type="primary",
|
238 |
use_container_width=True,
|
239 |
+
disabled=True,
|
240 |
+
# disabled=(st.session_state.data_item is None or st.session_state.running),
|
241 |
):
|
242 |
st.session_state.running = True
|
243 |
st.rerun()
|
|
|
245 |
if st.session_state.running:
|
246 |
st.session_state.running = False
|
247 |
with st.status("Running...", expanded=False) as status:
|
248 |
+
utils.run(tmpdirname)
|
249 |
st.rerun()
|