Spaces:
Sleeping
Sleeping
legend1234
commited on
Commit
•
3d6fbe8
1
Parent(s):
cf4c3c3
Add utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools as it
|
2 |
+
import os
|
3 |
+
|
4 |
+
import joblib
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import pkg_resources
|
8 |
+
import streamlit as st
|
9 |
+
from b3clf.descriptor_padel import compute_descriptors
|
10 |
+
from b3clf.geometry_opt import geometry_optimize
|
11 |
+
from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors
|
12 |
+
|
13 |
+
|
14 |
+
@st.cache_resource()
|
15 |
+
def load_all_models():
|
16 |
+
"""Get b3clf fitted classifier"""
|
17 |
+
clf_list = ["dtree", "knn", "logreg", "xgb"]
|
18 |
+
sampling_list = [
|
19 |
+
"borderline_SMOTE",
|
20 |
+
"classic_ADASYN",
|
21 |
+
"classic_RandUndersampling",
|
22 |
+
"classic_SMOTE",
|
23 |
+
"kmeans_SMOTE",
|
24 |
+
"common",
|
25 |
+
]
|
26 |
+
|
27 |
+
model_dict = {}
|
28 |
+
package_name = "b3clf"
|
29 |
+
|
30 |
+
for clf_str, sampling_str in it.product(clf_list, sampling_list):
|
31 |
+
# joblib_fpath = os.path.join(
|
32 |
+
# dirname, "pre_trained", "b3clf_{}_{}.joblib".format(clf_str, sampling_str))
|
33 |
+
# pred_model = joblib.load(joblib_fpath)
|
34 |
+
joblib_path_str = f"pre_trained/b3clf_{clf_str}_{sampling_str}.joblib"
|
35 |
+
with pkg_resources.resource_stream(package_name, joblib_path_str) as f:
|
36 |
+
pred_model = joblib.load(f)
|
37 |
+
|
38 |
+
model_dict[clf_str + "_" + sampling_str] = pred_model
|
39 |
+
|
40 |
+
return model_dict
|
41 |
+
|
42 |
+
|
43 |
+
@st.cache_resource
|
44 |
+
def predict_permeability(
|
45 |
+
clf_str, sampling_str, _models_dict, mol_features, info_df, threshold="none"
|
46 |
+
):
|
47 |
+
"""Compute permeability prediction for given feature data."""
|
48 |
+
# load the model
|
49 |
+
# pred_model = load_all_models()[clf_str + "_" + sampling_str]
|
50 |
+
pred_model = _models_dict[clf_str + "_" + sampling_str]
|
51 |
+
|
52 |
+
# load the threshold data
|
53 |
+
package_name = "b3clf"
|
54 |
+
with pkg_resources.resource_stream(package_name, "data/B3clf_thresholds.xlsx") as f:
|
55 |
+
df_thres = pd.read_excel(f, index_col=0, engine="openpyxl")
|
56 |
+
|
57 |
+
# default threshold is 0.5
|
58 |
+
label_pool = np.zeros(mol_features.shape[0], dtype=int)
|
59 |
+
|
60 |
+
if type(mol_features) == pd.DataFrame:
|
61 |
+
if mol_features.index.tolist() != info_df.index.tolist():
|
62 |
+
raise ValueError("Features_df and Info_df do not have the same index.")
|
63 |
+
|
64 |
+
# get predicted probabilities
|
65 |
+
info_df.loc[:, "B3clf_predicted_probability"] = pred_model.predict_proba(
|
66 |
+
mol_features
|
67 |
+
)[:, 1]
|
68 |
+
# get predicted label from probability using the threshold
|
69 |
+
mask = np.greater_equal(
|
70 |
+
info_df["B3clf_predicted_probability"].to_numpy(),
|
71 |
+
# df_thres.loc[clf_str + "-" + sampling_str, threshold])
|
72 |
+
df_thres.loc["xgb-classic_ADASYN", threshold],
|
73 |
+
)
|
74 |
+
label_pool[mask] = 1
|
75 |
+
|
76 |
+
# save the predicted labels
|
77 |
+
info_df["B3clf_predicted_label"] = label_pool
|
78 |
+
info_df.reset_index(inplace=True)
|
79 |
+
|
80 |
+
return info_df
|
81 |
+
|
82 |
+
|
83 |
+
@st.cache_resource
|
84 |
+
def generate_predictions(
|
85 |
+
input_fname: str = None,
|
86 |
+
sep: str = "\s+|\t+",
|
87 |
+
clf: str = "xgb",
|
88 |
+
_models_dict: dict = None,
|
89 |
+
keep_sdf: str = "no",
|
90 |
+
sampling: str = "classic_ADASYN",
|
91 |
+
time_per_mol: int = 120,
|
92 |
+
mol_features: pd.DataFrame = None,
|
93 |
+
info_df: pd.DataFrame = None,
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
Generate predictions for a given input file.
|
97 |
+
"""
|
98 |
+
if mol_features is None and info_df is None:
|
99 |
+
# mol_tag = os.path.splitext(uploaded_file.name)[0]
|
100 |
+
# uploaded_file = uploaded_file.read().decode("utf-8")
|
101 |
+
mol_tag = os.path.basename(input_fname).split(".")[0]
|
102 |
+
internal_sdf = f"{mol_tag}_optimized_3d.sdf"
|
103 |
+
|
104 |
+
# Geometry optimization
|
105 |
+
# Input:
|
106 |
+
# * Either an SDF file with molecular geometries or a text file with SMILES strings
|
107 |
+
|
108 |
+
geometry_optimize(input_fname=input_fname, output_sdf=internal_sdf, sep=sep)
|
109 |
+
|
110 |
+
df_features = compute_descriptors(
|
111 |
+
sdf_file=internal_sdf,
|
112 |
+
excel_out=None,
|
113 |
+
output_csv=None,
|
114 |
+
timeout=None,
|
115 |
+
time_per_molecule=time_per_mol,
|
116 |
+
)
|
117 |
+
|
118 |
+
# Get computed descriptors
|
119 |
+
mol_features, info_df = get_descriptors(df=df_features)
|
120 |
+
|
121 |
+
# Select descriptors
|
122 |
+
mol_features = select_descriptors(df=mol_features)
|
123 |
+
|
124 |
+
# Scale descriptors
|
125 |
+
mol_features.iloc[:, :] = scale_descriptors(df=mol_features)
|
126 |
+
|
127 |
+
# this is problematic for using the same file for calculation
|
128 |
+
if os.path.exists(internal_sdf) and keep_sdf == "no":
|
129 |
+
os.remove(internal_sdf)
|
130 |
+
|
131 |
+
# Get classifier
|
132 |
+
# clf = get_clf(clf_str=clf, sampling_str=sampling)
|
133 |
+
# Get classifier
|
134 |
+
result_df = predict_permeability(
|
135 |
+
clf_str=clf,
|
136 |
+
sampling_str=sampling,
|
137 |
+
_models_dict=_models_dict,
|
138 |
+
mol_features=mol_features,
|
139 |
+
info_df=info_df,
|
140 |
+
threshold="none",
|
141 |
+
)
|
142 |
+
|
143 |
+
# Get classifier
|
144 |
+
display_cols = [
|
145 |
+
"ID",
|
146 |
+
"SMILES",
|
147 |
+
"B3clf_predicted_probability",
|
148 |
+
"B3clf_predicted_label",
|
149 |
+
]
|
150 |
+
|
151 |
+
result_df = result_df[
|
152 |
+
[col for col in result_df.columns.to_list() if col in display_cols]
|
153 |
+
]
|
154 |
+
|
155 |
+
return mol_features, info_df, result_df
|