Spaces:
Sleeping
Sleeping
import itertools as it | |
import os | |
import tempfile | |
from io import StringIO | |
import joblib | |
import numpy as np | |
import pandas as pd | |
import pkg_resources | |
# page set up | |
import streamlit as st | |
from b3clf.descriptor_padel import compute_descriptors | |
from b3clf.geometry_opt import geometry_optimize | |
from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors | |
# from PIL import Image | |
from streamlit_extras.let_it_rain import rain | |
from streamlit_ketcher import st_ketcher | |
st.set_page_config( | |
page_title="BBB Permeability Prediction with Imbalanced Learning", | |
# page_icon="🧊", | |
layout="wide", | |
# initial_sidebar_state="expanded", | |
# menu_items={ | |
# 'Get Help': 'https://www.extremelycoolapp.com/help', | |
# 'Report a bug': "https://www.extremelycoolapp.com/bug", | |
# 'About': "# This is a header. This is an *extremely* cool app!" | |
# } | |
) | |
keep_features = "no" | |
keep_sdf = "no" | |
classifiers_dict = { | |
"decision tree": "dtree", | |
"kNN": "knn", | |
"logistic regression": "logreg", | |
"XGBoost": "xgb", | |
} | |
resample_methods_dict = { | |
"random undersampling": "classic_RandUndersampling", | |
"SMOTE": "classic_SMOTE", | |
"Borderline SMOTE": "borderline_SMOTE", | |
"k-means SMOTE": "kmeans_SMOTE", | |
"ADASYN": "classic_ADASYN", | |
"no resampling": "common", | |
} | |
pandas_display_options = { | |
"line_limit": 50, | |
} | |
mol_features = None | |
info_df = None | |
results = None | |
temp_file_path = None | |
def load_all_models(): | |
"""Get b3clf fitted classifier""" | |
clf_list = ["dtree", "knn", "logreg", "xgb"] | |
sampling_list = [ | |
"borderline_SMOTE", | |
"classic_ADASYN", | |
"classic_RandUndersampling", | |
"classic_SMOTE", | |
"kmeans_SMOTE", | |
"common", | |
] | |
model_dict = {} | |
package_name = "b3clf" | |
for clf_str, sampling_str in it.product(clf_list, sampling_list): | |
# joblib_fpath = os.path.join( | |
# dirname, "pre_trained", "b3clf_{}_{}.joblib".format(clf_str, sampling_str)) | |
# pred_model = joblib.load(joblib_fpath) | |
joblib_path_str = f"pre_trained/b3clf_{clf_str}_{sampling_str}.joblib" | |
with pkg_resources.resource_stream(package_name, joblib_path_str) as f: | |
pred_model = joblib.load(f) | |
model_dict[clf_str + "_" + sampling_str] = pred_model | |
return model_dict | |
def predict_permeability(clf_str, sampling_str, mol_features, info_df, threshold="none"): | |
"""Compute permeability prediction for given feature data.""" | |
# load the model | |
pred_model = load_all_models()[clf_str + "_" + sampling_str] | |
# load the threshold data | |
package_name = "b3clf" | |
with pkg_resources.resource_stream( | |
package_name, "data/B3clf_thresholds.xlsx" | |
) as f: | |
df_thres = pd.read_excel(f, index_col=0, engine="openpyxl") | |
# default threshold is 0.5 | |
label_pool = np.zeros(mol_features.shape[0], dtype=int) | |
if type(mol_features) == pd.DataFrame: | |
if mol_features.index.tolist() != info_df.index.tolist(): | |
raise ValueError( | |
"Features_df and Info_df do not have the same index." | |
) | |
# get predicted probabilities | |
info_df.loc[:, "B3clf_predicted_probability"] = pred_model.predict_proba(mol_features)[ | |
:, 1 | |
] | |
# get predicted label from probability using the threshold | |
mask = np.greater_equal( | |
info_df["B3clf_predicted_probability"].to_numpy(), | |
# df_thres.loc[clf_str + "-" + sampling_str, threshold]) | |
df_thres.loc["xgb-classic_ADASYN", threshold], | |
) | |
label_pool[mask] = 1 | |
# save the predicted labels | |
info_df["B3clf_predicted_label"] = label_pool | |
info_df.reset_index(inplace=True) | |
return info_df | |
# @st.cache_resource | |
def generate_predictions( | |
input_fname: str = None, | |
sep: str = "\s+|\t+", | |
clf: str = "xgb", | |
sampling: str = "classic_ADASYN", | |
time_per_mol: int = 120, | |
mol_features: pd.DataFrame = None, | |
info_df: pd.DataFrame = None, | |
): | |
""" | |
Generate predictions for a given input file. | |
""" | |
if mol_features is None and info_df is None: | |
# mol_tag = os.path.splitext(uploaded_file.name)[0] | |
# uploaded_file = uploaded_file.read().decode("utf-8") | |
mol_tag = os.path.basename(input_fname).split(".")[0] | |
internal_sdf = f"{mol_tag}_optimized_3d.sdf" | |
# Geometry optimization | |
# Input: | |
# * Either an SDF file with molecular geometries or a text file with SMILES strings | |
geometry_optimize(input_fname=input_fname, output_sdf=internal_sdf, sep=sep) | |
df_features = compute_descriptors( | |
sdf_file=internal_sdf, | |
excel_out=None, | |
output_csv=None, | |
timeout=None, | |
time_per_molecule=time_per_mol, | |
) | |
# st.write(df_features) | |
# Get computed descriptors | |
mol_features, info_df = get_descriptors(df=df_features) | |
# Select descriptors | |
mol_features = select_descriptors(df=mol_features) | |
# Scale descriptors | |
mol_features.iloc[:, :] = scale_descriptors(df=mol_features) | |
# this is problematic for using the same file for calculation | |
if os.path.exists(internal_sdf) and keep_sdf == "no": | |
os.remove(internal_sdf) | |
# Get classifier | |
# clf = get_clf(clf_str=clf, sampling_str=sampling) | |
# Get classifier | |
result_df = predict_permeability( | |
clf_str=clf, | |
sampling_str=sampling, | |
mol_features=mol_features, | |
info_df=info_df, | |
threshold="none", | |
) | |
# Get classifier | |
display_cols = [ | |
"ID", | |
"SMILES", | |
"B3clf_predicted_probability", | |
"B3clf_predicted_label", | |
] | |
result_df = result_df[ | |
[col for col in result_df.columns.to_list() if col in display_cols] | |
] | |
return mol_features, info_df, result_df | |
# Create the Streamlit app | |
st.title(":blue[BBB Permeability Prediction with Imbalanced Learning]") | |
info_column, upload_column = st.columns(2) | |
# download sample files | |
with info_column: | |
st.subheader("About `B3clf`") | |
# fmt: off | |
st.markdown( | |
""" | |
`B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. This project is supported by Digital Research Alliance of Canada (originally known as Compute Canada) and NSERC. This project is maintained by QC-Dev comminity. For further information and inquiries please contact us at [email protected].""" | |
) | |
st.text(" \n") | |
# text_body = ''' | |
# `B3clf` is a Python package for predicting the blood-brain barrier (BBB) permeability of small molecules using imbalanced learning. It supports decision tree, XGBoost, kNN, logistical regression and 5 resampling strategies (SMOTE, Borderline SMOTE, k-means SMOTE and ADASYN). The workflow of `B3clf` is summarized as below. The Source code and more details are available at https://github.com/theochem/B3clf. | |
# ''' | |
# st.markdown(f'<p align="justify">{text_body}</p>', | |
# unsafe_allow_html=True) | |
# image = Image.open("images/b3clf_workflow.png") | |
# st.image(image=image, use_column_width=True) | |
# image_path = "images/b3clf_workflow.png" | |
# image_width_percent = 80 | |
# info_column.markdown( | |
# f'<img src="{image_path}" style="max-width: {image_width_percent}%; height: auto;">', | |
# unsafe_allow_html=True | |
# ) | |
# fmt: on | |
sdf_col, smi_col = st.columns(2) | |
with sdf_col: | |
# uneven columns | |
# st.columns((2, 1, 1, 1)) | |
# two subcolumns for sample input files | |
# download sample sdf | |
# st.markdown(" \n \n") | |
with open("sample_input.sdf", "r") as file_sdf: | |
btn = st.download_button( | |
label="Download SDF sample file", | |
data=file_sdf, | |
file_name="sample_input.sdf", | |
) | |
with smi_col: | |
with open("sample_input_smiles.csv", "r") as file_smi: | |
btn = st.download_button( | |
label="Download SMILES sample file", | |
data=file_smi, | |
file_name="sample_input_smiles.csv", | |
) | |
# Create a file uploader | |
with upload_column: | |
st.subheader("Model Selection") | |
with st.container(): | |
algorithm_col, resampler_col = st.columns(2) | |
# algorithm and resampling method selection column | |
with algorithm_col: | |
classifier = st.selectbox( | |
label="Classification Algorithm:", | |
options=("XGBoost", "kNN", "decision tree", "logistic regression"), | |
) | |
with resampler_col: | |
resampler = st.selectbox( | |
label="Resampling Method:", | |
options=( | |
"ADASYN", | |
"random undersampling", | |
"Borderline SMOTE", | |
"k-means SMOTE", | |
"SMOTE", | |
"no resampling", | |
), | |
) | |
# horizontal line | |
st.divider() | |
# upload_col, submit_job_col = st.columns((2, 1)) | |
upload_col, _, submit_job_col, _ = st.columns((4, 0.05, 1, 0.05)) | |
# upload file column | |
with upload_col: | |
file = st.file_uploader( | |
label="Upload a CSV, SDF, TXT or SMI file", | |
type=["csv", "sdf", "txt", "smi"], | |
help="Input molecule file only supports *.csv, *.sdf, *.txt and *.smi.", | |
accept_multiple_files=False, | |
) | |
# submit job column | |
with submit_job_col: | |
st.text(" \n") | |
st.text(" \n") | |
st.markdown( | |
"<div style='display: flex; justify-content: center;'>", | |
unsafe_allow_html=True, | |
) | |
submit_job_button = st.button( | |
label="Submit Job", key="submit_job_button", type="secondary" | |
) | |
# submit_job_col.markdown("<div style='display: flex; justify-content: center;'>", | |
# unsafe_allow_html=True) | |
# submit_job_button = submit_job_col.button( | |
# label="Submit job", key="submit_job_button", type="secondary" | |
# ) | |
# submit_job_col.markdown("</div>", unsafe_allow_html=True) | |
# st.write("The content of the file will be displayed below once uploaded.") | |
# if file: | |
# if "csv" in file.name or "txt" in file.name: | |
# st.write(file.read().decode("utf-8")) | |
# st.write(file) | |
feature_column, prediction_column = st.columns(2) | |
with feature_column: | |
st.subheader("Molecular Features") | |
placeholder_features = st.empty() | |
# placeholder_features = pd.DataFrame(index=[1, 2, 3, 4], | |
# columns=["ID", "nAcid", "ALogP", "Alogp2", | |
# "AMR", "naAromAtom", "nH", "nN"]) | |
# st.dataframe(placeholder_features) | |
# placeholder_features.text("molecular features") | |
with prediction_column: | |
st.subheader("Predictions") | |
# placeholder_predictions = st.empty() | |
# placeholder_predictions.text("prediction") | |
# Generate predictions when the user uploads a file | |
if submit_job_button: | |
if file and mol_features is None and info_df is None: | |
temp_dir = tempfile.mkdtemp() | |
# Create a temporary file path for the uploaded file | |
temp_file_path = os.path.join(temp_dir, file.name) | |
# Save the uploaded file to the temporary file path | |
with open(temp_file_path, "wb") as temp_file: | |
temp_file.write(file.read()) | |
# mol_features, results = generate_predictions(temp_file_path) | |
mol_features, info_df, results = generate_predictions( | |
input_fname=temp_file_path, | |
sep="\s+|\t+", | |
clf=classifiers_dict[classifier], | |
sampling=resample_methods_dict[resampler], | |
time_per_mol=120, | |
mol_features=mol_features, | |
info_df=info_df, | |
) | |
st.balloons() | |
# feture table | |
with feature_column: | |
if mol_features is not None: | |
selected_feature_rows = np.min( | |
[mol_features.shape[0], pandas_display_options["line_limit"]] | |
) | |
st.dataframe(mol_features.iloc[:selected_feature_rows, :], hide_index=False) | |
# placeholder_features.dataframe(mol_features, hide_index=False) | |
feature_file_name = file.name.split(".")[0] + "_b3clf_features.csv" | |
features_csv = mol_features.to_csv(index=True) | |
st.download_button( | |
"Download features as CSV", | |
data=features_csv, | |
file_name=feature_file_name, | |
) | |
# prediction table | |
with prediction_column: | |
# st.subheader("Predictions") | |
if results is not None: | |
# Display the predictions in a table | |
selected_result_rows = np.min( | |
[results.shape[0], pandas_display_options["line_limit"]] | |
) | |
results_df_display = results.iloc[ | |
:selected_result_rows, : | |
].style.format({"B3clf_predicted_probability": "{:.6f}".format}) | |
st.dataframe(results_df_display, hide_index=True) | |
# Add a button to download the predictions as a CSV file | |
predictions_csv = results.to_csv(index=True) | |
results_file_name = file.name.split(".")[0] + "_b3clf_predictions.csv" | |
st.download_button( | |
"Download predictions as CSV", | |
data=predictions_csv, | |
file_name=results_file_name, | |
) | |
# indicate the success of the job | |
# rain( | |
# emoji="🎈", | |
# font_size=54, | |
# falling_speed=5, | |
# animation_length=10, | |
# ) | |
# hide footer | |
# https://github.com/streamlit/streamlit/issues/892 | |
hide_streamlit_style = """ | |
<style> | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(hide_streamlit_style, unsafe_allow_html=True) | |
# add google analytics | |
st.markdown( | |
""" | |
<!-- Google tag (gtag.js) --> | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-WG8QYRELP9"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag(){dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-WG8QYRELP9'); | |
</script> | |
""", | |
unsafe_allow_html=True, | |
) | |