Spaces:
Sleeping
Sleeping
import itertools as it | |
import os | |
import tempfile | |
from io import StringIO | |
import joblib | |
import numpy as np | |
import pandas as pd | |
import pkg_resources | |
# page set up | |
import streamlit as st | |
from b3clf.descriptor_padel import compute_descriptors | |
from b3clf.geometry_opt import geometry_optimize | |
from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors | |
# from PIL import Image | |
from streamlit_extras.let_it_rain import rain | |
from streamlit_ketcher import st_ketcher | |
from utils import generate_predictions, load_all_models | |
st.cache_data.clear() | |
st.set_page_config( | |
page_title="BBB Permeability Prediction with Imbalanced Learning", | |
# page_icon="🧊", | |
layout="wide", | |
# initial_sidebar_state="expanded", | |
# menu_items={ | |
# "Get Help": "https://www.extremelycoolapp.com/help", | |
# "Report a bug": "https://www.extremelycoolapp.com/bug", | |
# "About": "# This is a header. This is an *extremely* cool app!" | |
# } | |
) | |
keep_features = "no" | |
keep_sdf = "no" | |
classifiers_dict = { | |
"decision tree": "dtree", | |
"kNN": "knn", | |
"logistic regression": "logreg", | |
"XGBoost": "xgb", | |
} | |
resample_methods_dict = { | |
"random undersampling": "classic_RandUndersampling", | |
"SMOTE": "classic_SMOTE", | |
"Borderline SMOTE": "borderline_SMOTE", | |
"k-means SMOTE": "kmeans_SMOTE", | |
"ADASYN": "classic_ADASYN", | |
"no resampling": "common", | |
} | |
pandas_display_options = { | |
"line_limit": 50, | |
} | |
mol_features = None | |
info_df = None | |
results = None | |
temp_file_path = None | |
all_models = load_all_models() | |
# Create the Streamlit app | |
st.title(":blue[BBB Permeability Prediction with Imbalanced Learning]") | |
info_column, upload_column = st.columns(2) | |
# inatialize the molecule features and info dataframe session state | |
if "mol_features" not in st.session_state: | |
st.session_state.mol_features = None | |
if "info_df" not in st.session_state: | |
st.session_state.info_df = None | |
# download sample files | |
with info_column: | |
st.subheader("About `B3clf`") | |
# fmt: off | |
st.markdown( | |
""" | |
`B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. This project is supported by Digital Research Alliance of Canada (originally known as Compute Canada) and NSERC. This project is maintained by QC-Dev comminity. For further information and inquiries please contact us at [email protected].""" | |
) | |
st.text(" \n") | |
# text_body = """ | |
# `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. | |
# """ | |
# st.markdown(f"<p align="justify">{text_body}</p>", | |
# unsafe_allow_html=True) | |
# image = Image.open("images/b3clf_workflow.png") | |
# st.image(image=image, use_column_width=True) | |
# image_path = "images/b3clf_workflow.png" | |
# image_width_percent = 80 | |
# info_column.markdown( | |
# f"<img src="{image_path}" style="max-width: {image_width_percent}%; height: auto;">", | |
# unsafe_allow_html=True | |
# ) | |
# fmt: on | |
sdf_col, smi_col = st.columns(2) | |
with sdf_col: | |
# uneven columns | |
# st.columns((2, 1, 1, 1)) | |
# two subcolumns for sample input files | |
# download sample sdf | |
# st.markdown(" \n \n") | |
with open("sample_input.sdf", "r") as file_sdf: | |
btn = st.download_button( | |
label="Download SDF sample file", | |
data=file_sdf, | |
file_name="sample_input.sdf", | |
) | |
with smi_col: | |
with open("sample_input_smiles.csv", "r") as file_smi: | |
btn = st.download_button( | |
label="Download SMILES sample file", | |
data=file_smi, | |
file_name="sample_input_smiles.csv", | |
) | |
# Create a file uploader | |
with upload_column: | |
st.subheader("Model Selection") | |
with st.container(): | |
algorithm_col, resampler_col = st.columns(2) | |
# algorithm and resampling method selection column | |
with algorithm_col: | |
classifier = st.selectbox( | |
label="Classification Algorithm:", | |
options=("XGBoost", "kNN", "decision tree", "logistic regression"), | |
) | |
with resampler_col: | |
resampler = st.selectbox( | |
label="Resampling Method:", | |
options=( | |
"ADASYN", | |
"random undersampling", | |
"Borderline SMOTE", | |
"k-means SMOTE", | |
"SMOTE", | |
"no resampling", | |
), | |
) | |
# horizontal line | |
st.divider() | |
# upload_col, submit_job_col = st.columns((2, 1)) | |
upload_col, _, submit_job_col, _ = st.columns((4, 0.05, 1, 0.05)) | |
# upload file column | |
with upload_col: | |
# session state tracking of the file uploader | |
if "uploaded_file" not in st.session_state: | |
st.session_state.uploaded_file = None | |
if "uploaded_file_changed" not in st.session_state: | |
st.session_state.uploaded_file_changed = False | |
# def update_uploader_session_info(): | |
# """Update the session state of the file uploader.""" | |
# st.session_state.uploaded_file = uploaded_file | |
uploaded_file = st.file_uploader( | |
label="Upload a CSV, SDF, TXT or SMI file", | |
type=["csv", "sdf", "txt", "smi"], | |
help="Input molecule file only supports *.csv, *.sdf, *.txt and *.smi.", | |
accept_multiple_files=False, | |
# key="uploaded_file", | |
# on_change=update_uploader_session_info, | |
) | |
if uploaded_file: | |
# st.write(f"the uploaded file: {uploaded_file}") | |
# when new file is uploaded is different from the previous one | |
if st.session_state.uploaded_file != uploaded_file: | |
st.session_state.uploaded_file_changed = True | |
else: | |
st.session_state.uploaded_file_changed = False | |
st.session_state.uploaded_file = uploaded_file | |
# when new file is the same as the previous one | |
# else: | |
# st.session_state.uploaded_file_changed = False | |
# st.session_state.uploaded_file = uploaded_file | |
# set session state for the file uploader | |
# st.write(f"the state of uploaded file: {st.session_state.uploaded_file}") | |
# st.write(f"the state of uploaded file changed: {st.session_state.uploaded_file_changed}") | |
# submit job column | |
with submit_job_col: | |
st.text(" \n") | |
st.text(" \n") | |
st.markdown( | |
"<div style='display: flex; justify-content: center;'>", | |
unsafe_allow_html=True, | |
) | |
submit_job_button = st.button( | |
label="Submit Job", type="secondary", key="job_button" | |
) | |
# submit_job_col.markdown("<div style="display: flex; justify-content: center;">", | |
# unsafe_allow_html=True) | |
# submit_job_button = submit_job_col.button( | |
# label="Submit job", key="submit_job_button", type="secondary" | |
# ) | |
# submit_job_col.markdown("</div>", unsafe_allow_html=True) | |
# st.write("The content of the file will be displayed below once uploaded.") | |
# if file: | |
# if "csv" in file.name or "txt" in file.name: | |
# st.write(file.read().decode("utf-8")) | |
# st.write(file) | |
feature_column, prediction_column = st.columns(2) | |
with feature_column: | |
st.subheader("Molecular Features") | |
placeholder_features = st.empty() | |
# placeholder_features = pd.DataFrame(index=[1, 2, 3, 4], | |
# columns=["ID", "nAcid", "ALogP", "Alogp2", | |
# "AMR", "naAromAtom", "nH", "nN"]) | |
# st.dataframe(placeholder_features) | |
# placeholder_features.text("molecular features") | |
with prediction_column: | |
st.subheader("Predictions") | |
# placeholder_predictions = st.empty() | |
# placeholder_predictions.text("prediction") | |
st.write( | |
f"the state of uploaded file changed before checking: {st.session_state.uploaded_file_changed}" | |
) | |
# Generate predictions when the user uploads a file | |
# if submit_job_button: | |
print(st.session_state) | |
if "job_button" in st.session_state: | |
# when new file is uploaded | |
# update_uploader_session_info() | |
st.write( | |
f"the state of uploaded file changed after checking: {st.session_state.uploaded_file_changed}" | |
) | |
# if st.session_state.uploaded_file_changed: | |
# temp_dir = tempfile.mkdtemp() | |
# # Create a temporary file path for the uploaded file | |
# temp_file_path = os.path.join(temp_dir, uploaded_file.name) | |
# # Save the uploaded file to the temporary file path | |
# with open(temp_file_path, "wb") as temp_file: | |
# temp_file.write(uploaded_file.read()) | |
# mol_features, info_df, results = generate_predictions( | |
# input_fname=temp_file_path, | |
# sep="\s+|\t+", | |
# clf=classifiers_dict[classifier], | |
# _models_dict=all_models, | |
# sampling=resample_methods_dict[resampler], | |
# time_per_mol=120, | |
# mol_features=None, | |
# info_df=None, | |
# ) | |
# st.session_state.mol_features = mol_features | |
# st.session_state.info_df = info_df | |
# else: | |
# mol_features, info_df, results = generate_predictions( | |
# input_fname=None, | |
# sep="\s+|\t+", | |
# clf=classifiers_dict[classifier], | |
# _models_dict=all_models, | |
# sampling=resample_methods_dict[resampler], | |
# time_per_mol=120, | |
# mol_features=st.session_state.mol_features, | |
# info_df=st.session_state.info_df, | |
# ) | |
temp_dir = tempfile.mkdtemp() | |
# Create a temporary file path for the uploaded file | |
temp_file_path = os.path.join(temp_dir, uploaded_file.name) | |
# Save the uploaded file to the temporary file path | |
with open(temp_file_path, "wb") as temp_file: | |
temp_file.write(uploaded_file.read()) | |
mol_features, info_df, results = generate_predictions( | |
input_fname=temp_file_path, | |
sep="\s+|\t+", | |
clf=classifiers_dict[classifier], | |
_models_dict=all_models, | |
sampling=resample_methods_dict[resampler], | |
time_per_mol=120, | |
mol_features=None, | |
info_df=None, | |
) | |
# feture table | |
with feature_column: | |
if mol_features is not None: | |
selected_feature_rows = np.min( | |
[mol_features.shape[0], pandas_display_options["line_limit"]] | |
) | |
st.dataframe(mol_features.iloc[:selected_feature_rows, :], hide_index=False) | |
# placeholder_features.dataframe(mol_features, hide_index=False) | |
feature_file_name = uploaded_file.name.split(".")[0] + "_b3clf_features.csv" | |
features_csv = mol_features.to_csv(index=True) | |
st.download_button( | |
"Download features as CSV", | |
data=features_csv, | |
file_name=feature_file_name, | |
) | |
# prediction table | |
with prediction_column: | |
# st.subheader("Predictions") | |
if results is not None: | |
# Display the predictions in a table | |
selected_result_rows = np.min( | |
[results.shape[0], pandas_display_options["line_limit"]] | |
) | |
results_df_display = results.iloc[:selected_result_rows, :].style.format( | |
{"B3clf_predicted_probability": "{:.6f}".format} | |
) | |
st.dataframe(results_df_display, hide_index=True) | |
# Add a button to download the predictions as a CSV file | |
predictions_csv = results.to_csv(index=True) | |
results_file_name = ( | |
uploaded_file.name.split(".")[0] + "_b3clf_predictions.csv" | |
) | |
st.download_button( | |
"Download predictions as CSV", | |
data=predictions_csv, | |
file_name=results_file_name, | |
) | |
# indicate the success of the job | |
# rain( | |
# emoji="🎈", | |
# font_size=54, | |
# falling_speed=5, | |
# animation_length=10, | |
# ) | |
st.balloons() | |
# hide footer | |
# https://github.com/streamlit/streamlit/issues/892 | |
hide_streamlit_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(hide_streamlit_style, unsafe_allow_html=True) | |
# add google analytics | |
st.markdown( | |
""" | |
<!-- Google tag (gtag.js) --> | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-WG8QYRELP9"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag(){dataLayer.push(arguments);} | |
gtag("js", new Date()); | |
gtag("config", "G-WG8QYRELP9"); | |
</script> | |
""", | |
unsafe_allow_html=True, | |
) | |