File size: 4,600 Bytes
2771929
 
49d3c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2771929
49d3c66
 
 
 
 
 
 
 
 
 
b704445
2b975f7
 
49d3c66
 
05ba495
 
 
 
 
 
 
 
 
49d3c66
 
 
 
 
491cd4b
49d3c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac2a2c5
49d3c66
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
---
license: mit
datasets:
- AmelieSchreiber/general_binding_sites
language:
- en
metrics:
- precision
- recall
- f1
library_name: transformers
tags:
- biology
- esm
- esm2
- ESM-2
- protein language model
---

# ESM-2 for General Protein Binding Site Prediction

This model is trained to predict general binding sites of proteins using on the sequence. This is a finetuned version of 
`esm2_t6_8M_UR50D`, trained on [this dataset](https://huggingface.co/datasets/AmelieSchreiber/general_binding_sites). The data is 
not filtered by family, and thus the model may be overfit to some degree.  

## Training

```
epoch 3:
Training Loss Validation Loss Precision	Recall	 F1	      Auc
0.031100	  0.074720	      0.684798	0.966856 0.801743 0.980853
```

```
wandb: 	lr: 0.0004977045729600779
wandb: 	lr_scheduler_type: cosine
wandb: 	max_grad_norm: 0.5
wandb: 	num_train_epochs: 3
wandb: 	per_device_train_batch_size: 8
wandb: 	weight_decay: 0.025
```

## Using the Model

Try pasting a protein sequence into the cell on the right and clicking on "Compute". For example, try

```
MNSVTVSHAPYTIAYHDDWEPVMSQLVEFYNEAASWLLRDETSPIPSKFNIQLKQPLRNKRVCVFGIDPYPKDGTGVPFESPNFTKKSIKEIASSISRLMGVIDYEGYNLNIIDGVIPWNYYLSCKLGETKSHAIYWDKISKLLLQHITKHVSVLYCLGKTDFSNIRAKLESPVTTIVGYHPSARDRQFEKDRSFEIINVLLELDNKVPLNWAQGFIY
```

To use the model, try running:
```python
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer

def predict_binding_sites(model_path, protein_sequences):
    """
    Predict binding sites for a collection of protein sequences.

    Parameters:
    - model_path (str): Path to the saved model.
    - protein_sequences (List[str]): List of protein sequences.

    Returns:
    - List[List[str]]: Predicted labels for each sequence.
    """

    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path)

    # Ensure model is in evaluation mode
    model.eval()

    # Tokenize sequences
    inputs = tokenizer(protein_sequences, return_tensors="pt", padding=True, truncation=True)

    # Move to the same device as model and obtain logits
    with torch.no_grad():
        logits = model(**inputs).logits

    # Obtain predicted labels
    predicted_labels = torch.argmax(logits, dim=-1).cpu().numpy()

    # Convert label IDs to human-readable labels
    id2label = model.config.id2label
    human_readable_labels = [[id2label[label_id] for label_id in sequence] for sequence in predicted_labels]

    return human_readable_labels

# Usage:
model_path = "AmelieSchreiber/esm2_t6_8M_general_binding_sites_v2"  # Replace with your model's path
unseen_proteins = [
    "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIDVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKPKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", 
    "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEYVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIGKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKWLEGRIKGKENEVRLLKGFLKANGIYGAEYKVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFAVKFRKPDIVDDNLYPQLERASRKIFEFLERENFMPLRSAFKASEEFCYLLFECQIKEISRVFRRMGPQFEDERNVKKFLSRNRAFRPFIENGRWWAFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCEMMGVKD", 
    "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYEIRYAEHPYVHGVVKGVEVDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEVRLLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRRTVIDVAKGEVRKGEEFFVVDPVDEKRNVAANLSLDNLARFVHLCREFMEAPSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRSAFKASEEFCYLLFECQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIISGEKLFKEPVTAELCRMMGVKD", 
    "MKVEEILEKALELVIPDEEEVRKGREAEEELRRRLDELGVEAVFVGSYARNTWLKGSLEIAVFLLFPEEFSKEELRERGLEIEKAVLDSYGIRYAEHPYVHGVVKGVELDVVPCYKLKEPKNIKSAVDRTPFHHKELEGRIKGKENEYRSLKGFLKANGIYGAEYAVRGFSGYLCELLIVFYGSFLETVKNARRWTRKTVIDVAKGEVRKGEEFFVVDPVDEKRNVAALLSLDNLARFVHLCREFMEAVSLGFFKVKHPLEIEPERLRKIVEERGTAVFMVKFRKPDIVDDNLYPQLRRASRKIFEFLERNNFMPLRRAFKASEEFCYLLFEQQIKEISDVFRRMGPLFEDERNVKKFLSRNRALRPFIENGRWWIFEMRKFTTPEEGVRSYASTHWHTLGKNVGESIREYFEIIEGEKLFKEPVTAELCRMMGVKD"
]  # Replace with your unseen protein sequences
predictions = predict_binding_sites(model_path, unseen_proteins)
predictions
```