|
--- |
|
license: mit |
|
--- |
|
|
|
# ESM-2 QLoRA for Predicting Binding Sites |
|
|
|
This model is the ESM-2 model [esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D) finetuned with QLoRA on |
|
[this dataset](https://huggingface.co/datasets/AmelieSchreiber/2600K_binding_sites) of 2.6M protein sequences with binding and active |
|
site annotations from UniProt. The model and dataset size were scaled in a one-to-one way (following the Chinchilla paper) up from the smaller |
|
QLoRA adaptations of the `esm2_t6_8M_UR50D` models which were trained on 600K proteins. Since this model is 4.375 times larger, a dataset |
|
approximately 4.375 times larger is needed if Chinchilla scaling laws hold for QLoRA finetuning of protein language models. Determining if |
|
such scaling laws also hold is part of this project, so checking for improvements in performance metrics over a period of 3 epochs, as well |
|
as checking for signs of overfitting for each epoch are underway. |
|
|
|
|
|
## QLoRA Info |
|
|
|
``` |
|
trainable params: 71046 || all params: 17246053 || trainable%: 0.41195512967517844 |
|
``` |
|
|
|
```python |
|
'eval_loss': 0.6011912822723389, |
|
'eval_accuracy': 0.9297529150299436, |
|
'eval_precision': 0.22835223718675476, |
|
'eval_recall': 0.697386656717114, |
|
'eval_f1': 0.3440490710592986, |
|
'eval_auc': 0.8167222019799886, |
|
'eval_mcc': 0.3730152153022164 |
|
``` |
|
|
|
To use this model, run: |
|
|
|
``` |
|
!pip install transformers -q |
|
!pip install peft -q |
|
``` |
|
|
|
Then run: |
|
|
|
```python |
|
from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
|
|
# Path to the saved LoRA model |
|
model_path = "AmelieSchreiber/esm2_t12_35M_qlora_binding_2600K_cp1" |
|
# ESM2 base model |
|
base_model_path = "facebook/esm2_t12_35M_UR50D" |
|
|
|
# Load the model |
|
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) |
|
loaded_model = PeftModel.from_pretrained(base_model, model_path) |
|
|
|
# Ensure the model is in evaluation mode |
|
loaded_model.eval() |
|
|
|
# Load the tokenizer |
|
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path) |
|
|
|
# Protein sequence for inference |
|
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence |
|
|
|
# Tokenize the sequence |
|
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') |
|
|
|
# Run the model |
|
with torch.no_grad(): |
|
logits = loaded_model(**inputs).logits |
|
|
|
# Get predictions |
|
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens |
|
predictions = torch.argmax(logits, dim=2) |
|
|
|
# Define labels |
|
id2label = { |
|
0: "No binding site", |
|
1: "Binding site" |
|
} |
|
|
|
# Print the predicted labels for each token |
|
for token, prediction in zip(tokens, predictions[0].numpy()): |
|
if token not in ['<pad>', '<cls>', '<eos>']: |
|
print((token, id2label[prediction])) |
|
``` |
|
|