AmelieSchreiber
commited on
Commit
•
3de8059
1
Parent(s):
bf2d7db
Upload metrics.py
Browse files- metrics.py +116 -0
metrics.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset as TorchDataset
|
5 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer
|
6 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
|
7 |
+
from peft import PeftModel, get_peft_config, PeftConfig, get_peft_model, LoraConfig, TaskType
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# Initialize the Accelerator
|
12 |
+
accelerator = Accelerator()
|
13 |
+
|
14 |
+
class ProteinDataset(TorchDataset):
|
15 |
+
def __init__(self, sequences_path, labels_path, tokenizer, max_length):
|
16 |
+
self.tokenizer = tokenizer
|
17 |
+
self.max_length = max_length
|
18 |
+
|
19 |
+
with open(sequences_path, "rb") as f:
|
20 |
+
self.sequences = pickle.load(f)
|
21 |
+
|
22 |
+
with open(labels_path, "rb") as f:
|
23 |
+
self.labels = pickle.load(f)
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.sequences)
|
27 |
+
|
28 |
+
def __getitem__(self, idx):
|
29 |
+
sequence = self.sequences[idx]
|
30 |
+
label = self.labels[idx]
|
31 |
+
|
32 |
+
tokenized = self.tokenizer(sequence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt", is_split_into_words=False, add_special_tokens=False)
|
33 |
+
|
34 |
+
# Remove the extra batch dimension
|
35 |
+
for key in tokenized:
|
36 |
+
tokenized[key] = tokenized[key].squeeze(0)
|
37 |
+
|
38 |
+
# Ensure labels are also padded/truncated to match tokenized input
|
39 |
+
label_padded = [-100] * self.max_length # Using -100 as the ignore index
|
40 |
+
label_padded[:len(label)] = label[:self.max_length]
|
41 |
+
|
42 |
+
tokenized["labels"] = torch.tensor(label_padded)
|
43 |
+
|
44 |
+
return tokenized
|
45 |
+
|
46 |
+
def compute_metrics(p):
|
47 |
+
predictions, labels = p.predictions, p.label_ids
|
48 |
+
predictions = np.argmax(predictions, axis=2)
|
49 |
+
|
50 |
+
mask = labels != -100
|
51 |
+
predictions = predictions[mask].flatten()
|
52 |
+
labels = labels[mask].flatten()
|
53 |
+
|
54 |
+
accuracy = accuracy_score(labels, predictions)
|
55 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
|
56 |
+
auc = roc_auc_score(labels, predictions)
|
57 |
+
mcc = matthews_corrcoef(labels, predictions)
|
58 |
+
|
59 |
+
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
|
60 |
+
|
61 |
+
def evaluate_in_chunks(dataset, trainer, chunk_percentage=0.2):
|
62 |
+
chunk_size = int(len(dataset) * chunk_percentage)
|
63 |
+
all_results = []
|
64 |
+
|
65 |
+
# Wrap the loop with tqdm for progress bar
|
66 |
+
for i in tqdm(range(0, len(dataset), chunk_size), desc="Evaluating chunks"):
|
67 |
+
chunk = [dataset[j] for j in range(i, min(i + chunk_size, len(dataset)))]
|
68 |
+
chunk_results = trainer.evaluate(chunk)
|
69 |
+
print(f"Results for chunk starting at index {i}: {chunk_results}")
|
70 |
+
|
71 |
+
# Save the chunk results to disk
|
72 |
+
with open(f"results_chunk_{i}.pkl", "wb") as f:
|
73 |
+
pickle.dump(chunk_results, f)
|
74 |
+
|
75 |
+
all_results.append(chunk_results)
|
76 |
+
|
77 |
+
return all_results
|
78 |
+
|
79 |
+
def aggregate_results(results_list):
|
80 |
+
total_samples = sum([res["eval_samples"] for res in results_list])
|
81 |
+
aggregated_results = {}
|
82 |
+
|
83 |
+
for key in results_list[0].keys():
|
84 |
+
if key == "eval_samples":
|
85 |
+
continue
|
86 |
+
aggregated_results[key] = sum([res[key] * res["eval_samples"] for res in results_list]) / total_samples
|
87 |
+
|
88 |
+
return aggregated_results
|
89 |
+
|
90 |
+
# Initialize tokenizer and datasets
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
92 |
+
train_dataset = ProteinDataset("data/12M_data/512_train_sequences_chunked_by_family.pkl", "data/12M_data/512_train_labels_chunked_by_family.pkl", tokenizer, 512)
|
93 |
+
test_dataset = ProteinDataset("data/12M_data/512_test_sequences_chunked_by_family.pkl", "data/12M_data/512_test_labels_chunked_by_family.pkl", tokenizer, 512)
|
94 |
+
|
95 |
+
# Load the pre-trained LoRA model
|
96 |
+
base_model_path = "facebook/esm2_t33_650M_UR50D"
|
97 |
+
lora_model_path = "qlora_binding_sites/best_model_esm2_t33_650M_qlora_2023-10-18_02-14-48"
|
98 |
+
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
|
99 |
+
model = PeftModel.from_pretrained(base_model, lora_model_path)
|
100 |
+
model = accelerator.prepare(model)
|
101 |
+
|
102 |
+
# Initialize the Trainer
|
103 |
+
trainer = Trainer(
|
104 |
+
model=model,
|
105 |
+
compute_metrics=compute_metrics
|
106 |
+
)
|
107 |
+
|
108 |
+
Evaluate the model on chunks of the training dataset
|
109 |
+
train_results = evaluate_in_chunks(train_dataset, trainer)
|
110 |
+
aggregated_train_results = aggregate_results(train_results)
|
111 |
+
print(f"Aggregated Training Results: {aggregated_train_results}")
|
112 |
+
|
113 |
+
# Evaluate the model on chunks of the test dataset
|
114 |
+
test_results = evaluate_in_chunks(test_dataset, trainer)
|
115 |
+
aggregated_test_results = aggregate_results(test_results)
|
116 |
+
print(f"Aggregated Test Results: {aggregated_test_results}")
|