import gradio as gr import os from gradio_molecule3d import Molecule3D import sys import os # from boltz.main_test import predict # Import your predict function current_dir = os.path.dirname(os.path.abspath(__file__)) src_path = os.path.join(current_dir, "boltz", "src") sys.path.append(src_path) from boltz.main_test import predict # Import your predict function # from boltz.src.boltz.main_test import predict # Example Molecule3D representation settings reps = [ { "model": 0, "chain": "", "resname": "", "style": "stick", "color": "whiteCarbon", "residue_range": "", "around": 0, "byres": False, "visible": False } ] # Your prediction function def run_prediction(input_file): # Assuming `input_file` is a Gradio `File` object data = input_file.name # Get the path to the uploaded .fasta file out_dir = "./predict" # Set your output directory cache = "~/.boltz" accelerator = "cpu" sampling_steps = 1 diffusion_samples = 1 output_format = "pdb" # Call your original predict function predict( data=data, out_dir=out_dir, accelerator=accelerator, sampling_steps=sampling_steps, output_format=output_format ) # Fetch the generated PDB file output_pdb_path = os.path.join(out_dir, "output.pdb") if os.path.exists(output_pdb_path): print("Generated PDB file found:", output_pdb_path) return output_pdb_path # Return the path for Molecule3D to load else: print("Generated PDB file not found") return None # Gradio interface setup with gr.Blocks() as demo: gr.Markdown("# Molecule3D - Upload a .fasta File for Prediction") # Input: File upload component for .fasta files inp = gr.File(label="Upload a .fasta File", file_types=[".fasta"]) # Output: Molecule3D component for rendering the generated PDB file out = Molecule3D(label="Generated Molecule", reps=reps) btn = gr.Button("Predict") # Connect the button click to the prediction function btn.click(run_prediction, inputs=inp, outputs=out) if __name__ == "__main__": demo.launch()