widget:
- text: >-
MEPLDDLDLLLLEEDSGAEAVPRMEILQKKADAFFAETVLSRGVDNRYLVLAVETKLNERGAEEKHLLITVSQEGEQEVLCILRNGWSSVPVEPGDIIHIEGDCTSEPWIVDDDFGYFILSPDMLISGTSVASSIRCLRRAVLSETFRVSDTATRQMLIGTILHEVFQKAISESFAPEKLQELALQTLREVRHLKEMYRLNLSQDEVRCEVEEYLPSFSKWADEFMHKGTKAEFPQMHLSLPSDSSDRSSPCNIEVVKSLDIEESIWSPRFGLKGKIDVTVGVKIHRDCKTKYKIMPLELKTGKESNSIEHRGQVILYTLLSQERREDPEAGWLLYLKTGQMYPVPANHLDKRELLKLRNQLAFSLLHRVSRAAAGEEARLLALPQIIEEEKTCKYCSQMGNCALYSRAVEQVHDTSIPEGMRSKIQEGTQHLTRAHLKYFSLWCLMLTLESQSKDTKKSHQSIWLTPASKLEESGNCIGSLVRTEPVKRVCDGHYLHNFQRKNGPMPATNLMAGDRIILSGEERKLFALSKGYVKRIDTAAVTCLLDRNLSTLPETTLFRLDREEKHGDINTPLGNLSKLMENTDSSKRLRELIIDFKEPQFIAYLSSVLPHDAKDTVANILKGLNKPQRQAMKKVLLSKDYTLIVGMPGTGKTTTICALVRILSACGFSVLLTSYTHSAVDNILLKLAKFKIGFLRLGQSHKVHPDIQKFTEEEMCRLRSIASLAHLEELYNSHPVVATTCMGISHPMFSRKTFDFCIVDEASQISQPICLGPLFFSRRFVLVGDHKQLPPLVLNREARALGMSESLFKRLERNESAVVQLTIQYRMNRKIMSLSNKLTYEGKLECGSDRVANAVITLPNLKDVRLEFYADYSDNPWLAGVFEPDNPVCFLNTDKVPAPEQIENGGVSNVTEARLIVFLTSTFIKAGCSPSDIGIIAPYRQQLRTITDLLARSSVGMVEVNTVDKYQGRDKSLILVSFVRSNEDGTLGELLKDWRRLNVAITRAKHKLILLGSVSSLKRF
example_title: Protein Sequence 1
- text: >-
MNSVTVSHAPYYIVYHDDWEPVMSQLVEFYNEVASWLLRDETSPIPPKFFIQLKQMLRNKRVCVCGILPYPIDGTGVPFESPNFTKKSIKEIASSISRLTGVIDYKGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPAARDRQFEKDRSFEIINELLELDNKVPINWAQGFIY
example_title: Protein Sequence 2
- text: >-
MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY
example_title: Protein Sequence 3
license: mit
datasets:
- AmelieSchreiber/general_binding_sites
language:
- en
metrics:
- precision
- recall
- f1
library_name: transformers
tags:
- biology
- esm
- esm2
- ESM-2
- protein language model
ESM-2 for General Protein Binding Site Prediction
This model is trained to predict general binding sites of proteins using only the sequence. This is a finetuned version of
esm2_t6_8M_UR50D
(see here and also here
for more info on the base model), trained on this dataset. The data is
not filtered by family, and thus the model may be overfit to some degree. In the Hugging Face Inference API widget to the right
there are three protein sequence examples. The first is a DNA binding protein truncated to the first 1022 amino acid residues
(see UniProt entry here).
The second and third were obtained using EvoProtGrad a Markov Chain Monte Carlo method of (in silico) directed evolution of proteins based on a form of Gibbs sampling. The mutatant-type protein sequences in theory should have similar binding sites to the wild-type protein sequence, but perhaps with higher binding affinity. Testing this out on the model, we see the two proteins indeed have the same binding sites, which validates to some degree that the model has learned to predict binding sites well (and that EvoProtGrad works as intended).
Training
epoch 3:
Training Loss Validation Loss Precision Recall F1 Auc
0.031100 0.074720 0.684798 0.966856 0.801743 0.980853
wandb: lr: 0.0004977045729600779
wandb: lr_scheduler_type: cosine
wandb: max_grad_norm: 0.5
wandb: num_train_epochs: 3
wandb: per_device_train_batch_size: 8
wandb: weight_decay: 0.025
Using the Model
To use the model, try running:
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
def predict_binding_sites(model_path, protein_sequences):
"""
Predict binding sites for a collection of protein sequences.
Parameters:
- model_path (str): Path to the saved model.
- protein_sequences (List[str]): List of protein sequences.
Returns:
- List[List[str]]: Predicted labels for each sequence.
"""
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)
# Ensure model is in evaluation mode
model.eval()
# Tokenize sequences
inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True)
# Move to the same device as model and obtain logits
with torch.no_grad():
logits = model(**inputs).logits
# Obtain predicted labels
predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy()
# Convert label IDs to human-readable labels
id2label = model.config.id2label
human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels]
return human_readable_labels
# Usage:
model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2" # Replace with your model's path
unseen_proteins = [
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD"
] # Replace with your protein sequences
predictions = predict_binding_sites(model_path, unseen_proteins)
predictions