File size: 4,820 Bytes
d4e347d
cf729de
d4e347d
cf729de
 
 
 
 
e0f2a0e
cf729de
 
 
 
d4e347d
cf729de
 
 
 
 
 
 
 
 
d4e347d
cf729de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4e347d
 
cf729de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f2a0e
 
 
 
cf729de
 
 
 
 
 
e0f2a0e
cf729de
 
e0f2a0e
cf729de
 
 
e0f2a0e
cf729de
e0f2a0e
cf729de
 
 
 
 
 
 
 
 
e0f2a0e
 
 
cf729de
d4e347d
 
 
 
e0f2a0e
d4e347d
 
e0f2a0e
 
cf729de
e0f2a0e
cf729de
 
 
d4e347d
 
e0f2a0e
 
cf729de
 
 
 
d4e347d
cf729de
d4e347d
cf729de
 
 
d4e347d
cf729de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4e347d
 
 
 
cf729de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import gradio as gr
from gradio_molecule3d import Molecule3D
import spaces
import subprocess
import glob


# directory to store cached outputs
CACHE_DIR = "gradio_cached_examples"


reps =    [
    {
      "model": 0,
      "chain": "",
      "resname": "",
      "style": "stick",
      "color": "whiteCarbon",
      "residue_range": "",
      "around": 0,
      "byres": False,
      "visible": False
    }
  ]
# Ensure the cache directory exists
os.makedirs(CACHE_DIR, exist_ok=True)

# Define example files and precomputed outputs
example_fasta_files = [
    f"cache_examples/boltz_0.fasta",
    f"cache_examples/Armadillo_6.fasta",
    f"cache_examples/Covid_3.fasta",
    f"cache_examples/Malaria_2.fasta",
    f"cache_examples/MITOCHONDRIAL_9.fasta",
    f"cache_examples/Monkeypox_4.fasta",
    f"cache_examples/Plasmodium_1.fasta",
    f"cache_examples/PROTOCADHERIN_8.fasta",
    f"cache_examples/Vault_5.fasta",
    f"cache_examples/Zipper_7.fasta",
]

# matching `.pdb` files in the `CACHE_DIR`
example_outputs = [
    os.path.join(CACHE_DIR, os.path.basename(fasta_file).replace(".fasta", ".pdb"))
    for fasta_file in example_fasta_files
]

# must load cached outputs
def load_cached_example_outputs(fasta_file: str) -> str:
    # Find the corresponding `.pdb` file
    pdb_file = os.path.basename(fasta_file).replace(".fasta", ".pdb")
    cached_pdb_path = os.path.join(CACHE_DIR, pdb_file)
    if os.path.exists(cached_pdb_path):
        return cached_pdb_path
    else:
        raise FileNotFoundError(f"Cached output not found for {pdb_file}")

# handle example click
def on_example_click(fasta_file: str) -> str:
    return load_cached_example_outputs(fasta_file)

# run predictions
@spaces.GPU(duration=120)
def predict(data,
            accelerator="gpu", sampling_steps=50,
            diffusion_samples=1):

    print("Arguments passed to `predict` function:")
    print(f"  data: {data}")
    print(f"  accelerator: {accelerator}")
    print(f"  sampling_steps: {sampling_steps}")
    print(f"  diffusion_samples: {diffusion_samples}")
    # we construct the base command
    command = [
        "boltz", "predict",
        "--out_dir", "./",
        "--accelerator", accelerator,
        "--sampling_steps", str(sampling_steps),
        "--diffusion_samples", str(diffusion_samples),
        "--output_format", "pdb",
    ]
    command.extend(["--checkpoint", "./ckpt/boltz1.ckpt"])
    command.append(data)
    result = subprocess.run(command, capture_output=True, text=True)
    if result.returncode == 0:
        print("Prediction completed successfully...!")
        print(f"Output saved to: {out_dir}")
    else:
        print("Prediction failed :(")
        print("Error:", result.stderr)

def run_prediction(input_file, accelerator, sampling_steps,
                   diffusion_samples):
    data = input_file.name
    print("the data : ", data)
    predict(
        data=data,
        accelerator=accelerator,
        sampling_steps=sampling_steps,
        diffusion_samples=diffusion_samples
    )

    # search for the latest .pdb file in the predictions folder
    out_dir = "./"
    search_path = os.path.join(out_dir, "boltz_results*/predictions/**/*.pdb")
    pdb_files = glob.glob(search_path, recursive=True)
    
    if not pdb_files:
        print("No .pdb files found in the predictions folder.")
        return None

    # some manual logic
    # get the latest .pdb file based on modification time
    latest_pdb_file = max(pdb_files, key=os.path.getmtime)
    return latest_pdb_file


with gr.Blocks() as demo:
    gr.Markdown("# 🔬 Boltz-1: Democratizing Biomolecular Interaction Modeling 🧬")

    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"])

            with gr.Accordion("Advanced Settings", open=False):
                accelerator = gr.Radio(choices=["gpu", "cpu"], value="gpu", label="Accelerator")
                sampling_steps = gr.Slider(minimum=1, maximum=500, value=50, step=1, label="Sampling Steps")
                diffusion_samples = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Diffusion Samples")
            btn = gr.Button("Predict")
            
        with gr.Column(scale=3):
            out = Molecule3D(label="Generated Molecule", reps=reps)

        

    btn.click(
        run_prediction,
        inputs=[inp, accelerator, sampling_steps, diffusion_samples],
        outputs=out
    )
    gr.Examples(
                examples=[[fasta_file] for fasta_file in example_fasta_files],
                inputs=[inp],
                outputs=out, 
                fn=lambda fasta_file: on_example_click(fasta_file),
                cache_examples=True
            )

    

if __name__ == "__main__":
    demo.launch(share=True, debug=True)