Spaces:
Running
Running
path fixed
Browse files- README.md +4 -0
- app.py +9 -4
- models/fm4m.py +9 -222
- representation/esol_smi-ted.pkl +2 -2
README.md
CHANGED
@@ -8,6 +8,10 @@ sdk_version: 5.4.0
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
models:
|
12 |
+
- ibm/materials.smi-ted
|
13 |
+
- ibm/materials.selfies-ted
|
14 |
+
- ibm/materials.mhg-ged
|
15 |
---
|
16 |
|
17 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -35,11 +35,16 @@ theme = gr.themes.Default().set(
|
|
35 |
text_color="#FFFFFF", # テキスト色を白に設定
|
36 |
)
|
37 |
"""
|
38 |
-
|
39 |
import sys
|
40 |
sys.path.append("models")
|
41 |
sys.path.append("../models")
|
42 |
-
sys.path.append("../")
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
import models.fm4m as fm4m
|
45 |
|
@@ -388,8 +393,8 @@ def display_plot(plot_type):
|
|
388 |
|
389 |
# Predefined dataset paths (these should be adjusted to your file paths)
|
390 |
predefined_datasets = {
|
391 |
-
"Bace": f"data/bace/train.csv, data/bace/test.csv, smiles, Class",
|
392 |
-
"ESOL": f"data/esol/train.csv, data/esol/test.csv, smiles, prop",
|
393 |
}
|
394 |
|
395 |
|
|
|
35 |
text_color="#FFFFFF", # テキスト色を白に設定
|
36 |
)
|
37 |
"""
|
38 |
+
"""
|
39 |
import sys
|
40 |
sys.path.append("models")
|
41 |
sys.path.append("../models")
|
42 |
+
sys.path.append("../")"""
|
43 |
+
|
44 |
+
|
45 |
+
# Get the current file's directory
|
46 |
+
base_dir = os.path.dirname(__file__)
|
47 |
+
print("Base Dir : ", base_dir)
|
48 |
|
49 |
import models.fm4m as fm4m
|
50 |
|
|
|
393 |
|
394 |
# Predefined dataset paths (these should be adjusted to your file paths)
|
395 |
predefined_datasets = {
|
396 |
+
"Bace": f"./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
|
397 |
+
"ESOL": f"./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
|
398 |
}
|
399 |
|
400 |
|
models/fm4m.py
CHANGED
@@ -25,12 +25,9 @@ from sklearn.preprocessing import MinMaxScaler
|
|
25 |
import torch
|
26 |
from transformers import AutoTokenizer, AutoModel
|
27 |
|
28 |
-
import
|
29 |
-
|
30 |
-
|
31 |
-
from models.selfies_model.load import SELFIES as bart
|
32 |
-
from models.mhg_model import load as mhg
|
33 |
-
from models.smi_ted.smi_ted_light.load import load_smi_ted
|
34 |
|
35 |
datasets = {}
|
36 |
models = {}
|
@@ -181,175 +178,6 @@ def update_downstream_model_list(list_model):
|
|
181 |
|
182 |
avail_models_data()
|
183 |
|
184 |
-
def list_models():
|
185 |
-
#print(*list(models.keys()),sep='\n')
|
186 |
-
data = avail_models(raw=True)
|
187 |
-
# Convert data to a pandas DataFrame
|
188 |
-
df = pd.DataFrame(data)
|
189 |
-
|
190 |
-
# Add a column for displaying row numbers starting from 1
|
191 |
-
df.index += 1
|
192 |
-
|
193 |
-
# Create dropdown widget for sorting
|
194 |
-
sort_dropdown = widgets.Dropdown(
|
195 |
-
options=['Name', 'Timestamp'],
|
196 |
-
value='Name',
|
197 |
-
description='Sort by:',
|
198 |
-
disabled=False,
|
199 |
-
)
|
200 |
-
|
201 |
-
# Output widget to display the table
|
202 |
-
output = widgets.Output()
|
203 |
-
|
204 |
-
# Define function to update display based on sorting
|
205 |
-
def update_display(change):
|
206 |
-
with output:
|
207 |
-
output.clear_output(wait=True)
|
208 |
-
sorted_df = df.sort_values(by=sort_dropdown.value)
|
209 |
-
display(sorted_df.style.set_properties(**{
|
210 |
-
'text-align': 'left', 'border': '1px solid #ddd',
|
211 |
-
}))
|
212 |
-
|
213 |
-
# Attach the update_display function to the dropdown widget
|
214 |
-
sort_dropdown.observe(update_display, names='value')
|
215 |
-
|
216 |
-
# Display the dropdown and the table initially
|
217 |
-
display(sort_dropdown, output)
|
218 |
-
update_display(None)
|
219 |
-
|
220 |
-
def list_downstream_models():
|
221 |
-
#print(*list(models.keys()),sep='\n')
|
222 |
-
data = avail_downstream_models()
|
223 |
-
# Convert data to a pandas DataFrame
|
224 |
-
df = pd.DataFrame(data)
|
225 |
-
|
226 |
-
# Add a column for displaying row numbers starting from 1
|
227 |
-
df.index += 1
|
228 |
-
|
229 |
-
# Create dropdown widget for sorting
|
230 |
-
sort_dropdown = widgets.Dropdown(
|
231 |
-
options=['Name', 'Timestamp'],
|
232 |
-
value='Timestamp',
|
233 |
-
description='Sort by:',
|
234 |
-
disabled=False,
|
235 |
-
)
|
236 |
-
|
237 |
-
# Output widget to display the table
|
238 |
-
output = widgets.Output()
|
239 |
-
|
240 |
-
# Define function to update display based on sorting
|
241 |
-
def update_display(change):
|
242 |
-
with output:
|
243 |
-
output.clear_output(wait=True)
|
244 |
-
sorted_df = df.sort_values(by=sort_dropdown.value)
|
245 |
-
display(sorted_df.style.set_properties(**{
|
246 |
-
'text-align': 'left', 'border': '1px solid #ddd',
|
247 |
-
}))
|
248 |
-
|
249 |
-
# Attach the update_display function to the dropdown widget
|
250 |
-
sort_dropdown.observe(update_display, names='value')
|
251 |
-
|
252 |
-
# Display the dropdown and the table initially
|
253 |
-
display(sort_dropdown, output)
|
254 |
-
update_display(None)
|
255 |
-
|
256 |
-
def list_data():
|
257 |
-
|
258 |
-
#print(*list(datasets.keys()),sep='\n')
|
259 |
-
data = avail_datasets()
|
260 |
-
# Convert data to a pandas DataFrame
|
261 |
-
df = pd.DataFrame(data)
|
262 |
-
|
263 |
-
# Add a column for displaying row numbers starting from 1
|
264 |
-
df.index += 1
|
265 |
-
|
266 |
-
# Create dropdown widget for sorting
|
267 |
-
sort_dropdown = widgets.Dropdown(
|
268 |
-
options=['Dataset', 'Input', 'Output', 'Path', 'Timestamp'],
|
269 |
-
value='Input',
|
270 |
-
description='Sort by:',
|
271 |
-
disabled=False,
|
272 |
-
)
|
273 |
-
|
274 |
-
# Output widget to display the table
|
275 |
-
output = widgets.Output()
|
276 |
-
|
277 |
-
# Define function to update display based on sorting
|
278 |
-
def update_display(change):
|
279 |
-
with output:
|
280 |
-
output.clear_output(wait=True)
|
281 |
-
sorted_df = df.sort_values(by=sort_dropdown.value)
|
282 |
-
display(sorted_df.style.set_properties(**{
|
283 |
-
'text-align': 'left', 'border': '1px solid #ddd',
|
284 |
-
}))
|
285 |
-
|
286 |
-
# Attach the update_display function to the dropdown widget
|
287 |
-
sort_dropdown.observe(update_display, names='value')
|
288 |
-
|
289 |
-
# Display the dropdown and the table initially
|
290 |
-
display(sort_dropdown, output)
|
291 |
-
update_display(None)
|
292 |
-
|
293 |
-
def vizualize(roc_auc,fpr, tpr, features, labels):
|
294 |
-
#def vizualize(features, labels):
|
295 |
-
|
296 |
-
reducer = umap.UMAP(metric="jaccard", n_neighbors=20, n_components=2, low_memory=True, min_dist=0.001, verbose=False)
|
297 |
-
|
298 |
-
features_umap = reducer.fit_transform(features)
|
299 |
-
x = labels.values
|
300 |
-
index_0 = [index for index in range(len(x)) if x[index] == 0]
|
301 |
-
index_1 = [index for index in range(len(x)) if x[index] == 1]
|
302 |
-
|
303 |
-
class_0 = features_umap[index_0]
|
304 |
-
class_1 = features_umap[index_1]
|
305 |
-
|
306 |
-
|
307 |
-
# Function to create ROC AUC plot
|
308 |
-
def plot_roc_auc():
|
309 |
-
plt.figure(figsize=(8, 6))
|
310 |
-
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.4f})')
|
311 |
-
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
312 |
-
plt.xlim([0.0, 1.0])
|
313 |
-
plt.ylim([0.0, 1.05])
|
314 |
-
plt.xlabel('False Positive Rate')
|
315 |
-
plt.ylabel('True Positive Rate')
|
316 |
-
plt.title('Receiver Operating Characteristic')
|
317 |
-
plt.legend(loc='lower right')
|
318 |
-
plt.show()
|
319 |
-
|
320 |
-
# Function to create scatter plot of the dataset distribution
|
321 |
-
def plot_distribution():
|
322 |
-
plt.figure(figsize=(8, 6))
|
323 |
-
#plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm, edgecolors='k')
|
324 |
-
plt.scatter(class_1[:, 0], class_1[:, 1], c='red', label='Class 1')
|
325 |
-
plt.scatter(class_0[:, 0], class_0[:, 1], c='blue', label='Class 0')
|
326 |
-
|
327 |
-
plt.xlabel('Feature 1')
|
328 |
-
plt.ylabel('Feature 2')
|
329 |
-
plt.title('Dataset Distribution')
|
330 |
-
plt.show()
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
# Create tabs using ipywidgets
|
335 |
-
tab_contents = ['ROC AUC', 'Distribution']
|
336 |
-
children = [widgets.Output(), widgets.Output()]
|
337 |
-
|
338 |
-
tab = widgets.Tab()
|
339 |
-
tab.children = children
|
340 |
-
for i in range(len(tab_contents)):
|
341 |
-
tab.set_title(i, tab_contents[i])
|
342 |
-
|
343 |
-
# Display plots in their respective tabs
|
344 |
-
with children[0]:
|
345 |
-
plot_roc_auc()
|
346 |
-
|
347 |
-
with children[1]:
|
348 |
-
plot_distribution()
|
349 |
-
|
350 |
-
# Display the tab widget
|
351 |
-
display(tab)
|
352 |
-
|
353 |
def get_representation(train_data,test_data,model_type, return_tensor=True):
|
354 |
alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
|
355 |
if model_type in alias.keys():
|
@@ -434,7 +262,7 @@ def single_modal(model,dataset, downstream_model,params):
|
|
434 |
|
435 |
if dataset in list(df["Dataset"].values):
|
436 |
task = dataset
|
437 |
-
with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
|
438 |
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
|
439 |
print(f" Representation loaded successfully")
|
440 |
else:
|
@@ -472,7 +300,7 @@ def single_modal(model,dataset, downstream_model,params):
|
|
472 |
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
473 |
|
474 |
try:
|
475 |
-
with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1:
|
476 |
class_0,class_1 = pickle.load(f1)
|
477 |
except:
|
478 |
print("Generating latent plots")
|
@@ -505,7 +333,7 @@ def single_modal(model,dataset, downstream_model,params):
|
|
505 |
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
506 |
|
507 |
try:
|
508 |
-
with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1:
|
509 |
class_0,class_1 = pickle.load(f1)
|
510 |
except:
|
511 |
print("Generating latent plots")
|
@@ -673,7 +501,7 @@ def multi_modal(model_list,dataset, downstream_model,params):
|
|
673 |
|
674 |
if i == 0:
|
675 |
if predefined:
|
676 |
-
with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
|
677 |
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
|
678 |
print(f" Loaded representation/{task}_{model_type}.pkl")
|
679 |
else:
|
@@ -683,7 +511,7 @@ def multi_modal(model_list,dataset, downstream_model,params):
|
|
683 |
|
684 |
else:
|
685 |
if predefined:
|
686 |
-
with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
|
687 |
x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1)
|
688 |
print(f" Loaded representation/{task}_{model_type}.pkl")
|
689 |
else:
|
@@ -708,7 +536,7 @@ def multi_modal(model_list,dataset, downstream_model,params):
|
|
708 |
|
709 |
print(f"Representations loaded successfully")
|
710 |
try:
|
711 |
-
with open(f"plot_emb/{task}_multi.pkl", "rb") as f1:
|
712 |
class_0, class_1 = pickle.load(f1)
|
713 |
except:
|
714 |
print("Generating latent plots")
|
@@ -830,47 +658,6 @@ def multi_modal(model_list,dataset, downstream_model,params):
|
|
830 |
|
831 |
|
832 |
|
833 |
-
def finetune_optuna(x_batch,y_batch, x_batch_test, y_test ):
|
834 |
-
print(f" Finetuning with Optuna and calculating ROC AUC Score ...")
|
835 |
-
X_train = x_batch.values
|
836 |
-
y_train = y_batch.values
|
837 |
-
X_test = x_batch_test.values
|
838 |
-
y_test = y_test.values
|
839 |
-
def objective(trial):
|
840 |
-
# Define parameters to be optimized
|
841 |
-
params = {
|
842 |
-
# 'objective': 'binary:logistic',
|
843 |
-
'eval_metric': 'auc',
|
844 |
-
'verbosity': 0,
|
845 |
-
'n_estimators': trial.suggest_int('n_estimators', 1000, 10000),
|
846 |
-
# 'booster': trial.suggest_categorical('booster', ['gbtree', 'gblinear', 'dart']),
|
847 |
-
# 'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0),
|
848 |
-
'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0),
|
849 |
-
'max_depth': trial.suggest_int('max_depth', 1, 12),
|
850 |
-
# 'eta': trial.suggest_loguniform('eta', 1e-8, 1.0),
|
851 |
-
# 'gamma': trial.suggest_loguniform('gamma', 1e-8, 1.0),
|
852 |
-
# 'grow_policy': trial.suggest_categorical('grow_policy', ['depthwise', 'lossguide']),
|
853 |
-
# "subsample": trial.suggest_float("subsample", 0.05, 1.0),
|
854 |
-
# "colsample_bytree": trial.suggest_float("colsample_bytree", 0.05, 1.0),
|
855 |
-
}
|
856 |
-
|
857 |
-
# Train XGBoost model
|
858 |
-
dtrain = xgb.DMatrix(X_train, label=y_train)
|
859 |
-
dtest = xgb.DMatrix(X_test, label=y_test)
|
860 |
-
|
861 |
-
model = xgb.train(params, dtrain)
|
862 |
-
|
863 |
-
# Predict probabilities
|
864 |
-
y_pred = model.predict(dtest)
|
865 |
-
|
866 |
-
# Calculate ROC AUC score
|
867 |
-
roc_auc = roc_auc_score(y_test, y_pred)
|
868 |
-
print("ROC_AUC : ", roc_auc)
|
869 |
-
|
870 |
-
return roc_auc
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
|
875 |
|
876 |
|
|
|
25 |
import torch
|
26 |
from transformers import AutoTokenizer, AutoModel
|
27 |
|
28 |
+
from .selfies_model.load import SELFIES as bart
|
29 |
+
from .mhg_model import load as mhg
|
30 |
+
from .smi_ted.smi_ted_light.load import load_smi_ted
|
|
|
|
|
|
|
31 |
|
32 |
datasets = {}
|
33 |
models = {}
|
|
|
178 |
|
179 |
avail_models_data()
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
def get_representation(train_data,test_data,model_type, return_tensor=True):
|
182 |
alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
|
183 |
if model_type in alias.keys():
|
|
|
262 |
|
263 |
if dataset in list(df["Dataset"].values):
|
264 |
task = dataset
|
265 |
+
with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
|
266 |
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
|
267 |
print(f" Representation loaded successfully")
|
268 |
else:
|
|
|
300 |
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
301 |
|
302 |
try:
|
303 |
+
with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
|
304 |
class_0,class_1 = pickle.load(f1)
|
305 |
except:
|
306 |
print("Generating latent plots")
|
|
|
333 |
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
334 |
|
335 |
try:
|
336 |
+
with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
|
337 |
class_0,class_1 = pickle.load(f1)
|
338 |
except:
|
339 |
print("Generating latent plots")
|
|
|
501 |
|
502 |
if i == 0:
|
503 |
if predefined:
|
504 |
+
with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
|
505 |
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
|
506 |
print(f" Loaded representation/{task}_{model_type}.pkl")
|
507 |
else:
|
|
|
511 |
|
512 |
else:
|
513 |
if predefined:
|
514 |
+
with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
|
515 |
x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1)
|
516 |
print(f" Loaded representation/{task}_{model_type}.pkl")
|
517 |
else:
|
|
|
536 |
|
537 |
print(f"Representations loaded successfully")
|
538 |
try:
|
539 |
+
with open(f"./plot_emb/{task}_multi.pkl", "rb") as f1:
|
540 |
class_0, class_1 = pickle.load(f1)
|
541 |
except:
|
542 |
print("Generating latent plots")
|
|
|
658 |
|
659 |
|
660 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
|
662 |
|
663 |
|
representation/esol_smi-ted.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:52cbf2c9afa3a06ed068ba7583df229ec9a1aa823b22baecac97fa891475f85a
|
3 |
+
size 2964232
|