silterra commited on
Commit
6daa4cc
1 Parent(s): 74830dc

Have gradio app call inference.py

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. main.py +33 -26
  3. run_utils.py +85 -0
Dockerfile CHANGED
@@ -3,7 +3,7 @@ FROM silterra/diffdock-pocket-dev
3
  USER $APPUSER
4
  WORKDIR $HOME/app
5
 
6
- COPY --chown=$APPUSER . $HOME/app
7
 
8
  # Expose port for web service
9
  ENV PORT=7860
 
3
  USER $APPUSER
4
  WORKDIR $HOME/app
5
 
6
+ COPY --chown=$APPUSER: . $HOME/app
7
 
8
  # Expose port for web service
9
  ENV PORT=7860
main.py CHANGED
@@ -1,45 +1,52 @@
1
- import gradio as gr
2
- import torch
3
-
4
- if False:
5
- import requests
6
- from torchvision import transforms
7
- model = torch.hub.load("pytorch/vision:v0.6.0", "resnet18", pretrained=True).eval()
8
- response = requests.get("https://git.io/JJkYN")
9
- labels = response.text.split("\n")
10
 
 
11
 
12
- def predict(inp, *args, **kwargs):
13
- inp = transforms.ToTensor()(inp).unsqueeze(0)
14
- with torch.no_grad():
15
- prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
16
- confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
17
- return confidences
18
 
19
 
20
- def calculate(*args, **kwargs) -> str:
21
- output_file_path = "main_output.txt"
22
- with open(output_file_path, "w") as fi:
23
- fi.write(f"args: {args}\n")
24
- fi.write(f"kwargs: {kwargs}\n")
25
- return output_file_path
26
 
27
 
28
  def run():
29
  iface = gr.Interface(
30
- fn=calculate,
31
  inputs=[
32
  gr.File(label="Protein PDB", file_types=[".pdb"]),
33
  gr.File(label="Ligand SDF", file_types=[".sdf"]),
34
- gr.Number(label="Samples Per Complex", value=4, minimum=1, maximum=100, precision=0),
 
 
 
 
 
 
35
  gr.Checkbox(label="Keep Local Structures", value=True),
36
- gr.Checkbox(label="Save Visualization", value=True)
37
  ],
38
- outputs=gr.File(label="Result")
39
  )
40
 
41
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  if __name__ == "__main__":
 
 
 
45
  run()
 
1
+ import logging
2
+ import os.path
 
 
 
 
 
 
 
3
 
4
+ import gradio as gr
5
 
6
+ import run_utils
 
 
 
 
 
7
 
8
 
9
+ def run_wrapper(protein_file, ligand_file, *args, **kwargs) -> str:
10
+ return run_utils.run_cli_command(protein_file.name, ligand_file.name,
11
+ *args, **kwargs)
 
 
 
12
 
13
 
14
  def run():
15
  iface = gr.Interface(
16
+ fn=run_wrapper,
17
  inputs=[
18
  gr.File(label="Protein PDB", file_types=[".pdb"]),
19
  gr.File(label="Ligand SDF", file_types=[".sdf"]),
20
+ gr.Number(
21
+ label="Samples Per Complex",
22
+ value=1,
23
+ minimum=1,
24
+ maximum=100,
25
+ precision=0,
26
+ ),
27
  gr.Checkbox(label="Keep Local Structures", value=True),
28
+ gr.Checkbox(label="Save Visualisation", value=True),
29
  ],
30
+ outputs=gr.File(label="Result"),
31
  )
32
 
33
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
34
+
35
+
36
+ def set_env_variables():
37
+ if "DiffDock-Pocket-Dir" not in os.environ:
38
+ work_dir = os.path.abspath(os.path.join("../DiffDock-Pocket"))
39
+ if os.path.exists(work_dir):
40
+ os.environ["DiffDock-Pocket-Dir"] = work_dir
41
+ else:
42
+ raise ValueError(f"DiffDock-Pocket-Dir {work_dir} not found")
43
+
44
+ if "LOG_LEVEL" not in os.environ:
45
+ os.environ["LOG_LEVEL"] = "INFO"
46
 
47
 
48
  if __name__ == "__main__":
49
+ set_env_variables()
50
+ run_utils.configure_logging()
51
+
52
  run()
run_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import shutil
4
+ import subprocess
5
+ import tempfile
6
+ import uuid
7
+
8
+ import logging
9
+
10
+
11
+ def configure_logging(level=None):
12
+ if level is None:
13
+ level = getattr(logging, os.environ.get("LOG_LEVEL", "INFO"))
14
+
15
+ # Note that this sets the universal logger,
16
+ # which includes other libraries.
17
+ logging.basicConfig(
18
+ level=level,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
20
+ handlers=[
21
+ logging.StreamHandler(), # Outputs logs to stderr by default
22
+ # If you also want to log to a file, uncomment the following line:
23
+ # logging.FileHandler('my_app.log', mode='a', encoding='utf-8')
24
+ ]
25
+ )
26
+
27
+
28
+ def run_cli_command(protein_path: str, ligand: str, samples_per_complex: int,
29
+ keep_local_structures: bool, save_visualisation: bool, work_dir=None):
30
+
31
+ if work_dir is None:
32
+ work_dir = os.environ.get("DiffDock-Pocket-Dir",
33
+ os.path.join(os.environ["HOME"], "DiffDock-Pocket"))
34
+
35
+ command = ["python3", "inference.py", f"--protein_path={protein_path}", f"--ligand={ligand}",
36
+ f"--samples_per_complex={samples_per_complex}"]
37
+
38
+ # Adding boolean arguments only if they are True
39
+ if keep_local_structures:
40
+ command.append("--keep_local_structures")
41
+ if save_visualisation:
42
+ command.append("--save_visualisation")
43
+
44
+ with tempfile.TemporaryDirectory() as temp_dir:
45
+ temp_dir_path = temp_dir
46
+ logging.debug(f"temp dir: {temp_dir}")
47
+ command.append(f"--out_dir={temp_dir_path}")
48
+
49
+ # Convert command list to string for printing
50
+ command_str = " ".join(command)
51
+ logging.info(f"Executing command: {command_str}")
52
+
53
+ # Running the command
54
+ try:
55
+ result = subprocess.run(
56
+ command, cwd=work_dir, check=False, text=True, capture_output=True, env=os.environ
57
+ )
58
+ logging.debug(f"Command output:\n{result.stdout}")
59
+ if result.stderr:
60
+ logging.error(f"Command error:\n{result.stderr}")
61
+ except subprocess.CalledProcessError as e:
62
+ logging.error(f"An error occurred while executing the command: {e}")
63
+
64
+ # Zip the output directory
65
+ # Generate a unique filename using a timestamp and a UUID
66
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
67
+ uuid_tag = str(uuid.uuid4())[0:8]
68
+ unique_filename = f"output_{timestamp}_{uuid_tag}"
69
+ zip_base_name = os.path.join("tmp", unique_filename)
70
+ full_zip_path = shutil.make_archive(zip_base_name, 'zip', temp_dir)
71
+
72
+ logging.debug(f"Directory '{temp_dir}' zipped to '{full_zip_path}'")
73
+
74
+ return full_zip_path
75
+
76
+
77
+ if False and __name__ == "__main__":
78
+ # Testing code
79
+ work_dir = os.path.expanduser("~/Projects/DiffDock-Pocket")
80
+ os.environ["DiffDock-Pocket-Dir"] = work_dir
81
+ protein_path = os.path.join(work_dir, "example_data", "3dpf_protein.pdb")
82
+ ligand = os.path.join(work_dir, "example_data", "3dpf_ligand.sdf")
83
+
84
+ run_cli_command(protein_path, ligand, samples_per_complex=1,
85
+ keep_local_structures=True, save_visualisation=True)