File size: 2,198 Bytes
07cd985
d4e347d
 
e9c831c
 
 
 
 
 
d4e347d
e9c831c
d4e347d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9c831c
d4e347d
 
e9c831c
d4e347d
 
 
 
 
 
 
 
 
e9c831c
d4e347d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07cd985
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import os
from gradio_molecule3d import Molecule3D
import sys
import os
# from boltz.main_test import predict  # Import your predict function
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.join(current_dir, "boltz", "src")
sys.path.append(src_path)
from boltz.main_test import predict  # Import your predict function
# from boltz.src.boltz.main_test import predict
# Example Molecule3D representation settings
reps = [
    {
        "model": 0,
        "chain": "",
        "resname": "",
        "style": "stick",
        "color": "whiteCarbon",
        "residue_range": "",
        "around": 0,
        "byres": False,
        "visible": False
    }
]

# Your prediction function
def run_prediction(input_file):
    # Assuming `input_file` is a Gradio `File` object
    data = input_file.name  # Get the path to the uploaded .fasta file
    out_dir = "./predict"  # Set your output directory
    cache = "~/.boltz"
    accelerator = "cpu"
    sampling_steps = 1
    diffusion_samples = 1
    output_format = "pdb"

    # Call your original predict function
    predict(
        data=data,
        out_dir=out_dir,
        accelerator=accelerator,
        sampling_steps=sampling_steps,
        output_format=output_format
    )

    # Fetch the generated PDB file
    output_pdb_path = os.path.join(out_dir, "output.pdb")
    if os.path.exists(output_pdb_path):
        print("Generated PDB file found:", output_pdb_path)
        return output_pdb_path  # Return the path for Molecule3D to load
    else:
        print("Generated PDB file not found")
        return None

# Gradio interface setup
with gr.Blocks() as demo:
    gr.Markdown("# Molecule3D - Upload a .fasta File for Prediction")

    # Input: File upload component for .fasta files
    inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"])

    # Output: Molecule3D component for rendering the generated PDB file
    out = Molecule3D(label="Generated Molecule", reps=reps)

    btn = gr.Button("Predict")
    
    # Connect the button click to the prediction function
    btn.click(run_prediction, inputs=inp, outputs=out)

if __name__ == "__main__":
    demo.launch()