Spaces:
Sleeping
Sleeping
legend1234
commited on
Commit
•
9992ded
1
Parent(s):
95b3113
Refactor to use caching
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
from io import StringIO
|
@@ -5,12 +6,12 @@ from io import StringIO
|
|
5 |
import joblib
|
6 |
import numpy as np
|
7 |
import pandas as pd
|
|
|
8 |
# page set up
|
9 |
import streamlit as st
|
10 |
from b3clf.descriptor_padel import compute_descriptors
|
11 |
from b3clf.geometry_opt import geometry_optimize
|
12 |
-
from b3clf.utils import
|
13 |
-
scale_descriptors, select_descriptors)
|
14 |
# from PIL import Image
|
15 |
from streamlit_extras.let_it_rain import rain
|
16 |
from streamlit_ketcher import st_ketcher
|
@@ -50,6 +51,78 @@ pandas_display_options = {
|
|
50 |
}
|
51 |
mol_features = None
|
52 |
info_df = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
# @st.cache_resource
|
@@ -258,7 +331,7 @@ with prediction_column:
|
|
258 |
|
259 |
# Generate predictions when the user uploads a file
|
260 |
if submit_job_button:
|
261 |
-
if file:
|
262 |
temp_dir = tempfile.mkdtemp()
|
263 |
# Create a temporary file path for the uploaded file
|
264 |
temp_file_path = os.path.join(temp_dir, file.name)
|
@@ -266,59 +339,60 @@ if submit_job_button:
|
|
266 |
with open(temp_file_path, "wb") as temp_file:
|
267 |
temp_file.write(file.read())
|
268 |
# mol_features, results = generate_predictions(temp_file_path)
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
|
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
-
# prediction table
|
295 |
-
with prediction_column:
|
296 |
-
# st.subheader("Predictions")
|
297 |
-
if results is not None:
|
298 |
-
# Display the predictions in a table
|
299 |
-
selected_result_rows = np.min(
|
300 |
-
[results.shape[0], pandas_display_options["line_limit"]]
|
301 |
-
)
|
302 |
-
results_df_display = results.iloc[
|
303 |
-
:selected_result_rows, :
|
304 |
-
].style.format({"B3clf_predicted_probability": "{:.6f}".format})
|
305 |
-
st.dataframe(results_df_display, hide_index=True)
|
306 |
-
# Add a button to download the predictions as a CSV file
|
307 |
-
predictions_csv = results.to_csv(index=True)
|
308 |
-
results_file_name = file.name.split(".")[0] + "_b3clf_predictions.csv"
|
309 |
-
st.download_button(
|
310 |
-
"Download predictions as CSV",
|
311 |
-
data=predictions_csv,
|
312 |
-
file_name=results_file_name,
|
313 |
-
)
|
314 |
-
# indicate the success of the job
|
315 |
-
# rain(
|
316 |
-
# emoji="🎈",
|
317 |
-
# font_size=54,
|
318 |
-
# falling_speed=5,
|
319 |
-
# animation_length=10,
|
320 |
-
# )
|
321 |
-
st.balloons()
|
322 |
|
323 |
# hide footer
|
324 |
# https://github.com/streamlit/streamlit/issues/892
|
|
|
1 |
+
import itertools as it
|
2 |
import os
|
3 |
import tempfile
|
4 |
from io import StringIO
|
|
|
6 |
import joblib
|
7 |
import numpy as np
|
8 |
import pandas as pd
|
9 |
+
import pkg_resources
|
10 |
# page set up
|
11 |
import streamlit as st
|
12 |
from b3clf.descriptor_padel import compute_descriptors
|
13 |
from b3clf.geometry_opt import geometry_optimize
|
14 |
+
from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors
|
|
|
15 |
# from PIL import Image
|
16 |
from streamlit_extras.let_it_rain import rain
|
17 |
from streamlit_ketcher import st_ketcher
|
|
|
51 |
}
|
52 |
mol_features = None
|
53 |
info_df = None
|
54 |
+
results = None
|
55 |
+
temp_file_path = None
|
56 |
+
|
57 |
+
|
58 |
+
@st.cache_data
|
59 |
+
def load_all_models():
|
60 |
+
"""Get b3clf fitted classifier"""
|
61 |
+
clf_list = ["dtree", "knn", "logreg", "xgb"]
|
62 |
+
sampling_list = [
|
63 |
+
"borderline_SMOTE",
|
64 |
+
"classic_ADASYN",
|
65 |
+
"classic_RandUndersampling",
|
66 |
+
"classic_SMOTE",
|
67 |
+
"kmeans_SMOTE",
|
68 |
+
"common",
|
69 |
+
]
|
70 |
+
|
71 |
+
model_dict = {}
|
72 |
+
package_name = "b3clf"
|
73 |
+
|
74 |
+
for clf_str, sampling_str in it.product(clf_list, sampling_list):
|
75 |
+
# joblib_fpath = os.path.join(
|
76 |
+
# dirname, "pre_trained", "b3clf_{}_{}.joblib".format(clf_str, sampling_str))
|
77 |
+
# pred_model = joblib.load(joblib_fpath)
|
78 |
+
joblib_path_str = f"pre_trained/b3clf_{clf_str}_{sampling_str}.joblib"
|
79 |
+
with pkg_resources.resource_stream(package_name, joblib_path_str) as f:
|
80 |
+
pred_model = joblib.load(f)
|
81 |
+
|
82 |
+
model_dict[clf_str + "_" + sampling_str] = pred_model
|
83 |
+
|
84 |
+
return model_dict
|
85 |
+
|
86 |
+
|
87 |
+
@st.cache_resource
|
88 |
+
def predict_permeability(clf_str, sampling_str, mol_features, info_df, threshold="none"):
|
89 |
+
"""Compute permeability prediction for given feature data."""
|
90 |
+
# load the model
|
91 |
+
pred_model = load_all_models()[clf_str + "_" + sampling_str]
|
92 |
+
|
93 |
+
# load the threshold data
|
94 |
+
package_name = "b3clf"
|
95 |
+
with pkg_resources.resource_stream(
|
96 |
+
package_name, "data/B3clf_thresholds.xlsx"
|
97 |
+
) as f:
|
98 |
+
df_thres = pd.read_excel(f, index_col=0, engine="openpyxl")
|
99 |
+
|
100 |
+
# default threshold is 0.5
|
101 |
+
label_pool = np.zeros(mol_features.shape[0], dtype=int)
|
102 |
+
|
103 |
+
if type(mol_features) == pd.DataFrame:
|
104 |
+
if mol_features.index.tolist() != info_df.index.tolist():
|
105 |
+
raise ValueError(
|
106 |
+
"Features_df and Info_df do not have the same index."
|
107 |
+
)
|
108 |
+
|
109 |
+
# get predicted probabilities
|
110 |
+
info_df.loc[:, "B3clf_predicted_probability"] = pred_model.predict_proba(mol_features)[
|
111 |
+
:, 1
|
112 |
+
]
|
113 |
+
# get predicted label from probability using the threshold
|
114 |
+
mask = np.greater_equal(
|
115 |
+
info_df["B3clf_predicted_probability"].to_numpy(),
|
116 |
+
# df_thres.loc[clf_str + "-" + sampling_str, threshold])
|
117 |
+
df_thres.loc["xgb-classic_ADASYN", threshold],
|
118 |
+
)
|
119 |
+
label_pool[mask] = 1
|
120 |
+
|
121 |
+
# save the predicted labels
|
122 |
+
info_df["B3clf_predicted_label"] = label_pool
|
123 |
+
info_df.reset_index(inplace=True)
|
124 |
+
|
125 |
+
return info_df
|
126 |
|
127 |
|
128 |
# @st.cache_resource
|
|
|
331 |
|
332 |
# Generate predictions when the user uploads a file
|
333 |
if submit_job_button:
|
334 |
+
if file and mol_features is None and info_df is None:
|
335 |
temp_dir = tempfile.mkdtemp()
|
336 |
# Create a temporary file path for the uploaded file
|
337 |
temp_file_path = os.path.join(temp_dir, file.name)
|
|
|
339 |
with open(temp_file_path, "wb") as temp_file:
|
340 |
temp_file.write(file.read())
|
341 |
# mol_features, results = generate_predictions(temp_file_path)
|
342 |
+
mol_features, info_df, results = generate_predictions(
|
343 |
+
input_fname=temp_file_path,
|
344 |
+
sep="\s+|\t+",
|
345 |
+
clf=classifiers_dict[classifier],
|
346 |
+
sampling=resample_methods_dict[resampler],
|
347 |
+
time_per_mol=120,
|
348 |
+
mol_features=mol_features,
|
349 |
+
info_df=info_df,
|
350 |
+
)
|
351 |
+
st.balloons()
|
352 |
|
353 |
+
# feture table
|
354 |
+
with feature_column:
|
355 |
+
if mol_features is not None:
|
356 |
+
selected_feature_rows = np.min(
|
357 |
+
[mol_features.shape[0], pandas_display_options["line_limit"]]
|
358 |
+
)
|
359 |
+
st.dataframe(mol_features.iloc[:selected_feature_rows, :], hide_index=False)
|
360 |
+
# placeholder_features.dataframe(mol_features, hide_index=False)
|
361 |
+
feature_file_name = file.name.split(".")[0] + "_b3clf_features.csv"
|
362 |
+
features_csv = mol_features.to_csv(index=True)
|
363 |
+
st.download_button(
|
364 |
+
"Download features as CSV",
|
365 |
+
data=features_csv,
|
366 |
+
file_name=feature_file_name,
|
367 |
+
)
|
368 |
+
# prediction table
|
369 |
+
with prediction_column:
|
370 |
+
# st.subheader("Predictions")
|
371 |
+
if results is not None:
|
372 |
+
# Display the predictions in a table
|
373 |
+
selected_result_rows = np.min(
|
374 |
+
[results.shape[0], pandas_display_options["line_limit"]]
|
375 |
+
)
|
376 |
+
results_df_display = results.iloc[
|
377 |
+
:selected_result_rows, :
|
378 |
+
].style.format({"B3clf_predicted_probability": "{:.6f}".format})
|
379 |
+
st.dataframe(results_df_display, hide_index=True)
|
380 |
+
# Add a button to download the predictions as a CSV file
|
381 |
+
predictions_csv = results.to_csv(index=True)
|
382 |
+
results_file_name = file.name.split(".")[0] + "_b3clf_predictions.csv"
|
383 |
+
st.download_button(
|
384 |
+
"Download predictions as CSV",
|
385 |
+
data=predictions_csv,
|
386 |
+
file_name=results_file_name,
|
387 |
+
)
|
388 |
+
# indicate the success of the job
|
389 |
+
# rain(
|
390 |
+
# emoji="🎈",
|
391 |
+
# font_size=54,
|
392 |
+
# falling_speed=5,
|
393 |
+
# animation_length=10,
|
394 |
+
# )
|
395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
# hide footer
|
398 |
# https://github.com/streamlit/streamlit/issues/892
|