Spaces:
Sleeping
Sleeping
Massimo G. Totaro
commited on
Commit
β’
fba8f5e
1
Parent(s):
82caf01
update fix
Browse files- .gitignore +3 -1
- LICENSE +11 -0
- README.md +2 -2
- app.py +90 -19
- data.py +169 -40
- instructions.md +39 -13
- model.py +74 -47
- requirements.txt +1 -1
.gitignore
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
|
|
|
|
|
|
1 |
+
Dockerfile
|
2 |
+
*.ipynb
|
3 |
+
*/
|
LICENSE
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2021, Massimo G. Totaro All rights reserved.
|
2 |
+
|
3 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
4 |
+
|
5 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
6 |
+
|
7 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
8 |
+
|
9 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
10 |
+
|
11 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
CHANGED
@@ -4,10 +4,10 @@ emoji: π
|
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.8.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: bsd-2-clause
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -1,26 +1,97 @@
|
|
1 |
-
from model import MODELS
|
2 |
-
from data import Data
|
3 |
-
import gradio as gr
|
4 |
from tempfile import NamedTemporaryFile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
|
9 |
def app(*argv):
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
|
|
25 |
if __name__ == "__main__":
|
26 |
-
|
|
|
|
|
|
|
|
|
1 |
from tempfile import NamedTemporaryFile
|
2 |
+
from gradio import Blocks, Button, Checkbox, Dropdown, Examples, File, HTML, Markdown, Textbox
|
3 |
+
|
4 |
+
from model import get_models
|
5 |
+
from data import Data
|
6 |
+
|
7 |
+
# Define scoring strategies
|
8 |
+
SCORING = ["wt-marginals", "masked-marginals"]
|
9 |
|
10 |
+
# Get available models
|
11 |
+
MODELS = get_models()
|
12 |
|
13 |
def app(*argv):
|
14 |
+
"""
|
15 |
+
Main application function
|
16 |
+
"""
|
17 |
+
# Unpack the arguments
|
18 |
+
seq, trg, model_name, *_ = argv
|
19 |
+
scoring = SCORING[scoring_strategy.value]
|
20 |
+
try:
|
21 |
+
# Calculate the data based on the input parameters
|
22 |
+
data = Data(seq, trg, model_name, scoring, out_file).calculate()
|
23 |
+
except Exception as e:
|
24 |
+
# If an error occurs, return an HTML error message
|
25 |
+
return f'<!DOCTYPE html><html><body><h1 style="background-color:#F70D1A;text-align:center;">Error: {str(e)}</h1></body></html>', None
|
26 |
+
# If no error occurs, return the calculated data
|
27 |
+
return repr(data), File(value=out_file.name, visible=True)
|
28 |
|
29 |
+
# Create the Gradio interface
|
30 |
+
with open("instructions.md", "r", encoding="utf-8") as md,\
|
31 |
+
NamedTemporaryFile(mode='w+') as out_file,\
|
32 |
+
Blocks() as esm_scan:
|
33 |
+
|
34 |
+
# Define the interface components
|
35 |
+
Markdown(md.read())
|
36 |
+
seq = Textbox(
|
37 |
+
lines=2,
|
38 |
+
label="Sequence",
|
39 |
+
placeholder="FASTA sequence here...",
|
40 |
+
value=''
|
41 |
+
)
|
42 |
+
trg = Textbox(
|
43 |
+
lines=1,
|
44 |
+
label="Substitutions",
|
45 |
+
placeholder="Substitutions here...",
|
46 |
+
value=""
|
47 |
+
)
|
48 |
+
model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
|
49 |
+
scoring_strategy = Checkbox(value=True, label="Use masked-marginals scoring")
|
50 |
+
btn = Button(value="Run")
|
51 |
+
out = HTML()
|
52 |
+
bto = File(
|
53 |
+
value=out_file.name,
|
54 |
+
visible=False,
|
55 |
+
label="Download",
|
56 |
+
file_count='single',
|
57 |
+
interactive=False
|
58 |
+
)
|
59 |
+
btn.click(
|
60 |
+
fn=app,
|
61 |
+
inputs=[seq, trg, model_name],
|
62 |
+
outputs=[out, bto]
|
63 |
+
)
|
64 |
+
ex = Examples(
|
65 |
+
examples=[
|
66 |
+
[
|
67 |
+
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
|
68 |
+
"deep mutational scanning",
|
69 |
+
"facebook/esm2_t6_8M_UR50D"
|
70 |
+
],
|
71 |
+
[
|
72 |
+
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
|
73 |
+
"217 218 219",
|
74 |
+
"facebook/esm2_t12_35M_UR50D"
|
75 |
+
],
|
76 |
+
[
|
77 |
+
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
|
78 |
+
"R218K R218S R218N R218A R218V R218D",
|
79 |
+
"facebook/esm2_t30_150M_UR50D",
|
80 |
+
],
|
81 |
+
[
|
82 |
+
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
|
83 |
+
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
|
84 |
+
"facebook/esm2_t33_650M_UR50D",
|
85 |
+
],
|
86 |
+
],
|
87 |
+
inputs=[seq,
|
88 |
+
trg,
|
89 |
+
model_name],
|
90 |
+
outputs=[out,
|
91 |
+
bto],
|
92 |
+
fn=app
|
93 |
+
)
|
94 |
|
95 |
+
# Launch the Gradio interface
|
96 |
if __name__ == "__main__":
|
97 |
+
esm_scan.launch()
|
data.py
CHANGED
@@ -1,80 +1,209 @@
|
|
|
|
|
|
|
|
|
|
1 |
from model import Model
|
|
|
|
|
|
|
2 |
import pandas as pd
|
3 |
-
|
|
|
|
|
4 |
|
5 |
class Data:
|
6 |
"""Container for input and output data"""
|
7 |
-
|
8 |
model = Model()
|
9 |
|
10 |
-
def parse_seq(self, src:str):
|
11 |
-
"
|
12 |
-
self.seq = src.strip().upper()
|
13 |
-
if not all(x in self.model.alphabet for x in
|
14 |
raise RuntimeError("Unrecognised characters in sequence")
|
15 |
|
16 |
-
def parse_sub(self, trg:str):
|
17 |
-
"
|
18 |
self.mode = None
|
19 |
self.sub = list()
|
20 |
self.trg = trg.strip().upper()
|
|
|
21 |
|
22 |
-
|
23 |
-
if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq)
|
24 |
-
|
25 |
-
|
|
|
26 |
if src != trg:
|
27 |
self.sub.append(f"{src}{resi}{trg}")
|
|
|
28 |
else:
|
29 |
self.trg = self.trg.split()
|
30 |
-
if all(match(r'\d+', x) for x in self.trg):
|
|
|
31 |
self.mode = 'DMS'
|
32 |
for resi in map(int, self.trg):
|
33 |
src = self.seq[resi-1]
|
34 |
-
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''):
|
35 |
self.sub.append(f"{src}{resi}{trg}")
|
36 |
-
|
|
|
|
|
37 |
self.mode = 'MUT'
|
38 |
self.sub = self.trg
|
|
|
|
|
|
|
|
|
39 |
else:
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
self.sub = pd.DataFrame(self.sub, columns=['0'])
|
43 |
|
44 |
-
def __init__(self, src:str, trg:str, model_name:str, scoring_strategy:str, out_file):
|
45 |
"initialise data"
|
46 |
# if model has changed, load new model
|
47 |
if self.model.model_name != model_name:
|
48 |
self.model_name = model_name
|
49 |
self.model = Model(model_name)
|
50 |
self.parse_seq(src)
|
|
|
51 |
self.parse_sub(trg)
|
52 |
self.scoring_strategy = scoring_strategy
|
|
|
53 |
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
|
54 |
-
self.
|
|
|
55 |
|
56 |
-
def parse_output(self) ->
|
57 |
"format output data for visualisation"
|
58 |
-
if self.mode == '
|
59 |
-
self.
|
60 |
-
|
61 |
-
self.
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
def calculate(self):
|
78 |
"run model and parse output"
|
79 |
self.model.run_model(self)
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import ceil
|
2 |
+
from re import match
|
3 |
+
import seaborn as sns
|
4 |
+
|
5 |
from model import Model
|
6 |
+
|
7 |
+
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
import pandas as pd
|
10 |
+
import seaborn as sns
|
11 |
+
|
12 |
+
from model import Model
|
13 |
|
14 |
class Data:
|
15 |
"""Container for input and output data"""
|
16 |
+
# Initialise empty model as static class member for efficiency
|
17 |
model = Model()
|
18 |
|
19 |
+
def parse_seq(self, src: str):
|
20 |
+
"""Parse input sequence"""
|
21 |
+
self.seq = src.strip().upper().replace('\n', '')
|
22 |
+
if not all(x in self.model.alphabet for x in self.seq):
|
23 |
raise RuntimeError("Unrecognised characters in sequence")
|
24 |
|
25 |
+
def parse_sub(self, trg: str):
|
26 |
+
"""Parse input substitutions"""
|
27 |
self.mode = None
|
28 |
self.sub = list()
|
29 |
self.trg = trg.strip().upper()
|
30 |
+
self.resi = list()
|
31 |
|
32 |
+
# Identify running mode
|
33 |
+
if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq) and all(match(r'\w+', x) for x in self.trg):
|
34 |
+
# If single string of same length as sequence, seq vs seq mode
|
35 |
+
self.mode = 'MUT'
|
36 |
+
for resi, (src, trg) in enumerate(zip(self.seq, self.trg), 1):
|
37 |
if src != trg:
|
38 |
self.sub.append(f"{src}{resi}{trg}")
|
39 |
+
self.resi.append(resi)
|
40 |
else:
|
41 |
self.trg = self.trg.split()
|
42 |
+
if all(match(r'\d+', x) for x in self.trg):
|
43 |
+
# If all strings are numbers, deep mutational scanning mode
|
44 |
self.mode = 'DMS'
|
45 |
for resi in map(int, self.trg):
|
46 |
src = self.seq[resi-1]
|
47 |
+
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
|
48 |
self.sub.append(f"{src}{resi}{trg}")
|
49 |
+
self.resi.append(resi)
|
50 |
+
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg):
|
51 |
+
# If all strings are of the form X#Y, single substitution mode
|
52 |
self.mode = 'MUT'
|
53 |
self.sub = self.trg
|
54 |
+
self.resi = [int(x[1:-1]) for x in self.trg]
|
55 |
+
for s, *resi, _ in self.trg:
|
56 |
+
if self.seq[int(''.join(resi))-1] != s:
|
57 |
+
raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}")
|
58 |
else:
|
59 |
+
self.mode = 'TMS'
|
60 |
+
for resi, src in enumerate(self.seq, 1):
|
61 |
+
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''):
|
62 |
+
self.sub.append(f"{src}{resi}{trg}")
|
63 |
+
self.resi.append(resi)
|
64 |
+
|
65 |
self.sub = pd.DataFrame(self.sub, columns=['0'])
|
66 |
|
67 |
+
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file=None):
|
68 |
"initialise data"
|
69 |
# if model has changed, load new model
|
70 |
if self.model.model_name != model_name:
|
71 |
self.model_name = model_name
|
72 |
self.model = Model(model_name)
|
73 |
self.parse_seq(src)
|
74 |
+
self.offset = 0
|
75 |
self.parse_sub(trg)
|
76 |
self.scoring_strategy = scoring_strategy
|
77 |
+
self.token_probs = None
|
78 |
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
|
79 |
+
self.out_str = None
|
80 |
+
self.out_buffer = out_file.name if 'name' in dir(out_file) else out_file
|
81 |
|
82 |
+
def parse_output(self) -> None:
|
83 |
"format output data for visualisation"
|
84 |
+
if self.mode == 'TMS':
|
85 |
+
self.process_tms_mode()
|
86 |
+
else:
|
87 |
+
if self.mode == 'DMS':
|
88 |
+
self.sort_by_residue_and_score()
|
89 |
+
elif self.mode == 'MUT':
|
90 |
+
self.sort_by_score()
|
91 |
+
else:
|
92 |
+
raise RuntimeError(f"Unrecognised mode {self.mode}")
|
93 |
+
if self.out_buffer:
|
94 |
+
self.out.round(2).to_csv(self.out_buffer, index=False, header=False)
|
95 |
+
self.out_str = (self.out.style
|
96 |
+
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
|
97 |
+
.hide(axis=0)
|
98 |
+
.hide(axis=1)
|
99 |
+
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
|
100 |
+
.to_html(justify='center'))
|
101 |
+
|
102 |
+
def sort_by_score(self):
|
103 |
+
self.out = self.out.sort_values(self.model_name, ascending=False)
|
104 |
+
|
105 |
+
def sort_by_residue_and_score(self):
|
106 |
+
self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
|
107 |
+
.sort_values(['resi', self.model_name], ascending=[True,False])
|
108 |
+
.groupby(['resi'])
|
109 |
+
.head(19)
|
110 |
+
.drop(['resi'], axis=1))
|
111 |
+
self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)]
|
112 |
+
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns')
|
113 |
+
|
114 |
+
def process_tms_mode(self):
|
115 |
+
self.out = self.assign_resi_and_group()
|
116 |
+
self.out = self.concat_and_set_axis()
|
117 |
+
self.out /= self.out.abs().max().max()
|
118 |
+
divs = self.calculate_divs()
|
119 |
+
ncols = min(divs, key=lambda x: abs(x-60))
|
120 |
+
nrows = ceil(self.out.shape[1]/ncols)
|
121 |
+
ncols = self.adjust_ncols(ncols, nrows)
|
122 |
+
self.plot_heatmap(ncols, nrows)
|
123 |
+
|
124 |
+
def assign_resi_and_group(self):
|
125 |
+
return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int))
|
126 |
+
.groupby(['resi'])
|
127 |
+
.head(19))
|
128 |
+
|
129 |
+
def concat_and_set_axis(self):
|
130 |
+
return (pd.concat([(self.out.iloc[19*x:19*(x+1)]
|
131 |
+
.pipe(self.create_dataframe)
|
132 |
+
.sort_values(['0'], ascending=[True])
|
133 |
+
.drop(['resi', '0'], axis=1)
|
134 |
+
.set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
|
135 |
+
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'])
|
136 |
+
.astype(float)
|
137 |
+
) for x in range(self.out.shape[0]//19)]
|
138 |
+
, axis=1)
|
139 |
+
.set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns'))
|
140 |
+
|
141 |
+
def create_dataframe(self, df):
|
142 |
+
return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True)
|
143 |
+
|
144 |
+
def calculate_divs(self):
|
145 |
+
return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60]
|
146 |
+
|
147 |
+
def adjust_ncols(self, ncols, nrows):
|
148 |
+
while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]:
|
149 |
+
ncols -= 1
|
150 |
+
return ncols + 1
|
151 |
+
|
152 |
+
def plot_heatmap(self, ncols, nrows):
|
153 |
+
if nrows < 2:
|
154 |
+
self.plot_single_heatmap()
|
155 |
+
else:
|
156 |
+
self.plot_multiple_heatmaps(ncols, nrows)
|
157 |
+
|
158 |
+
if self.out_buffer:
|
159 |
+
plt.savefig(self.out_buffer, format='svg')
|
160 |
+
with open(self.out_buffer, 'r', encoding='utf-8') as f:
|
161 |
+
self.out_str = f.read()
|
162 |
+
|
163 |
+
def plot_single_heatmap(self):
|
164 |
+
fig = plt.figure(figsize=(12, 6))
|
165 |
+
sns.heatmap(self.out
|
166 |
+
, cmap='RdBu'
|
167 |
+
, cbar=False
|
168 |
+
, square=True
|
169 |
+
, xticklabels=1
|
170 |
+
, yticklabels=1
|
171 |
+
, center=0
|
172 |
+
, annot=self.out.map(lambda x: ' ' if x != 0 else 'Β·')
|
173 |
+
, fmt='s'
|
174 |
+
, annot_kws={'size': 'xx-large'})
|
175 |
+
fig.tight_layout()
|
176 |
+
|
177 |
+
def plot_multiple_heatmaps(self, ncols, nrows):
|
178 |
+
fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows))
|
179 |
+
for i in range(nrows):
|
180 |
+
tmp = self.out.iloc[:,i*ncols:(i+1)*ncols]
|
181 |
+
label = tmp.map(lambda x: ' ' if x != 0 else 'Β·')
|
182 |
+
sns.heatmap(tmp
|
183 |
+
, ax=ax[i]
|
184 |
+
, cmap='RdBu'
|
185 |
+
, cbar=False
|
186 |
+
, square=True
|
187 |
+
, xticklabels=1
|
188 |
+
, yticklabels=1
|
189 |
+
, center=0
|
190 |
+
, annot=label
|
191 |
+
, fmt='s'
|
192 |
+
, annot_kws={'size': 'xx-large'})
|
193 |
+
ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0)
|
194 |
+
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90)
|
195 |
+
fig.tight_layout()
|
196 |
+
|
197 |
def calculate(self):
|
198 |
"run model and parse output"
|
199 |
self.model.run_model(self)
|
200 |
+
self.parse_output()
|
201 |
+
return self
|
202 |
+
|
203 |
+
def __str__(self):
|
204 |
+
"return output data in DataFrame format"
|
205 |
+
return str(self.out)
|
206 |
+
|
207 |
+
def __repr__(self):
|
208 |
+
"return output data in html format"
|
209 |
+
return self.out_str
|
instructions.md
CHANGED
@@ -1,13 +1,39 @@
|
|
1 |
-
# **ESM
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# **ESM-Scan**
|
2 |
+
Calculate the <u>fitness of single amino acid substitutions</u> on proteins, using a [zero-shot](https://doi.org/10.1101/2021.07.09.450648) [language model predictor](https://github.com/facebookresearch/esm)
|
3 |
+
|
4 |
+
<details>
|
5 |
+
<summary> <b> USAGE INSTRUCTIONS </b> </summary>
|
6 |
+
|
7 |
+
### **Setup**
|
8 |
+
No setup is required, just fill the input boxes with the required data and click on the `Run` button.
|
9 |
+
A list of examples can be found at the bottom of the page, click on them to autofill the fields.
|
10 |
+
If the server is not used for some time, it will go into standby.
|
11 |
+
Running a calculation resumes the tool from standby, the first run might take longer due to startup and model loading.
|
12 |
+
|
13 |
+
### **Input**
|
14 |
+
- write the protein full amino acid sequence to be analysed in the **Sequence** text box
|
15 |
+
jolly charachters (e.g. `-X.B`) can be inserted but, at the moment, visualisation cannot handle them
|
16 |
+
- write the substitutions to test in the **Substitutions** box
|
17 |
+
there are three running modes that can be used, depending on the input:
|
18 |
+
+ *single substitution* or list thereof (in the form of `R218K R218W`): the single substitution is scored
|
19 |
+
+ *residue position* or list thereof: all possible substitutions will be evaluated
|
20 |
+
+ *same-length sequence*: the differing amino acid substitutions will be evaluated, one by one
|
21 |
+
+ any other *different input*: a deep mutational scan of the full sequence will be performed
|
22 |
+
- the ESM model to use for the calculations can be chosen among those that are available on Hugging Face Model Hub;
|
23 |
+
`esm2_t33_650M_UR50D` offers the best expense-accuracy tradeoff[*](https://doi.org/10.1126/science.ade2574)
|
24 |
+
- the `masked-marginals` scoring strategy considers sequence context at inference time, being slower but more accurate;
|
25 |
+
in case of long runtimes, you can tick the box off to speed the calculations up significantly, sacrificing accuracy
|
26 |
+
- when running a deep mutational scan, it is recommended to use smaller models (8M, 35M, 150M parameters), since the runtime is significant, especially for longer sequences and the server might be overloaded;
|
27 |
+
over 30 min might be necessary for calculating a 300-residue-long sequence with larger models
|
28 |
+
in general, accuracy is influenced significantly by the scoring strategy and less so by the model size, so it is suggested to reduce the latter first when optimising for runtime;
|
29 |
+
the scoring strategy computational cost scales with the number of substitutions tested, while the modelβs with the wild-type sequence length
|
30 |
+
- it is possible to calculate the effect of multiple concurrent substitutions, but this has to be done manually, by changing the input sequence and running the calculation again
|
31 |
+
|
32 |
+
### **Output**
|
33 |
+
Your results will be shown in a color-coded table, except for the deep mutational scan which will yield a heatmap.
|
34 |
+
The output data can be downloaded from the box at the bottom.
|
35 |
+
File extensions are not supported by the server and need to be appended to the filenames after downloading:
|
36 |
+
- `CSV` for tables
|
37 |
+
- `SVG` for full-sequence deep mutational scan
|
38 |
+
|
39 |
+
</details>
|
model.py
CHANGED
@@ -1,72 +1,99 @@
|
|
1 |
from huggingface_hub import HfApi, ModelFilter
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
|
|
10 |
class Model:
|
11 |
-
"""Wrapper for ESM models"""
|
12 |
-
def __init__(self, model_name:str=""):
|
13 |
-
"
|
14 |
self.model_name = model_name
|
15 |
if model_name:
|
16 |
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
17 |
self.batch_converter = AutoTokenizer.from_pretrained(model_name)
|
18 |
self.alphabet = self.batch_converter.get_vocab()
|
|
|
19 |
if torch.cuda.is_available():
|
20 |
self.model = self.model.cuda()
|
21 |
|
22 |
-
def
|
23 |
-
"
|
24 |
-
return self.
|
25 |
-
|
26 |
-
def
|
27 |
-
"
|
28 |
-
return self.
|
29 |
|
30 |
-
def __getitem__(self, key:str) -> int:
|
31 |
-
"
|
32 |
return self.alphabet[key]
|
33 |
-
|
34 |
def run_model(self, data):
|
35 |
-
"
|
36 |
def label_row(row, token_probs):
|
37 |
-
"
|
|
|
38 |
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
|
|
|
39 |
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
|
40 |
return score.item()
|
41 |
-
|
42 |
-
batch_tokens = self<<data.seq
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
elif data.scoring_strategy.startswith("masked-marginals"):
|
56 |
all_token_probs = []
|
|
|
57 |
for i in range(batch_tokens.size()[1]):
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
)
|
|
|
|
|
|
1 |
from huggingface_hub import HfApi, ModelFilter
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
4 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
5 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
6 |
|
7 |
+
# Function to fetch suitable ESM models from HuggingFace Hub
|
8 |
+
def get_models() -> list[None|str]:
|
9 |
+
"""Fetch suitable ESM models from HuggingFace Hub."""
|
10 |
+
if not any(
|
11 |
+
out := [
|
12 |
+
m.modelId for m in HfApi().list_models(
|
13 |
+
filter=ModelFilter(
|
14 |
+
author="facebook", model_name="esm", task="fill-mask"
|
15 |
+
),
|
16 |
+
sort="lastModified",
|
17 |
+
direction=-1
|
18 |
+
)
|
19 |
+
]
|
20 |
+
):
|
21 |
+
raise RuntimeError("Error while retrieving models from HuggingFace Hub")
|
22 |
+
return out
|
23 |
|
24 |
+
# Class to wrap ESM models
|
25 |
class Model:
|
26 |
+
"""Wrapper for ESM models."""
|
27 |
+
def __init__(self, model_name: str = ""):
|
28 |
+
"""Load selected model and tokenizer."""
|
29 |
self.model_name = model_name
|
30 |
if model_name:
|
31 |
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
32 |
self.batch_converter = AutoTokenizer.from_pretrained(model_name)
|
33 |
self.alphabet = self.batch_converter.get_vocab()
|
34 |
+
# Check if CUDA is available and if so, use it
|
35 |
if torch.cuda.is_available():
|
36 |
self.model = self.model.cuda()
|
37 |
|
38 |
+
def tokenise(self, input: str) -> BatchEncoding:
|
39 |
+
"""Convert input string to batch of tokens."""
|
40 |
+
return self.batch_converter(input, return_tensors="pt")
|
41 |
+
|
42 |
+
def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
|
43 |
+
"""Run model on batch of tokens."""
|
44 |
+
return self.model(batch_tokens, **kwargs)
|
45 |
|
46 |
+
def __getitem__(self, key: str) -> int:
|
47 |
+
"""Get token ID from character."""
|
48 |
return self.alphabet[key]
|
49 |
+
|
50 |
def run_model(self, data):
|
51 |
+
"""Run model on data."""
|
52 |
def label_row(row, token_probs):
|
53 |
+
"""Label row with score."""
|
54 |
+
# Extract wild type, index and mutant type from the row
|
55 |
wt, idx, mt = row[0], int(row[1:-1])-1, row[-1]
|
56 |
+
# Calculate the score as the difference between the token probabilities of the mutant type and the wild type
|
57 |
score = token_probs[0, 1+idx, self[mt]] - token_probs[0, 1+idx, self[wt]]
|
58 |
return score.item()
|
|
|
|
|
59 |
|
60 |
+
# Tokenise the sequence data
|
61 |
+
batch_tokens = self.tokenise(data.seq).input_ids
|
62 |
+
|
63 |
+
# Calculate the token probabilities without updating the model parameters
|
64 |
+
with torch.no_grad():
|
65 |
+
token_probs = torch.log_softmax(self(batch_tokens).logits, dim=-1)
|
66 |
+
# Store the token probabilities in the data
|
67 |
+
data.token_probs = token_probs.cpu().numpy()
|
68 |
+
|
69 |
+
# If the scoring strategy starts with "masked-marginals"
|
70 |
+
if data.scoring_strategy.startswith("masked-marginals"):
|
|
|
71 |
all_token_probs = []
|
72 |
+
# For each token in the batch
|
73 |
for i in range(batch_tokens.size()[1]):
|
74 |
+
# If the token is in the list of residues
|
75 |
+
if i in data.resi:
|
76 |
+
# Clone the batch tokens and mask the current token
|
77 |
+
batch_tokens_masked = batch_tokens.clone()
|
78 |
+
batch_tokens_masked[0, i] = self['<mask>']
|
79 |
+
# Calculate the masked token probabilities
|
80 |
+
with torch.no_grad():
|
81 |
+
masked_token_probs = torch.log_softmax(
|
82 |
+
self(batch_tokens_masked).logits, dim=-1
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
# If the token is not in the list of residues, use the original token probabilities
|
86 |
+
masked_token_probs = token_probs
|
87 |
+
# Append the token probabilities to the list
|
88 |
+
all_token_probs.append(masked_token_probs[:, i])
|
89 |
+
# Concatenate all token probabilities
|
90 |
token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
|
91 |
+
|
92 |
+
# Apply the label_row function to each row of the substitutions dataframe
|
93 |
+
data.out[self.model_name] = data.sub.apply(
|
94 |
+
lambda row: label_row(
|
95 |
+
row['0'],
|
96 |
+
token_probs,
|
97 |
+
),
|
98 |
+
axis=1,
|
99 |
+
)
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
gradio
|
2 |
-
huggingface_hub
|
3 |
pandas
|
|
|
4 |
torch
|
5 |
transformers
|
|
|
1 |
gradio
|
|
|
2 |
pandas
|
3 |
+
seaborn
|
4 |
torch
|
5 |
transformers
|