Andromeda / testing /accuracy.py
kye's picture
Upload 73 files
ca4fc4d
raw
history blame contribute delete
No virus
2.39 kB
import matplotlib.pyplot as plt
import time
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import tracemalloc
# from Andromeda.model import Andromeda
from Andromeda.model import Andromeda
from Andromeda.utils.stable_adamw import StableAdamWUnfused
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn.functional as F
from nltk.translate.bleu_score import corpus_bleu
from rouge import Rouge
from sklearn.metrics import f1_score
class AccuracyMetrics:
def __init__(self):
self.rouge = Rouge()
def calculate_perplexity(self, model, data_loader):
model.eval()
total_loss = 0
with torch.no_grad():
for batch in data_loader:
input_ids, labels = batch
output = model(input_ids)
loss = F.cross_entropy(output.view(-1, output.size(-1)), labels.view(-1))
total_loss += loss.item()
return torch.exp(torch.tensor(total_loss / len(data_loader)))
def calculate_bleu(self, references, hypotheses):
return corpus_bleu(references, hypotheses)
def calculate_rouge(self, references, hypotheses):
scores = self.rouge.get_scores(hypotheses, references, avg=True)
return scores
def calculate_f1(self, true_labels, pred_labels):
return f1_score(true_labels, pred_labels, average="weighted")
#mock test dataset
test_dataset = datasets.FakeData(size=1000, transform=transforms.ToTensor())
#model
model = Andromeda(
num_tokens=50304,
dim=1024,
depth=24,
dim_head=128,
heads=8,
alibi_num_heads=4
)
# Usage:
accuracy_metrics = AccuracyMetrics()
# Calculate Perplexity
perplexity = accuracy_metrics.calculate_perplexity(model, data_loader)
print('Perplexity:', perplexity)
# Calculate BLEU
bleu = accuracy_metrics.calculate_bleu(references, hypotheses)
print('BLEU Score:', bleu)
# Calculate ROUGE
rouge_scores = accuracy_metrics.calculate_rouge(references, hypotheses)
print('ROUGE Scores:', rouge_scores)
# Calculate F1 Score
f1 = accuracy_metrics.calculate_f1(true_labels, pred_labels)
print('F1 Score:', f1)
# Add at the bottom of your file
if __name__ == "__main__":
AccuracyMetrics()