PascalNotin's picture
Implemented first version of design app
1335bda
raw
history blame
10.9 kB
import torch
import transformers
from transformers import PreTrainedTokenizerFast
import tranception
import datasets
from tranception import config, model_pytorch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]"
)
#######################################################################################################################################
############################################### HELPER FUNCTIONS ####################################################################
#######################################################################################################################################
AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
all_single_mutants={}
sequence_list=list(sequence)
if mutation_range_start is None: mutation_range_start=1
if mutation_range_end is None: mutation_range_end=len(sequence)
for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]):
for mutated_AA in AA_vocab:
if current_AA!=mutated_AA:
mutated_sequence = sequence_list.copy()
mutated_sequence[position] = mutated_AA
all_single_mutants[current_AA+str(position+1)+mutated_AA]="".join(mutated_sequence)
all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index')
all_single_mutants.reset_index(inplace=True)
all_single_mutants.columns = ['mutant','mutated_sequence']
return all_single_mutants
def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
piv=scores.pivot(index='position',columns='target_AA',values='avg_score').transpose().round(4)
fig, ax = plt.subplots(figsize=(len(sequence)*1.2,20))
scores_dict = {}
valid_mutant_set=set(scores.mutant)
if mutation_range_start is None: mutation_range_start=1
if mutation_range_end is None: mutation_range_start=len(sequence)
for target_AA in list(AA_vocab):
for position in range(mutation_range_start,mutation_range_end+1):
mutant = sequence[position-1]+str(position)+target_AA
if mutant in valid_mutant_set:
scores_dict[mutant]= float(scores.loc[scores.mutant==mutant,'avg_score'])
else:
scores_dict[mutant]=0.0
labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(len(AA_vocab),mutation_range_end-mutation_range_start+1)
heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'})
heat.figure.axes[-1].yaxis.label.set_size(20)
#heat.set_title("Fitness scores for all single amino acid substitutions",fontsize=30)
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=30, pad=40)
heat.set_xlabel("Sequence position", fontsize = 20)
heat.set_ylabel("Amino Acid mutation", fontsize = 20)
plt.savefig('fitness_scoring_substitution_matrix.png')
return plt
def suggest_mutations(scores):
intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
#Best mutants
top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant)
mutant_recos = "The 5 single mutants with highest predicted fitness are:\n {} \n\n".format(", ".join(top_mutants))
#Best positions
positive_scores = scores[scores.avg_score > 0]
positive_scores_position_avg = positive_scores.groupby(['position']).mean()
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
print(top_positions)
position_recos = "The 5 positions with the highest average fitness increase are:\n {}".format(", ".join(top_positions))
return intro_message+mutant_recos+position_recos
def get_mutated_protein(sequence,mutant):
mutated_sequence = list(sequence)
mutated_sequence[int(mutant[1:-1])-1]=mutant[-1]
return ''.join(mutated_sequence)
def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,num_workers=0,AA_vocab=AA_vocab):
if model_type=="Small":
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Small",use_auth_token=True)
elif model_type=="Medium":
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium",use_auth_token=True)
elif model_type=="Large":
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large",use_auth_token=True)
model.config.tokenizer = tokenizer
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
scores = model.score_mutants(DMS_data=all_single_mutants,
target_seq=sequence,
scoring_mirror=scoring_mirror,
batch_size_inference=batch_size_inference,
num_workers=num_workers,
indel_mode=False
)
scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
score_heatmap = create_scoring_matrix_visual(scores,sequence,AA_vocab,mutation_range_start,mutation_range_end)
return score_heatmap,suggest_mutations(scores)
#######################################################################################################################################
############################################### GRADIO INTERFACE ####################################################################
#######################################################################################################################################
title = "Interactive in silico directed evolution with Tranception"
description = "Perform in silico directed evolution with Tranception to iteratively improve the fitness of a starting protein sequence one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply. Note: The current version does not currently leverage homologs retrieval at inference time to boost fitness prediction performance."
article = "<p style='text-align: center'><a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</a></p>"
examples=[
['A4_HUMAN: MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN'],
['ADRB2_HUMAN: MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'],
['AMIE_PSEAE: MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA'],
['P53_HUMAN: MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD']
]
model_size_selection = gr.Radio(label="Tranception model size", choices=["Small","Medium","Large"], value="Small")
protein_sequence_input = gr.Textbox(lines=1, label="Input protein sequence (see below for examples; default = RL40A_YEAST)",value="MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGGIIEPSLKALASKYNCDKSVCRKCYARLPPRATNCRKRKCGHTNQLRPKKKLK")
mutation_range_start = gr.Number(label="Start of mutation range (min value = 1)",value=1,precision=0)
mutation_range_end = gr.Number(label="End of mutation range (leave empty for full lenth)",value=10,precision=0)
scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)")
#output ==> find a way to make scroallable
output_plot = gr.Plot(label="Fitness scores for all single amino acid substitutions in mutation range")
output_recommendations = gr.Textbox(label="Mutation recommendations")
gr.Interface(
fn=score_and_create_matrix_all_singles,
inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror],
outputs=["plot","text"],
title=title,
description=description,
article=article,
examples=examples,
enable_queue=True,
allow_flagging="never"
).launch(debug=True)