DanielXu0208 commited on
Commit
785ef2b
0 Parent(s):

Initial commit

Browse files
Files changed (40) hide show
  1. dl_supervised_pipeline.py +158 -0
  2. run_gradio.py +55 -0
  3. svm_pipeline.py +100 -0
  4. utils/MAE.py +253 -0
  5. utils/__init__.py +1 -0
  6. utils/__pycache__/MAE.cpython-311.pyc +0 -0
  7. utils/__pycache__/MAE.cpython-38.pyc +0 -0
  8. utils/__pycache__/MAE.cpython-38.pyc:Zone.Identifier +3 -0
  9. utils/__pycache__/__init__.cpython-310.pyc +0 -0
  10. utils/__pycache__/__init__.cpython-310.pyc:Zone.Identifier +3 -0
  11. utils/__pycache__/__init__.cpython-311.pyc +0 -0
  12. utils/__pycache__/__init__.cpython-38.pyc +0 -0
  13. utils/__pycache__/__init__.cpython-38.pyc:Zone.Identifier +3 -0
  14. utils/__pycache__/__init__.cpython-39.pyc +0 -0
  15. utils/__pycache__/__init__.cpython-39.pyc:Zone.Identifier +3 -0
  16. utils/__pycache__/arg_utils.cpython-38.pyc +0 -0
  17. utils/__pycache__/arg_utils.cpython-38.pyc:Zone.Identifier +3 -0
  18. utils/__pycache__/arg_utils.cpython-39.pyc +0 -0
  19. utils/__pycache__/arg_utils.cpython-39.pyc:Zone.Identifier +3 -0
  20. utils/__pycache__/experiment_utils.cpython-311.pyc +0 -0
  21. utils/__pycache__/experiment_utils.cpython-38.pyc +0 -0
  22. utils/__pycache__/experiment_utils.cpython-38.pyc:Zone.Identifier +3 -0
  23. utils/__pycache__/experiment_utils.cpython-39.pyc +0 -0
  24. utils/__pycache__/experiment_utils.cpython-39.pyc:Zone.Identifier +3 -0
  25. utils/__pycache__/model_utils.cpython-311.pyc +0 -0
  26. utils/__pycache__/model_utils.cpython-38.pyc +0 -0
  27. utils/__pycache__/model_utils.cpython-38.pyc:Zone.Identifier +3 -0
  28. utils/__pycache__/util_function.cpython-310.pyc +0 -0
  29. utils/__pycache__/util_function.cpython-310.pyc:Zone.Identifier +3 -0
  30. utils/__pycache__/util_function.cpython-311.pyc +0 -0
  31. utils/__pycache__/util_function.cpython-38.pyc +0 -0
  32. utils/__pycache__/util_function.cpython-38.pyc:Zone.Identifier +3 -0
  33. utils/__pycache__/util_function.cpython-39.pyc +0 -0
  34. utils/__pycache__/util_function.cpython-39.pyc:Zone.Identifier +3 -0
  35. utils/arg_utils.py +18 -0
  36. utils/experiment_utils.py +298 -0
  37. utils/model_utils.py +96 -0
  38. utils/util_function.py +238 -0
  39. vis_confusion_mtx.py +54 -0
  40. vote_analysis.py +107 -0
dl_supervised_pipeline.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code modified from pytorch-image-classification
2
+ # obtained from https://colab.research.google.com/github/bentrevett/pytorch-image-classification/blob/master/5_resnet.ipynb#scrollTo=4QmwmcXuPuLo
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import torch.optim as optim
9
+ import torch.optim.lr_scheduler as lr_scheduler
10
+
11
+ import torch.utils.data as data
12
+
13
+ import numpy as np
14
+ import random
15
+ import tqdm
16
+ import os
17
+ from pathlib import Path
18
+
19
+ from data_utils.data_tribology import TribologyDataset
20
+ from utils.experiment_utils import get_model, get_name, get_logger, train, evaluate, evaluate_vote
21
+ from utils.arg_utils import get_args
22
+
23
+ def main(args):
24
+ '''Reproducibility'''
25
+ SEED = args.seed
26
+ random.seed(SEED)
27
+ np.random.seed(SEED)
28
+ torch.manual_seed(SEED)
29
+ torch.cuda.manual_seed(SEED)
30
+ torch.backends.cudnn.deterministic = True
31
+ torch.backends.cudnn.benchmark = False
32
+
33
+ '''Folder Creation'''
34
+ basepath=os.getcwd()
35
+ experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
36
+ experiment_dir.mkdir(parents=True, exist_ok=True)
37
+ checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
38
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
39
+
40
+ '''Logging'''
41
+ model_name = get_name(args)
42
+ print(model_name, 'STARTED')
43
+ if os.path.exists(checkpoint_dir / 'epoch10.pth'):
44
+ print('CHECKPOINT FOUND')
45
+ print('TERMINATING TRAINING')
46
+ return 0 # terminate training if checkpoint exists
47
+
48
+ logger = get_logger(experiment_dir, model_name)
49
+
50
+ '''Data Loading'''
51
+ train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
52
+ test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
53
+ img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
54
+
55
+ # results_acc_1 = {}
56
+ # results_acc_3 = {}
57
+ # classes_num = 6
58
+ BATCHSIZE = args.batch_size
59
+ train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
60
+ test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
61
+
62
+ # prepare the data augmentation
63
+ means, stds = train_dataset.get_statistics()
64
+ train_dataset.prepare_transform(means, stds, mode='train')
65
+ test_dataset.prepare_transform(means, stds, mode='test')
66
+
67
+ VALID_RATIO = 0.1
68
+
69
+ num_train = len(train_dataset)
70
+ num_valid = int(VALID_RATIO * num_train)
71
+ train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
72
+ logger.info(f'Number of training samples: {len(train_dataset)}')
73
+ logger.info(f'Number of validation samples: {len(valid_dataset)}')
74
+ train_iterator = torch.utils.data.DataLoader(train_dataset,
75
+ batch_size=BATCHSIZE,
76
+ num_workers=4,
77
+ shuffle=True,
78
+ pin_memory=True,
79
+ drop_last=False)
80
+
81
+ valid_iterator = torch.utils.data.DataLoader(valid_dataset,
82
+ batch_size=BATCHSIZE,
83
+ num_workers=4,
84
+ shuffle=True,
85
+ pin_memory=True,
86
+ drop_last=False)
87
+ test_iterator = torch.utils.data.DataLoader(test_dataset,
88
+ batch_size=BATCHSIZE,
89
+ num_workers=4,
90
+ shuffle=False,
91
+ pin_memory=True,
92
+ drop_last=False)
93
+ print('DATA LOADED')
94
+
95
+ # Define model
96
+ model = get_model(args)
97
+ print('MODEL LOADED')
98
+
99
+ # Define optimizer and scheduler
100
+ START_LR = args.start_lr
101
+ optimizer = optim.Adam(model.parameters(), lr=START_LR)
102
+ STEPS_PER_EPOCH = len(train_iterator)
103
+ print('STEPS_PER_EPOCH:', STEPS_PER_EPOCH)
104
+ print('VALIDATION STEPS:', len(valid_iterator))
105
+ scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=max(STEPS_PER_EPOCH,STEPS_PER_EPOCH//10))
106
+
107
+ # Define loss function
108
+ criterion = nn.CrossEntropyLoss()
109
+
110
+ # Define device
111
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
112
+ model = model.to(device)
113
+ criterion = criterion.to(device)
114
+
115
+ EPOCHS = args.epochs
116
+
117
+ print('SETUP DONE')
118
+ # train our model
119
+
120
+ print('TRAINING STARTED')
121
+ for epoch in tqdm.tqdm(range(EPOCHS)):
122
+
123
+ train_loss, train_acc_1, train_acc_3 = train(model, train_iterator, optimizer, criterion, scheduler, device)
124
+
125
+ torch.cuda.empty_cache() # clear cache between train and val
126
+
127
+ valid_loss, valid_acc_1, valid_acc_3 = evaluate(model, valid_iterator, criterion, device)
128
+
129
+ torch.save(model.state_dict(), checkpoint_dir / f'epoch{epoch+1}.pth')
130
+
131
+ logger.info(f'Epoch: {epoch + 1:02}')
132
+ logger.info(f'\tTrain Loss: {train_loss:.3f} | Train Acc @1: {train_acc_1 * 100:6.2f}% | ' \
133
+ f'Train Acc @3: {train_acc_3 * 100:6.2f}%')
134
+ logger.info(f'\tValid Loss: {valid_loss:.3f} | Valid Acc @1: {valid_acc_1 * 100:6.2f}% | ' \
135
+ f'Valid Acc @3: {valid_acc_3 * 100:6.2f}%')
136
+
137
+ logger.info('-------------------End of Training-------------------')
138
+ print('TRAINING DONE')
139
+ logger.info('-------------------Beginning of Testing-------------------')
140
+ print('TESTING STARTED')
141
+ for epoch in tqdm.tqdm(range(EPOCHS)):
142
+ model.load_state_dict(torch.load(checkpoint_dir / f'epoch{epoch+1}.pth'))
143
+
144
+ if args.vote == 'vote':
145
+ test_acc = evaluate_vote(model, test_iterator, device)
146
+ logger.info(f'Test Acc @1: {test_acc * 100:6.2f}%')
147
+ else:
148
+ test_loss, test_acc_1, test_acc_3 = evaluate(model, test_iterator, criterion, device)
149
+
150
+ logger.info(f'Test Acc @1: {test_acc_1 * 100:6.2f}% | ' \
151
+ f'Test Acc @3: {test_acc_3 * 100:6.2f}%')
152
+ logger.info('-------------------End of Testing-------------------')
153
+ print('TESTING DONE')
154
+
155
+
156
+ if __name__ == '__main__':
157
+ args = get_args()
158
+ main(args)
run_gradio.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision
4
+ from utils.experiment_utils import get_model
5
+
6
+
7
+ # 加载DINOv2模型
8
+ def load_model():
9
+ class Args:
10
+ model = 'DINOv2'
11
+ pretrained = 'pretrained'
12
+ frozen = 'unfrozen'
13
+
14
+ args = Args()
15
+ model = get_model(args)
16
+ model.eval()
17
+ return model
18
+
19
+
20
+ model = load_model()
21
+
22
+
23
+ # 预测函数,返回每个类别的概率
24
+ def predict(image):
25
+ transform = torchvision.transforms.Compose([
26
+ torchvision.transforms.Resize((224, 224)),
27
+ torchvision.transforms.ToTensor(),
28
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29
+ ])
30
+
31
+ image = transform(image).unsqueeze(0)
32
+ with torch.no_grad():
33
+ output = model(image)
34
+ probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
35
+
36
+ # 类别名称列表
37
+ class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
38
+
39
+ # 将类别和对应的概率配对
40
+ results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
41
+
42
+ return results
43
+
44
+
45
+ # 创建Gradio界面
46
+ interface = gr.Interface(
47
+ fn=predict,
48
+ inputs=gr.Image(type="pil"),
49
+ outputs=gr.Label(num_top_classes=len(["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"])),
50
+ title="LUWA DINOv2 Prediction",
51
+ description="Upload an image to get the probabilities for each class using the DINOv2 model."
52
+ )
53
+
54
+ if __name__ == "__main__":
55
+ interface.launch(share=True)
svm_pipeline.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.svm import LinearSVC
3
+
4
+ from skimage.feature import fisher_vector, learn_gmm
5
+
6
+ import numpy as np
7
+ import random
8
+ import os
9
+ from pathlib import Path
10
+
11
+ from data_utils.data_tribology import TribologyDataset
12
+ from utils.arg_utils import get_args
13
+ from utils.experiment_utils import get_name, get_logger, SIFT_extraction, conduct_voting
14
+ from utils.visualization_utils import plot_confusion_matrix
15
+ from vis_confusion_mtx import generate_confusion_matrix
16
+
17
+ def main(args):
18
+ '''Reproducibility'''
19
+ SEED = args.seed
20
+ random.seed(SEED)
21
+ np.random.seed(SEED)
22
+
23
+ '''Folder Creation'''
24
+ basepath=os.getcwd()
25
+ experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.vote))
26
+ experiment_dir.mkdir(parents=True, exist_ok=True)
27
+ checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
28
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ '''Logging'''
31
+ model_name = get_name(args)
32
+ print(model_name, 'STARTED', flush=True)
33
+ logger = get_logger(experiment_dir, model_name)
34
+
35
+ '''Data Loading'''
36
+ train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
37
+ test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
38
+ img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
39
+
40
+ BATCHSIZE = args.batch_size
41
+ train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
42
+ test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
43
+
44
+ # prepare the data augmentation
45
+ means, stds = train_dataset.get_statistics()
46
+ train_dataset.prepare_transform(means, stds, mode='train')
47
+ test_dataset.prepare_transform(means, stds, mode='test')
48
+
49
+ VALID_RATIO = 0.1
50
+
51
+ num_train = len(train_dataset)
52
+ num_valid = int(VALID_RATIO * num_train)
53
+ # train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
54
+ # logger.info(f'Number of training samples: {len(train_dataset)}')
55
+ # logger.info(f'Number of validation samples: {len(valid_dataset)}')
56
+
57
+ train_names, train_descriptor, train_labels = SIFT_extraction(train_dataset)
58
+ test_names, test_descriptor, test_labels = SIFT_extraction(test_dataset)
59
+ # val_descriptor, val_labels = SIFT_extraction(valid_dataset)
60
+ print('DATA LOADED', flush=True)
61
+
62
+ print('TRAINING STARTED', flush=True)
63
+
64
+ # Train a K-mode GMM
65
+ k = 16
66
+ gmm = learn_gmm(train_descriptor, n_modes=k)
67
+
68
+ # Compute the Fisher vectors
69
+ training_fvs = np.array([
70
+ fisher_vector(descriptor_mat, gmm)
71
+ for descriptor_mat in train_descriptor
72
+ ])
73
+
74
+ testing_fvs = np.array([
75
+ fisher_vector(descriptor_mat, gmm)
76
+ for descriptor_mat in test_descriptor
77
+ ])
78
+
79
+ svm = LinearSVC().fit(training_fvs, train_labels)
80
+
81
+ logger.info('-------------------End of Training-------------------')
82
+ print('TRAINING DONE')
83
+ logger.info('-------------------Beginning of Testing-------------------')
84
+ print('TESTING STARTED')
85
+ predictions = svm.predict(testing_fvs)
86
+ conduct_voting(test_names, predictions)
87
+ plot_confusion_matrix('visualization_results/SIFT+FVs_confusion_mtx.png', predictions, test_labels,classes=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY","SPRUCEWOOD"])
88
+ correct = 0
89
+ for i in range(len(predictions)):
90
+ if predictions[i] == test_labels[i]:
91
+ correct += 1
92
+ test_acc = float(correct)/len(predictions)
93
+ logger.info(f'Test Acc @1: {test_acc * 100:6.2f}%')
94
+
95
+ logger.info('-------------------End of Testing-------------------')
96
+ print('TESTING DONE')
97
+
98
+ if __name__ == '__main__':
99
+ args = get_args()
100
+ main(args)
utils/MAE.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from timm.models.vision_transformer import PatchEmbed, Block
18
+
19
+ from utils.model_utils import get_2d_sincos_pos_embed
20
+
21
+
22
+ class MaskedAutoencoderViT(nn.Module):
23
+ """ Masked Autoencoder with VisionTransformer backbone
24
+ """
25
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
26
+ embed_dim=1024, depth=24, num_heads=16,
27
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
28
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
29
+ super().__init__()
30
+
31
+ # --------------------------------------------------------------------------
32
+ # MAE encoder specifics
33
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
34
+ num_patches = self.patch_embed.num_patches
35
+
36
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
37
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
38
+
39
+ self.blocks = nn.ModuleList([
40
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
41
+ for i in range(depth)])
42
+ self.norm = norm_layer(embed_dim)
43
+ # --------------------------------------------------------------------------
44
+
45
+ # --------------------------------------------------------------------------
46
+ # MAE decoder specifics
47
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
48
+
49
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
50
+
51
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
52
+
53
+ self.decoder_blocks = nn.ModuleList([
54
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
55
+ for i in range(decoder_depth)])
56
+
57
+ self.decoder_norm = norm_layer(decoder_embed_dim)
58
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
59
+ # --------------------------------------------------------------------------
60
+
61
+ self.norm_pix_loss = norm_pix_loss
62
+
63
+ self.initialize_weights()
64
+
65
+ def initialize_weights(self):
66
+ # initialization
67
+ # initialize (and freeze) pos_embed by sin-cos embedding
68
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
69
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
70
+
71
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
72
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
73
+
74
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
75
+ w = self.patch_embed.proj.weight.data
76
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
77
+
78
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
79
+ torch.nn.init.normal_(self.cls_token, std=.02)
80
+ torch.nn.init.normal_(self.mask_token, std=.02)
81
+
82
+ # initialize nn.Linear and nn.LayerNorm
83
+ self.apply(self._init_weights)
84
+
85
+ def _init_weights(self, m):
86
+ if isinstance(m, nn.Linear):
87
+ # we use xavier_uniform following official JAX ViT:
88
+ torch.nn.init.xavier_uniform_(m.weight)
89
+ if isinstance(m, nn.Linear) and m.bias is not None:
90
+ nn.init.constant_(m.bias, 0)
91
+ elif isinstance(m, nn.LayerNorm):
92
+ nn.init.constant_(m.bias, 0)
93
+ nn.init.constant_(m.weight, 1.0)
94
+
95
+ def patchify(self, imgs):
96
+ """
97
+ imgs: (N, 3, H, W)
98
+ x: (N, L, patch_size**2 *3)
99
+ """
100
+ p = self.patch_embed.patch_size[0]
101
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
102
+
103
+ h = w = imgs.shape[2] // p
104
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
105
+ x = torch.einsum('nchpwq->nhwpqc', x)
106
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
107
+ return x
108
+
109
+ def unpatchify(self, x):
110
+ """
111
+ x: (N, L, patch_size**2 *3)
112
+ imgs: (N, 3, H, W)
113
+ """
114
+ p = self.patch_embed.patch_size[0]
115
+ h = w = int(x.shape[1]**.5)
116
+ assert h * w == x.shape[1]
117
+
118
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
119
+ x = torch.einsum('nhwpqc->nchpwq', x)
120
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
121
+ return imgs
122
+
123
+ def random_masking(self, x, mask_ratio):
124
+ """
125
+ Perform per-sample random masking by per-sample shuffling.
126
+ Per-sample shuffling is done by argsort random noise.
127
+ x: [N, L, D], sequence
128
+ """
129
+ N, L, D = x.shape # batch, length, dim
130
+ len_keep = int(L * (1 - mask_ratio))
131
+
132
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
133
+
134
+ # sort noise for each sample
135
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
136
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
137
+
138
+ # keep the first subset
139
+ ids_keep = ids_shuffle[:, :len_keep]
140
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
141
+
142
+ # generate the binary mask: 0 is keep, 1 is remove
143
+ mask = torch.ones([N, L], device=x.device)
144
+ mask[:, :len_keep] = 0
145
+ # unshuffle to get the binary mask
146
+ mask = torch.gather(mask, dim=1, index=ids_restore)
147
+
148
+ return x_masked, mask, ids_restore
149
+
150
+ def forward_encoder(self, x, mask_ratio):
151
+ # embed patches
152
+ x = self.patch_embed(x)
153
+
154
+ # add pos embed w/o cls token
155
+ x = x + self.pos_embed[:, 1:, :]
156
+
157
+ # masking: length -> length * mask_ratio
158
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
159
+
160
+ # append cls token
161
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
162
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
163
+ x = torch.cat((cls_tokens, x), dim=1)
164
+
165
+ # apply Transformer blocks
166
+ for blk in self.blocks:
167
+ x = blk(x)
168
+ x = self.norm(x)
169
+
170
+ return x, mask, ids_restore
171
+
172
+ def forward_decoder(self, x, ids_restore):
173
+ # embed tokens
174
+ x = self.decoder_embed(x)
175
+
176
+ # append mask tokens to sequence
177
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
178
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
179
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
180
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
181
+
182
+ # add pos embed
183
+ x = x + self.decoder_pos_embed
184
+
185
+ # apply Transformer blocks
186
+ for blk in self.decoder_blocks:
187
+ x = blk(x)
188
+ x = self.decoder_norm(x)
189
+
190
+ # predictor projection
191
+ x = self.decoder_pred(x)
192
+
193
+ # remove cls token
194
+ x = x[:, 1:, :]
195
+
196
+ return x
197
+
198
+ def forward_loss(self, imgs, pred, mask):
199
+ """
200
+ imgs: [N, 3, H, W]
201
+ pred: [N, L, p*p*3]
202
+ mask: [N, L], 0 is keep, 1 is remove,
203
+ """
204
+ target = self.patchify(imgs)
205
+ if self.norm_pix_loss:
206
+ mean = target.mean(dim=-1, keepdim=True)
207
+ var = target.var(dim=-1, keepdim=True)
208
+ target = (target - mean) / (var + 1.e-6)**.5
209
+
210
+ loss = (pred - target) ** 2
211
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
212
+
213
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
214
+ return loss
215
+
216
+ def forward(self, imgs, mask_ratio=0.75):
217
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
218
+ # pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
219
+ # loss = self.forward_loss(imgs, pred, mask)
220
+ # return loss, pred, mask
221
+ print(latent.shape)
222
+ return latent
223
+
224
+
225
+
226
+ def mae_vit_base_patch16_dec512d8b(**kwargs):
227
+ model = MaskedAutoencoderViT(
228
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
229
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
230
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
231
+ return model
232
+
233
+
234
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
235
+ model = MaskedAutoencoderViT(
236
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
237
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
238
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
239
+ return model
240
+
241
+
242
+ def mae_vit_huge_patch14_dec512d8b(**kwargs):
243
+ model = MaskedAutoencoderViT(
244
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16,
245
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
246
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
247
+ return model
248
+
249
+
250
+ # set recommended archs
251
+ mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
252
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
253
+ mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .util_function import epoch_time, plot_lr_finder, plot_confusion_matrix, plot_most_incorrect, get_pca, plot_representations, plot_filtered_images, plot_filters
utils/__pycache__/MAE.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
utils/__pycache__/MAE.cpython-38.pyc ADDED
Binary file (7.16 kB). View file
 
utils/__pycache__/MAE.cpython-38.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (422 Bytes). View file
 
utils/__pycache__/__init__.cpython-310.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (516 Bytes). View file
 
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (376 Bytes). View file
 
utils/__pycache__/__init__.cpython-38.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (392 Bytes). View file
 
utils/__pycache__/__init__.cpython-39.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/arg_utils.cpython-38.pyc ADDED
Binary file (1.03 kB). View file
 
utils/__pycache__/arg_utils.cpython-38.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/arg_utils.cpython-39.pyc ADDED
Binary file (1.05 kB). View file
 
utils/__pycache__/arg_utils.cpython-39.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/experiment_utils.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
utils/__pycache__/experiment_utils.cpython-38.pyc ADDED
Binary file (5.71 kB). View file
 
utils/__pycache__/experiment_utils.cpython-38.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/experiment_utils.cpython-39.pyc ADDED
Binary file (4.93 kB). View file
 
utils/__pycache__/experiment_utils.cpython-39.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (4.24 kB). View file
 
utils/__pycache__/model_utils.cpython-38.pyc ADDED
Binary file (2.4 kB). View file
 
utils/__pycache__/model_utils.cpython-38.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/util_function.cpython-310.pyc ADDED
Binary file (5.35 kB). View file
 
utils/__pycache__/util_function.cpython-310.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/util_function.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
utils/__pycache__/util_function.cpython-38.pyc ADDED
Binary file (6.8 kB). View file
 
utils/__pycache__/util_function.cpython-38.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/__pycache__/util_function.cpython-39.pyc ADDED
Binary file (6.82 kB). View file
 
utils/__pycache__/util_function.cpython-39.pyc:Zone.Identifier ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [ZoneTransfer]
2
+ ZoneId=3
3
+ ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
utils/arg_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def get_args():
4
+ # Training settings
5
+ parser = argparse.ArgumentParser('train')
6
+
7
+ parser.add_argument('--resolution', type=str, default='256', help='Resolution of input image')
8
+ parser.add_argument('--magnification', type=str, default='20x', help='Magnification of input image')
9
+ parser.add_argument('--modality', type=str, default='texture', help='Modality of input image')
10
+ parser.add_argument('--model', type=str, default='ResNet50', help='Model to use')
11
+ parser.add_argument('--pretrained', type=str, default='pretrained', help='Use pretrained model')
12
+ parser.add_argument('--frozen', type=str, default='unfrozen', help='Freeze pretrained model')
13
+ parser.add_argument('--vote', type=str, default='vote', help='Conduct voting')
14
+ parser.add_argument('--epochs', type=int, default=2, help='Number of epochs to train')
15
+ parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
16
+ parser.add_argument('--start_lr', type=float, default=0.01, help='Learning rate')
17
+ parser.add_argument('--seed', type=int, default=1234, help='Random seed')
18
+ return parser.parse_args()
utils/experiment_utils.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import logging
6
+ from collections import Counter
7
+ from utils.MAE import mae_vit_large_patch16_dec512d8b as MAE_large
8
+
9
+ def get_model(args) -> nn.Module:
10
+ if 'ResNet' in args.model:
11
+ # resnet family
12
+ if args.model == 'ResNet50':
13
+ if args.pretrained == 'pretrained':
14
+ model = torchvision.models.resnet50(weights='IMAGENET1K_V2')
15
+ else:
16
+ model = torchvision.models.resnet50()
17
+ elif args.model == 'ResNet152':
18
+ if args.pretrained == 'pretrained':
19
+ model = torchvision.models.resnet152(weights='IMAGENET1K_V2')
20
+ else:
21
+ model = torchvision.models.resnet152()
22
+ else:
23
+ raise NotImplementedError
24
+ if args.frozen == 'frozen':
25
+ model = freeze_backbone(model)
26
+ model.fc = nn.Linear(model.fc.in_features, 6)
27
+
28
+ elif 'ConvNext' in args.model:
29
+ if args.model == 'ConvNext_Tiny':
30
+ if args.pretrained == 'pretrained':
31
+ model = torchvision.models.convnext_tiny(weights='IMAGENET1K_V1')
32
+ else:
33
+ model = torchvision.models.convnext_tiny()
34
+ elif args.model == 'ConvNext_Large':
35
+ if args.pretrained == 'pretrained':
36
+ model = torchvision.models.convnext_large(weights='IMAGENET1K_V1')
37
+ else:
38
+ model = torchvision.models.convnext_large()
39
+ else:
40
+ raise NotImplementedError
41
+ if args.frozen == 'frozen':
42
+ model = freeze_backbone(model)
43
+ num_ftrs = model.classifier[2].in_features
44
+ model.classifier[2] = nn.Linear(int(num_ftrs), 6)
45
+
46
+ elif 'ViT' in args.model:
47
+ if args.pretrained == 'pretrained':
48
+ model = torchvision.models.vit_h_14(weights='IMAGENET1K_SWAG_LINEAR_V1')
49
+ else:
50
+ raise NotImplementedError('ViT does not support training from scratch')
51
+ if args.frozen == 'frozen':
52
+ model = freeze_backbone(model)
53
+ model.heads[0] = torch.nn.Linear(model.heads[0].in_features, 6)
54
+
55
+ elif 'DINOv2' in args.model:
56
+ if args.pretrained == 'pretrained':
57
+ model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc')
58
+ else:
59
+ raise NotImplementedError('DINOv2 does not support training from scratch')
60
+ if args.frozen == 'frozen':
61
+ model = freeze_backbone(model)
62
+ model.linear_head = torch.nn.Linear(model.linear_head.in_features, 6)
63
+
64
+ elif 'MAE' in args.model:
65
+ if args.pretrained == 'pretrained':
66
+ model = MAE_large()
67
+ model.load_state_dict(torch.load('/scratch/zf540/LUWA/workspace/utils/pretrained_weights/mae_visualize_vit_large.pth')['model'])
68
+ else:
69
+ raise NotImplementedError('MAE does not support training from scratch')
70
+ if args.frozen == 'frozen':
71
+ model = freeze_backbone(model)
72
+ model = nn.Sequential(model, nn.Linear(1024, 6))
73
+ print(model)
74
+ else:
75
+ raise NotImplementedError
76
+ return model
77
+
78
+
79
+ def freeze_backbone(model):
80
+ # freeze backbone
81
+ # we will replace the classifier at the end with a trainable one anyway, so we freeze the default here as well
82
+ for param in model.parameters():
83
+ param.requires_grad = False
84
+ return model
85
+
86
+ def get_name(args):
87
+ name = args.model
88
+ name += '_'+str(args.resolution)
89
+ name += '_'+args.magnification
90
+ name += '_'+args.modality
91
+ if args.pretrained == 'pretrained':
92
+ name += '_pretrained'
93
+ else:
94
+ name += '_scratch'
95
+ if args.frozen == 'frozen':
96
+ name += '_frozen'
97
+ else:
98
+ name += '_unfrozen'
99
+ if args.vote == 'vote':
100
+ name += '_vote'
101
+ else:
102
+ name += '_novote'
103
+ return name
104
+
105
+ def get_logger(path, name):
106
+ # set up logger
107
+
108
+ logger = logging.getLogger(name)
109
+ logger.setLevel(logging.INFO)
110
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
111
+ file_handler = logging.FileHandler(path.joinpath(f'{name}_log.txt'))
112
+ file_handler.setLevel(logging.INFO)
113
+ file_handler.setFormatter(formatter)
114
+ logger.addHandler(file_handler)
115
+ logger.info('---------------------------------------------------TRANING---------------------------------------------------')
116
+
117
+ return logger
118
+
119
+ def calculate_topk_accuracy(y_pred, y, k = 3):
120
+ with torch.no_grad():
121
+ batch_size = y.shape[0]
122
+ _, top_pred = y_pred.topk(k, 1)
123
+ top_pred = top_pred.t()
124
+ correct = top_pred.eq(y.view(1, -1).expand_as(top_pred))
125
+ correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True)
126
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim = True)
127
+ acc_1 = correct_1 / batch_size
128
+ acc_k = correct_k / batch_size
129
+ return acc_1, acc_k
130
+
131
+ def train(model, iterator, optimizer, criterion, scheduler, device):
132
+ epoch_loss = 0
133
+ epoch_acc_1 = 0
134
+ epoch_acc_3 = 0
135
+
136
+ model.train()
137
+
138
+ for image, label, image_name in iterator:
139
+ x = image.to(device)
140
+ y = label.to(device)
141
+
142
+ optimizer.zero_grad()
143
+
144
+ y_pred = model(x)
145
+ print(y_pred.shape)
146
+ print(y.shape)
147
+ loss = criterion(y_pred, y)
148
+
149
+ acc_1, acc_3 = calculate_topk_accuracy(y_pred, y)
150
+
151
+ loss.backward()
152
+
153
+ optimizer.step()
154
+
155
+ scheduler.step()
156
+
157
+ epoch_loss += loss.item()
158
+ epoch_acc_1 += acc_1.item()
159
+ epoch_acc_3 += acc_3.item()
160
+
161
+ epoch_loss /= len(iterator)
162
+ epoch_acc_1 /= len(iterator)
163
+ epoch_acc_3 /= len(iterator)
164
+
165
+ return epoch_loss, epoch_acc_1, epoch_acc_3
166
+
167
+
168
+ def evaluate(model, iterator, criterion, device):
169
+ epoch_loss = 0
170
+ epoch_acc_1 = 0
171
+ epoch_acc_3 = 0
172
+
173
+ model.eval()
174
+
175
+ with torch.no_grad():
176
+ for image, label, image_name in iterator:
177
+ x = image.to(device)
178
+ y = label.to(device)
179
+
180
+ y_pred = model(x)
181
+ loss = criterion(y_pred, y)
182
+
183
+ acc_1, acc_3 = calculate_topk_accuracy(y_pred, y)
184
+
185
+ epoch_loss += loss.item()
186
+ epoch_acc_1 += acc_1.item()
187
+ epoch_acc_3 += acc_3.item()
188
+
189
+ epoch_loss /= len(iterator)
190
+ epoch_acc_1 /= len(iterator)
191
+ epoch_acc_3 /= len(iterator)
192
+
193
+ return epoch_loss, epoch_acc_1, epoch_acc_3
194
+
195
+ def evaluate_vote(model, iterator, device):
196
+
197
+ model.eval()
198
+
199
+ image_names = []
200
+ labels = []
201
+ predictions = []
202
+
203
+ with torch.no_grad():
204
+
205
+ for image, label, image_name in iterator:
206
+
207
+ x = image.to(device)
208
+
209
+ y_pred = model(x)
210
+ y_prob = F.softmax(y_pred, dim = -1)
211
+ top_pred = y_prob.argmax(1, keepdim = True)
212
+
213
+ image_names.extend(image_name)
214
+ labels.extend(label.numpy())
215
+ predictions.extend(top_pred.cpu().squeeze().numpy())
216
+
217
+ conduct_voting(image_names, predictions)
218
+
219
+ correct_count = 0
220
+ for i in range(len(labels)):
221
+ if labels[i] == predictions[i]:
222
+ correct_count += 1
223
+ accuracy = correct_count/len(labels)
224
+ return accuracy
225
+
226
+ def conduct_voting(image_names, predictions):
227
+ # we need to do this because not all stones have the same number of partition
228
+ last_stone = image_names[0][:-8] # the name of the stone of the last image
229
+ voting_list = []
230
+ for i in range(len(image_names)):
231
+ image_area_name = image_names[i][:-8]
232
+ if image_area_name != last_stone:
233
+ # we have run through all the images of the last stone. We start voting
234
+ vote(voting_list, predictions, i)
235
+ voting_list = [] # reset the voting list
236
+ voting_list.append(predictions[i])
237
+ last_stone = image_area_name # update the last stone name
238
+
239
+ # vote for the last stone
240
+ vote(voting_list, predictions, len(image_names))
241
+
242
+ def vote(voting_list, predictions, i):
243
+ vote_result = Counter(voting_list).most_common(1)[0][0] # the most common prediction in the list
244
+ predictions[i-len(voting_list):i] = [vote_result]*len(voting_list) # replace the predictions of the last stone with the vote result
245
+
246
+
247
+
248
+
249
+ # def get_predictions(model, iterator):
250
+
251
+ # model.eval()
252
+
253
+ # images = []
254
+ # labels = []
255
+ # probs = []
256
+
257
+ # with torch.no_grad():
258
+
259
+ # for (x, y) in iterator:
260
+
261
+ # x = x.to(device)
262
+
263
+ # y_pred = model(x)
264
+
265
+ # y_prob = F.softmax(y_pred, dim = -1)
266
+ # top_pred = y_prob.argmax(1, keepdim = True)
267
+
268
+ # images.append(x.cpu())
269
+ # labels.append(y.cpu())
270
+ # probs.append(y_prob.cpu())
271
+
272
+ # images = torch.cat(images, dim = 0)
273
+ # labels = torch.cat(labels, dim = 0)
274
+ # probs = torch.cat(probs, dim = 0)
275
+
276
+ # return images, labels, probs
277
+
278
+
279
+ # def get_representations(model, iterator):
280
+ # model.eval()
281
+
282
+ # outputs = []
283
+ # intermediates = []
284
+ # labels = []
285
+
286
+ # with torch.no_grad():
287
+ # for (x, y) in iterator:
288
+ # x = x.to(device)
289
+
290
+ # y_pred = model(x)
291
+
292
+ # outputs.append(y_pred.cpu())
293
+ # labels.append(y)
294
+
295
+ # outputs = torch.cat(outputs, dim=0)
296
+ # labels = torch.cat(labels, dim=0)
297
+
298
+ # return outputs, labels
utils/model_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """
51
+ embed_dim: output dimension for each position
52
+ pos: a list of positions to be encoded: size (M,)
53
+ out: (M, D)
54
+ """
55
+ assert embed_dim % 2 == 0
56
+ omega = np.arange(embed_dim // 2, dtype=float)
57
+ omega /= embed_dim / 2.
58
+ omega = 1. / 10000**omega # (D/2,)
59
+
60
+ pos = pos.reshape(-1) # (M,)
61
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62
+
63
+ emb_sin = np.sin(out) # (M, D/2)
64
+ emb_cos = np.cos(out) # (M, D/2)
65
+
66
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67
+ return emb
68
+
69
+
70
+ # --------------------------------------------------------
71
+ # Interpolate position embeddings for high-resolution
72
+ # References:
73
+ # DeiT: https://github.com/facebookresearch/deit
74
+ # --------------------------------------------------------
75
+ def interpolate_pos_embed(model, checkpoint_model):
76
+ if 'pos_embed' in checkpoint_model:
77
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
78
+ embedding_size = pos_embed_checkpoint.shape[-1]
79
+ num_patches = model.patch_embed.num_patches
80
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81
+ # height (== width) for the checkpoint position embedding
82
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83
+ # height (== width) for the new position embedding
84
+ new_size = int(num_patches ** 0.5)
85
+ # class_token and dist_token are kept unchanged
86
+ if orig_size != new_size:
87
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89
+ # only the position tokens are interpolated
90
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92
+ pos_tokens = torch.nn.functional.interpolate(
93
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96
+ checkpoint_model['pos_embed'] = new_pos_embed
utils/util_function.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from sklearn.manifold import TSNE
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import torch.nn.functional as F
7
+ from sklearn.metrics import confusion_matrix
8
+ from sklearn.metrics import ConfusionMatrixDisplay
9
+ from sklearn import decomposition
10
+ import itertools
11
+
12
+ def normalize_image(image):
13
+ image_min = image.min()
14
+ image_max = image.max()
15
+ image.clamp_(min = image_min, max = image_max)
16
+ image.add_(-image_min).div_(image_max - image_min + 1e-5)
17
+ return image
18
+
19
+
20
+ def plot_lr_finder(fig_name, lrs, losses, skip_start=5, skip_end=5):
21
+ if skip_end == 0:
22
+ lrs = lrs[skip_start:]
23
+ losses = losses[skip_start:]
24
+ else:
25
+ lrs = lrs[skip_start:-skip_end]
26
+ losses = losses[skip_start:-skip_end]
27
+
28
+ fig = plt.figure(figsize=(16, 8))
29
+ ax = fig.add_subplot(1, 1, 1)
30
+ ax.plot(lrs, losses)
31
+ ax.set_xscale('log')
32
+ ax.set_xlabel('Learning rate')
33
+ ax.set_ylabel('Loss')
34
+ ax.grid(True, 'both', 'x')
35
+ plt.show()
36
+ plt.savefig(fig_name)
37
+
38
+ def epoch_time(start_time, end_time):
39
+ elapsed_time = end_time - start_time
40
+ elapsed_mins = int(elapsed_time / 60)
41
+ elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
42
+ return elapsed_mins, elapsed_secs
43
+
44
+
45
+ def plot_confusion_matrix(fig_name, labels, pred_labels, classes):
46
+ fig = plt.figure(figsize=(50, 50));
47
+ ax = fig.add_subplot(1, 1, 1);
48
+ cm = confusion_matrix(labels, pred_labels);
49
+ cm = ConfusionMatrixDisplay(cm, display_labels=classes);
50
+ cm.plot(values_format='d', cmap='Blues', ax=ax)
51
+ fig.delaxes(fig.axes[1]) # delete colorbar
52
+ plt.xticks(rotation=90, fontsize=50)
53
+ plt.yticks(fontsize=50)
54
+ plt.rcParams.update({'font.size': 50})
55
+ plt.xlabel('Predicted Label', fontsize=50)
56
+ plt.ylabel('True Label', fontsize=50)
57
+ plt.savefig(fig_name)
58
+
59
+ def plot_confusion_matrix_SVM(fig_name, true_labels, predicted_labels, classes):
60
+ fig = plt.figure(figsize=(100, 100))
61
+ ax = fig.add_subplot(1, 1, 1)
62
+
63
+ cm = confusion_matrix(true_labels, predicted_labels)
64
+ cm_display = ConfusionMatrixDisplay(cm, display_labels=classes)
65
+
66
+ cm_display.plot(values_format='d', cmap='Blues', ax=ax)
67
+
68
+ fig.delaxes(fig.axes[1]) # delete colorbar
69
+ plt.xticks(rotation=90, fontsize=50)
70
+ plt.yticks(fontsize=50)
71
+ plt.rcParams.update({'font.size': 50})
72
+ plt.xlabel('Predicted Label', fontsize=50)
73
+ plt.ylabel('True Label', fontsize=50)
74
+ plt.savefig(fig_name)
75
+
76
+
77
+ def plot_most_incorrect(fig_name, incorrect, classes, n_images, normalize=True):
78
+ rows = int(np.sqrt(n_images))
79
+ cols = int(np.sqrt(n_images))
80
+
81
+ fig = plt.figure(figsize=(25, 20))
82
+
83
+ for i in range(rows * cols):
84
+
85
+ ax = fig.add_subplot(rows, cols, i + 1)
86
+
87
+ image, true_label, probs = incorrect[i]
88
+ image = image.permute(1, 2, 0)
89
+ true_prob = probs[true_label]
90
+ incorrect_prob, incorrect_label = torch.max(probs, dim=0)
91
+ true_class = classes[true_label]
92
+ incorrect_class = classes[incorrect_label]
93
+
94
+ if normalize:
95
+ image = normalize_image(image)
96
+
97
+ ax.imshow(image.cpu().numpy())
98
+ ax.set_title(f'true label: {true_class} ({true_prob:.3f})\n' \
99
+ f'pred label: {incorrect_class} ({incorrect_prob:.3f})')
100
+ ax.axis('off')
101
+
102
+ fig.subplots_adjust(hspace=0.4)
103
+ plt.savefig(fig_name)
104
+
105
+ def get_pca(data, n_components = 2):
106
+ pca = decomposition.PCA()
107
+ pca.n_components = n_components
108
+ pca_data = pca.fit_transform(data)
109
+ return pca_data
110
+
111
+
112
+ def plot_representations(fig_name, data, labels, classes, n_images=None):
113
+ if n_images is not None:
114
+ data = data[:n_images]
115
+ labels = labels[:n_images]
116
+
117
+ fig = plt.figure(figsize=(15, 15))
118
+ ax = fig.add_subplot(111)
119
+ scatter = ax.scatter(data[:, 0], data[:, 1], c=labels, cmap='hsv')
120
+ # handles, _ = scatter.legend_elements(num = None)
121
+ # legend = plt.legend(handles = handles, labels = classes)
122
+ plt.savefig(fig_name)
123
+
124
+
125
+ def plot_filtered_images(fig_name, images, filters, n_filters = None, normalize = True):
126
+
127
+ images = torch.cat([i.unsqueeze(0) for i in images], dim = 0).cpu()
128
+ filters = filters.cpu()
129
+
130
+ if n_filters is not None:
131
+ filters = filters[:n_filters]
132
+
133
+ n_images = images.shape[0]
134
+ n_filters = filters.shape[0]
135
+
136
+ filtered_images = F.conv2d(images, filters)
137
+
138
+ fig = plt.figure(figsize = (30, 30))
139
+
140
+ for i in range(n_images):
141
+
142
+ image = images[i]
143
+
144
+ if normalize:
145
+ image = normalize_image(image)
146
+
147
+ ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters))
148
+ ax.imshow(image.permute(1,2,0).numpy())
149
+ ax.set_title('Original')
150
+ ax.axis('off')
151
+
152
+ for j in range(n_filters):
153
+ image = filtered_images[i][j]
154
+
155
+ if normalize:
156
+ image = normalize_image(image)
157
+
158
+ ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters)+j+1)
159
+ ax.imshow(image.numpy(), cmap = 'bone')
160
+ ax.set_title(f'Filter {j+1}')
161
+ ax.axis('off');
162
+
163
+ fig.subplots_adjust(hspace = -0.7)
164
+ plt.savefig(fig_name)
165
+
166
+
167
+ def plot_filters(fig_name, filters, normalize=True):
168
+ filters = filters.cpu()
169
+
170
+ n_filters = filters.shape[0]
171
+
172
+ rows = int(np.sqrt(n_filters))
173
+ cols = int(np.sqrt(n_filters))
174
+
175
+ fig = plt.figure(figsize=(30, 15))
176
+
177
+ for i in range(rows * cols):
178
+
179
+ image = filters[i]
180
+
181
+ if normalize:
182
+ image = normalize_image(image)
183
+
184
+ ax = fig.add_subplot(rows, cols, i + 1)
185
+ ax.imshow(image.permute(1, 2, 0))
186
+ ax.axis('off')
187
+
188
+ fig.subplots_adjust(wspace=-0.9)
189
+ plt.savefig(fig_name)
190
+
191
+ def plot_tsne(fig_name, all_features, all_labels):
192
+ tsne = TSNE(n_components=2, random_state=42)
193
+ tsne_results = tsne.fit_transform(all_features)
194
+ plt.figure(figsize=(10, 7))
195
+ scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=all_labels, cmap='viridis', s=5)
196
+ plt.colorbar(scatter)
197
+ plt.title('t-SNE Visualization')
198
+ plt.show()
199
+ plt.savefig(fig_name)
200
+
201
+
202
+ def plot_grad_cam(images, cams, predicted_labels, true_labels, classes, path):
203
+ fig, axs = plt.subplots(nrows=2, ncols=len(images), figsize=(20, 10))
204
+
205
+ for i, (img, cam, pred_label, true_label) in enumerate(zip(images, cams, predicted_labels, true_labels)):
206
+ # Display the original image on the top row
207
+ axs[0, i].imshow(img.permute(1,2,0).cpu().numpy())
208
+ pred_class_name = classes[pred_label]
209
+ true_class_name = classes[true_label]
210
+ axs[0, i].set_title(f"Predicted: {pred_class_name}\nTrue: {true_class_name}", fontsize=12)
211
+ axs[0, i].axis('off')
212
+
213
+ # Add label to the leftmost plot
214
+ if i == 0:
215
+ axs[0, i].set_ylabel("Original Image", fontsize=14, rotation=90, labelpad=10)
216
+
217
+ # Convert the original image to grayscale
218
+ grayscale_img = cv2.cvtColor(img.permute(1,2,0).cpu().numpy(), cv2.COLOR_RGB2GRAY)
219
+ grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
220
+
221
+ # Overlay the Grad-CAM heatmap on the grayscale image
222
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
223
+ heatmap = np.float32(heatmap) / 255
224
+ cam_img = heatmap + np.float32(grayscale_img)
225
+ cam_img = cam_img / np.max(cam_img)
226
+
227
+ # Display the Grad-CAM image on the bottom row
228
+ axs[1, i].imshow(cam_img)
229
+ axs[1, i].axis('off')
230
+
231
+ # Add label to the leftmost plot
232
+ if i == 0:
233
+ axs[1, i].set_ylabel("Grad-CAM", fontsize=14, rotation=90, labelpad=10)
234
+
235
+ plt.tight_layout()
236
+ plt.savefig(path)
237
+ plt.close()
238
+
vis_confusion_mtx.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from data_utils.data_tribology import TribologyDataset
7
+ from utils.experiment_utils import get_model, get_prediction
8
+ from utils.arg_utils import get_args
9
+ from utils.visualization_utils import plot_confusion_matrix
10
+
11
+ def generate_confusion_matrix(image_name, model, iterator, device):
12
+ labels, predictions = get_prediction(model, iterator, device)
13
+ plot_confusion_matrix('visualization_results/'+image_name+'_confusion_mtx.png', labels, predictions, classes=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY","SPRUCEWOOD"])
14
+
15
+
16
+ def main(args):
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+
19
+ model = get_model(args)
20
+
21
+ basepath=os.getcwd()
22
+ experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
23
+ if args.model == 'ViT':
24
+ experiment_dir = Path(os.path.join(basepath,'experiments','ViT_H',args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
25
+ checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
26
+ checkpoint_path = checkpoint_dir / f'epoch{str(args.epochs)}.pth'
27
+ model.load_state_dict(torch.load(checkpoint_path))
28
+ model = model.to(device)
29
+
30
+ train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
31
+ test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
32
+ img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
33
+ BATCHSIZE = args.batch_size
34
+ train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
35
+ test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
36
+
37
+ means, stds = train_dataset.get_statistics()
38
+ train_dataset.prepare_transform(means, stds, mode='train')
39
+ test_dataset.prepare_transform(means, stds, mode='test')
40
+
41
+ test_iterator = torch.utils.data.DataLoader(test_dataset,
42
+ batch_size=BATCHSIZE,
43
+ num_workers=4,
44
+ shuffle=False,
45
+ pin_memory=True,
46
+ drop_last=False)
47
+
48
+
49
+ generate_confusion_matrix(args.model, model, test_iterator, device)
50
+
51
+ if __name__ == "__main__":
52
+ args = get_args()
53
+ main(args)
54
+
vote_analysis.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import torch.optim as optim
6
+ import torch.optim.lr_scheduler as lr_scheduler
7
+
8
+ import torch.utils.data as data
9
+
10
+ import numpy as np
11
+ import random
12
+ import tqdm
13
+ import os
14
+ from pathlib import Path
15
+
16
+ from data_utils.data_tribology import TribologyDataset
17
+ from utils.experiment_utils import get_model, get_name, get_logger, train, evaluate, evaluate_vote, evaluate_vote_analysis
18
+ from utils.arg_utils import get_args
19
+
20
+ def main(args):
21
+ '''Reproducibility'''
22
+ SEED = args.seed
23
+ random.seed(SEED)
24
+ np.random.seed(SEED)
25
+ torch.manual_seed(SEED)
26
+ torch.cuda.manual_seed(SEED)
27
+ torch.backends.cudnn.deterministic = True
28
+ torch.backends.cudnn.benchmark = False
29
+
30
+ '''Folder Creation'''
31
+ basepath=os.getcwd()
32
+ experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
33
+ experiment_dir.mkdir(parents=True, exist_ok=True)
34
+ checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
35
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
36
+
37
+ '''Logging'''
38
+ model_name = get_name(args)
39
+ print(model_name, 'STARTED')
40
+
41
+ logger = get_logger(experiment_dir, 'vote_analysis')
42
+
43
+ '''Data Loading'''
44
+ train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
45
+ test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
46
+ img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
47
+
48
+ # results_acc_1 = {}
49
+ # results_acc_3 = {}
50
+ # classes_num = 6
51
+ BATCHSIZE = args.batch_size
52
+ train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
53
+ test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
54
+
55
+ # prepare the data augmentation
56
+ means, stds = train_dataset.get_statistics()
57
+ train_dataset.prepare_transform(means, stds, mode='train')
58
+ test_dataset.prepare_transform(means, stds, mode='test')
59
+
60
+ VALID_RATIO = 0.1
61
+
62
+ num_train = len(train_dataset)
63
+ num_valid = int(VALID_RATIO * num_train)
64
+ train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
65
+ logger.info(f'Number of training samples: {len(train_dataset)}')
66
+ logger.info(f'Number of validation samples: {len(valid_dataset)}')
67
+
68
+ test_iterator = torch.utils.data.DataLoader(test_dataset,
69
+ batch_size=BATCHSIZE,
70
+ num_workers=4,
71
+ shuffle=False,
72
+ pin_memory=True,
73
+ drop_last=False)
74
+ print('DATA LOADED')
75
+
76
+ # Define model
77
+ model = get_model(args)
78
+ print('MODEL LOADED')
79
+
80
+ # Define device
81
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
82
+ model = model.to(device)
83
+
84
+
85
+ print('SETUP DONE')
86
+ # train our model
87
+
88
+ print('TRAINING STARTED')
89
+
90
+ model.load_state_dict(torch.load(checkpoint_dir / f'epoch{args.epochs}.pth'))
91
+ logger.info('-------------------Beginning of Testing-------------------')
92
+ print('TESTING STARTED')
93
+
94
+ vote_accuracy, correct_case_accuracy, incorrect_case_accuracy, incorrect_most_common, novote_accuracy = evaluate_vote_analysis(model, test_iterator, device)
95
+ logger.info(f'Test Acc @1: {vote_accuracy * 100:6.2f}%')
96
+ logger.info(f'No Vote Accuracy @1: {novote_accuracy * 100:6.2f}%')
97
+ logger.info(f'Correct Case Consistency @1: {correct_case_accuracy * 100:6.2f}%')
98
+ logger.info(f'Incorrect Case Consistency @1: {incorrect_case_accuracy * 100:6.2f}%')
99
+ logger.info(f'Incorrect Most Common: {incorrect_most_common* 100:6.2f}%')
100
+
101
+ logger.info('-------------------End of Testing-------------------')
102
+ print('TESTING DONE')
103
+
104
+
105
+ if __name__ == '__main__':
106
+ args = get_args()
107
+ main(args)