legend1234 commited on
Commit
9992ded
1 Parent(s): 95b3113

Refactor to use caching

Browse files
Files changed (1) hide show
  1. app.py +128 -54
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import tempfile
3
  from io import StringIO
@@ -5,12 +6,12 @@ from io import StringIO
5
  import joblib
6
  import numpy as np
7
  import pandas as pd
 
8
  # page set up
9
  import streamlit as st
10
  from b3clf.descriptor_padel import compute_descriptors
11
  from b3clf.geometry_opt import geometry_optimize
12
- from b3clf.utils import (get_descriptors, predict_permeability,
13
- scale_descriptors, select_descriptors)
14
  # from PIL import Image
15
  from streamlit_extras.let_it_rain import rain
16
  from streamlit_ketcher import st_ketcher
@@ -50,6 +51,78 @@ pandas_display_options = {
50
  }
51
  mol_features = None
52
  info_df = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
  # @st.cache_resource
@@ -258,7 +331,7 @@ with prediction_column:
258
 
259
  # Generate predictions when the user uploads a file
260
  if submit_job_button:
261
- if file:
262
  temp_dir = tempfile.mkdtemp()
263
  # Create a temporary file path for the uploaded file
264
  temp_file_path = os.path.join(temp_dir, file.name)
@@ -266,59 +339,60 @@ if submit_job_button:
266
  with open(temp_file_path, "wb") as temp_file:
267
  temp_file.write(file.read())
268
  # mol_features, results = generate_predictions(temp_file_path)
269
- mol_features, info_df, results = generate_predictions(
270
- input_fname=temp_file_path,
271
- sep="\s+|\t+",
272
- clf=classifiers_dict[classifier],
273
- sampling=resample_methods_dict[resampler],
274
- time_per_mol=120,
275
- mol_features=mol_features,
276
- info_df=info_df,
277
- )
 
278
 
279
- # feture table
280
- with feature_column:
281
- selected_feature_rows = np.min(
282
- [mol_features.shape[0], pandas_display_options["line_limit"]]
283
- )
284
- st.dataframe(mol_features.iloc[:selected_feature_rows, :], hide_index=False)
285
- # placeholder_features.dataframe(mol_features, hide_index=False)
286
- feature_file_name = file.name.split(".")[0] + "_b3clf_features.csv"
287
- features_csv = mol_features.to_csv(index=True)
288
- st.download_button(
289
- "Download features as CSV",
290
- data=features_csv,
291
- file_name=feature_file_name,
292
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
- # prediction table
295
- with prediction_column:
296
- # st.subheader("Predictions")
297
- if results is not None:
298
- # Display the predictions in a table
299
- selected_result_rows = np.min(
300
- [results.shape[0], pandas_display_options["line_limit"]]
301
- )
302
- results_df_display = results.iloc[
303
- :selected_result_rows, :
304
- ].style.format({"B3clf_predicted_probability": "{:.6f}".format})
305
- st.dataframe(results_df_display, hide_index=True)
306
- # Add a button to download the predictions as a CSV file
307
- predictions_csv = results.to_csv(index=True)
308
- results_file_name = file.name.split(".")[0] + "_b3clf_predictions.csv"
309
- st.download_button(
310
- "Download predictions as CSV",
311
- data=predictions_csv,
312
- file_name=results_file_name,
313
- )
314
- # indicate the success of the job
315
- # rain(
316
- # emoji="🎈",
317
- # font_size=54,
318
- # falling_speed=5,
319
- # animation_length=10,
320
- # )
321
- st.balloons()
322
 
323
  # hide footer
324
  # https://github.com/streamlit/streamlit/issues/892
 
1
+ import itertools as it
2
  import os
3
  import tempfile
4
  from io import StringIO
 
6
  import joblib
7
  import numpy as np
8
  import pandas as pd
9
+ import pkg_resources
10
  # page set up
11
  import streamlit as st
12
  from b3clf.descriptor_padel import compute_descriptors
13
  from b3clf.geometry_opt import geometry_optimize
14
+ from b3clf.utils import get_descriptors, scale_descriptors, select_descriptors
 
15
  # from PIL import Image
16
  from streamlit_extras.let_it_rain import rain
17
  from streamlit_ketcher import st_ketcher
 
51
  }
52
  mol_features = None
53
  info_df = None
54
+ results = None
55
+ temp_file_path = None
56
+
57
+
58
+ @st.cache_data
59
+ def load_all_models():
60
+ """Get b3clf fitted classifier"""
61
+ clf_list = ["dtree", "knn", "logreg", "xgb"]
62
+ sampling_list = [
63
+ "borderline_SMOTE",
64
+ "classic_ADASYN",
65
+ "classic_RandUndersampling",
66
+ "classic_SMOTE",
67
+ "kmeans_SMOTE",
68
+ "common",
69
+ ]
70
+
71
+ model_dict = {}
72
+ package_name = "b3clf"
73
+
74
+ for clf_str, sampling_str in it.product(clf_list, sampling_list):
75
+ # joblib_fpath = os.path.join(
76
+ # dirname, "pre_trained", "b3clf_{}_{}.joblib".format(clf_str, sampling_str))
77
+ # pred_model = joblib.load(joblib_fpath)
78
+ joblib_path_str = f"pre_trained/b3clf_{clf_str}_{sampling_str}.joblib"
79
+ with pkg_resources.resource_stream(package_name, joblib_path_str) as f:
80
+ pred_model = joblib.load(f)
81
+
82
+ model_dict[clf_str + "_" + sampling_str] = pred_model
83
+
84
+ return model_dict
85
+
86
+
87
+ @st.cache_resource
88
+ def predict_permeability(clf_str, sampling_str, mol_features, info_df, threshold="none"):
89
+ """Compute permeability prediction for given feature data."""
90
+ # load the model
91
+ pred_model = load_all_models()[clf_str + "_" + sampling_str]
92
+
93
+ # load the threshold data
94
+ package_name = "b3clf"
95
+ with pkg_resources.resource_stream(
96
+ package_name, "data/B3clf_thresholds.xlsx"
97
+ ) as f:
98
+ df_thres = pd.read_excel(f, index_col=0, engine="openpyxl")
99
+
100
+ # default threshold is 0.5
101
+ label_pool = np.zeros(mol_features.shape[0], dtype=int)
102
+
103
+ if type(mol_features) == pd.DataFrame:
104
+ if mol_features.index.tolist() != info_df.index.tolist():
105
+ raise ValueError(
106
+ "Features_df and Info_df do not have the same index."
107
+ )
108
+
109
+ # get predicted probabilities
110
+ info_df.loc[:, "B3clf_predicted_probability"] = pred_model.predict_proba(mol_features)[
111
+ :, 1
112
+ ]
113
+ # get predicted label from probability using the threshold
114
+ mask = np.greater_equal(
115
+ info_df["B3clf_predicted_probability"].to_numpy(),
116
+ # df_thres.loc[clf_str + "-" + sampling_str, threshold])
117
+ df_thres.loc["xgb-classic_ADASYN", threshold],
118
+ )
119
+ label_pool[mask] = 1
120
+
121
+ # save the predicted labels
122
+ info_df["B3clf_predicted_label"] = label_pool
123
+ info_df.reset_index(inplace=True)
124
+
125
+ return info_df
126
 
127
 
128
  # @st.cache_resource
 
331
 
332
  # Generate predictions when the user uploads a file
333
  if submit_job_button:
334
+ if file and mol_features is None and info_df is None:
335
  temp_dir = tempfile.mkdtemp()
336
  # Create a temporary file path for the uploaded file
337
  temp_file_path = os.path.join(temp_dir, file.name)
 
339
  with open(temp_file_path, "wb") as temp_file:
340
  temp_file.write(file.read())
341
  # mol_features, results = generate_predictions(temp_file_path)
342
+ mol_features, info_df, results = generate_predictions(
343
+ input_fname=temp_file_path,
344
+ sep="\s+|\t+",
345
+ clf=classifiers_dict[classifier],
346
+ sampling=resample_methods_dict[resampler],
347
+ time_per_mol=120,
348
+ mol_features=mol_features,
349
+ info_df=info_df,
350
+ )
351
+ st.balloons()
352
 
353
+ # feture table
354
+ with feature_column:
355
+ if mol_features is not None:
356
+ selected_feature_rows = np.min(
357
+ [mol_features.shape[0], pandas_display_options["line_limit"]]
358
+ )
359
+ st.dataframe(mol_features.iloc[:selected_feature_rows, :], hide_index=False)
360
+ # placeholder_features.dataframe(mol_features, hide_index=False)
361
+ feature_file_name = file.name.split(".")[0] + "_b3clf_features.csv"
362
+ features_csv = mol_features.to_csv(index=True)
363
+ st.download_button(
364
+ "Download features as CSV",
365
+ data=features_csv,
366
+ file_name=feature_file_name,
367
+ )
368
+ # prediction table
369
+ with prediction_column:
370
+ # st.subheader("Predictions")
371
+ if results is not None:
372
+ # Display the predictions in a table
373
+ selected_result_rows = np.min(
374
+ [results.shape[0], pandas_display_options["line_limit"]]
375
+ )
376
+ results_df_display = results.iloc[
377
+ :selected_result_rows, :
378
+ ].style.format({"B3clf_predicted_probability": "{:.6f}".format})
379
+ st.dataframe(results_df_display, hide_index=True)
380
+ # Add a button to download the predictions as a CSV file
381
+ predictions_csv = results.to_csv(index=True)
382
+ results_file_name = file.name.split(".")[0] + "_b3clf_predictions.csv"
383
+ st.download_button(
384
+ "Download predictions as CSV",
385
+ data=predictions_csv,
386
+ file_name=results_file_name,
387
+ )
388
+ # indicate the success of the job
389
+ # rain(
390
+ # emoji="🎈",
391
+ # font_size=54,
392
+ # falling_speed=5,
393
+ # animation_length=10,
394
+ # )
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  # hide footer
398
  # https://github.com/streamlit/streamlit/issues/892