boltz / app.py
jadechoghari's picture
add boltz
e9c831c
raw
history blame
2.2 kB
import gradio as gr
import os
from gradio_molecule3d import Molecule3D
import sys
import os
# from boltz.main_test import predict # Import your predict function
current_dir = os.path.dirname(os.path.abspath(__file__))
src_path = os.path.join(current_dir, "boltz", "src")
sys.path.append(src_path)
from boltz.main_test import predict # Import your predict function
# from boltz.src.boltz.main_test import predict
# Example Molecule3D representation settings
reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "stick",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"visible": False
}
]
# Your prediction function
def run_prediction(input_file):
# Assuming `input_file` is a Gradio `File` object
data = input_file.name # Get the path to the uploaded .fasta file
out_dir = "./predict" # Set your output directory
cache = "~/.boltz"
accelerator = "cpu"
sampling_steps = 1
diffusion_samples = 1
output_format = "pdb"
# Call your original predict function
predict(
data=data,
out_dir=out_dir,
accelerator=accelerator,
sampling_steps=sampling_steps,
output_format=output_format
)
# Fetch the generated PDB file
output_pdb_path = os.path.join(out_dir, "output.pdb")
if os.path.exists(output_pdb_path):
print("Generated PDB file found:", output_pdb_path)
return output_pdb_path # Return the path for Molecule3D to load
else:
print("Generated PDB file not found")
return None
# Gradio interface setup
with gr.Blocks() as demo:
gr.Markdown("# Molecule3D - Upload a .fasta File for Prediction")
# Input: File upload component for .fasta files
inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"])
# Output: Molecule3D component for rendering the generated PDB file
out = Molecule3D(label="Generated Molecule", reps=reps)
btn = gr.Button("Predict")
# Connect the button click to the prediction function
btn.click(run_prediction, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()