Training:
For a report on the training please see here and
here.
Metrics:
Train:
({'accuracy': 0.9406146072672105,
'precision': 0.2947122459102886,
'recall': 0.952624323712029,
'f1': 0.4501592605994876,
'auc': 0.9464622170085311,
'mcc': 0.5118390407598565},
Test:
{'accuracy': 0.9266827008067329,
'precision': 0.22378953253253775,
'recall': 0.7790246675002842,
'f1': 0.3476966444342296,
'auc': 0.8547531675185658,
'mcc': 0.3930283737012391})
Using the Model
Using on your Protein Sequences
To use the model on one of your protein sequences try running the following:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1"
base_model_path = "facebook/esm2_t12_35M_UR50D"
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
loaded_model.eval()
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
with torch.no_grad():
logits = loaded_model(**inputs).logits
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
predictions = torch.argmax(logits, dim=2)
id2label = {
0: "No binding site",
1: "Binding site"
}
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))
Getting the Train/Test Metrics:
Head over to here
to download the dataset first. Once you have the pickle files downloaded locally, run the following:
from datasets import Dataset
from transformers import AutoTokenizer
import pickle
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
def truncate_labels(labels, max_length):
"""Truncate labels to the specified max_length."""
return [label[:max_length] for label in labels]
max_sequence_length = 1000
with open("train_sequences_chunked_by_family.pkl", "rb") as f:
train_sequences = pickle.load(f)
with open("test_sequences_chunked_by_family.pkl", "rb") as f:
test_sequences = pickle.load(f)
with open("train_labels_chunked_by_family.pkl", "rb") as f:
train_labels = pickle.load(f)
with open("test_labels_chunked_by_family.pkl", "rb") as f:
test_labels = pickle.load(f)
train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
train_labels = truncate_labels(train_labels, max_sequence_length)
test_labels = truncate_labels(test_labels, max_sequence_length)
train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
Then run the following to get the train/test metrics:
from sklearn.metrics import(
matthews_corrcoef,
accuracy_score,
precision_recall_fscore_support,
roc_auc_score
)
from peft import PeftModel
from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification
from transformers import Trainer
from accelerate import Accelerator
accelerator = Accelerator()
base_model_path = "facebook/esm2_t12_35M_UR50D"
lora_model_path = "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1"
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
model = PeftModel.from_pretrained(base_model, lora_model_path)
model = accelerator.prepare(model)
id2label = {0: "No binding site", 1: "Binding site"}
label2id = {v: k for k, v in id2label.items()}
data_collator = DataCollatorForTokenClassification(tokenizer)
def compute_metrics(dataset):
trainer = Trainer(model=model, data_collator=data_collator)
predictions, labels, _ = trainer.predict(test_dataset=dataset)
mask = labels != -100
true_labels = labels[mask].flatten()
flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()
accuracy = accuracy_score(true_labels, flat_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
auc = roc_auc_score(true_labels, flat_predictions)
mcc = matthews_corrcoef(true_labels, flat_predictions)
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}
train_metrics = compute_metrics(train_dataset)
test_metrics = compute_metrics(test_dataset)
train_metrics, test_metrics