Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
from gradio_molecule3d import Molecule3D | |
import sys | |
import os | |
# from boltz.main_test import predict # Import your predict function | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
src_path = os.path.join(current_dir, "boltz", "src") | |
sys.path.append(src_path) | |
from boltz.main_test import predict # Import your predict function | |
# from boltz.src.boltz.main_test import predict | |
# Example Molecule3D representation settings | |
reps = [ | |
{ | |
"model": 0, | |
"chain": "", | |
"resname": "", | |
"style": "stick", | |
"color": "whiteCarbon", | |
"residue_range": "", | |
"around": 0, | |
"byres": False, | |
"visible": False | |
} | |
] | |
# Your prediction function | |
def run_prediction(input_file): | |
# Assuming `input_file` is a Gradio `File` object | |
data = input_file.name # Get the path to the uploaded .fasta file | |
out_dir = "./predict" # Set your output directory | |
cache = "~/.boltz" | |
accelerator = "cpu" | |
sampling_steps = 1 | |
diffusion_samples = 1 | |
output_format = "pdb" | |
# Call your original predict function | |
predict( | |
data=data, | |
out_dir=out_dir, | |
accelerator=accelerator, | |
sampling_steps=sampling_steps, | |
output_format=output_format | |
) | |
# Fetch the generated PDB file | |
output_pdb_path = os.path.join(out_dir, "output.pdb") | |
if os.path.exists(output_pdb_path): | |
print("Generated PDB file found:", output_pdb_path) | |
return output_pdb_path # Return the path for Molecule3D to load | |
else: | |
print("Generated PDB file not found") | |
return None | |
# Gradio interface setup | |
with gr.Blocks() as demo: | |
gr.Markdown("# Molecule3D - Upload a .fasta File for Prediction") | |
# Input: File upload component for .fasta files | |
inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"]) | |
# Output: Molecule3D component for rendering the generated PDB file | |
out = Molecule3D(label="Generated Molecule", reps=reps) | |
btn = gr.Button("Predict") | |
# Connect the button click to the prediction function | |
btn.click(run_prediction, inputs=inp, outputs=out) | |
if __name__ == "__main__": | |
demo.launch() | |