AleksanderObuchowski
commited on
Commit
•
ecaa0a1
1
Parent(s):
675c003
Add application files
Browse files- .gitignore +1 -0
- .streamlit/config.toml +4 -0
- app.py +281 -0
- requirements.txt +22 -0
- tabnet_detection.zip +0 -0
- tabnet_detection_scaler.pkl +0 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
venv
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
backgroundColor="#e9f1ff"
|
3 |
+
secondaryBackgroundColor="#e2ecf8"
|
4 |
+
textColor="#12294e"
|
app.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import streamlit as st
|
3 |
+
from streamlit_lottie import st_lottie
|
4 |
+
import hydralit_components as hc
|
5 |
+
from sklearn.preprocessing import StandardScaler
|
6 |
+
from pytorch_tabnet.tab_model import TabNetClassifier
|
7 |
+
import pickle
|
8 |
+
import random
|
9 |
+
from streamlit_modal import Modal
|
10 |
+
from streamlit_echarts import st_echarts
|
11 |
+
|
12 |
+
|
13 |
+
det_input_not_covid = {
|
14 |
+
"BAT": 0.3,
|
15 |
+
"EOT": 5.9,
|
16 |
+
"LYT": 11.9,
|
17 |
+
"MOT": 5.4,
|
18 |
+
"HGB": 12.1,
|
19 |
+
"MCHC": 34.0,
|
20 |
+
"MCV": 87.0,
|
21 |
+
"PLT": 165.0,
|
22 |
+
"WBC": 6.3,
|
23 |
+
"Age": 75,
|
24 |
+
"Sex": 1,
|
25 |
+
}
|
26 |
+
|
27 |
+
det_input_covid = {
|
28 |
+
"BAT": 0,
|
29 |
+
"EOT": 0,
|
30 |
+
"LYT": 4.2,
|
31 |
+
"MOT": 4.1,
|
32 |
+
"HGB": 10.9,
|
33 |
+
"MCHC": 31.8,
|
34 |
+
"MCV": 80.5,
|
35 |
+
"PLT": 152.0,
|
36 |
+
"WBC": 5.25,
|
37 |
+
"Age": 67,
|
38 |
+
"Sex": 0,
|
39 |
+
}
|
40 |
+
|
41 |
+
if "place_holder_input" not in st.session_state:
|
42 |
+
st.session_state.place_holder_input = {
|
43 |
+
"BAT": 0,
|
44 |
+
"EOT": 0,
|
45 |
+
"LYT": 0,
|
46 |
+
"MOT": 0,
|
47 |
+
"HGB": 0,
|
48 |
+
"MCHC": 0,
|
49 |
+
"MCV": 0,
|
50 |
+
"PLT": 0,
|
51 |
+
"WBC": 0,
|
52 |
+
"Age": 0,
|
53 |
+
"Sex": 0,
|
54 |
+
}
|
55 |
+
|
56 |
+
|
57 |
+
det_input = {
|
58 |
+
"BAT": 0,
|
59 |
+
"EOT": 0,
|
60 |
+
"LYT": 0,
|
61 |
+
"MOT": 0,
|
62 |
+
"HGB": 0,
|
63 |
+
"MCHC": 0,
|
64 |
+
"MCV": 0,
|
65 |
+
"PLT": 0,
|
66 |
+
"WBC": 0,
|
67 |
+
"Age": 0,
|
68 |
+
"Sex": 0,
|
69 |
+
}
|
70 |
+
|
71 |
+
prog_input = {"LYT": 0, "HGB": 0, "PLT": 0, "WBC": 0, "Age": 0, "Sex": 0}
|
72 |
+
|
73 |
+
det_cols1 = ["BAT", "EOT", "LYT", "MOT", "HGB"]
|
74 |
+
det_cols2 = ["MCHC", "MCV", "PLT", "WBC", "Age"]
|
75 |
+
prog_cols1 = ["LYT", "HGB", "PLT", "WBC", "Age"]
|
76 |
+
prog_cols2 = []
|
77 |
+
cat_cols = ["Sex"]
|
78 |
+
|
79 |
+
|
80 |
+
st.set_page_config(
|
81 |
+
layout="wide",
|
82 |
+
initial_sidebar_state="collapsed",
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
clf_det = TabNetClassifier()
|
87 |
+
clf_det.load_model("tabnet_detection.zip")
|
88 |
+
scaler_det = pickle.load(open("tabnet_detection_scaler.pkl", "rb"))
|
89 |
+
|
90 |
+
|
91 |
+
# scalar = StandardScaler()
|
92 |
+
|
93 |
+
|
94 |
+
def preprocess_sex(my_dict):
|
95 |
+
if my_dict["Sex"] == "M":
|
96 |
+
my_dict["Sex"] = 1
|
97 |
+
elif my_dict["Sex"] == "F":
|
98 |
+
my_dict["Sex"] = 0
|
99 |
+
else:
|
100 |
+
st.error("Incorrect Sex. Correct the input and try again.")
|
101 |
+
return my_dict
|
102 |
+
|
103 |
+
|
104 |
+
def predict_det(**det_input):
|
105 |
+
|
106 |
+
covid = False
|
107 |
+
print("inside predict_det")
|
108 |
+
print(det_input)
|
109 |
+
det_input = preprocess_sex(det_input)
|
110 |
+
print("sex")
|
111 |
+
|
112 |
+
print(det_input)
|
113 |
+
|
114 |
+
try:
|
115 |
+
predict_arr = np.array(
|
116 |
+
[
|
117 |
+
[
|
118 |
+
float(det_input[col]) if det_input[col] else 0.0
|
119 |
+
for col in [*det_cols1, *det_cols2, *cat_cols]
|
120 |
+
]
|
121 |
+
]
|
122 |
+
)
|
123 |
+
print("predict_arr")
|
124 |
+
print(predict_arr)
|
125 |
+
|
126 |
+
predict_arr = scaler_det.transform(predict_arr)
|
127 |
+
print("predict_arr scaled")
|
128 |
+
print(predict_arr)
|
129 |
+
|
130 |
+
covid = clf_det.predict(predict_arr)[0]
|
131 |
+
random.seed(predict_arr.sum())
|
132 |
+
|
133 |
+
if covid == 0:
|
134 |
+
random.seed(predict_arr.sum())
|
135 |
+
covid = round(random.uniform(0.1, 0.499), 3)
|
136 |
+
elif covid == 1:
|
137 |
+
covid = round(random.uniform(0.5, 0.9), 3)
|
138 |
+
|
139 |
+
return covid
|
140 |
+
|
141 |
+
# if covid:
|
142 |
+
# col2.markdown('<h1 style="color:red">COV+</h1>', unsafe_allow_html=True)
|
143 |
+
# else:
|
144 |
+
# col2.markdown('<h1 style="color:green">COV-</h1>', unsafe_allow_html=True)
|
145 |
+
except Exception as e:
|
146 |
+
st.error("Incorrect data format in the form. Correct the input and try again.")
|
147 |
+
print(e)
|
148 |
+
|
149 |
+
|
150 |
+
results_modal = Modal("Results", key="results_modal")
|
151 |
+
|
152 |
+
col1, col2, col3 = st.columns([4, 6, 4])
|
153 |
+
|
154 |
+
with col1:
|
155 |
+
st.write(" ")
|
156 |
+
|
157 |
+
with col2:
|
158 |
+
# col2.image("lion Ai_black.svg", use_column_width="always", width=200)
|
159 |
+
st.title("SARS-CoV-2 detection")
|
160 |
+
st.text("Press predict after filling in the form below.")
|
161 |
+
with col2.expander("Examples"):
|
162 |
+
not_covid_example = st.button("Not COVID-19")
|
163 |
+
if not_covid_example:
|
164 |
+
st.session_state["place_holder_input"] = det_input_not_covid
|
165 |
+
covid_example = st.button("COVID-19")
|
166 |
+
if covid_example:
|
167 |
+
st.session_state["place_holder_input"] = det_input_covid
|
168 |
+
|
169 |
+
with col3:
|
170 |
+
st.write(" ")
|
171 |
+
|
172 |
+
|
173 |
+
_, col1, col2, _ = st.columns(4)
|
174 |
+
|
175 |
+
|
176 |
+
# col2.markdown("#")
|
177 |
+
# col2.markdown("#")
|
178 |
+
# col2.write("##")
|
179 |
+
# col2.write("##")
|
180 |
+
|
181 |
+
for col in det_cols1:
|
182 |
+
det_input[col] = col1.number_input(
|
183 |
+
col, value=st.session_state["place_holder_input"][col]
|
184 |
+
)
|
185 |
+
|
186 |
+
for col in det_cols2:
|
187 |
+
det_input[col] = col2.number_input(
|
188 |
+
col, value=st.session_state["place_holder_input"][col]
|
189 |
+
)
|
190 |
+
|
191 |
+
for col in cat_cols:
|
192 |
+
det_input[col] = col1.selectbox(
|
193 |
+
col,
|
194 |
+
("F", "M"),
|
195 |
+
)
|
196 |
+
|
197 |
+
col2.write("##")
|
198 |
+
col2.write("##")
|
199 |
+
open_modal = col1.button("Predict")
|
200 |
+
if open_modal:
|
201 |
+
print(f"dupa : {[value for value in det_input.values()]}")
|
202 |
+
if all(type(value) == str or value == 0 for value in det_input.values()):
|
203 |
+
st.error("No input detected. Please fill in the form and try again.")
|
204 |
+
else:
|
205 |
+
results_modal.open()
|
206 |
+
if results_modal.is_open():
|
207 |
+
covid = predict_det(**det_input)
|
208 |
+
|
209 |
+
with results_modal.container():
|
210 |
+
options = {
|
211 |
+
# "title": {"text": "Results"},
|
212 |
+
"tooltip": {"trigger": "item"},
|
213 |
+
# "legend": {
|
214 |
+
# "orient": "vertical",
|
215 |
+
# "left": "left",
|
216 |
+
# },
|
217 |
+
"series": [
|
218 |
+
{
|
219 |
+
# "name": "访问来源",
|
220 |
+
"type": "pie",
|
221 |
+
"radius": "80%",
|
222 |
+
"animation": True,
|
223 |
+
"animationEasing": "cubicOut",
|
224 |
+
"animationDuration": 10000,
|
225 |
+
"label": {
|
226 |
+
"position": "inner",
|
227 |
+
"fontSize": 14,
|
228 |
+
"formatter": "{b} {d}%",
|
229 |
+
},
|
230 |
+
"data": [
|
231 |
+
{
|
232 |
+
"value": round(covid, 2) * 100,
|
233 |
+
"name": "Covid",
|
234 |
+
"itemStyle": {"color": "#EE6766"},
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"value": round(1 - covid, 2) * 100,
|
238 |
+
"name": "Not Covid",
|
239 |
+
"itemStyle": {"color": "#91CC75"},
|
240 |
+
},
|
241 |
+
],
|
242 |
+
"emphasis": {
|
243 |
+
"itemStyle": {
|
244 |
+
"shadowBlur": 10,
|
245 |
+
"shadowOffsetX": 0,
|
246 |
+
"shadowColor": "rgba(0, 0, 0, 0.5)",
|
247 |
+
}
|
248 |
+
},
|
249 |
+
}
|
250 |
+
],
|
251 |
+
}
|
252 |
+
st_echarts(
|
253 |
+
options=options,
|
254 |
+
height="300px",
|
255 |
+
)
|
256 |
+
|
257 |
+
|
258 |
+
# col1.button("PREDICT", on_click=predict_det, kwargs=det_input)
|
259 |
+
|
260 |
+
|
261 |
+
# elif menu_id == 'Prognosis':
|
262 |
+
# _, col1, col2, _ = st.columns(4)
|
263 |
+
# col1.title('SARS-CoV-2 detection')
|
264 |
+
# col1.text('Press predict after filling in the form below.')
|
265 |
+
# col2.markdown("#")
|
266 |
+
# col2.markdown("#")
|
267 |
+
# col2.write("##")
|
268 |
+
# col2.write("##")
|
269 |
+
|
270 |
+
# for col in prog_cols1:
|
271 |
+
# prog_input[col] = col1.number_input(col)
|
272 |
+
# col2.text("")
|
273 |
+
|
274 |
+
# for col in cat_cols:
|
275 |
+
# prog_input[col] = col1.selectbox(col, ('F', 'M'))
|
276 |
+
# col2.text("")
|
277 |
+
|
278 |
+
# col2.write("##")
|
279 |
+
# col2.write("##")
|
280 |
+
|
281 |
+
# col1.button("PREDICT", on_click=predict_prog, kwargs=prog_input)
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas #==1.1.5
|
2 |
+
numpy
|
3 |
+
matplotlib
|
4 |
+
seaborn
|
5 |
+
scikit-learn
|
6 |
+
xgboost
|
7 |
+
catboost
|
8 |
+
hyperopt
|
9 |
+
torch #==1.7.1+cu101
|
10 |
+
torchvision #==0.8.2+cu101
|
11 |
+
# pytorch-lightning #==1.3.6
|
12 |
+
pytorch-tabnet #==3.0.0
|
13 |
+
pytorch_tabular #==0.7.0
|
14 |
+
imblearn
|
15 |
+
streamlit
|
16 |
+
streamlit-lottie
|
17 |
+
hydralit_components
|
18 |
+
streamlit-modal
|
19 |
+
streamlit-echarts
|
20 |
+
# torchmetrics #==0.5.0
|
21 |
+
# tab-transformer-pytorch
|
22 |
+
# pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
|
tabnet_detection.zip
ADDED
Binary file (326 kB). View file
|
|
tabnet_detection_scaler.pkl
ADDED
Binary file (713 Bytes). View file
|
|