import torch | |
import torchvision | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import logging | |
from collections import Counter | |
from utils.MAE import mae_vit_large_patch16_dec512d8b as MAE_large | |
def get_model(args) -> nn.Module: | |
if 'ResNet' in args.model: | |
# resnet family | |
if args.model == 'ResNet50': | |
if args.pretrained == 'pretrained': | |
model = torchvision.models.resnet50(weights='IMAGENET1K_V2') | |
else: | |
model = torchvision.models.resnet50() | |
elif args.model == 'ResNet152': | |
if args.pretrained == 'pretrained': | |
model = torchvision.models.resnet152(weights='IMAGENET1K_V2') | |
else: | |
model = torchvision.models.resnet152() | |
else: | |
raise NotImplementedError | |
if args.frozen == 'frozen': | |
model = freeze_backbone(model) | |
model.fc = nn.Linear(model.fc.in_features, 6) | |
elif 'ConvNext' in args.model: | |
if args.model == 'ConvNext_Tiny': | |
if args.pretrained == 'pretrained': | |
model = torchvision.models.convnext_tiny(weights='IMAGENET1K_V1') | |
else: | |
model = torchvision.models.convnext_tiny() | |
elif args.model == 'ConvNext_Large': | |
if args.pretrained == 'pretrained': | |
model = torchvision.models.convnext_large(weights='IMAGENET1K_V1') | |
else: | |
model = torchvision.models.convnext_large() | |
else: | |
raise NotImplementedError | |
if args.frozen == 'frozen': | |
model = freeze_backbone(model) | |
num_ftrs = model.classifier[2].in_features | |
model.classifier[2] = nn.Linear(int(num_ftrs), 6) | |
elif 'ViT' in args.model: | |
if args.pretrained == 'pretrained': | |
model = torchvision.models.vit_h_14(weights='IMAGENET1K_SWAG_LINEAR_V1') | |
else: | |
raise NotImplementedError('ViT does not support training from scratch') | |
if args.frozen == 'frozen': | |
model = freeze_backbone(model) | |
model.heads[0] = torch.nn.Linear(model.heads[0].in_features, 6) | |
elif 'DINOv2' in args.model: | |
if args.pretrained == 'pretrained': | |
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc') | |
else: | |
raise NotImplementedError('DINOv2 does not support training from scratch') | |
if args.frozen == 'frozen': | |
model = freeze_backbone(model) | |
model.linear_head = torch.nn.Linear(model.linear_head.in_features, 6) | |
elif 'MAE' in args.model: | |
if args.pretrained == 'pretrained': | |
model = MAE_large() | |
model.load_state_dict(torch.load('/scratch/zf540/LUWA/workspace/utils/pretrained_weights/mae_visualize_vit_large.pth')['model']) | |
else: | |
raise NotImplementedError('MAE does not support training from scratch') | |
if args.frozen == 'frozen': | |
model = freeze_backbone(model) | |
model = nn.Sequential(model, nn.Linear(1024, 6)) | |
print(model) | |
else: | |
raise NotImplementedError | |
return model | |
def freeze_backbone(model): | |
# freeze backbone | |
# we will replace the classifier at the end with a trainable one anyway, so we freeze the default here as well | |
for param in model.parameters(): | |
param.requires_grad = False | |
return model | |
def get_name(args): | |
name = args.model | |
name += '_'+str(args.resolution) | |
name += '_'+args.magnification | |
name += '_'+args.modality | |
if args.pretrained == 'pretrained': | |
name += '_pretrained' | |
else: | |
name += '_scratch' | |
if args.frozen == 'frozen': | |
name += '_frozen' | |
else: | |
name += '_unfrozen' | |
if == 'vote': | |
name += '_vote' | |
else: | |
name += '_novote' | |
return name | |
def get_logger(path, name): | |
# set up logger | |
logger = logging.getLogger(name) | |
logger.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
file_handler = logging.FileHandler(path.joinpath(f'{name}_log.txt')) | |
file_handler.setLevel(logging.INFO) | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |'---------------------------------------------------TRANING---------------------------------------------------') | |
return logger | |
def calculate_topk_accuracy(y_pred, y, k = 3): | |
with torch.no_grad(): | |
batch_size = y.shape[0] | |
_, top_pred = y_pred.topk(k, 1) | |
top_pred = top_pred.t() | |
correct = top_pred.eq(y.view(1, -1).expand_as(top_pred)) | |
correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True) | |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim = True) | |
acc_1 = correct_1 / batch_size | |
acc_k = correct_k / batch_size | |
return acc_1, acc_k | |
def train(model, iterator, optimizer, criterion, scheduler, device): | |
epoch_loss = 0 | |
epoch_acc_1 = 0 | |
epoch_acc_3 = 0 | |
model.train() | |
for image, label, image_name in iterator: | |
x = | |
y = | |
optimizer.zero_grad() | |
y_pred = model(x) | |
print(y_pred.shape) | |
print(y.shape) | |
loss = criterion(y_pred, y) | |
acc_1, acc_3 = calculate_topk_accuracy(y_pred, y) | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
epoch_loss += loss.item() | |
epoch_acc_1 += acc_1.item() | |
epoch_acc_3 += acc_3.item() | |
epoch_loss /= len(iterator) | |
epoch_acc_1 /= len(iterator) | |
epoch_acc_3 /= len(iterator) | |
return epoch_loss, epoch_acc_1, epoch_acc_3 | |
def evaluate(model, iterator, criterion, device): | |
epoch_loss = 0 | |
epoch_acc_1 = 0 | |
epoch_acc_3 = 0 | |
model.eval() | |
with torch.no_grad(): | |
for image, label, image_name in iterator: | |
x = | |
y = | |
y_pred = model(x) | |
loss = criterion(y_pred, y) | |
acc_1, acc_3 = calculate_topk_accuracy(y_pred, y) | |
epoch_loss += loss.item() | |
epoch_acc_1 += acc_1.item() | |
epoch_acc_3 += acc_3.item() | |
epoch_loss /= len(iterator) | |
epoch_acc_1 /= len(iterator) | |
epoch_acc_3 /= len(iterator) | |
return epoch_loss, epoch_acc_1, epoch_acc_3 | |
def evaluate_vote(model, iterator, device): | |
model.eval() | |
image_names = [] | |
labels = [] | |
predictions = [] | |
with torch.no_grad(): | |
for image, label, image_name in iterator: | |
x = | |
y_pred = model(x) | |
y_prob = F.softmax(y_pred, dim = -1) | |
top_pred = y_prob.argmax(1, keepdim = True) | |
image_names.extend(image_name) | |
labels.extend(label.numpy()) | |
predictions.extend(top_pred.cpu().squeeze().numpy()) | |
conduct_voting(image_names, predictions) | |
correct_count = 0 | |
for i in range(len(labels)): | |
if labels[i] == predictions[i]: | |
correct_count += 1 | |
accuracy = correct_count/len(labels) | |
return accuracy | |
def conduct_voting(image_names, predictions): | |
# we need to do this because not all stones have the same number of partition | |
last_stone = image_names[0][:-8] # the name of the stone of the last image | |
voting_list = [] | |
for i in range(len(image_names)): | |
image_area_name = image_names[i][:-8] | |
if image_area_name != last_stone: | |
# we have run through all the images of the last stone. We start voting | |
vote(voting_list, predictions, i) | |
voting_list = [] # reset the voting list | |
voting_list.append(predictions[i]) | |
last_stone = image_area_name # update the last stone name | |
# vote for the last stone | |
vote(voting_list, predictions, len(image_names)) | |
def vote(voting_list, predictions, i): | |
vote_result = Counter(voting_list).most_common(1)[0][0] # the most common prediction in the list | |
predictions[i-len(voting_list):i] = [vote_result]*len(voting_list) # replace the predictions of the last stone with the vote result | |
# def get_predictions(model, iterator): | |
# model.eval() | |
# images = [] | |
# labels = [] | |
# probs = [] | |
# with torch.no_grad(): | |
# for (x, y) in iterator: | |
# x = | |
# y_pred = model(x) | |
# y_prob = F.softmax(y_pred, dim = -1) | |
# top_pred = y_prob.argmax(1, keepdim = True) | |
# images.append(x.cpu()) | |
# labels.append(y.cpu()) | |
# probs.append(y_prob.cpu()) | |
# images =, dim = 0) | |
# labels =, dim = 0) | |
# probs =, dim = 0) | |
# return images, labels, probs | |
# def get_representations(model, iterator): | |
# model.eval() | |
# outputs = [] | |
# intermediates = [] | |
# labels = [] | |
# with torch.no_grad(): | |
# for (x, y) in iterator: | |
# x = | |
# y_pred = model(x) | |
# outputs.append(y_pred.cpu()) | |
# labels.append(y) | |
# outputs =, dim=0) | |
# labels =, dim=0) | |
# return outputs, labels | |