colbyford commited on
Commit
094fd02
1 Parent(s): 620584c

Add protein visualization capabilities

Browse files
Files changed (2) hide show
  1. app.py +116 -25
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
 
2
  import gradio as gr
3
  import numpy as np
 
4
  import torch
5
  import py3Dmol
6
  from huggingface_hub import login
@@ -12,6 +13,9 @@ from esm.sdk.api import (
12
  GenerationConfig,
13
  )
14
 
 
 
 
15
  theme = gr.themes.Monochrome(
16
  primary_hue="gray",
17
  )
@@ -28,26 +32,51 @@ def get_model(model_name, token):
28
  # model = ESM3.from_pretrained(model_name, device=torch.device("cpu"))
29
  return model
30
 
31
- ## Function to render 3D structure using py3Dmol
32
- def render_pdb(pdb_string, motif_start=None, motif_end=None):
33
- view = py3Dmol.view(width=800, height=800)
34
- view.addModel(pdb_string, "pdb")
35
- view.setStyle({"cartoon": {"color": "spectrum"}})
36
- if motif_start is not None and motif_end is not None:
37
- motif_inds = np.arange(motif_start, motif_end)
38
- view.setStyle({"cartoon": {"color": "lightgrey"}})
39
- motif_res_inds = (motif_inds + 1).tolist()
40
- view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
41
- view.zoomTo()
42
- return view
43
-
44
  ## Function to get PDB data
45
  def get_pdb(pdb_id, chain_id):
46
  pdb = ProteinChain.from_rcsb(pdb_id, chain_id)
47
  # return [pdb.sequence, render_pdb(pdb.to_pdb_string())]
48
  return pdb
49
 
 
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def scaffold(model_name, token, pdb_id, chain_id, motif_start, motif_end, prompt_length, insert_size):
52
  pdb = get_pdb(pdb_id, chain_id)
53
 
@@ -75,16 +104,39 @@ def scaffold(model_name, token, pdb_id, chain_id, motif_start, motif_end, prompt
75
  sequence_generation = model.generate(protein_prompt, sequence_generation_config)
76
  generated_sequence = sequence_generation.sequence
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return [
79
  pdb.sequence,
80
  motif_sequence,
 
81
  # motif_atom37_positions,
82
  sequence_prompt,
83
  # structure_prompt,
84
  # protein_prompt
85
- generated_sequence
 
 
 
86
  ]
87
 
 
88
  def ss_edit(model_name, token, pdb_id, chain_id, region_start, region_end, shortened_region_length, shortening_ss8):
89
  pdb = get_pdb(pdb_id, chain_id)
90
  edit_region = np.arange(region_start, region_end)
@@ -109,17 +161,28 @@ def ss_edit(model_name, token, pdb_id, chain_id, region_start, region_end, short
109
  model = get_model(model_name, token)
110
  sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=protein_prompt.sequence.count("_") // 2, temperature=0.5))
111
 
 
 
 
 
 
 
 
 
112
  return [
113
  original_sequence,
114
  original_ss8,
115
  original_ss8_region,
 
116
  sequence_prompt,
117
  ss8_prompt,
118
  proposed_ss8_region,
119
  # protein_prompt,
120
- sequence_generation
 
121
  ]
122
 
 
123
  def sasa_edit(model_name, token, pdb_id, chain_id, span_start, span_end, n_samples):
124
  pdb = get_pdb(pdb_id, chain_id)
125
 
@@ -144,10 +207,16 @@ def sasa_edit(model_name, token, pdb_id, chain_id, span_start, span_end, n_sampl
144
  ## Sort generations by ptm
145
  generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True)
146
 
 
 
 
 
147
  return [
148
- protein_prompt,
149
- sequence_generation,
150
- generated_proteins
 
 
151
  ]
152
 
153
 
@@ -166,13 +235,14 @@ scaffold_app = gr.Interface(
166
  ],
167
  outputs=[
168
  gr.Textbox(label="Sequence"),
169
- # gr.Plot(label="3D Structure")
170
  gr.Textbox(label="Motif Sequence"),
 
171
  # gr.Textbox(label="Motif Positions")
172
  gr.Textbox(label="Sequence Prompt"),
173
  # gr.Textbox(label="Structure Prompt"),
174
  # gr.Textbox(label="Protein Prompt"),
175
- gr.Textbox(label="Generated Sequence")
 
176
  ]
177
  )
178
 
@@ -193,11 +263,13 @@ ss_app = gr.Interface(
193
  gr.Textbox(label="Original Sequence"),
194
  gr.Textbox(label="Original SS8"),
195
  gr.Textbox(label="Original SS8 Edit Region"),
 
196
  gr.Textbox(label="Sequence Prompt"),
197
  gr.Textbox(label="Edited SS8 Prompt"),
198
  gr.Textbox(label="Proposed SS8 of Edit Region"),
199
  # gr.Textbox(label="Protein Prompt"),
200
- gr.Textbox(label="Generated Sequence")
 
201
  ]
202
  )
203
 
@@ -212,15 +284,32 @@ sasa_app = gr.Interface(
212
  gr.Number(value=105, label="Span Start"),
213
  gr.Number(value=116, label="Span End"),
214
  # gr.Textbox(value="CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC", label="SS8 String")
215
- gr.Number(value=4, label="Number of Samples")
216
  ],
217
  outputs = [
218
  gr.Textbox(label="Protein Prompt"),
 
219
  gr.Textbox(label="Generated Sequences"),
220
- gr.Textbox(label="Generated Proteins")
 
221
  ]
222
  )
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  ## Main Interface
225
  with gr.Blocks(theme=theme) as esm_app:
226
  with gr.Row():
@@ -239,12 +328,14 @@ with gr.Blocks(theme=theme) as esm_app:
239
  gr.TabbedInterface([
240
  scaffold_app,
241
  ss_app,
242
- sasa_app
 
243
  ],
244
  [
245
  "Scaffolding Example",
246
  "Secondary Structure Editing Example",
247
- "SASA Editing Example"
 
248
  ])
249
 
250
  if __name__ == "__main__":
 
1
 
2
  import gradio as gr
3
  import numpy as np
4
+ import os, tempfile
5
  import torch
6
  import py3Dmol
7
  from huggingface_hub import login
 
13
  GenerationConfig,
14
  )
15
 
16
+ from gradio_molecule3d import Molecule3D
17
+
18
+
19
  theme = gr.themes.Monochrome(
20
  primary_hue="gray",
21
  )
 
32
  # model = ESM3.from_pretrained(model_name, device=torch.device("cpu"))
33
  return model
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ## Function to get PDB data
36
  def get_pdb(pdb_id, chain_id):
37
  pdb = ProteinChain.from_rcsb(pdb_id, chain_id)
38
  # return [pdb.sequence, render_pdb(pdb.to_pdb_string())]
39
  return pdb
40
 
41
+ ## Function to generate rep for 3D structure
42
+ def make_reps(res_start=None, res_end=None, main_color="whiteCarbon", highlight_color="redCarbon", main_style="cartoon", highlight_style="cartoon"):
43
+ residue_range = f"{res_start}-{res_end}" if res_start != res_end else ""
44
 
45
+ return [
46
+ {
47
+ "model": 0,
48
+ "chain": "",
49
+ "resname": "",
50
+ "style": main_style,
51
+ "color": main_color,
52
+ "residue_range": "",
53
+ "around": 0,
54
+ "byres": False,
55
+ "visible": True
56
+ },
57
+ {
58
+ "model": 0,
59
+ "chain": "",
60
+ "resname": "",
61
+ "style": highlight_style,
62
+ "color": highlight_color,
63
+ "residue_range": residue_range,
64
+ "around": 0,
65
+ "byres": False,
66
+ "visible": True
67
+ }]
68
+
69
+ ## Function to render 3D structure
70
+ def render_pdb(pdb_id, chain_id, res_start, res_end, pdb_string=None):
71
+ if pdb_string is None:
72
+ pdb_string = get_pdb(pdb_id, chain_id).to_pdb_string()
73
+ ## Write to temporary file and read back in to get the 3D structure
74
+ tmp_pdb = tempfile.NamedTemporaryFile(delete=False, prefix=f"{pdb_id}_chain{chain_id}_", suffix=".pdb")
75
+ tmp_pdb.write(str.encode(pdb_string))
76
+
77
+ return Molecule3D(tmp_pdb.name, reps=make_reps(res_start=res_start, res_end=res_end))
78
+
79
+ ## Function for Scaffolding
80
  def scaffold(model_name, token, pdb_id, chain_id, motif_start, motif_end, prompt_length, insert_size):
81
  pdb = get_pdb(pdb_id, chain_id)
82
 
 
104
  sequence_generation = model.generate(protein_prompt, sequence_generation_config)
105
  generated_sequence = sequence_generation.sequence
106
 
107
+ ## Generate structure
108
+ structure_prediction_config = GenerationConfig(
109
+ track="structure", # We want ESM3 to generate tokens for the structure track
110
+ num_steps=len(sequence_generation) // 8,
111
+ temperature=0.7,
112
+ )
113
+ structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
114
+ structure_prediction = model.generate(structure_prediction_prompt, structure_prediction_config)
115
+ ## Convert the generated structure to a back into a ProteinChain object
116
+ structure_prediction_chain = structure_prediction.to_protein_chain()
117
+ motif_inds_in_generation = np.arange(insert_size, insert_size+len(motif_sequence))
118
+ structure_prediction_chain.align(pdb, mobile_inds=motif_inds_in_generation, target_inds=motif_inds)
119
+ # crmsd = structure_prediction_chain.rmsd(renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds)
120
+
121
+ structure_orig_highlight = render_pdb(pdb_id, chain_id, res_start=motif_start, res_end=motif_end)
122
+ structure_new_highlight = render_pdb(pdb_id, chain_id, res_start=insert_size, res_end=insert_size+len(motif_sequence),
123
+ pdb_string=structure_prediction_chain.to_pdb_string())
124
+
125
  return [
126
  pdb.sequence,
127
  motif_sequence,
128
+ structure_orig_highlight,
129
  # motif_atom37_positions,
130
  sequence_prompt,
131
  # structure_prompt,
132
  # protein_prompt
133
+ generated_sequence,
134
+ # structure_prediction,
135
+ # structure_prediction_chain,
136
+ structure_new_highlight
137
  ]
138
 
139
+ ## Function for Secondary Structure Editing
140
  def ss_edit(model_name, token, pdb_id, chain_id, region_start, region_end, shortened_region_length, shortening_ss8):
141
  pdb = get_pdb(pdb_id, chain_id)
142
  edit_region = np.arange(region_start, region_end)
 
161
  model = get_model(model_name, token)
162
  sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=protein_prompt.sequence.count("_") // 2, temperature=0.5))
163
 
164
+ ## Generate structure
165
+ structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track="structure", num_steps=len(protein_prompt) // 4, temperature=0))
166
+ structure_prediction_chain = structure_prediction.to_protein_chain()
167
+
168
+ structure_orig_highlight = render_pdb(pdb_id, chain_id, res_start=region_start, res_end=region_end)
169
+ structure_new_highlight = render_pdb(pdb_id, chain_id, res_start=region_start, res_end=region_end,
170
+ pdb_string=structure_prediction_chain.to_pdb_string())
171
+
172
  return [
173
  original_sequence,
174
  original_ss8,
175
  original_ss8_region,
176
+ structure_orig_highlight,
177
  sequence_prompt,
178
  ss8_prompt,
179
  proposed_ss8_region,
180
  # protein_prompt,
181
+ sequence_generation,
182
+ structure_new_highlight
183
  ]
184
 
185
+ ## Function for SASA Editing
186
  def sasa_edit(model_name, token, pdb_id, chain_id, span_start, span_end, n_samples):
187
  pdb = get_pdb(pdb_id, chain_id)
188
 
 
207
  ## Sort generations by ptm
208
  generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True)
209
 
210
+ structure_orig_highlight = render_pdb(pdb_id, chain_id, res_start=span_start, res_end=span_end)
211
+ structure_new_highlight = render_pdb(pdb_id, chain_id, res_start=span_start, res_end=span_end,
212
+ pdb_string=generated_proteins[0].to_protein_chain().to_pdb_string())
213
+
214
  return [
215
+ protein_prompt.sequence,
216
+ structure_orig_highlight,
217
+ [seq.sequence for seq in sequence_generation],
218
+ # [pro.sequence for pro in generated_proteins]
219
+ structure_new_highlight
220
  ]
221
 
222
 
 
235
  ],
236
  outputs=[
237
  gr.Textbox(label="Sequence"),
 
238
  gr.Textbox(label="Motif Sequence"),
239
+ Molecule3D(label="Original Structure"),
240
  # gr.Textbox(label="Motif Positions")
241
  gr.Textbox(label="Sequence Prompt"),
242
  # gr.Textbox(label="Structure Prompt"),
243
  # gr.Textbox(label="Protein Prompt"),
244
+ gr.Textbox(label="Generated Sequence"),
245
+ Molecule3D(label="Generated Structure")
246
  ]
247
  )
248
 
 
263
  gr.Textbox(label="Original Sequence"),
264
  gr.Textbox(label="Original SS8"),
265
  gr.Textbox(label="Original SS8 Edit Region"),
266
+ Molecule3D(label="Original Structure"),
267
  gr.Textbox(label="Sequence Prompt"),
268
  gr.Textbox(label="Edited SS8 Prompt"),
269
  gr.Textbox(label="Proposed SS8 of Edit Region"),
270
  # gr.Textbox(label="Protein Prompt"),
271
+ gr.Textbox(label="Generated Sequence"),
272
+ Molecule3D(label="Generated Structure")
273
  ]
274
  )
275
 
 
284
  gr.Number(value=105, label="Span Start"),
285
  gr.Number(value=116, label="Span End"),
286
  # gr.Textbox(value="CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC", label="SS8 String")
287
+ gr.Number(value=1, label="Number of Samples")
288
  ],
289
  outputs = [
290
  gr.Textbox(label="Protein Prompt"),
291
+ Molecule3D(label="Original Structure"),
292
  gr.Textbox(label="Generated Sequences"),
293
+ # gr.Textbox(label="Generated Proteins")
294
+ Molecule3D(label="Best Generated Structure")
295
  ]
296
  )
297
 
298
+
299
+ protein_viewer = gr.Interface(
300
+ fn=render_pdb,
301
+ inputs=[
302
+ gr.Textbox(value = "1LBS", label="PDB ID"),
303
+ gr.Textbox(value = "A", label="Chain ID"),
304
+ gr.Number(value=10, label="Residue Highlight Start"),
305
+ gr.Number(value=20, label="Residue Highlight End")
306
+ ],
307
+ outputs=[
308
+ Molecule3D(label="3D Structure")
309
+ ]
310
+ )
311
+
312
+
313
  ## Main Interface
314
  with gr.Blocks(theme=theme) as esm_app:
315
  with gr.Row():
 
328
  gr.TabbedInterface([
329
  scaffold_app,
330
  ss_app,
331
+ sasa_app,
332
+ protein_viewer
333
  ],
334
  [
335
  "Scaffolding Example",
336
  "Secondary Structure Editing Example",
337
+ "SASA Editing Example",
338
+ "PDB Viewer"
339
  ])
340
 
341
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -2,4 +2,5 @@ esm
2
  numpy
3
  torch>=2.3.0
4
  py3Dmol
5
- huggingface_hub
 
 
2
  numpy
3
  torch>=2.3.0
4
  py3Dmol
5
+ huggingface_hub
6
+ gradio_molecule3d