File size: 4,600 Bytes
2771929 49d3c66 2771929 49d3c66 b704445 2b975f7 49d3c66 05ba495 49d3c66 491cd4b 49d3c66 ac2a2c5 49d3c66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
---
license: mit
datasets:
- AmelieSchreiber/general_binding_sites
language:
- en
metrics:
- precision
- recall
- f1
library_name: transformers
tags:
- biology
- esm
- esm2
- ESM-2
- protein language model
---
# ESM-2 for General Protein Binding Site Prediction
This model is trained to predict general binding sites of proteins using on the sequence. This is a finetuned version of
`esm2_t6_8M_UR50D`, trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/general_binding_sites). The data is
not filtered by family, and thus the model may be overfit to some degree.
## Training
```
epoch 3:
Training Loss Validation Loss Precision Recall F1 Auc
0.031100 0.074720 0.684798 0.966856 0.801743 0.980853
```
```
wandb: lr: 0.0004977045729600779
wandb: lr_scheduler_type: cosine
wandb: max_grad_norm: 0.5
wandb: num_train_epochs: 3
wandb: per_device_train_batch_size: 8
wandb: weight_decay: 0.025
```
## Using the Model
Try pasting a protein sequence into the cell on the right and clicking on "Compute". For example, try
```
MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY
```
To use the model, try running:
```python
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
def predict_binding_sites(model_path, protein_sequences):
"""
Predict binding sites for a collection of protein sequences.
Parameters:
- model_path (str): Path to the saved model.
- protein_sequences (List[str]): List of protein sequences.
Returns:
- List[List[str]]: Predicted labels for each sequence.
"""
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)
# Ensure model is in evaluation mode
model.eval()
# Tokenize sequences
inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True)
# Move to the same device as model and obtain logits
with torch.no_grad():
logits = model(**inputs).logits
# Obtain predicted labels
predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy()
# Convert label IDs to human-readable labels
id2label = model.config.id2label
human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels]
return human_readable_labels
# Usage:
model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2" # Replace with your model's path
unseen_proteins = [
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD"
] # Replace with your unseen protein sequences
predictions = predict_binding_sites(model_path, unseen_proteins)
predictions
``` |