b3clf_hf / app.py
legend1234's picture
Use relative path for joblib files
438a6f5
raw
history blame
4.9 kB
import os
import tempfile
from io import StringIO
import joblib
import numpy as np
import pandas as pd
import streamlit as st
from b3clf.descriptor_padel import compute_descriptors
from b3clf.geometry_opt import geometry_optimize
from b3clf.utils import (get_descriptors, predict_permeability,
scale_descriptors, select_descriptors)
from streamlit_ketcher import st_ketcher
# from geometry_opt import geometry_optimize
# Load the pre-trained model and feature scaler
model = joblib.load(
"pre_trained/b3clf_knn_kmeans_SMOTE.joblib"
)
scaler = joblib.load(
"pre_trained/b3clf_scaler.joblib"
)
# Define a function to generate predictions
# def generate_predictions(file):
# # Read the input file
# if file.type == "text/csv":
# df = pd.read_csv(file)
# elif file.type == "chemical/x-mdl-sdfile":
# df = pd.read_sdf(file)
# else:
# st.error("Invalid file type. Please upload a CSV or SDF file.")
# return
# # Compute the molecular geometry, calculate the features, and perform the predictions
# X = df.drop("ID", axis=1)
# X_scaled = scaler.transform(X)
# y_pred_proba = model.predict_proba(X_scaled)[:, 1]
# y_pred = model.predict(X_scaled)
# # Create a DataFrame with the predictions
# results = pd.DataFrame({"ID": df["ID"], "B3clf_predicted_probability": y_pred_proba, "B3clf_predicted_label": y_pred})
# return results
keep_features = "no"
keep_sdf = "no"
def generate_predictions(
uploaded_file: st.file_uploader,
sep: str = "\s+|\t+",
clf: str = "xgb",
sampling: str = "classic_ADASYN",
time_per_mol: int = 120,
):
"""
Generate predictions for a given input file.
"""
# mol_tag = os.path.splitext(uploaded_file.name)[0]
# uploaded_file = uploaded_file.read().decode("utf-8")
mol_tag = os.path.basename(uploaded_file).split(".")[0]
internal_sdf = f"{mol_tag}_optimized_3d.sdf"
# Geometry optimization
# Input:
# * Either an SDF file with molecular geometries or a text file with SMILES strings
geometry_optimize(input_fname=uploaded_file, output_sdf=internal_sdf, sep=sep)
df_features = compute_descriptors(
sdf_file=internal_sdf,
excel_out=None,
output_csv=None,
timeout=None,
time_per_molecule=time_per_mol,
)
# st.write(df_features)
# Get computed descriptors
X_features, info_df = get_descriptors(df=df_features)
# Select descriptors
X_features = select_descriptors(df=X_features)
# Scale descriptors
X_features = scale_descriptors(df=X_features)
# Get classifier
# clf = get_clf(clf_str=clf, sampling_str=sampling)
# Get classifier
result_df = predict_permeability(
clf_str=clf,
sampling_str=sampling,
features_df=X_features,
info_df=info_df,
threshold="none",
)
# Get classifier
display_cols = [
"ID",
"SMILES",
"B3clf_predicted_probability",
"B3clf_predicted_label",
]
result_df = result_df[
[col for col in result_df.columns.to_list() if col in display_cols]
]
os.remove(internal_sdf)
return X_features, result_df
# Create the Streamlit app
st.title("BBB Permeability Prediction with Imbalanced Learning")
# Create a file uploader
st.subheader("Input Data")
file = st.file_uploader("Upload a CSV or SDF file",
type=["csv", "sdf", "txt"],
# accept_multiple_files=False,
)
# st.write("The content of the file will be displayed below once uploaded.")
if file:
# if "csv" in file.name or "txt" in file.name:
# st.write(file.read().decode("utf-8"))
st.write(file)
# Generate predictions when the user uploads a file
if file:
temp_dir = tempfile.mkdtemp()
# Create a temporary file path for the uploaded file
temp_file_path = os.path.join(temp_dir, file.name)
# Save the uploaded file to the temporary file path
with open(temp_file_path, 'wb') as temp_file:
temp_file.write(file.read())
X_features, results = generate_predictions(temp_file_path)
feature_column, prediction_column = st.columns(2)
# feture table
with feature_column:
st.subheader("Features")
st.dataframe(X_features)
# prediction table
with prediction_column:
st.subheader("Predictions")
if results is not None:
# Display the predictions in a table
st.dataframe(results)
# Add a button to download the predictions as a CSV file
predictions_csv = results.to_csv(index=False)
results_file_name = file.name.split(".")[0] + "_b3clf_predictions.csv"
st.download_button(
"Download predictions as CSV", data=predictions_csv, file_name=results_file_name
)