AmelieSchreiber's picture
Update README.md
491cd4b
|
raw
history blame
4.6 kB
---
license: mit
datasets:
- AmelieSchreiber/general_binding_sites
language:
- en
metrics:
- precision
- recall
- f1
library_name: transformers
tags:
- biology
- esm
- esm2
- ESM-2
- protein language model
---
# ESM-2 for General Protein Binding Site Prediction
This model is trained to predict general binding sites of proteins using on the sequence. This is a finetuned version of
`esm2_t6_8M_UR50D`, trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/general_binding_sites). The data is
not filtered by family, and thus the model may be overfit to some degree.
## Training
```
epoch 3:
Training Loss Validation Loss Precision Recall F1 Auc
0.031100 0.074720 0.684798 0.966856 0.801743 0.980853
```
```
wandb: lr: 0.0004977045729600779
wandb: lr_scheduler_type: cosine
wandb: max_grad_norm: 0.5
wandb: num_train_epochs: 3
wandb: per_device_train_batch_size: 8
wandb: weight_decay: 0.025
```
## Using the Model
Try pasting a protein sequence into the cell on the right and clicking on "Compute". For example, try
```
MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY
```
To use the model, try running:
```python
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
def predict_binding_sites(model_path, protein_sequences):
"""
Predict binding sites for a collection of protein sequences.
Parameters:
- model_path (str): Path to the saved model.
- protein_sequences (List[str]): List of protein sequences.
Returns:
- List[List[str]]: Predicted labels for each sequence.
"""
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)
# Ensure model is in evaluation mode
model.eval()
# Tokenize sequences
inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True)
# Move to the same device as model and obtain logits
with torch.no_grad():
logits = model(**inputs).logits
# Obtain predicted labels
predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy()
# Convert label IDs to human-readable labels
id2label = model.config.id2label
human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels]
return human_readable_labels
# Usage:
model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2" # Replace with your model's path
unseen_proteins = [
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD",
"MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD"
] # Replace with your unseen protein sequences
predictions = predict_binding_sites(model_path, unseen_proteins)
predictions
```