boltz / app.py
jadechoghari's picture
add final
e0f2a0e
raw
history blame
4.82 kB
import os
import gradio as gr
from gradio_molecule3d import Molecule3D
import spaces
import subprocess
import glob
# directory to store cached outputs
CACHE_DIR = "gradio_cached_examples"
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "stick",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"visible": False
}
]
# Ensure the cache directory exists
os.makedirs(CACHE_DIR, exist_ok=True)
# Define example files and precomputed outputs
example_fasta_files = [
f"cache_examples/boltz_0.fasta",
f"cache_examples/Armadillo_6.fasta",
f"cache_examples/Covid_3.fasta",
f"cache_examples/Malaria_2.fasta",
f"cache_examples/MITOCHONDRIAL_9.fasta",
f"cache_examples/Monkeypox_4.fasta",
f"cache_examples/Plasmodium_1.fasta",
f"cache_examples/PROTOCADHERIN_8.fasta",
f"cache_examples/Vault_5.fasta",
f"cache_examples/Zipper_7.fasta",
]
# matching `.pdb` files in the `CACHE_DIR`
example_outputs = [
os.path.join(CACHE_DIR, os.path.basename(fasta_file).replace(".fasta", ".pdb"))
for fasta_file in example_fasta_files
]
# must load cached outputs
def load_cached_example_outputs(fasta_file: str) -> str:
# Find the corresponding `.pdb` file
pdb_file = os.path.basename(fasta_file).replace(".fasta", ".pdb")
cached_pdb_path = os.path.join(CACHE_DIR, pdb_file)
if os.path.exists(cached_pdb_path):
return cached_pdb_path
else:
raise FileNotFoundError(f"Cached output not found for {pdb_file}")
# handle example click
def on_example_click(fasta_file: str) -> str:
return load_cached_example_outputs(fasta_file)
# run predictions
@spaces.GPU(duration=120)
def predict(data,
accelerator="gpu", sampling_steps=50,
diffusion_samples=1):
print("Arguments passed to `predict` function:")
print(f" data: {data}")
print(f" accelerator: {accelerator}")
print(f" sampling_steps: {sampling_steps}")
print(f" diffusion_samples: {diffusion_samples}")
# we construct the base command
command = [
"boltz", "predict",
"--out_dir", "./",
"--accelerator", accelerator,
"--sampling_steps", str(sampling_steps),
"--diffusion_samples", str(diffusion_samples),
"--output_format", "pdb",
]
command.extend(["--checkpoint", "./ckpt/boltz1.ckpt"])
command.append(data)
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode == 0:
print("Prediction completed successfully...!")
print(f"Output saved to: {out_dir}")
else:
print("Prediction failed :(")
print("Error:", result.stderr)
def run_prediction(input_file, accelerator, sampling_steps,
diffusion_samples):
data = input_file.name
print("the data : ", data)
predict(
data=data,
accelerator=accelerator,
sampling_steps=sampling_steps,
diffusion_samples=diffusion_samples
)
# search for the latest .pdb file in the predictions folder
out_dir = "./"
search_path = os.path.join(out_dir, "boltz_results*/predictions/**/*.pdb")
pdb_files = glob.glob(search_path, recursive=True)
if not pdb_files:
print("No .pdb files found in the predictions folder.")
return None
# some manual logic
# get the latest .pdb file based on modification time
latest_pdb_file = max(pdb_files, key=os.path.getmtime)
return latest_pdb_file
with gr.Blocks() as demo:
gr.Markdown("# 🔬 Boltz-1: Democratizing Biomolecular Interaction Modeling 🧬")
with gr.Row():
with gr.Column(scale=1):
inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"])
with gr.Accordion("Advanced Settings", open=False):
accelerator = gr.Radio(choices=["gpu", "cpu"], value="gpu", label="Accelerator")
sampling_steps = gr.Slider(minimum=1, maximum=500, value=50, step=1, label="Sampling Steps")
diffusion_samples = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Diffusion Samples")
btn = gr.Button("Predict")
with gr.Column(scale=3):
out = Molecule3D(label="Generated Molecule", reps=reps)
btn.click(
run_prediction,
inputs=[inp, accelerator, sampling_steps, diffusion_samples],
outputs=out
)
gr.Examples(
examples=[[fasta_file] for fasta_file in example_fasta_files],
inputs=[inp],
outputs=out,
fn=lambda fasta_file: on_example_click(fasta_file),
cache_examples=True
)
if __name__ == "__main__":
demo.launch(share=True, debug=True)