File size: 2,835 Bytes
1f979bf
 
 
2839395
 
 
e0fc34c
33f746c
 
e0fc34c
07c5e2c
 
 
 
bec8696
 
07c5e2c
 
 
 
 
 
 
 
 
 
 
 
8abd1e4
bec8696
 
07c5e2c
 
 
 
 
 
 
 
 
 
b9eeb80
31c4997
2839395
 
 
 
 
 
 
 
 
 
3975a11
2839395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
---
license: mit
---

# ESM-2 for Predicting Binding Sites

This is the 650M parameter version of ESM-2, finetuned with QLoRA to predict binding sites of proteins based on single sequences alone. 
No multiple sequence alignment or structure is required. The embeddings from this model can also be used in structural models. The model is trained on 
approximately 12M protein sequences from UniProt, with an 80/20 train/test split. 

## Metrics

### Train Metrics

(Based on a 40% sample)

```python
'eval_loss': 0.05597764626145363, 
'eval_accuracy': 0.9829392036087405, 
'eval_precision': 0.5626191259397847, 
'eval_recall': 0.9488112528941492, 
'eval_f1': 0.7063763773187873, 
'eval_auc': 0.9662524626230765, 
'eval_mcc': 0.7235838533979579
```

### Test Metrics

Due to the size of the dataset we had to get the test metrics in chunks and aggregate. To see the metrics for each chunk, 
[refer to this text file](https://huggingface.co/AmelieSchreiber/esm2_t33_650M_qlora_binding_12M/blob/main/test_metrics.txt). 

```python
'eval_loss': 0.16281947493553162, 
'eval_accuracy': 0.9569658774883986, 
'eval_precision': 0.3209956738348438, 
'eval_recall': 0.7883697002335764, 
'eval_f1': 0.4562306866120791, 
'eval_auc': 0.8746433990040084, 
'eval_mcc': 0.48648765699020435
```

The metrics for the earlier checkpoints are not reported here yet. 

## Using the Model

```python
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t33_650M_qlora_binding_12M"
# ESM2 base model
base_model_path = "facebook/esm2_t33_650M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))
```