Jauhar's picture
final commit to hf
deb7039
raw
history blame contribute delete
567 Bytes
import torch
def test(model, testloader, criterion, DEVICE):
model.eval()
test_loss, correct = 0.0, 0
with torch.no_grad():
for imgs, targets in testloader:
imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
pred = model(imgs)
loss = criterion(pred, targets)
test_loss += loss.item()
correct += (pred.argmax(1) == targets).type(torch.float).sum().item()
# test_loss = test_loss / len(testloader)
accuracy = correct / len(testloader.dataset) * 100
return accuracy