Spaces:
Running
Running
Takashi Itoh
commited on
Commit
•
96fd466
1
Parent(s):
b804866
Change globals to state. Change transformers version. Delete code history comments.
Browse files- Dockerfile +5 -0
- app.py +20 -109
- requirements.txt +1 -1
Dockerfile
CHANGED
@@ -3,6 +3,11 @@ FROM python:3.9.7
|
|
3 |
WORKDIR /app
|
4 |
COPY requirements.txt .
|
5 |
RUN pip install -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
6 |
COPY . .
|
7 |
|
8 |
CMD ["python", "app.py"]
|
|
|
3 |
WORKDIR /app
|
4 |
COPY requirements.txt .
|
5 |
RUN pip install -r requirements.txt
|
6 |
+
# preload models
|
7 |
+
RUN python -c '\
|
8 |
+
from transformers import BartForConditionalGeneration, AutoTokenizer;\
|
9 |
+
AutoTokenizer.from_pretrained("ibm/materials.selfies-ted");\
|
10 |
+
BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")'
|
11 |
COPY . .
|
12 |
|
13 |
CMD ["python", "app.py"]
|
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import gradio as gr
|
2 |
-
from huggingface_hub import InferenceClient
|
3 |
import matplotlib.pyplot as plt
|
4 |
from PIL import Image
|
5 |
from rdkit.Chem import Descriptors, QED, Draw
|
@@ -25,23 +24,6 @@ import os
|
|
25 |
|
26 |
os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
|
27 |
|
28 |
-
# my_theme = gr.Theme.from_hub("ysharma/steampunk")
|
29 |
-
# my_theme = gr.themes.Glass()
|
30 |
-
|
31 |
-
"""
|
32 |
-
# カスタムテーマ設定
|
33 |
-
theme = gr.themes.Default().set(
|
34 |
-
body_background_fill="#000000", # 背景色を黒に設定
|
35 |
-
text_color="#FFFFFF", # テキスト色を白に設定
|
36 |
-
)
|
37 |
-
"""
|
38 |
-
"""
|
39 |
-
import sys
|
40 |
-
sys.path.append("models")
|
41 |
-
sys.path.append("../models")
|
42 |
-
sys.path.append("../")"""
|
43 |
-
|
44 |
-
|
45 |
# Get the current file's directory
|
46 |
base_dir = os.path.dirname(__file__)
|
47 |
print("Base Dir : ", base_dir)
|
@@ -139,7 +121,6 @@ def load_image(path):
|
|
139 |
pass
|
140 |
|
141 |
|
142 |
-
|
143 |
# Function to handle image selection
|
144 |
def handle_image_selection(image_key):
|
145 |
smiles = smiles_image_mapping[image_key]["smiles"]
|
@@ -171,9 +152,6 @@ def calculate_tanimoto(smiles1, smiles2):
|
|
171 |
return None
|
172 |
|
173 |
|
174 |
-
#with open("models/selfies_model/bart-2908.pickle", "rb") as input_file:
|
175 |
-
# gen_model, gen_tokenizer = pickle.load(input_file)
|
176 |
-
|
177 |
gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
|
178 |
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")
|
179 |
|
@@ -201,11 +179,6 @@ def encode(selfies):
|
|
201 |
attention_mask = encoding['attention_mask']
|
202 |
outputs = gen_model.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
203 |
model_output = outputs.last_hidden_state
|
204 |
-
|
205 |
-
"""input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
|
206 |
-
sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
|
207 |
-
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
208 |
-
model_output = sum_embeddings / sum_mask"""
|
209 |
return model_output, attention_mask
|
210 |
|
211 |
|
@@ -226,16 +199,6 @@ def generate_canonical(smiles):
|
|
226 |
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): break
|
227 |
else:
|
228 |
print('Abnormal molecule:', gen[0])
|
229 |
-
gen_mols = []
|
230 |
-
for sel in gen[0].split('.'):
|
231 |
-
mol = Chem.MolFromSmiles(sel)
|
232 |
-
if mol:
|
233 |
-
mol = Chem.MolToSmiles(mol)
|
234 |
-
if mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)):
|
235 |
-
gen_mols.append(mol)
|
236 |
-
if len(gen_mols) > 0:
|
237 |
-
gen_mol = '.'.join(gen_mols)
|
238 |
-
break
|
239 |
|
240 |
if gen_mol:
|
241 |
# Calculate properties for ref and gen molecules
|
@@ -262,7 +225,7 @@ def generate_canonical(smiles):
|
|
262 |
|
263 |
|
264 |
# Function to display evaluation score
|
265 |
-
def display_eval(selected_models, dataset, task_type, downstream, fusion_type):
|
266 |
result = None
|
267 |
|
268 |
try:
|
@@ -278,68 +241,47 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type):
|
|
278 |
params = None
|
279 |
|
280 |
|
281 |
-
|
282 |
-
|
283 |
try:
|
284 |
if not selected_models:
|
285 |
-
return "Please select at least one enabled model."
|
286 |
-
|
287 |
-
if task_type == "Classification":
|
288 |
-
global roc_auc, fpr, tpr, x_batch, y_batch
|
289 |
-
elif task_type == "Regression":
|
290 |
-
global RMSE, y_batch_test, y_prob
|
291 |
|
292 |
if len(selected_models) > 1:
|
293 |
if task_type == "Classification":
|
294 |
-
#result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
|
295 |
-
# downstream_model="XGBClassifier",
|
296 |
-
# dataset=dataset.lower())
|
297 |
if downstream_model == "Default Settings":
|
298 |
downstream_model = "DefaultClassifier"
|
299 |
params = None
|
300 |
-
result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
|
301 |
downstream_model=downstream_model,
|
302 |
params = params,
|
303 |
dataset=dataset)
|
304 |
|
305 |
elif task_type == "Regression":
|
306 |
-
#result, RMSE, y_batch_test, y_prob = fm4m.multi_modal(model_list=selected_models,
|
307 |
-
# downstream_model="XGBRegressor",
|
308 |
-
# dataset=dataset.lower())
|
309 |
-
|
310 |
if downstream_model == "Default Settings":
|
311 |
downstream_model = "DefaultRegressor"
|
312 |
params = None
|
313 |
|
314 |
-
result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
|
315 |
downstream_model=downstream_model,
|
316 |
params=params,
|
317 |
dataset=dataset)
|
318 |
|
319 |
else:
|
320 |
if task_type == "Classification":
|
321 |
-
#result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
|
322 |
-
# downstream_model="XGBClassifier",
|
323 |
-
# dataset=dataset.lower())
|
324 |
if downstream_model == "Default Settings":
|
325 |
downstream_model = "DefaultClassifier"
|
326 |
params = None
|
327 |
|
328 |
-
result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
|
329 |
downstream_model=downstream_model,
|
330 |
params=params,
|
331 |
dataset=dataset)
|
332 |
|
333 |
elif task_type == "Regression":
|
334 |
-
#result, RMSE, y_batch_test, y_prob = fm4m.single_modal(model=selected_models[0],
|
335 |
-
# downstream_model="XGBRegressor",
|
336 |
-
# dataset=dataset.lower())
|
337 |
-
|
338 |
if downstream_model == "Default Settings":
|
339 |
downstream_model = "DefaultRegressor"
|
340 |
params = None
|
341 |
|
342 |
-
result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
|
343 |
downstream_model=downstream_model,
|
344 |
params=params,
|
345 |
dataset=dataset)
|
@@ -347,28 +289,20 @@ def display_eval(selected_models, dataset, task_type, downstream, fusion_type):
|
|
347 |
if result == None:
|
348 |
result = "Data & Model Setting is incorrect"
|
349 |
except Exception as e:
|
350 |
-
return f"An error occurred: {e}"
|
351 |
-
return f"{result}"
|
352 |
|
353 |
|
354 |
# Function to handle plot display
|
355 |
-
def display_plot(plot_type):
|
356 |
fig, ax = plt.subplots()
|
357 |
|
358 |
if plot_type == "Latent Space":
|
359 |
-
|
360 |
ax.set_title("T-SNE Plot")
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
# index_0 = [index for index in range(len(x)) if x[index] == 0]
|
365 |
-
# index_1 = [index for index in range(len(x)) if x[index] == 1]
|
366 |
-
class_0 = x_batch # features_umap[index_0]
|
367 |
-
class_1 = y_batch # features_umap[index_1]
|
368 |
-
|
369 |
-
"""with open("latent_multi_bace.pkl", "rb") as f:
|
370 |
-
class_0, class_1 = pickle.load(f)
|
371 |
-
"""
|
372 |
plt.scatter(class_1[:, 0], class_1[:, 1], c='red', label='Class 1')
|
373 |
plt.scatter(class_0[:, 0], class_0[:, 1], c='blue', label='Class 0')
|
374 |
|
@@ -377,7 +311,7 @@ def display_plot(plot_type):
|
|
377 |
ax.set_title('Dataset Distribution')
|
378 |
|
379 |
elif plot_type == "ROC-AUC":
|
380 |
-
|
381 |
ax.set_title("ROC-AUC Curve")
|
382 |
try:
|
383 |
ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.4f})')
|
@@ -392,7 +326,7 @@ def display_plot(plot_type):
|
|
392 |
ax.legend(loc='lower right')
|
393 |
|
394 |
elif plot_type == "Parity Plot":
|
395 |
-
|
396 |
ax.set_title("Parity plot")
|
397 |
|
398 |
# change format
|
@@ -415,9 +349,6 @@ def display_plot(plot_type):
|
|
415 |
print(y_prob)
|
416 |
|
417 |
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
ax.set_xlabel('Actual Values')
|
422 |
ax.set_ylabel('Predicted Values')
|
423 |
|
@@ -493,6 +424,7 @@ def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degr
|
|
493 |
return "Model not supported."
|
494 |
|
495 |
return f"{model_name} * {model.get_params()}"
|
|
|
496 |
def model_selector(model_name):
|
497 |
# Dynamically return the appropriate hyperparameter components based on the selected model
|
498 |
if model_name == "XGBClassifier":
|
@@ -518,10 +450,10 @@ def model_selector(model_name):
|
|
518 |
return ()
|
519 |
|
520 |
|
521 |
-
|
522 |
# Define the Gradio layout
|
523 |
# with gr.Blocks(theme=my_theme) as demo:
|
524 |
with gr.Blocks() as demo:
|
|
|
525 |
with gr.Row():
|
526 |
# Left Column
|
527 |
with gr.Column():
|
@@ -530,9 +462,6 @@ with gr.Blocks() as demo:
|
|
530 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3>
|
531 |
</div>
|
532 |
''')
|
533 |
-
# gr.Markdown("## Data & Model Setting")
|
534 |
-
#dataset_dropdown = gr.Dropdown(choices=datasets, label="Select Dat")
|
535 |
-
|
536 |
# Dropdown menu for predefined datasets including "Custom Dataset" option
|
537 |
dataset_selector = gr.Dropdown(label="Select Dataset",
|
538 |
choices=list(predefined_datasets.keys()) + ["Custom Dataset"])
|
@@ -553,13 +482,10 @@ with gr.Blocks() as demo:
|
|
553 |
interactive=False)
|
554 |
|
555 |
|
556 |
-
|
557 |
# Dropdowns for selecting input and output columns for the custom dataset
|
558 |
input_column_selector = gr.Dropdown(label="Select Input Column", choices=[], visible=False)
|
559 |
output_column_selector = gr.Dropdown(label="Select Output Column", choices=[], visible=False)
|
560 |
|
561 |
-
#selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=True)
|
562 |
-
|
563 |
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
|
564 |
dataset_selector.change(handle_dataset_selection,
|
565 |
inputs=dataset_selector,
|
@@ -594,10 +520,6 @@ with gr.Blocks() as demo:
|
|
594 |
|
595 |
model_checkbox = gr.CheckboxGroup(choices=models_enabled, label="Select Model")
|
596 |
|
597 |
-
# Add disabled checkboxes for GNN and FNN
|
598 |
-
# gnn_checkbox = gr.Checkbox(label="GNN (Disabled)", value=False, interactive=False)
|
599 |
-
# fnn_checkbox = gr.Checkbox(label="FNN (Disabled)", value=False, interactive=False)
|
600 |
-
|
601 |
task_radiobutton = gr.Radio(choices=["Classification", "Regression"], label="Task Type")
|
602 |
|
603 |
####### adding hyper parameter tuning ###########
|
@@ -662,9 +584,7 @@ with gr.Blocks() as demo:
|
|
662 |
fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
|
663 |
|
664 |
|
665 |
-
|
666 |
eval_button = gr.Button("Train downstream model")
|
667 |
-
#eval_button.style(css_class="custom-button-left")
|
668 |
|
669 |
# Middle Column
|
670 |
with gr.Column():
|
@@ -673,23 +593,20 @@ with gr.Blocks() as demo:
|
|
673 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3>
|
674 |
</div>
|
675 |
''')
|
676 |
-
# gr.Markdown("## Downstream task Result")
|
677 |
eval_output = gr.Textbox(label="Train downstream model")
|
678 |
|
679 |
plot_radio = gr.Radio(choices=["ROC-AUC", "Parity Plot", "Latent Space"], label="Select Plot Type")
|
680 |
plot_output = gr.Plot(label="Visualization")#, height=250, width=250)
|
681 |
|
682 |
-
#download_rep = gr.Button("Download representation")
|
683 |
-
|
684 |
create_log = gr.Button("Store log")
|
685 |
|
686 |
log_table = gr.Dataframe(value=log_df, label="Log of Selections and Results", interactive=False)
|
687 |
|
688 |
eval_button.click(display_eval,
|
689 |
-
inputs=[model_checkbox, selected_columns_message, task_radiobutton, output, fusion_radiobutton],
|
690 |
-
outputs=eval_output)
|
691 |
|
692 |
-
plot_radio.change(display_plot, inputs=plot_radio, outputs=plot_output)
|
693 |
|
694 |
|
695 |
# Function to gather selected models
|
@@ -700,9 +617,6 @@ with gr.Blocks() as demo:
|
|
700 |
|
701 |
create_log.click(evaluate_and_log, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
|
702 |
outputs=log_table)
|
703 |
-
#download_rep.click(save_rep, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
|
704 |
-
# outputs=None)
|
705 |
-
|
706 |
# Right Column
|
707 |
with gr.Column():
|
708 |
gr.HTML('''
|
@@ -710,7 +624,6 @@ with gr.Blocks() as demo:
|
|
710 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3>
|
711 |
</div>
|
712 |
''')
|
713 |
-
# gr.Markdown("## Molecular Generation")
|
714 |
smiles_input = gr.Textbox(label="Input SMILES String")
|
715 |
image_display = gr.Image(label="Molecule Image", height=250, width=250)
|
716 |
# Show images for selection
|
@@ -719,7 +632,6 @@ with gr.Blocks() as demo:
|
|
719 |
choices=list(smiles_image_mapping.keys()),
|
720 |
label="Select from sample molecules",
|
721 |
value=None,
|
722 |
-
#item_images=[load_image(smiles_image_mapping[key]["image"]) for key in smiles_image_mapping.keys()]
|
723 |
)
|
724 |
image_selector.change(load_image, image_selector, image_display)
|
725 |
generate_button = gr.Button("Generate")
|
@@ -728,7 +640,6 @@ with gr.Blocks() as demo:
|
|
728 |
property_table = gr.Dataframe(label="Molecular Properties Comparison")
|
729 |
|
730 |
|
731 |
-
|
732 |
# Handle image selection
|
733 |
image_selector.change(handle_image_selection, inputs=image_selector, outputs=[smiles_input, image_display])
|
734 |
smiles_input.change(smiles_to_image, inputs=smiles_input, outputs=image_display)
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import matplotlib.pyplot as plt
|
3 |
from PIL import Image
|
4 |
from rdkit.Chem import Descriptors, QED, Draw
|
|
|
24 |
|
25 |
os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Get the current file's directory
|
28 |
base_dir = os.path.dirname(__file__)
|
29 |
print("Base Dir : ", base_dir)
|
|
|
121 |
pass
|
122 |
|
123 |
|
|
|
124 |
# Function to handle image selection
|
125 |
def handle_image_selection(image_key):
|
126 |
smiles = smiles_image_mapping[image_key]["smiles"]
|
|
|
152 |
return None
|
153 |
|
154 |
|
|
|
|
|
|
|
155 |
gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
|
156 |
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")
|
157 |
|
|
|
179 |
attention_mask = encoding['attention_mask']
|
180 |
outputs = gen_model.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
181 |
model_output = outputs.last_hidden_state
|
|
|
|
|
|
|
|
|
|
|
182 |
return model_output, attention_mask
|
183 |
|
184 |
|
|
|
199 |
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): break
|
200 |
else:
|
201 |
print('Abnormal molecule:', gen[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
if gen_mol:
|
204 |
# Calculate properties for ref and gen molecules
|
|
|
225 |
|
226 |
|
227 |
# Function to display evaluation score
|
228 |
+
def display_eval(selected_models, dataset, task_type, downstream, fusion_type, state):
|
229 |
result = None
|
230 |
|
231 |
try:
|
|
|
241 |
params = None
|
242 |
|
243 |
|
|
|
|
|
244 |
try:
|
245 |
if not selected_models:
|
246 |
+
return "Please select at least one enabled model.", state
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
if len(selected_models) > 1:
|
249 |
if task_type == "Classification":
|
|
|
|
|
|
|
250 |
if downstream_model == "Default Settings":
|
251 |
downstream_model = "DefaultClassifier"
|
252 |
params = None
|
253 |
+
result, state["roc_auc"], state["fpr"], state["tpr"], state["x_batch"], state["y_batch"] = fm4m.multi_modal(model_list=selected_models,
|
254 |
downstream_model=downstream_model,
|
255 |
params = params,
|
256 |
dataset=dataset)
|
257 |
|
258 |
elif task_type == "Regression":
|
|
|
|
|
|
|
|
|
259 |
if downstream_model == "Default Settings":
|
260 |
downstream_model = "DefaultRegressor"
|
261 |
params = None
|
262 |
|
263 |
+
result, state["RMSE"], state["y_batch_test"], state["y_prob"], state["x_batch"], state["y_batch"] = fm4m.multi_modal(model_list=selected_models,
|
264 |
downstream_model=downstream_model,
|
265 |
params=params,
|
266 |
dataset=dataset)
|
267 |
|
268 |
else:
|
269 |
if task_type == "Classification":
|
|
|
|
|
|
|
270 |
if downstream_model == "Default Settings":
|
271 |
downstream_model = "DefaultClassifier"
|
272 |
params = None
|
273 |
|
274 |
+
result, state["roc_auc"], state["fpr"], state["tpr"], state["x_batch"], state["y_batch"] = fm4m.single_modal(model=selected_models[0],
|
275 |
downstream_model=downstream_model,
|
276 |
params=params,
|
277 |
dataset=dataset)
|
278 |
|
279 |
elif task_type == "Regression":
|
|
|
|
|
|
|
|
|
280 |
if downstream_model == "Default Settings":
|
281 |
downstream_model = "DefaultRegressor"
|
282 |
params = None
|
283 |
|
284 |
+
result, state["RMSE"], state["y_batch_test"], state["y_prob"], state["x_batch"], state["y_batch"] = fm4m.single_modal(model=selected_models[0],
|
285 |
downstream_model=downstream_model,
|
286 |
params=params,
|
287 |
dataset=dataset)
|
|
|
289 |
if result == None:
|
290 |
result = "Data & Model Setting is incorrect"
|
291 |
except Exception as e:
|
292 |
+
return f"An error occurred: {e}", state
|
293 |
+
return f"{result}", state
|
294 |
|
295 |
|
296 |
# Function to handle plot display
|
297 |
+
def display_plot(plot_type, state):
|
298 |
fig, ax = plt.subplots()
|
299 |
|
300 |
if plot_type == "Latent Space":
|
301 |
+
x_batch, y_batch = state.get("x_batch"), state.get("y_batch")
|
302 |
ax.set_title("T-SNE Plot")
|
303 |
+
class_0 = x_batch
|
304 |
+
class_1 = y_batch
|
305 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
plt.scatter(class_1[:, 0], class_1[:, 1], c='red', label='Class 1')
|
307 |
plt.scatter(class_0[:, 0], class_0[:, 1], c='blue', label='Class 0')
|
308 |
|
|
|
311 |
ax.set_title('Dataset Distribution')
|
312 |
|
313 |
elif plot_type == "ROC-AUC":
|
314 |
+
roc_auc, fpr, tpr = state.get("roc_auc"), state.get("fpr"), state.get("tpr")
|
315 |
ax.set_title("ROC-AUC Curve")
|
316 |
try:
|
317 |
ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.4f})')
|
|
|
326 |
ax.legend(loc='lower right')
|
327 |
|
328 |
elif plot_type == "Parity Plot":
|
329 |
+
RMSE, y_batch_test, y_prob = state.get("RMSE"), state.get("y_batch_test"), state.get("y_prob")
|
330 |
ax.set_title("Parity plot")
|
331 |
|
332 |
# change format
|
|
|
349 |
print(y_prob)
|
350 |
|
351 |
|
|
|
|
|
|
|
352 |
ax.set_xlabel('Actual Values')
|
353 |
ax.set_ylabel('Predicted Values')
|
354 |
|
|
|
424 |
return "Model not supported."
|
425 |
|
426 |
return f"{model_name} * {model.get_params()}"
|
427 |
+
|
428 |
def model_selector(model_name):
|
429 |
# Dynamically return the appropriate hyperparameter components based on the selected model
|
430 |
if model_name == "XGBClassifier":
|
|
|
450 |
return ()
|
451 |
|
452 |
|
|
|
453 |
# Define the Gradio layout
|
454 |
# with gr.Blocks(theme=my_theme) as demo:
|
455 |
with gr.Blocks() as demo:
|
456 |
+
state = gr.State({})
|
457 |
with gr.Row():
|
458 |
# Left Column
|
459 |
with gr.Column():
|
|
|
462 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3>
|
463 |
</div>
|
464 |
''')
|
|
|
|
|
|
|
465 |
# Dropdown menu for predefined datasets including "Custom Dataset" option
|
466 |
dataset_selector = gr.Dropdown(label="Select Dataset",
|
467 |
choices=list(predefined_datasets.keys()) + ["Custom Dataset"])
|
|
|
482 |
interactive=False)
|
483 |
|
484 |
|
|
|
485 |
# Dropdowns for selecting input and output columns for the custom dataset
|
486 |
input_column_selector = gr.Dropdown(label="Select Input Column", choices=[], visible=False)
|
487 |
output_column_selector = gr.Dropdown(label="Select Output Column", choices=[], visible=False)
|
488 |
|
|
|
|
|
489 |
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
|
490 |
dataset_selector.change(handle_dataset_selection,
|
491 |
inputs=dataset_selector,
|
|
|
520 |
|
521 |
model_checkbox = gr.CheckboxGroup(choices=models_enabled, label="Select Model")
|
522 |
|
|
|
|
|
|
|
|
|
523 |
task_radiobutton = gr.Radio(choices=["Classification", "Regression"], label="Task Type")
|
524 |
|
525 |
####### adding hyper parameter tuning ###########
|
|
|
584 |
fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
|
585 |
|
586 |
|
|
|
587 |
eval_button = gr.Button("Train downstream model")
|
|
|
588 |
|
589 |
# Middle Column
|
590 |
with gr.Column():
|
|
|
593 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3>
|
594 |
</div>
|
595 |
''')
|
|
|
596 |
eval_output = gr.Textbox(label="Train downstream model")
|
597 |
|
598 |
plot_radio = gr.Radio(choices=["ROC-AUC", "Parity Plot", "Latent Space"], label="Select Plot Type")
|
599 |
plot_output = gr.Plot(label="Visualization")#, height=250, width=250)
|
600 |
|
|
|
|
|
601 |
create_log = gr.Button("Store log")
|
602 |
|
603 |
log_table = gr.Dataframe(value=log_df, label="Log of Selections and Results", interactive=False)
|
604 |
|
605 |
eval_button.click(display_eval,
|
606 |
+
inputs=[model_checkbox, selected_columns_message, task_radiobutton, output, fusion_radiobutton, state],
|
607 |
+
outputs=[eval_output, state])
|
608 |
|
609 |
+
plot_radio.change(display_plot, inputs=[plot_radio, state], outputs=plot_output)
|
610 |
|
611 |
|
612 |
# Function to gather selected models
|
|
|
617 |
|
618 |
create_log.click(evaluate_and_log, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
|
619 |
outputs=log_table)
|
|
|
|
|
|
|
620 |
# Right Column
|
621 |
with gr.Column():
|
622 |
gr.HTML('''
|
|
|
624 |
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3>
|
625 |
</div>
|
626 |
''')
|
|
|
627 |
smiles_input = gr.Textbox(label="Input SMILES String")
|
628 |
image_display = gr.Image(label="Molecule Image", height=250, width=250)
|
629 |
# Show images for selection
|
|
|
632 |
choices=list(smiles_image_mapping.keys()),
|
633 |
label="Select from sample molecules",
|
634 |
value=None,
|
|
|
635 |
)
|
636 |
image_selector.change(load_image, image_selector, image_display)
|
637 |
generate_button = gr.Button("Generate")
|
|
|
640 |
property_table = gr.Dataframe(label="Molecular Properties Comparison")
|
641 |
|
642 |
|
|
|
643 |
# Handle image selection
|
644 |
image_selector.change(handle_image_selection, inputs=image_selector, outputs=[smiles_input, image_display])
|
645 |
smiles_input.change(smiles_to_image, inputs=smiles_input, outputs=image_display)
|
requirements.txt
CHANGED
@@ -16,7 +16,7 @@ datasets>=2.13.1
|
|
16 |
evaluate>=0.4.0
|
17 |
selfies>=2.1.0
|
18 |
gradio==3.41.0
|
19 |
-
transformers
|
20 |
requests>=2.32.2
|
21 |
urllib3>=2.2.2
|
22 |
aiohttp>=3.10.2
|
|
|
16 |
evaluate>=0.4.0
|
17 |
selfies>=2.1.0
|
18 |
gradio==3.41.0
|
19 |
+
transformers==4.38.1
|
20 |
requests>=2.32.2
|
21 |
urllib3>=2.2.2
|
22 |
aiohttp>=3.10.2
|