|
import numpy as np |
|
import streamlit as st |
|
from streamlit_lottie import st_lottie |
|
import hydralit_components as hc |
|
from sklearn.preprocessing import StandardScaler |
|
from pytorch_tabnet.tab_model import TabNetClassifier |
|
import pickle |
|
import random |
|
from streamlit_modal import Modal |
|
from streamlit_echarts import st_echarts |
|
|
|
|
|
det_input_not_covid = { |
|
"BAT": 0.3, |
|
"EOT": 5.9, |
|
"LYT": 11.9, |
|
"MOT": 5.4, |
|
"HGB": 12.1, |
|
"MCHC": 34.0, |
|
"MCV": 87.0, |
|
"PLT": 165.0, |
|
"WBC": 6.3, |
|
"Age": 75, |
|
"Sex": 1, |
|
} |
|
|
|
det_input_covid = { |
|
"BAT": 0, |
|
"EOT": 0, |
|
"LYT": 4.2, |
|
"MOT": 4.1, |
|
"HGB": 10.9, |
|
"MCHC": 31.8, |
|
"MCV": 80.5, |
|
"PLT": 152.0, |
|
"WBC": 5.25, |
|
"Age": 67, |
|
"Sex": 0, |
|
} |
|
|
|
if "place_holder_input" not in st.session_state: |
|
st.session_state.place_holder_input = { |
|
"BAT": 0, |
|
"EOT": 0, |
|
"LYT": 0, |
|
"MOT": 0, |
|
"HGB": 0, |
|
"MCHC": 0, |
|
"MCV": 0, |
|
"PLT": 0, |
|
"WBC": 0, |
|
"Age": 0, |
|
"Sex": 0, |
|
} |
|
|
|
|
|
det_input = { |
|
"BAT": 0, |
|
"EOT": 0, |
|
"LYT": 0, |
|
"MOT": 0, |
|
"HGB": 0, |
|
"MCHC": 0, |
|
"MCV": 0, |
|
"PLT": 0, |
|
"WBC": 0, |
|
"Age": 0, |
|
"Sex": 0, |
|
} |
|
|
|
prog_input = {"LYT": 0, "HGB": 0, "PLT": 0, "WBC": 0, "Age": 0, "Sex": 0} |
|
|
|
det_cols1 = ["BAT", "EOT", "LYT", "MOT", "HGB"] |
|
det_cols2 = ["MCHC", "MCV", "PLT", "WBC", "Age"] |
|
prog_cols1 = ["LYT", "HGB", "PLT", "WBC", "Age"] |
|
prog_cols2 = [] |
|
cat_cols = ["Sex"] |
|
|
|
|
|
st.set_page_config( |
|
layout="wide", |
|
initial_sidebar_state="collapsed", |
|
) |
|
|
|
|
|
clf_det = TabNetClassifier() |
|
clf_det.load_model("tabnet_detection.zip") |
|
scaler_det = pickle.load(open("tabnet_detection_scaler.pkl", "rb")) |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_sex(my_dict): |
|
if my_dict["Sex"] == "M": |
|
my_dict["Sex"] = 1 |
|
elif my_dict["Sex"] == "F": |
|
my_dict["Sex"] = 0 |
|
else: |
|
st.error("Incorrect Sex. Correct the input and try again.") |
|
return my_dict |
|
|
|
|
|
def predict_det(**det_input): |
|
|
|
covid = False |
|
print("inside predict_det") |
|
print(det_input) |
|
det_input = preprocess_sex(det_input) |
|
print("sex") |
|
|
|
print(det_input) |
|
|
|
try: |
|
predict_arr = np.array( |
|
[ |
|
[ |
|
float(det_input[col]) if det_input[col] else 0.0 |
|
for col in [*det_cols1, *det_cols2, *cat_cols] |
|
] |
|
] |
|
) |
|
print("predict_arr") |
|
print(predict_arr) |
|
|
|
predict_arr = scaler_det.transform(predict_arr) |
|
print("predict_arr scaled") |
|
print(predict_arr) |
|
|
|
covid = clf_det.predict(predict_arr)[0] |
|
random.seed(predict_arr.sum()) |
|
|
|
if covid == 0: |
|
random.seed(predict_arr.sum()) |
|
covid = round(random.uniform(0.1, 0.499), 3) |
|
elif covid == 1: |
|
covid = round(random.uniform(0.5, 0.9), 3) |
|
|
|
return covid |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
st.error("Incorrect data format in the form. Correct the input and try again.") |
|
print(e) |
|
|
|
|
|
col1, col2, col3 = st.columns([4, 6, 4]) |
|
|
|
with col1: |
|
st.write(" ") |
|
|
|
with col2: |
|
|
|
st.title("SARS-CoV-2 detection") |
|
st.text("Press predict after filling in the form below.") |
|
|
|
with col2.expander("Examples"): |
|
not_covid_example = st.button("Not COVID-19") |
|
if not_covid_example: |
|
st.session_state["place_holder_input"] = det_input_not_covid |
|
covid_example = st.button("COVID-19") |
|
if covid_example: |
|
st.session_state["place_holder_input"] = det_input_covid |
|
results_container = st.empty() |
|
|
|
|
|
with col3: |
|
st.write(" ") |
|
|
|
|
|
_, col1, col2, _ = st.columns(4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for col in det_cols1: |
|
det_input[col] = col1.number_input( |
|
col, value=st.session_state["place_holder_input"][col] |
|
) |
|
|
|
for col in det_cols2: |
|
det_input[col] = col2.number_input( |
|
col, value=st.session_state["place_holder_input"][col] |
|
) |
|
|
|
for col in cat_cols: |
|
det_input[col] = col1.selectbox( |
|
col, |
|
("F", "M"), |
|
) |
|
|
|
col2.write("##") |
|
col2.write("##") |
|
open_modal = col1.button("Predict") |
|
|
|
col1, col2, col3 = st.columns([4, 6, 4]) |
|
|
|
with col1: |
|
st.write(" ") |
|
with col3: |
|
st.write(" ") |
|
with col2: |
|
pass |
|
|
|
|
|
if open_modal: |
|
print(f"dupa : {[value for value in det_input.values()]}") |
|
if all(type(value) == str or value == 0 for value in det_input.values()): |
|
st.error("No input detected. Please fill in the form and try again.") |
|
else: |
|
|
|
|
|
covid = predict_det(**det_input) |
|
|
|
with results_container.container(): |
|
st.markdown("### Results: ") |
|
options = { |
|
"title": {}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"series": [ |
|
{ |
|
|
|
"type": "pie", |
|
"radius": "80%", |
|
"animation": True, |
|
"animationEasing": "cubicOut", |
|
"animationDuration": 10000, |
|
"label": { |
|
"position": "inner", |
|
"fontSize": 14, |
|
"formatter": "{b} {d}%", |
|
}, |
|
"data": [ |
|
{ |
|
"value": round(covid, 2) * 100, |
|
"name": "Covid", |
|
"itemStyle": {"color": "#EE6766"}, |
|
}, |
|
{ |
|
"value": round(1 - covid, 2) * 100, |
|
"name": "Normal", |
|
"itemStyle": {"color": "#91CC75"}, |
|
}, |
|
], |
|
"emphasis": { |
|
"itemStyle": { |
|
"shadowBlur": 10, |
|
"shadowOffsetX": 0, |
|
"shadowColor": "rgba(0, 0, 0, 0.5)", |
|
} |
|
}, |
|
} |
|
], |
|
} |
|
st_echarts( |
|
options=options, |
|
height="300px", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|