LUWA / utils /experiment_utils.py
DanielXu0208's picture
Initial commit
785ef2b
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import logging
from collections import Counter
from utils.MAE import mae_vit_large_patch16_dec512d8b as MAE_large
def get_model(args) -> nn.Module:
if 'ResNet' in args.model:
# resnet family
if args.model == 'ResNet50':
if args.pretrained == 'pretrained':
model = torchvision.models.resnet50(weights='IMAGENET1K_V2')
else:
model = torchvision.models.resnet50()
elif args.model == 'ResNet152':
if args.pretrained == 'pretrained':
model = torchvision.models.resnet152(weights='IMAGENET1K_V2')
else:
model = torchvision.models.resnet152()
else:
raise NotImplementedError
if args.frozen == 'frozen':
model = freeze_backbone(model)
model.fc = nn.Linear(model.fc.in_features, 6)
elif 'ConvNext' in args.model:
if args.model == 'ConvNext_Tiny':
if args.pretrained == 'pretrained':
model = torchvision.models.convnext_tiny(weights='IMAGENET1K_V1')
else:
model = torchvision.models.convnext_tiny()
elif args.model == 'ConvNext_Large':
if args.pretrained == 'pretrained':
model = torchvision.models.convnext_large(weights='IMAGENET1K_V1')
else:
model = torchvision.models.convnext_large()
else:
raise NotImplementedError
if args.frozen == 'frozen':
model = freeze_backbone(model)
num_ftrs = model.classifier[2].in_features
model.classifier[2] = nn.Linear(int(num_ftrs), 6)
elif 'ViT' in args.model:
if args.pretrained == 'pretrained':
model = torchvision.models.vit_h_14(weights='IMAGENET1K_SWAG_LINEAR_V1')
else:
raise NotImplementedError('ViT does not support training from scratch')
if args.frozen == 'frozen':
model = freeze_backbone(model)
model.heads[0] = torch.nn.Linear(model.heads[0].in_features, 6)
elif 'DINOv2' in args.model:
if args.pretrained == 'pretrained':
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc')
else:
raise NotImplementedError('DINOv2 does not support training from scratch')
if args.frozen == 'frozen':
model = freeze_backbone(model)
model.linear_head = torch.nn.Linear(model.linear_head.in_features, 6)
elif 'MAE' in args.model:
if args.pretrained == 'pretrained':
model = MAE_large()
model.load_state_dict(torch.load('/scratch/zf540/LUWA/workspace/utils/pretrained_weights/mae_visualize_vit_large.pth')['model'])
else:
raise NotImplementedError('MAE does not support training from scratch')
if args.frozen == 'frozen':
model = freeze_backbone(model)
model = nn.Sequential(model, nn.Linear(1024, 6))
print(model)
else:
raise NotImplementedError
return model
def freeze_backbone(model):
# freeze backbone
# we will replace the classifier at the end with a trainable one anyway, so we freeze the default here as well
for param in model.parameters():
param.requires_grad = False
return model
def get_name(args):
name = args.model
name += '_'+str(args.resolution)
name += '_'+args.magnification
name += '_'+args.modality
if args.pretrained == 'pretrained':
name += '_pretrained'
else:
name += '_scratch'
if args.frozen == 'frozen':
name += '_frozen'
else:
name += '_unfrozen'
if args.vote == 'vote':
name += '_vote'
else:
name += '_novote'
return name
def get_logger(path, name):
# set up logger
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(path.joinpath(f'{name}_log.txt'))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info('---------------------------------------------------TRANING---------------------------------------------------')
return logger
def calculate_topk_accuracy(y_pred, y, k = 3):
with torch.no_grad():
batch_size = y.shape[0]
_, top_pred = y_pred.topk(k, 1)
top_pred = top_pred.t()
correct = top_pred.eq(y.view(1, -1).expand_as(top_pred))
correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim = True)
acc_1 = correct_1 / batch_size
acc_k = correct_k / batch_size
return acc_1, acc_k
def train(model, iterator, optimizer, criterion, scheduler, device):
epoch_loss = 0
epoch_acc_1 = 0
epoch_acc_3 = 0
model.train()
for image, label, image_name in iterator:
x = image.to(device)
y = label.to(device)
optimizer.zero_grad()
y_pred = model(x)
print(y_pred.shape)
print(y.shape)
loss = criterion(y_pred, y)
acc_1, acc_3 = calculate_topk_accuracy(y_pred, y)
loss.backward()
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
epoch_acc_1 += acc_1.item()
epoch_acc_3 += acc_3.item()
epoch_loss /= len(iterator)
epoch_acc_1 /= len(iterator)
epoch_acc_3 /= len(iterator)
return epoch_loss, epoch_acc_1, epoch_acc_3
def evaluate(model, iterator, criterion, device):
epoch_loss = 0
epoch_acc_1 = 0
epoch_acc_3 = 0
model.eval()
with torch.no_grad():
for image, label, image_name in iterator:
x = image.to(device)
y = label.to(device)
y_pred = model(x)
loss = criterion(y_pred, y)
acc_1, acc_3 = calculate_topk_accuracy(y_pred, y)
epoch_loss += loss.item()
epoch_acc_1 += acc_1.item()
epoch_acc_3 += acc_3.item()
epoch_loss /= len(iterator)
epoch_acc_1 /= len(iterator)
epoch_acc_3 /= len(iterator)
return epoch_loss, epoch_acc_1, epoch_acc_3
def evaluate_vote(model, iterator, device):
model.eval()
image_names = []
labels = []
predictions = []
with torch.no_grad():
for image, label, image_name in iterator:
x = image.to(device)
y_pred = model(x)
y_prob = F.softmax(y_pred, dim = -1)
top_pred = y_prob.argmax(1, keepdim = True)
image_names.extend(image_name)
labels.extend(label.numpy())
predictions.extend(top_pred.cpu().squeeze().numpy())
conduct_voting(image_names, predictions)
correct_count = 0
for i in range(len(labels)):
if labels[i] == predictions[i]:
correct_count += 1
accuracy = correct_count/len(labels)
return accuracy
def conduct_voting(image_names, predictions):
# we need to do this because not all stones have the same number of partition
last_stone = image_names[0][:-8] # the name of the stone of the last image
voting_list = []
for i in range(len(image_names)):
image_area_name = image_names[i][:-8]
if image_area_name != last_stone:
# we have run through all the images of the last stone. We start voting
vote(voting_list, predictions, i)
voting_list = [] # reset the voting list
voting_list.append(predictions[i])
last_stone = image_area_name # update the last stone name
# vote for the last stone
vote(voting_list, predictions, len(image_names))
def vote(voting_list, predictions, i):
vote_result = Counter(voting_list).most_common(1)[0][0] # the most common prediction in the list
predictions[i-len(voting_list):i] = [vote_result]*len(voting_list) # replace the predictions of the last stone with the vote result
# def get_predictions(model, iterator):
# model.eval()
# images = []
# labels = []
# probs = []
# with torch.no_grad():
# for (x, y) in iterator:
# x = x.to(device)
# y_pred = model(x)
# y_prob = F.softmax(y_pred, dim = -1)
# top_pred = y_prob.argmax(1, keepdim = True)
# images.append(x.cpu())
# labels.append(y.cpu())
# probs.append(y_prob.cpu())
# images = torch.cat(images, dim = 0)
# labels = torch.cat(labels, dim = 0)
# probs = torch.cat(probs, dim = 0)
# return images, labels, probs
# def get_representations(model, iterator):
# model.eval()
# outputs = []
# intermediates = []
# labels = []
# with torch.no_grad():
# for (x, y) in iterator:
# x = x.to(device)
# y_pred = model(x)
# outputs.append(y_pred.cpu())
# labels.append(y)
# outputs = torch.cat(outputs, dim=0)
# labels = torch.cat(labels, dim=0)
# return outputs, labels