|
--- |
|
license: mit |
|
datasets: |
|
- AmelieSchreiber/general_binding_sites |
|
language: |
|
- en |
|
metrics: |
|
- precision |
|
- recall |
|
- f1 |
|
library_name: transformers |
|
tags: |
|
- biology |
|
- esm |
|
- esm2 |
|
- ESM-2 |
|
- protein language model |
|
--- |
|
|
|
# ESM-2 for General Protein Binding Site Prediction |
|
|
|
This model is trained to predict general binding sites of proteins using on the sequence. This is a finetuned version of |
|
`esm2_t6_8M_UR50D`, trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/general_binding_sites). The data is |
|
not filtered by family, and thus the model may be overfit to some degree. |
|
|
|
## Training |
|
|
|
``` |
|
epoch 3: |
|
Training Loss Validation Loss Precision Recall F1 Auc |
|
0.031100 0.074720 0.684798 0.966856 0.801743 0.980853 |
|
``` |
|
|
|
``` |
|
wandb: lr: 0.0004977045729600779 |
|
wandb: lr_scheduler_type: cosine |
|
wandb: max_grad_norm: 0.5 |
|
wandb: num_train_epochs: 3 |
|
wandb: per_device_train_batch_size: 8 |
|
wandb: weight_decay: 0.025 |
|
``` |
|
|
|
## Using the Model |
|
|
|
Try pasting a protein sequence into the cell on the right and clicking on "Compute". For example, try |
|
|
|
``` |
|
MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY |
|
``` |
|
|
|
To use the model, try running: |
|
```python |
|
import torch |
|
from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
|
|
def predict_binding_sites(model_path, protein_sequences): |
|
""" |
|
Predict binding sites for a collection of protein sequences. |
|
|
|
Parameters: |
|
- model_path (str): Path to the saved model. |
|
- protein_sequences (List[str]): List of protein sequences. |
|
|
|
Returns: |
|
- List[List[str]]: Predicted labels for each sequence. |
|
""" |
|
|
|
# Load tokenizer and model |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForTokenClassification.from_pretrained(model_path) |
|
|
|
# Ensure model is in evaluation mode |
|
model.eval() |
|
|
|
# Tokenize sequences |
|
inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True) |
|
|
|
# Move to the same device as model and obtain logits |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
# Obtain predicted labels |
|
predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy() |
|
|
|
# Convert label IDs to human-readable labels |
|
id2label = model.config.id2label |
|
human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels] |
|
|
|
return human_readable_labels |
|
|
|
# Usage: |
|
model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2" # Replace with your model's path |
|
unseen_proteins = [ |
|
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", |
|
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", |
|
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD", |
|
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD" |
|
] # Replace with your unseen protein sequences |
|
predictions = predict_binding_sites(model_path, unseen_proteins) |
|
predictions |
|
``` |