Spaces:
Sleeping
Sleeping
from math import ceil | |
from re import match | |
import seaborn as sns | |
from model import Model | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import seaborn as sns | |
from model import Model | |
class Data: | |
"""Container for input and output data""" | |
# Initialise empty model as static class member for efficiency | |
model = Model() | |
def parse_seq(self, src: str): | |
"""Parse input sequence""" | |
self.seq = src.strip().upper().replace('\n', '') | |
if not all(x in self.model.alphabet for x in self.seq): | |
raise RuntimeError("Unrecognised characters in sequence") | |
def parse_sub(self, trg: str): | |
"""Parse input substitutions""" | |
self.mode = None | |
self.sub = list() | |
self.trg = trg.strip().upper() | |
self.resi = list() | |
# Identify running mode | |
if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq) and all(match(r'\w+', x) for x in self.trg): | |
# If single string of same length as sequence, seq vs seq mode | |
self.mode = 'MUT' | |
for resi, (src, trg) in enumerate(zip(self.seq, self.trg), 1): | |
if src != trg: | |
self.sub.append(f"{src}{resi}{trg}") | |
self.resi.append(resi) | |
else: | |
self.trg = self.trg.split() | |
if all(match(r'\d+', x) for x in self.trg): | |
# If all strings are numbers, deep mutational scanning mode | |
self.mode = 'DMS' | |
for resi in map(int, self.trg): | |
src = self.seq[resi-1] | |
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''): | |
self.sub.append(f"{src}{resi}{trg}") | |
self.resi.append(resi) | |
elif all(match(r'[A-Z]\d+[A-Z]', x) for x in self.trg): | |
# If all strings are of the form X#Y, single substitution mode | |
self.mode = 'MUT' | |
self.sub = self.trg | |
self.resi = [int(x[1:-1]) for x in self.trg] | |
for s, *resi, _ in self.trg: | |
if self.seq[int(''.join(resi))-1] != s: | |
raise RuntimeError(f"Unrecognised input substitution {self.seq[int(''.join(resi))]}{int(''.join(resi))} /= {s}{int(''.join(resi))}") | |
else: | |
self.mode = 'TMS' | |
for resi, src in enumerate(self.seq, 1): | |
for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src, ''): | |
self.sub.append(f"{src}{resi}{trg}") | |
self.resi.append(resi) | |
self.sub = pd.DataFrame(self.sub, columns=['0']) | |
def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file=None): | |
"initialise data" | |
# if model has changed, load new model | |
if self.model.model_name != model_name: | |
self.model_name = model_name | |
self.model = Model(model_name) | |
self.parse_seq(src) | |
self.offset = 0 | |
self.parse_sub(trg) | |
self.scoring_strategy = scoring_strategy | |
self.token_probs = None | |
self.out = pd.DataFrame(self.sub, columns=['0', self.model_name]) | |
self.out_str = None | |
self.out_buffer = out_file.name if 'name' in dir(out_file) else out_file | |
def parse_output(self) -> None: | |
"format output data for visualisation" | |
if self.mode == 'TMS': | |
self.process_tms_mode() | |
else: | |
if self.mode == 'DMS': | |
self.sort_by_residue_and_score() | |
elif self.mode == 'MUT': | |
self.sort_by_score() | |
else: | |
raise RuntimeError(f"Unrecognised mode {self.mode}") | |
if self.out_buffer: | |
self.out.round(2).to_csv(self.out_buffer, index=False, header=False) | |
self.out_str = (self.out.style | |
.format(lambda x: f'{x:.2f}' if isinstance(x, float) else x) | |
.hide(axis=0) | |
.hide(axis=1) | |
.background_gradient(cmap="RdYlGn", vmax=8, vmin=-8) | |
.to_html(justify='center')) | |
def sort_by_score(self): | |
self.out = self.out.sort_values(self.model_name, ascending=False) | |
def sort_by_residue_and_score(self): | |
self.out = (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) | |
.sort_values(['resi', self.model_name], ascending=[True,False]) | |
.groupby(['resi']) | |
.head(19) | |
.drop(['resi'], axis=1)) | |
self.out = pd.concat([self.out.iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(self.out.shape[0]//19)] | |
, axis=1).set_axis(range(self.out.shape[0]//19*2), axis='columns') | |
def process_tms_mode(self): | |
self.out = self.assign_resi_and_group() | |
self.out = self.concat_and_set_axis() | |
self.out /= self.out.abs().max().max() | |
divs = self.calculate_divs() | |
ncols = min(divs, key=lambda x: abs(x-60)) | |
nrows = ceil(self.out.shape[1]/ncols) | |
ncols = self.adjust_ncols(ncols, nrows) | |
self.plot_heatmap(ncols, nrows) | |
def assign_resi_and_group(self): | |
return (self.out.assign(resi=self.out['0'].str.extract(r'(\d+)', expand=False).astype(int)) | |
.groupby(['resi']) | |
.head(19)) | |
def concat_and_set_axis(self): | |
return (pd.concat([(self.out.iloc[19*x:19*(x+1)] | |
.pipe(self.create_dataframe) | |
.sort_values(['0'], ascending=[True]) | |
.drop(['resi', '0'], axis=1) | |
.set_axis(['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', | |
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']) | |
.astype(float) | |
) for x in range(self.out.shape[0]//19)] | |
, axis=1) | |
.set_axis([f'{a}{i}' for i, a in enumerate(self.seq, 1)], axis='columns')) | |
def create_dataframe(self, df): | |
return pd.concat([pd.Series([df.iloc[0, 0][:-1]+df.iloc[0, 0][0], 0, 0], index=df.columns).to_frame().T, df], axis=0, ignore_index=True) | |
def calculate_divs(self): | |
return [x for x in range(1, self.out.shape[1]+1) if self.out.shape[1] % x == 0 and 30 <= x and x <= 60] or [60] | |
def adjust_ncols(self, ncols, nrows): | |
while self.out.shape[1]/ncols < nrows and ncols > 45 and ncols*nrows >= self.out.shape[1]: | |
ncols -= 1 | |
return ncols + 1 | |
def plot_heatmap(self, ncols, nrows): | |
if nrows < 2: | |
self.plot_single_heatmap() | |
else: | |
self.plot_multiple_heatmaps(ncols, nrows) | |
if self.out_buffer: | |
plt.savefig(self.out_buffer, format='svg') | |
with open(self.out_buffer, 'r', encoding='utf-8') as f: | |
self.out_str = f.read() | |
def plot_single_heatmap(self): | |
fig = plt.figure(figsize=(12, 6)) | |
sns.heatmap(self.out | |
, cmap='RdBu' | |
, cbar=False | |
, square=True | |
, xticklabels=1 | |
, yticklabels=1 | |
, center=0 | |
, annot=self.out.map(lambda x: ' ' if x != 0 else '·') | |
, fmt='s' | |
, annot_kws={'size': 'xx-large'}) | |
fig.tight_layout() | |
def plot_multiple_heatmaps(self, ncols, nrows): | |
fig, ax = plt.subplots(nrows=nrows, figsize=(12, 6*nrows)) | |
for i in range(nrows): | |
tmp = self.out.iloc[:,i*ncols:(i+1)*ncols] | |
label = tmp.map(lambda x: ' ' if x != 0 else '·') | |
sns.heatmap(tmp | |
, ax=ax[i] | |
, cmap='RdBu' | |
, cbar=False | |
, square=True | |
, xticklabels=1 | |
, yticklabels=1 | |
, center=0 | |
, annot=label | |
, fmt='s' | |
, annot_kws={'size': 'xx-large'}) | |
ax[i].set_yticklabels(ax[i].get_yticklabels(), rotation=0) | |
ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=90) | |
fig.tight_layout() | |
def calculate(self): | |
"run model and parse output" | |
self.model.run_model(self) | |
self.parse_output() | |
return self | |
def __str__(self): | |
"return output data in DataFrame format" | |
return str(self.out) | |
def __repr__(self): | |
"return output data in html format" | |
return self.out_str | |