|
--- |
|
license: mit |
|
--- |
|
|
|
# ESM-2 Finetuned for PPI Prediction |
|
This is a finetuned version of `facebook/esm2_t33_650M_UR50D` using the masked language modeling objective. |
|
The model was finetuned for four epochs on concatenated pairs of interacting proteins, clustered using persistent homology |
|
landscapes as explained in [this post](https://huggingface.co/blog/AmelieSchreiber/faster-pha). The dataset consists of 10,000 |
|
protein pairs, which can be [found here](https://huggingface.co/datasets/AmelieSchreiber/pha_clustered_protein_complexes). |
|
This is a very new method for clustering protein-protein complexes. |
|
|
|
Using the MLM loss to predict pairs of interacting proteins was inspired by [this paper](https://arxiv.org/abs/2308.07136). However, |
|
the authors do not finetune the models for this task. Thus we reasoned that improved performance on this method could be achieved |
|
by finetuning the model on pairs of interacting proteins. |
|
|
|
## Using the Model |
|
To use the model, we follow [this blog post](https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners-with-esm2). |
|
Below we see how to use the model for ranking potential binders for a target protein of interest. The lower the MLM loss average, |
|
the more likely the two proteins are to interact with one another. |
|
|
|
```python |
|
import numpy as np |
|
from transformers import AutoTokenizer, EsmForMaskedLM |
|
import torch |
|
|
|
# Load the base model and tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_mlmppi_ph50_v3") |
|
|
|
# Ensure the model is in evaluation mode |
|
model.eval() |
|
|
|
# Define the protein of interest and its potential binders |
|
protein_of_interest = "MLTEVMEVWHGLVIAVVSLFLQACFLTAINYLLSRHMAHKSEQILKAASLQVPRPSPGHHHPPAVKEMKETQTERDIPMSDSLYRHDSDTPSDSLDSSCSSPPACQATEDVDYTQVVFSDPGELKNDSPLDYENIKEITDYVNVNPERHKPSFWYFVNPALSEPAEYDQVAM" |
|
potential_binders = [ |
|
# Known to interact |
|
"MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL", |
|
"MAAGVAGWGVEAEEFEDAPDVEPLEPTLSNIIEQRSLKWIFVGGKGGVGKTTCSCSLAVQLSKGRESVLIISTDPAHNISDAFDQKFSKVPTKVKGYDNLFAMEIDPSLGVAELPDEFFEEDNMLSMGKKMMQEAMSAFPGIDEAMSYAEVMRLVKGMNFSVVVFDTAPTGHTLRLLNFPTIVERGLGRLMQIKNQISPFISQMCNMLGLGDMNADQLASKLEETLPVIRSVSEQFKDPEQTTFICVCIAEFLSLYETERLIQELAKCKIDTHNIIVNQLVFPDPEKPCKMCEARHKIQAKYLDQMEDLYEDFHIVKLPLLPHEVRGADKVNTFSALLLEPYKPPSAQ", |
|
"EKTGLSIRGAQEEDPPDPQLMRLDNMLLAEGVSGPEKGGGSAAAAAAAAASGGSSDNSIEHSDYRAKLTQIRQIYHTELEKYEQACNEFTTHVMNLLREQSRTRPISPKEIERMVGIIHRKFSSIQMQLKQSTCEAVMILRSRFLDARRKRRNFSKQATEILNEYFYSHLSNPYPSEEAKEELAKKCSITVSQSLVKDPKERGSKGSDIQPTSVVSNWFGNKRIRYKKNIGKFQEEANLYAAKTAVTAAHAVAAAVQNNQTNSPTTPNSGSSGSFNLPNSGDMFMNMQSLNGDSYQGSQVGANVQSQVDTLRHVINQTGGYSDGLGGNSLYSPHNLNANGGWQDATTPSSVTSPTEGPGSVHSDTSN", |
|
# Not known to interact |
|
"MRQRLLPSVTSLLLVALLFPGSSQARHVNHSATEALGELRERAPGQGTNGFQLLRHAVKRDLLPPRTPPYQVHISHREARGPSFRICVDFLGPRWARGCSTGN", |
|
"MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS" |
|
] # Add potential binding sequences here |
|
|
|
def compute_mlm_loss(protein, binder, iterations=5): |
|
total_loss = 0.0 |
|
|
|
for _ in range(iterations): |
|
# Concatenate protein sequences with a separator |
|
concatenated_sequence = protein + binder |
|
|
|
# Mask a subset of amino acids in the concatenated sequence (excluding the separator) |
|
tokens = list(concatenated_sequence) |
|
mask_rate = 0.35 # For instance, masking 35% of the sequence |
|
num_mask = int(len(tokens) * mask_rate) |
|
|
|
# Exclude the separator from potential mask indices |
|
available_indices = [i for i, token in enumerate(tokens) if token != ":"] |
|
probs = torch.ones(len(available_indices)) |
|
mask_indices = torch.multinomial(probs, num_mask, replacement=False) |
|
|
|
for idx in mask_indices: |
|
tokens[available_indices[idx]] = tokenizer.mask_token |
|
|
|
masked_sequence = "".join(tokens) |
|
inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length', add_special_tokens=False) |
|
|
|
# Compute the MLM loss |
|
with torch.no_grad(): |
|
outputs = model(**inputs, labels=inputs["input_ids"]) |
|
loss = outputs.loss |
|
|
|
total_loss += loss.item() |
|
|
|
# Return the average loss |
|
return total_loss / iterations |
|
|
|
# Compute MLM loss for each potential binder |
|
mlm_losses = {} |
|
for binder in potential_binders: |
|
loss = compute_mlm_loss(protein_of_interest, binder) |
|
mlm_losses[binder] = loss |
|
|
|
# Rank binders based on MLM loss |
|
ranked_binders = sorted(mlm_losses, key=mlm_losses.get) |
|
|
|
print("Ranking of Potential Binders:") |
|
for idx, binder in enumerate(ranked_binders, 1): |
|
print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}") |
|
``` |
|
|
|
## PPI Networks |
|
|
|
To construct a protein-protein interaction network, try running the following code. Try adjusting the length of the connector |
|
(0-25 for example). Try also adjusting the number of iterations and the masking percentage. |
|
|
|
```python |
|
import networkx as nx |
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmForMaskedLM |
|
import plotly.graph_objects as go |
|
from ipywidgets import interact |
|
from ipywidgets import widgets |
|
|
|
# Check if CUDA is available and set the default device accordingly |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_mlmppi_ph50_v3") |
|
|
|
# Send the model to the device (GPU or CPU) |
|
model.to(device) |
|
|
|
# Ensure the model is in evaluation mode |
|
model.eval() |
|
|
|
# Define Protein Sequences (Replace with your list) |
|
all_proteins = [ |
|
"MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV", |
|
"MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV", |
|
"MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ", |
|
|
|
"MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL", |
|
"MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS", |
|
"MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV", |
|
|
|
"MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN", |
|
"MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE", |
|
"MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND", |
|
|
|
"MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV", |
|
"MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH", |
|
"MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD", |
|
|
|
"MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS", |
|
"MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD", |
|
"MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV" |
|
] |
|
|
|
def compute_average_mlm_loss(protein1, protein2, iterations=10): |
|
total_loss = 0.0 |
|
connector = "G" * 25 # Connector sequence of G's |
|
for _ in range(iterations): |
|
concatenated_sequence = protein1 + connector + protein2 |
|
inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024) |
|
|
|
mask_prob = 0.35 |
|
mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob |
|
|
|
# Locate the positions of the connector 'G's and set their mask indices to False |
|
connector_indices = tokenizer.encode(connector, add_special_tokens=False) |
|
connector_length = len(connector_indices) |
|
start_connector = len(tokenizer.encode(protein1, add_special_tokens=False)) |
|
end_connector = start_connector + connector_length |
|
|
|
# Avoid masking the connector 'G's |
|
mask_indices[0, start_connector:end_connector] = False |
|
|
|
# Apply the mask to the input IDs |
|
inputs["input_ids"][mask_indices] = tokenizer.mask_token_id |
|
inputs = {k: v.to(device) for k, v in inputs.items()} # Send inputs to the device |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs, labels=inputs["input_ids"]) |
|
|
|
loss = outputs.loss |
|
total_loss += loss.item() |
|
|
|
return total_loss / iterations |
|
|
|
# Compute all average losses to determine the maximum threshold for the slider |
|
all_losses = [] |
|
for i, protein1 in enumerate(all_proteins): |
|
for j, protein2 in enumerate(all_proteins[i+1:], start=i+1): |
|
avg_loss = compute_average_mlm_loss(protein1, protein2) |
|
all_losses.append(avg_loss) |
|
|
|
# Set the maximum threshold to the maximum loss computed |
|
max_threshold = max(all_losses) |
|
print(f"Maximum loss (maximum threshold for slider): {max_threshold}") |
|
|
|
def plot_graph(threshold): |
|
G = nx.Graph() |
|
|
|
# Add all protein nodes to the graph |
|
for i, protein in enumerate(all_proteins): |
|
G.add_node(f"protein {i+1}") |
|
|
|
# Loop through all pairs of proteins and calculate average MLM loss |
|
loss_idx = 0 # Index to keep track of the position in the all_losses list |
|
for i, protein1 in enumerate(all_proteins): |
|
for j, protein2 in enumerate(all_proteins[i+1:], start=i+1): |
|
avg_loss = all_losses[loss_idx] |
|
loss_idx += 1 |
|
|
|
# Add an edge if the loss is below the threshold |
|
if avg_loss < threshold: |
|
G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3)) |
|
|
|
# 3D Network Plot |
|
# Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value. |
|
k_value = 2 # Lower value will bring nodes closer together |
|
pos = nx.spring_layout(G, dim=3, seed=42, k=k_value) |
|
|
|
edge_x = [] |
|
edge_y = [] |
|
edge_z = [] |
|
for edge in G.edges(): |
|
x0, y0, z0 = pos[edge[0]] |
|
x1, y1, z1 = pos[edge[1]] |
|
edge_x.extend([x0, x1, None]) |
|
edge_y.extend([y0, y1, None]) |
|
edge_z.extend([z0, z1, None]) |
|
|
|
edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey')) |
|
|
|
node_x = [] |
|
node_y = [] |
|
node_z = [] |
|
node_text = [] |
|
for node in G.nodes(): |
|
x, y, z = pos[node] |
|
node_x.append(x) |
|
node_y.append(y) |
|
node_z.append(z) |
|
node_text.append(node) |
|
|
|
node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text) |
|
|
|
layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False))) |
|
|
|
fig = go.Figure(data=[edge_trace, node_trace], layout=layout) |
|
fig.show() |
|
|
|
# Create an interactive slider for the threshold value with a default of 8.50 |
|
interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25)) |
|
``` |