Spaces:
Sleeping
Sleeping
DanielXu0208
commited on
Commit
•
785ef2b
0
Parent(s):
Initial commit
Browse files- dl_supervised_pipeline.py +158 -0
- run_gradio.py +55 -0
- svm_pipeline.py +100 -0
- utils/MAE.py +253 -0
- utils/__init__.py +1 -0
- utils/__pycache__/MAE.cpython-311.pyc +0 -0
- utils/__pycache__/MAE.cpython-38.pyc +0 -0
- utils/__pycache__/MAE.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/__init__.cpython-310.pyc:Zone.Identifier +3 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/__init__.cpython-39.pyc:Zone.Identifier +3 -0
- utils/__pycache__/arg_utils.cpython-38.pyc +0 -0
- utils/__pycache__/arg_utils.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/arg_utils.cpython-39.pyc +0 -0
- utils/__pycache__/arg_utils.cpython-39.pyc:Zone.Identifier +3 -0
- utils/__pycache__/experiment_utils.cpython-311.pyc +0 -0
- utils/__pycache__/experiment_utils.cpython-38.pyc +0 -0
- utils/__pycache__/experiment_utils.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/experiment_utils.cpython-39.pyc +0 -0
- utils/__pycache__/experiment_utils.cpython-39.pyc:Zone.Identifier +3 -0
- utils/__pycache__/model_utils.cpython-311.pyc +0 -0
- utils/__pycache__/model_utils.cpython-38.pyc +0 -0
- utils/__pycache__/model_utils.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/util_function.cpython-310.pyc +0 -0
- utils/__pycache__/util_function.cpython-310.pyc:Zone.Identifier +3 -0
- utils/__pycache__/util_function.cpython-311.pyc +0 -0
- utils/__pycache__/util_function.cpython-38.pyc +0 -0
- utils/__pycache__/util_function.cpython-38.pyc:Zone.Identifier +3 -0
- utils/__pycache__/util_function.cpython-39.pyc +0 -0
- utils/__pycache__/util_function.cpython-39.pyc:Zone.Identifier +3 -0
- utils/arg_utils.py +18 -0
- utils/experiment_utils.py +298 -0
- utils/model_utils.py +96 -0
- utils/util_function.py +238 -0
- vis_confusion_mtx.py +54 -0
- vote_analysis.py +107 -0
dl_supervised_pipeline.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code modified from pytorch-image-classification
|
2 |
+
# obtained from https://colab.research.google.com/github/bentrevett/pytorch-image-classification/blob/master/5_resnet.ipynb#scrollTo=4QmwmcXuPuLo
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import torch.optim as optim
|
9 |
+
import torch.optim.lr_scheduler as lr_scheduler
|
10 |
+
|
11 |
+
import torch.utils.data as data
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import random
|
15 |
+
import tqdm
|
16 |
+
import os
|
17 |
+
from pathlib import Path
|
18 |
+
|
19 |
+
from data_utils.data_tribology import TribologyDataset
|
20 |
+
from utils.experiment_utils import get_model, get_name, get_logger, train, evaluate, evaluate_vote
|
21 |
+
from utils.arg_utils import get_args
|
22 |
+
|
23 |
+
def main(args):
|
24 |
+
'''Reproducibility'''
|
25 |
+
SEED = args.seed
|
26 |
+
random.seed(SEED)
|
27 |
+
np.random.seed(SEED)
|
28 |
+
torch.manual_seed(SEED)
|
29 |
+
torch.cuda.manual_seed(SEED)
|
30 |
+
torch.backends.cudnn.deterministic = True
|
31 |
+
torch.backends.cudnn.benchmark = False
|
32 |
+
|
33 |
+
'''Folder Creation'''
|
34 |
+
basepath=os.getcwd()
|
35 |
+
experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
|
36 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
37 |
+
checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
|
38 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
39 |
+
|
40 |
+
'''Logging'''
|
41 |
+
model_name = get_name(args)
|
42 |
+
print(model_name, 'STARTED')
|
43 |
+
if os.path.exists(checkpoint_dir / 'epoch10.pth'):
|
44 |
+
print('CHECKPOINT FOUND')
|
45 |
+
print('TERMINATING TRAINING')
|
46 |
+
return 0 # terminate training if checkpoint exists
|
47 |
+
|
48 |
+
logger = get_logger(experiment_dir, model_name)
|
49 |
+
|
50 |
+
'''Data Loading'''
|
51 |
+
train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
|
52 |
+
test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
|
53 |
+
img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
|
54 |
+
|
55 |
+
# results_acc_1 = {}
|
56 |
+
# results_acc_3 = {}
|
57 |
+
# classes_num = 6
|
58 |
+
BATCHSIZE = args.batch_size
|
59 |
+
train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
|
60 |
+
test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
|
61 |
+
|
62 |
+
# prepare the data augmentation
|
63 |
+
means, stds = train_dataset.get_statistics()
|
64 |
+
train_dataset.prepare_transform(means, stds, mode='train')
|
65 |
+
test_dataset.prepare_transform(means, stds, mode='test')
|
66 |
+
|
67 |
+
VALID_RATIO = 0.1
|
68 |
+
|
69 |
+
num_train = len(train_dataset)
|
70 |
+
num_valid = int(VALID_RATIO * num_train)
|
71 |
+
train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
|
72 |
+
logger.info(f'Number of training samples: {len(train_dataset)}')
|
73 |
+
logger.info(f'Number of validation samples: {len(valid_dataset)}')
|
74 |
+
train_iterator = torch.utils.data.DataLoader(train_dataset,
|
75 |
+
batch_size=BATCHSIZE,
|
76 |
+
num_workers=4,
|
77 |
+
shuffle=True,
|
78 |
+
pin_memory=True,
|
79 |
+
drop_last=False)
|
80 |
+
|
81 |
+
valid_iterator = torch.utils.data.DataLoader(valid_dataset,
|
82 |
+
batch_size=BATCHSIZE,
|
83 |
+
num_workers=4,
|
84 |
+
shuffle=True,
|
85 |
+
pin_memory=True,
|
86 |
+
drop_last=False)
|
87 |
+
test_iterator = torch.utils.data.DataLoader(test_dataset,
|
88 |
+
batch_size=BATCHSIZE,
|
89 |
+
num_workers=4,
|
90 |
+
shuffle=False,
|
91 |
+
pin_memory=True,
|
92 |
+
drop_last=False)
|
93 |
+
print('DATA LOADED')
|
94 |
+
|
95 |
+
# Define model
|
96 |
+
model = get_model(args)
|
97 |
+
print('MODEL LOADED')
|
98 |
+
|
99 |
+
# Define optimizer and scheduler
|
100 |
+
START_LR = args.start_lr
|
101 |
+
optimizer = optim.Adam(model.parameters(), lr=START_LR)
|
102 |
+
STEPS_PER_EPOCH = len(train_iterator)
|
103 |
+
print('STEPS_PER_EPOCH:', STEPS_PER_EPOCH)
|
104 |
+
print('VALIDATION STEPS:', len(valid_iterator))
|
105 |
+
scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=max(STEPS_PER_EPOCH,STEPS_PER_EPOCH//10))
|
106 |
+
|
107 |
+
# Define loss function
|
108 |
+
criterion = nn.CrossEntropyLoss()
|
109 |
+
|
110 |
+
# Define device
|
111 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
112 |
+
model = model.to(device)
|
113 |
+
criterion = criterion.to(device)
|
114 |
+
|
115 |
+
EPOCHS = args.epochs
|
116 |
+
|
117 |
+
print('SETUP DONE')
|
118 |
+
# train our model
|
119 |
+
|
120 |
+
print('TRAINING STARTED')
|
121 |
+
for epoch in tqdm.tqdm(range(EPOCHS)):
|
122 |
+
|
123 |
+
train_loss, train_acc_1, train_acc_3 = train(model, train_iterator, optimizer, criterion, scheduler, device)
|
124 |
+
|
125 |
+
torch.cuda.empty_cache() # clear cache between train and val
|
126 |
+
|
127 |
+
valid_loss, valid_acc_1, valid_acc_3 = evaluate(model, valid_iterator, criterion, device)
|
128 |
+
|
129 |
+
torch.save(model.state_dict(), checkpoint_dir / f'epoch{epoch+1}.pth')
|
130 |
+
|
131 |
+
logger.info(f'Epoch: {epoch + 1:02}')
|
132 |
+
logger.info(f'\tTrain Loss: {train_loss:.3f} | Train Acc @1: {train_acc_1 * 100:6.2f}% | ' \
|
133 |
+
f'Train Acc @3: {train_acc_3 * 100:6.2f}%')
|
134 |
+
logger.info(f'\tValid Loss: {valid_loss:.3f} | Valid Acc @1: {valid_acc_1 * 100:6.2f}% | ' \
|
135 |
+
f'Valid Acc @3: {valid_acc_3 * 100:6.2f}%')
|
136 |
+
|
137 |
+
logger.info('-------------------End of Training-------------------')
|
138 |
+
print('TRAINING DONE')
|
139 |
+
logger.info('-------------------Beginning of Testing-------------------')
|
140 |
+
print('TESTING STARTED')
|
141 |
+
for epoch in tqdm.tqdm(range(EPOCHS)):
|
142 |
+
model.load_state_dict(torch.load(checkpoint_dir / f'epoch{epoch+1}.pth'))
|
143 |
+
|
144 |
+
if args.vote == 'vote':
|
145 |
+
test_acc = evaluate_vote(model, test_iterator, device)
|
146 |
+
logger.info(f'Test Acc @1: {test_acc * 100:6.2f}%')
|
147 |
+
else:
|
148 |
+
test_loss, test_acc_1, test_acc_3 = evaluate(model, test_iterator, criterion, device)
|
149 |
+
|
150 |
+
logger.info(f'Test Acc @1: {test_acc_1 * 100:6.2f}% | ' \
|
151 |
+
f'Test Acc @3: {test_acc_3 * 100:6.2f}%')
|
152 |
+
logger.info('-------------------End of Testing-------------------')
|
153 |
+
print('TESTING DONE')
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == '__main__':
|
157 |
+
args = get_args()
|
158 |
+
main(args)
|
run_gradio.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
from utils.experiment_utils import get_model
|
5 |
+
|
6 |
+
|
7 |
+
# 加载DINOv2模型
|
8 |
+
def load_model():
|
9 |
+
class Args:
|
10 |
+
model = 'DINOv2'
|
11 |
+
pretrained = 'pretrained'
|
12 |
+
frozen = 'unfrozen'
|
13 |
+
|
14 |
+
args = Args()
|
15 |
+
model = get_model(args)
|
16 |
+
model.eval()
|
17 |
+
return model
|
18 |
+
|
19 |
+
|
20 |
+
model = load_model()
|
21 |
+
|
22 |
+
|
23 |
+
# 预测函数,返回每个类别的概率
|
24 |
+
def predict(image):
|
25 |
+
transform = torchvision.transforms.Compose([
|
26 |
+
torchvision.transforms.Resize((224, 224)),
|
27 |
+
torchvision.transforms.ToTensor(),
|
28 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
29 |
+
])
|
30 |
+
|
31 |
+
image = transform(image).unsqueeze(0)
|
32 |
+
with torch.no_grad():
|
33 |
+
output = model(image)
|
34 |
+
probabilities = torch.nn.functional.softmax(output, dim=1).squeeze().tolist()
|
35 |
+
|
36 |
+
# 类别名称列表
|
37 |
+
class_names = ["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"]
|
38 |
+
|
39 |
+
# 将类别和对应的概率配对
|
40 |
+
results = {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
41 |
+
|
42 |
+
return results
|
43 |
+
|
44 |
+
|
45 |
+
# 创建Gradio界面
|
46 |
+
interface = gr.Interface(
|
47 |
+
fn=predict,
|
48 |
+
inputs=gr.Image(type="pil"),
|
49 |
+
outputs=gr.Label(num_top_classes=len(["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY", "SPRUCEWOOD"])),
|
50 |
+
title="LUWA DINOv2 Prediction",
|
51 |
+
description="Upload an image to get the probabilities for each class using the DINOv2 model."
|
52 |
+
)
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
interface.launch(share=True)
|
svm_pipeline.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.svm import LinearSVC
|
3 |
+
|
4 |
+
from skimage.feature import fisher_vector, learn_gmm
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
from data_utils.data_tribology import TribologyDataset
|
12 |
+
from utils.arg_utils import get_args
|
13 |
+
from utils.experiment_utils import get_name, get_logger, SIFT_extraction, conduct_voting
|
14 |
+
from utils.visualization_utils import plot_confusion_matrix
|
15 |
+
from vis_confusion_mtx import generate_confusion_matrix
|
16 |
+
|
17 |
+
def main(args):
|
18 |
+
'''Reproducibility'''
|
19 |
+
SEED = args.seed
|
20 |
+
random.seed(SEED)
|
21 |
+
np.random.seed(SEED)
|
22 |
+
|
23 |
+
'''Folder Creation'''
|
24 |
+
basepath=os.getcwd()
|
25 |
+
experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.vote))
|
26 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
27 |
+
checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
|
28 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
29 |
+
|
30 |
+
'''Logging'''
|
31 |
+
model_name = get_name(args)
|
32 |
+
print(model_name, 'STARTED', flush=True)
|
33 |
+
logger = get_logger(experiment_dir, model_name)
|
34 |
+
|
35 |
+
'''Data Loading'''
|
36 |
+
train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
|
37 |
+
test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
|
38 |
+
img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
|
39 |
+
|
40 |
+
BATCHSIZE = args.batch_size
|
41 |
+
train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
|
42 |
+
test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
|
43 |
+
|
44 |
+
# prepare the data augmentation
|
45 |
+
means, stds = train_dataset.get_statistics()
|
46 |
+
train_dataset.prepare_transform(means, stds, mode='train')
|
47 |
+
test_dataset.prepare_transform(means, stds, mode='test')
|
48 |
+
|
49 |
+
VALID_RATIO = 0.1
|
50 |
+
|
51 |
+
num_train = len(train_dataset)
|
52 |
+
num_valid = int(VALID_RATIO * num_train)
|
53 |
+
# train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
|
54 |
+
# logger.info(f'Number of training samples: {len(train_dataset)}')
|
55 |
+
# logger.info(f'Number of validation samples: {len(valid_dataset)}')
|
56 |
+
|
57 |
+
train_names, train_descriptor, train_labels = SIFT_extraction(train_dataset)
|
58 |
+
test_names, test_descriptor, test_labels = SIFT_extraction(test_dataset)
|
59 |
+
# val_descriptor, val_labels = SIFT_extraction(valid_dataset)
|
60 |
+
print('DATA LOADED', flush=True)
|
61 |
+
|
62 |
+
print('TRAINING STARTED', flush=True)
|
63 |
+
|
64 |
+
# Train a K-mode GMM
|
65 |
+
k = 16
|
66 |
+
gmm = learn_gmm(train_descriptor, n_modes=k)
|
67 |
+
|
68 |
+
# Compute the Fisher vectors
|
69 |
+
training_fvs = np.array([
|
70 |
+
fisher_vector(descriptor_mat, gmm)
|
71 |
+
for descriptor_mat in train_descriptor
|
72 |
+
])
|
73 |
+
|
74 |
+
testing_fvs = np.array([
|
75 |
+
fisher_vector(descriptor_mat, gmm)
|
76 |
+
for descriptor_mat in test_descriptor
|
77 |
+
])
|
78 |
+
|
79 |
+
svm = LinearSVC().fit(training_fvs, train_labels)
|
80 |
+
|
81 |
+
logger.info('-------------------End of Training-------------------')
|
82 |
+
print('TRAINING DONE')
|
83 |
+
logger.info('-------------------Beginning of Testing-------------------')
|
84 |
+
print('TESTING STARTED')
|
85 |
+
predictions = svm.predict(testing_fvs)
|
86 |
+
conduct_voting(test_names, predictions)
|
87 |
+
plot_confusion_matrix('visualization_results/SIFT+FVs_confusion_mtx.png', predictions, test_labels,classes=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY","SPRUCEWOOD"])
|
88 |
+
correct = 0
|
89 |
+
for i in range(len(predictions)):
|
90 |
+
if predictions[i] == test_labels[i]:
|
91 |
+
correct += 1
|
92 |
+
test_acc = float(correct)/len(predictions)
|
93 |
+
logger.info(f'Test Acc @1: {test_acc * 100:6.2f}%')
|
94 |
+
|
95 |
+
logger.info('-------------------End of Testing-------------------')
|
96 |
+
print('TESTING DONE')
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
args = get_args()
|
100 |
+
main(args)
|
utils/MAE.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
10 |
+
# --------------------------------------------------------
|
11 |
+
|
12 |
+
from functools import partial
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
|
17 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
18 |
+
|
19 |
+
from utils.model_utils import get_2d_sincos_pos_embed
|
20 |
+
|
21 |
+
|
22 |
+
class MaskedAutoencoderViT(nn.Module):
|
23 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
24 |
+
"""
|
25 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3,
|
26 |
+
embed_dim=1024, depth=24, num_heads=16,
|
27 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
28 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
# --------------------------------------------------------------------------
|
32 |
+
# MAE encoder specifics
|
33 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
34 |
+
num_patches = self.patch_embed.num_patches
|
35 |
+
|
36 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
37 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
38 |
+
|
39 |
+
self.blocks = nn.ModuleList([
|
40 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
41 |
+
for i in range(depth)])
|
42 |
+
self.norm = norm_layer(embed_dim)
|
43 |
+
# --------------------------------------------------------------------------
|
44 |
+
|
45 |
+
# --------------------------------------------------------------------------
|
46 |
+
# MAE decoder specifics
|
47 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
48 |
+
|
49 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
50 |
+
|
51 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
52 |
+
|
53 |
+
self.decoder_blocks = nn.ModuleList([
|
54 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
55 |
+
for i in range(decoder_depth)])
|
56 |
+
|
57 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
58 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
|
59 |
+
# --------------------------------------------------------------------------
|
60 |
+
|
61 |
+
self.norm_pix_loss = norm_pix_loss
|
62 |
+
|
63 |
+
self.initialize_weights()
|
64 |
+
|
65 |
+
def initialize_weights(self):
|
66 |
+
# initialization
|
67 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
68 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
69 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
70 |
+
|
71 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
72 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
73 |
+
|
74 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
75 |
+
w = self.patch_embed.proj.weight.data
|
76 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
77 |
+
|
78 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
79 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
80 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
81 |
+
|
82 |
+
# initialize nn.Linear and nn.LayerNorm
|
83 |
+
self.apply(self._init_weights)
|
84 |
+
|
85 |
+
def _init_weights(self, m):
|
86 |
+
if isinstance(m, nn.Linear):
|
87 |
+
# we use xavier_uniform following official JAX ViT:
|
88 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
89 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
90 |
+
nn.init.constant_(m.bias, 0)
|
91 |
+
elif isinstance(m, nn.LayerNorm):
|
92 |
+
nn.init.constant_(m.bias, 0)
|
93 |
+
nn.init.constant_(m.weight, 1.0)
|
94 |
+
|
95 |
+
def patchify(self, imgs):
|
96 |
+
"""
|
97 |
+
imgs: (N, 3, H, W)
|
98 |
+
x: (N, L, patch_size**2 *3)
|
99 |
+
"""
|
100 |
+
p = self.patch_embed.patch_size[0]
|
101 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
102 |
+
|
103 |
+
h = w = imgs.shape[2] // p
|
104 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
105 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
106 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
107 |
+
return x
|
108 |
+
|
109 |
+
def unpatchify(self, x):
|
110 |
+
"""
|
111 |
+
x: (N, L, patch_size**2 *3)
|
112 |
+
imgs: (N, 3, H, W)
|
113 |
+
"""
|
114 |
+
p = self.patch_embed.patch_size[0]
|
115 |
+
h = w = int(x.shape[1]**.5)
|
116 |
+
assert h * w == x.shape[1]
|
117 |
+
|
118 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
119 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
120 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
121 |
+
return imgs
|
122 |
+
|
123 |
+
def random_masking(self, x, mask_ratio):
|
124 |
+
"""
|
125 |
+
Perform per-sample random masking by per-sample shuffling.
|
126 |
+
Per-sample shuffling is done by argsort random noise.
|
127 |
+
x: [N, L, D], sequence
|
128 |
+
"""
|
129 |
+
N, L, D = x.shape # batch, length, dim
|
130 |
+
len_keep = int(L * (1 - mask_ratio))
|
131 |
+
|
132 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
133 |
+
|
134 |
+
# sort noise for each sample
|
135 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
136 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
137 |
+
|
138 |
+
# keep the first subset
|
139 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
140 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
141 |
+
|
142 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
143 |
+
mask = torch.ones([N, L], device=x.device)
|
144 |
+
mask[:, :len_keep] = 0
|
145 |
+
# unshuffle to get the binary mask
|
146 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
147 |
+
|
148 |
+
return x_masked, mask, ids_restore
|
149 |
+
|
150 |
+
def forward_encoder(self, x, mask_ratio):
|
151 |
+
# embed patches
|
152 |
+
x = self.patch_embed(x)
|
153 |
+
|
154 |
+
# add pos embed w/o cls token
|
155 |
+
x = x + self.pos_embed[:, 1:, :]
|
156 |
+
|
157 |
+
# masking: length -> length * mask_ratio
|
158 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
159 |
+
|
160 |
+
# append cls token
|
161 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
162 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
163 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
164 |
+
|
165 |
+
# apply Transformer blocks
|
166 |
+
for blk in self.blocks:
|
167 |
+
x = blk(x)
|
168 |
+
x = self.norm(x)
|
169 |
+
|
170 |
+
return x, mask, ids_restore
|
171 |
+
|
172 |
+
def forward_decoder(self, x, ids_restore):
|
173 |
+
# embed tokens
|
174 |
+
x = self.decoder_embed(x)
|
175 |
+
|
176 |
+
# append mask tokens to sequence
|
177 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
178 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
179 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
180 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
181 |
+
|
182 |
+
# add pos embed
|
183 |
+
x = x + self.decoder_pos_embed
|
184 |
+
|
185 |
+
# apply Transformer blocks
|
186 |
+
for blk in self.decoder_blocks:
|
187 |
+
x = blk(x)
|
188 |
+
x = self.decoder_norm(x)
|
189 |
+
|
190 |
+
# predictor projection
|
191 |
+
x = self.decoder_pred(x)
|
192 |
+
|
193 |
+
# remove cls token
|
194 |
+
x = x[:, 1:, :]
|
195 |
+
|
196 |
+
return x
|
197 |
+
|
198 |
+
def forward_loss(self, imgs, pred, mask):
|
199 |
+
"""
|
200 |
+
imgs: [N, 3, H, W]
|
201 |
+
pred: [N, L, p*p*3]
|
202 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
203 |
+
"""
|
204 |
+
target = self.patchify(imgs)
|
205 |
+
if self.norm_pix_loss:
|
206 |
+
mean = target.mean(dim=-1, keepdim=True)
|
207 |
+
var = target.var(dim=-1, keepdim=True)
|
208 |
+
target = (target - mean) / (var + 1.e-6)**.5
|
209 |
+
|
210 |
+
loss = (pred - target) ** 2
|
211 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
212 |
+
|
213 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
214 |
+
return loss
|
215 |
+
|
216 |
+
def forward(self, imgs, mask_ratio=0.75):
|
217 |
+
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
|
218 |
+
# pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
219 |
+
# loss = self.forward_loss(imgs, pred, mask)
|
220 |
+
# return loss, pred, mask
|
221 |
+
print(latent.shape)
|
222 |
+
return latent
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
def mae_vit_base_patch16_dec512d8b(**kwargs):
|
227 |
+
model = MaskedAutoencoderViT(
|
228 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
229 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
230 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
231 |
+
return model
|
232 |
+
|
233 |
+
|
234 |
+
def mae_vit_large_patch16_dec512d8b(**kwargs):
|
235 |
+
model = MaskedAutoencoderViT(
|
236 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
237 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
238 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
239 |
+
return model
|
240 |
+
|
241 |
+
|
242 |
+
def mae_vit_huge_patch14_dec512d8b(**kwargs):
|
243 |
+
model = MaskedAutoencoderViT(
|
244 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
|
245 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
246 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
247 |
+
return model
|
248 |
+
|
249 |
+
|
250 |
+
# set recommended archs
|
251 |
+
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
|
252 |
+
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
|
253 |
+
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .util_function import epoch_time, plot_lr_finder, plot_confusion_matrix, plot_most_incorrect, get_pca, plot_representations, plot_filtered_images, plot_filters
|
utils/__pycache__/MAE.cpython-311.pyc
ADDED
Binary file (14 kB). View file
|
|
utils/__pycache__/MAE.cpython-38.pyc
ADDED
Binary file (7.16 kB). View file
|
|
utils/__pycache__/MAE.cpython-38.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (422 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-310.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (516 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (376 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-38.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (392 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-39.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/arg_utils.cpython-38.pyc
ADDED
Binary file (1.03 kB). View file
|
|
utils/__pycache__/arg_utils.cpython-38.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/arg_utils.cpython-39.pyc
ADDED
Binary file (1.05 kB). View file
|
|
utils/__pycache__/arg_utils.cpython-39.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/experiment_utils.cpython-311.pyc
ADDED
Binary file (12.9 kB). View file
|
|
utils/__pycache__/experiment_utils.cpython-38.pyc
ADDED
Binary file (5.71 kB). View file
|
|
utils/__pycache__/experiment_utils.cpython-38.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/experiment_utils.cpython-39.pyc
ADDED
Binary file (4.93 kB). View file
|
|
utils/__pycache__/experiment_utils.cpython-39.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/model_utils.cpython-311.pyc
ADDED
Binary file (4.24 kB). View file
|
|
utils/__pycache__/model_utils.cpython-38.pyc
ADDED
Binary file (2.4 kB). View file
|
|
utils/__pycache__/model_utils.cpython-38.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/util_function.cpython-310.pyc
ADDED
Binary file (5.35 kB). View file
|
|
utils/__pycache__/util_function.cpython-310.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/util_function.cpython-311.pyc
ADDED
Binary file (14.6 kB). View file
|
|
utils/__pycache__/util_function.cpython-38.pyc
ADDED
Binary file (6.8 kB). View file
|
|
utils/__pycache__/util_function.cpython-38.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/__pycache__/util_function.cpython-39.pyc
ADDED
Binary file (6.82 kB). View file
|
|
utils/__pycache__/util_function.cpython-39.pyc:Zone.Identifier
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[ZoneTransfer]
|
2 |
+
ZoneId=3
|
3 |
+
ReferrerUrl=C:\Users\Daniel\Desktop\LUWA-main.zip
|
utils/arg_utils.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
def get_args():
|
4 |
+
# Training settings
|
5 |
+
parser = argparse.ArgumentParser('train')
|
6 |
+
|
7 |
+
parser.add_argument('--resolution', type=str, default='256', help='Resolution of input image')
|
8 |
+
parser.add_argument('--magnification', type=str, default='20x', help='Magnification of input image')
|
9 |
+
parser.add_argument('--modality', type=str, default='texture', help='Modality of input image')
|
10 |
+
parser.add_argument('--model', type=str, default='ResNet50', help='Model to use')
|
11 |
+
parser.add_argument('--pretrained', type=str, default='pretrained', help='Use pretrained model')
|
12 |
+
parser.add_argument('--frozen', type=str, default='unfrozen', help='Freeze pretrained model')
|
13 |
+
parser.add_argument('--vote', type=str, default='vote', help='Conduct voting')
|
14 |
+
parser.add_argument('--epochs', type=int, default=2, help='Number of epochs to train')
|
15 |
+
parser.add_argument('--batch_size', type=int, default=100, help='Batch size')
|
16 |
+
parser.add_argument('--start_lr', type=float, default=0.01, help='Learning rate')
|
17 |
+
parser.add_argument('--seed', type=int, default=1234, help='Random seed')
|
18 |
+
return parser.parse_args()
|
utils/experiment_utils.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import logging
|
6 |
+
from collections import Counter
|
7 |
+
from utils.MAE import mae_vit_large_patch16_dec512d8b as MAE_large
|
8 |
+
|
9 |
+
def get_model(args) -> nn.Module:
|
10 |
+
if 'ResNet' in args.model:
|
11 |
+
# resnet family
|
12 |
+
if args.model == 'ResNet50':
|
13 |
+
if args.pretrained == 'pretrained':
|
14 |
+
model = torchvision.models.resnet50(weights='IMAGENET1K_V2')
|
15 |
+
else:
|
16 |
+
model = torchvision.models.resnet50()
|
17 |
+
elif args.model == 'ResNet152':
|
18 |
+
if args.pretrained == 'pretrained':
|
19 |
+
model = torchvision.models.resnet152(weights='IMAGENET1K_V2')
|
20 |
+
else:
|
21 |
+
model = torchvision.models.resnet152()
|
22 |
+
else:
|
23 |
+
raise NotImplementedError
|
24 |
+
if args.frozen == 'frozen':
|
25 |
+
model = freeze_backbone(model)
|
26 |
+
model.fc = nn.Linear(model.fc.in_features, 6)
|
27 |
+
|
28 |
+
elif 'ConvNext' in args.model:
|
29 |
+
if args.model == 'ConvNext_Tiny':
|
30 |
+
if args.pretrained == 'pretrained':
|
31 |
+
model = torchvision.models.convnext_tiny(weights='IMAGENET1K_V1')
|
32 |
+
else:
|
33 |
+
model = torchvision.models.convnext_tiny()
|
34 |
+
elif args.model == 'ConvNext_Large':
|
35 |
+
if args.pretrained == 'pretrained':
|
36 |
+
model = torchvision.models.convnext_large(weights='IMAGENET1K_V1')
|
37 |
+
else:
|
38 |
+
model = torchvision.models.convnext_large()
|
39 |
+
else:
|
40 |
+
raise NotImplementedError
|
41 |
+
if args.frozen == 'frozen':
|
42 |
+
model = freeze_backbone(model)
|
43 |
+
num_ftrs = model.classifier[2].in_features
|
44 |
+
model.classifier[2] = nn.Linear(int(num_ftrs), 6)
|
45 |
+
|
46 |
+
elif 'ViT' in args.model:
|
47 |
+
if args.pretrained == 'pretrained':
|
48 |
+
model = torchvision.models.vit_h_14(weights='IMAGENET1K_SWAG_LINEAR_V1')
|
49 |
+
else:
|
50 |
+
raise NotImplementedError('ViT does not support training from scratch')
|
51 |
+
if args.frozen == 'frozen':
|
52 |
+
model = freeze_backbone(model)
|
53 |
+
model.heads[0] = torch.nn.Linear(model.heads[0].in_features, 6)
|
54 |
+
|
55 |
+
elif 'DINOv2' in args.model:
|
56 |
+
if args.pretrained == 'pretrained':
|
57 |
+
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc')
|
58 |
+
else:
|
59 |
+
raise NotImplementedError('DINOv2 does not support training from scratch')
|
60 |
+
if args.frozen == 'frozen':
|
61 |
+
model = freeze_backbone(model)
|
62 |
+
model.linear_head = torch.nn.Linear(model.linear_head.in_features, 6)
|
63 |
+
|
64 |
+
elif 'MAE' in args.model:
|
65 |
+
if args.pretrained == 'pretrained':
|
66 |
+
model = MAE_large()
|
67 |
+
model.load_state_dict(torch.load('/scratch/zf540/LUWA/workspace/utils/pretrained_weights/mae_visualize_vit_large.pth')['model'])
|
68 |
+
else:
|
69 |
+
raise NotImplementedError('MAE does not support training from scratch')
|
70 |
+
if args.frozen == 'frozen':
|
71 |
+
model = freeze_backbone(model)
|
72 |
+
model = nn.Sequential(model, nn.Linear(1024, 6))
|
73 |
+
print(model)
|
74 |
+
else:
|
75 |
+
raise NotImplementedError
|
76 |
+
return model
|
77 |
+
|
78 |
+
|
79 |
+
def freeze_backbone(model):
|
80 |
+
# freeze backbone
|
81 |
+
# we will replace the classifier at the end with a trainable one anyway, so we freeze the default here as well
|
82 |
+
for param in model.parameters():
|
83 |
+
param.requires_grad = False
|
84 |
+
return model
|
85 |
+
|
86 |
+
def get_name(args):
|
87 |
+
name = args.model
|
88 |
+
name += '_'+str(args.resolution)
|
89 |
+
name += '_'+args.magnification
|
90 |
+
name += '_'+args.modality
|
91 |
+
if args.pretrained == 'pretrained':
|
92 |
+
name += '_pretrained'
|
93 |
+
else:
|
94 |
+
name += '_scratch'
|
95 |
+
if args.frozen == 'frozen':
|
96 |
+
name += '_frozen'
|
97 |
+
else:
|
98 |
+
name += '_unfrozen'
|
99 |
+
if args.vote == 'vote':
|
100 |
+
name += '_vote'
|
101 |
+
else:
|
102 |
+
name += '_novote'
|
103 |
+
return name
|
104 |
+
|
105 |
+
def get_logger(path, name):
|
106 |
+
# set up logger
|
107 |
+
|
108 |
+
logger = logging.getLogger(name)
|
109 |
+
logger.setLevel(logging.INFO)
|
110 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
111 |
+
file_handler = logging.FileHandler(path.joinpath(f'{name}_log.txt'))
|
112 |
+
file_handler.setLevel(logging.INFO)
|
113 |
+
file_handler.setFormatter(formatter)
|
114 |
+
logger.addHandler(file_handler)
|
115 |
+
logger.info('---------------------------------------------------TRANING---------------------------------------------------')
|
116 |
+
|
117 |
+
return logger
|
118 |
+
|
119 |
+
def calculate_topk_accuracy(y_pred, y, k = 3):
|
120 |
+
with torch.no_grad():
|
121 |
+
batch_size = y.shape[0]
|
122 |
+
_, top_pred = y_pred.topk(k, 1)
|
123 |
+
top_pred = top_pred.t()
|
124 |
+
correct = top_pred.eq(y.view(1, -1).expand_as(top_pred))
|
125 |
+
correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True)
|
126 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim = True)
|
127 |
+
acc_1 = correct_1 / batch_size
|
128 |
+
acc_k = correct_k / batch_size
|
129 |
+
return acc_1, acc_k
|
130 |
+
|
131 |
+
def train(model, iterator, optimizer, criterion, scheduler, device):
|
132 |
+
epoch_loss = 0
|
133 |
+
epoch_acc_1 = 0
|
134 |
+
epoch_acc_3 = 0
|
135 |
+
|
136 |
+
model.train()
|
137 |
+
|
138 |
+
for image, label, image_name in iterator:
|
139 |
+
x = image.to(device)
|
140 |
+
y = label.to(device)
|
141 |
+
|
142 |
+
optimizer.zero_grad()
|
143 |
+
|
144 |
+
y_pred = model(x)
|
145 |
+
print(y_pred.shape)
|
146 |
+
print(y.shape)
|
147 |
+
loss = criterion(y_pred, y)
|
148 |
+
|
149 |
+
acc_1, acc_3 = calculate_topk_accuracy(y_pred, y)
|
150 |
+
|
151 |
+
loss.backward()
|
152 |
+
|
153 |
+
optimizer.step()
|
154 |
+
|
155 |
+
scheduler.step()
|
156 |
+
|
157 |
+
epoch_loss += loss.item()
|
158 |
+
epoch_acc_1 += acc_1.item()
|
159 |
+
epoch_acc_3 += acc_3.item()
|
160 |
+
|
161 |
+
epoch_loss /= len(iterator)
|
162 |
+
epoch_acc_1 /= len(iterator)
|
163 |
+
epoch_acc_3 /= len(iterator)
|
164 |
+
|
165 |
+
return epoch_loss, epoch_acc_1, epoch_acc_3
|
166 |
+
|
167 |
+
|
168 |
+
def evaluate(model, iterator, criterion, device):
|
169 |
+
epoch_loss = 0
|
170 |
+
epoch_acc_1 = 0
|
171 |
+
epoch_acc_3 = 0
|
172 |
+
|
173 |
+
model.eval()
|
174 |
+
|
175 |
+
with torch.no_grad():
|
176 |
+
for image, label, image_name in iterator:
|
177 |
+
x = image.to(device)
|
178 |
+
y = label.to(device)
|
179 |
+
|
180 |
+
y_pred = model(x)
|
181 |
+
loss = criterion(y_pred, y)
|
182 |
+
|
183 |
+
acc_1, acc_3 = calculate_topk_accuracy(y_pred, y)
|
184 |
+
|
185 |
+
epoch_loss += loss.item()
|
186 |
+
epoch_acc_1 += acc_1.item()
|
187 |
+
epoch_acc_3 += acc_3.item()
|
188 |
+
|
189 |
+
epoch_loss /= len(iterator)
|
190 |
+
epoch_acc_1 /= len(iterator)
|
191 |
+
epoch_acc_3 /= len(iterator)
|
192 |
+
|
193 |
+
return epoch_loss, epoch_acc_1, epoch_acc_3
|
194 |
+
|
195 |
+
def evaluate_vote(model, iterator, device):
|
196 |
+
|
197 |
+
model.eval()
|
198 |
+
|
199 |
+
image_names = []
|
200 |
+
labels = []
|
201 |
+
predictions = []
|
202 |
+
|
203 |
+
with torch.no_grad():
|
204 |
+
|
205 |
+
for image, label, image_name in iterator:
|
206 |
+
|
207 |
+
x = image.to(device)
|
208 |
+
|
209 |
+
y_pred = model(x)
|
210 |
+
y_prob = F.softmax(y_pred, dim = -1)
|
211 |
+
top_pred = y_prob.argmax(1, keepdim = True)
|
212 |
+
|
213 |
+
image_names.extend(image_name)
|
214 |
+
labels.extend(label.numpy())
|
215 |
+
predictions.extend(top_pred.cpu().squeeze().numpy())
|
216 |
+
|
217 |
+
conduct_voting(image_names, predictions)
|
218 |
+
|
219 |
+
correct_count = 0
|
220 |
+
for i in range(len(labels)):
|
221 |
+
if labels[i] == predictions[i]:
|
222 |
+
correct_count += 1
|
223 |
+
accuracy = correct_count/len(labels)
|
224 |
+
return accuracy
|
225 |
+
|
226 |
+
def conduct_voting(image_names, predictions):
|
227 |
+
# we need to do this because not all stones have the same number of partition
|
228 |
+
last_stone = image_names[0][:-8] # the name of the stone of the last image
|
229 |
+
voting_list = []
|
230 |
+
for i in range(len(image_names)):
|
231 |
+
image_area_name = image_names[i][:-8]
|
232 |
+
if image_area_name != last_stone:
|
233 |
+
# we have run through all the images of the last stone. We start voting
|
234 |
+
vote(voting_list, predictions, i)
|
235 |
+
voting_list = [] # reset the voting list
|
236 |
+
voting_list.append(predictions[i])
|
237 |
+
last_stone = image_area_name # update the last stone name
|
238 |
+
|
239 |
+
# vote for the last stone
|
240 |
+
vote(voting_list, predictions, len(image_names))
|
241 |
+
|
242 |
+
def vote(voting_list, predictions, i):
|
243 |
+
vote_result = Counter(voting_list).most_common(1)[0][0] # the most common prediction in the list
|
244 |
+
predictions[i-len(voting_list):i] = [vote_result]*len(voting_list) # replace the predictions of the last stone with the vote result
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
# def get_predictions(model, iterator):
|
250 |
+
|
251 |
+
# model.eval()
|
252 |
+
|
253 |
+
# images = []
|
254 |
+
# labels = []
|
255 |
+
# probs = []
|
256 |
+
|
257 |
+
# with torch.no_grad():
|
258 |
+
|
259 |
+
# for (x, y) in iterator:
|
260 |
+
|
261 |
+
# x = x.to(device)
|
262 |
+
|
263 |
+
# y_pred = model(x)
|
264 |
+
|
265 |
+
# y_prob = F.softmax(y_pred, dim = -1)
|
266 |
+
# top_pred = y_prob.argmax(1, keepdim = True)
|
267 |
+
|
268 |
+
# images.append(x.cpu())
|
269 |
+
# labels.append(y.cpu())
|
270 |
+
# probs.append(y_prob.cpu())
|
271 |
+
|
272 |
+
# images = torch.cat(images, dim = 0)
|
273 |
+
# labels = torch.cat(labels, dim = 0)
|
274 |
+
# probs = torch.cat(probs, dim = 0)
|
275 |
+
|
276 |
+
# return images, labels, probs
|
277 |
+
|
278 |
+
|
279 |
+
# def get_representations(model, iterator):
|
280 |
+
# model.eval()
|
281 |
+
|
282 |
+
# outputs = []
|
283 |
+
# intermediates = []
|
284 |
+
# labels = []
|
285 |
+
|
286 |
+
# with torch.no_grad():
|
287 |
+
# for (x, y) in iterator:
|
288 |
+
# x = x.to(device)
|
289 |
+
|
290 |
+
# y_pred = model(x)
|
291 |
+
|
292 |
+
# outputs.append(y_pred.cpu())
|
293 |
+
# labels.append(y)
|
294 |
+
|
295 |
+
# outputs = torch.cat(outputs, dim=0)
|
296 |
+
# labels = torch.cat(labels, dim=0)
|
297 |
+
|
298 |
+
# return outputs, labels
|
utils/model_utils.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# Position embedding utils
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# 2D sine-cosine position embedding
|
16 |
+
# References:
|
17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
19 |
+
# --------------------------------------------------------
|
20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
21 |
+
"""
|
22 |
+
grid_size: int of the grid height and width
|
23 |
+
return:
|
24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
25 |
+
"""
|
26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
29 |
+
grid = np.stack(grid, axis=0)
|
30 |
+
|
31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
33 |
+
if cls_token:
|
34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
35 |
+
return pos_embed
|
36 |
+
|
37 |
+
|
38 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
39 |
+
assert embed_dim % 2 == 0
|
40 |
+
|
41 |
+
# use half of dimensions to encode grid_h
|
42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
44 |
+
|
45 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
46 |
+
return emb
|
47 |
+
|
48 |
+
|
49 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
50 |
+
"""
|
51 |
+
embed_dim: output dimension for each position
|
52 |
+
pos: a list of positions to be encoded: size (M,)
|
53 |
+
out: (M, D)
|
54 |
+
"""
|
55 |
+
assert embed_dim % 2 == 0
|
56 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
57 |
+
omega /= embed_dim / 2.
|
58 |
+
omega = 1. / 10000**omega # (D/2,)
|
59 |
+
|
60 |
+
pos = pos.reshape(-1) # (M,)
|
61 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
62 |
+
|
63 |
+
emb_sin = np.sin(out) # (M, D/2)
|
64 |
+
emb_cos = np.cos(out) # (M, D/2)
|
65 |
+
|
66 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
67 |
+
return emb
|
68 |
+
|
69 |
+
|
70 |
+
# --------------------------------------------------------
|
71 |
+
# Interpolate position embeddings for high-resolution
|
72 |
+
# References:
|
73 |
+
# DeiT: https://github.com/facebookresearch/deit
|
74 |
+
# --------------------------------------------------------
|
75 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
76 |
+
if 'pos_embed' in checkpoint_model:
|
77 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
78 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
79 |
+
num_patches = model.patch_embed.num_patches
|
80 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
81 |
+
# height (== width) for the checkpoint position embedding
|
82 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
83 |
+
# height (== width) for the new position embedding
|
84 |
+
new_size = int(num_patches ** 0.5)
|
85 |
+
# class_token and dist_token are kept unchanged
|
86 |
+
if orig_size != new_size:
|
87 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
88 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
89 |
+
# only the position tokens are interpolated
|
90 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
91 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
92 |
+
pos_tokens = torch.nn.functional.interpolate(
|
93 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
94 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
95 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
96 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
utils/util_function.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from sklearn.manifold import TSNE
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from sklearn.metrics import confusion_matrix
|
8 |
+
from sklearn.metrics import ConfusionMatrixDisplay
|
9 |
+
from sklearn import decomposition
|
10 |
+
import itertools
|
11 |
+
|
12 |
+
def normalize_image(image):
|
13 |
+
image_min = image.min()
|
14 |
+
image_max = image.max()
|
15 |
+
image.clamp_(min = image_min, max = image_max)
|
16 |
+
image.add_(-image_min).div_(image_max - image_min + 1e-5)
|
17 |
+
return image
|
18 |
+
|
19 |
+
|
20 |
+
def plot_lr_finder(fig_name, lrs, losses, skip_start=5, skip_end=5):
|
21 |
+
if skip_end == 0:
|
22 |
+
lrs = lrs[skip_start:]
|
23 |
+
losses = losses[skip_start:]
|
24 |
+
else:
|
25 |
+
lrs = lrs[skip_start:-skip_end]
|
26 |
+
losses = losses[skip_start:-skip_end]
|
27 |
+
|
28 |
+
fig = plt.figure(figsize=(16, 8))
|
29 |
+
ax = fig.add_subplot(1, 1, 1)
|
30 |
+
ax.plot(lrs, losses)
|
31 |
+
ax.set_xscale('log')
|
32 |
+
ax.set_xlabel('Learning rate')
|
33 |
+
ax.set_ylabel('Loss')
|
34 |
+
ax.grid(True, 'both', 'x')
|
35 |
+
plt.show()
|
36 |
+
plt.savefig(fig_name)
|
37 |
+
|
38 |
+
def epoch_time(start_time, end_time):
|
39 |
+
elapsed_time = end_time - start_time
|
40 |
+
elapsed_mins = int(elapsed_time / 60)
|
41 |
+
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
|
42 |
+
return elapsed_mins, elapsed_secs
|
43 |
+
|
44 |
+
|
45 |
+
def plot_confusion_matrix(fig_name, labels, pred_labels, classes):
|
46 |
+
fig = plt.figure(figsize=(50, 50));
|
47 |
+
ax = fig.add_subplot(1, 1, 1);
|
48 |
+
cm = confusion_matrix(labels, pred_labels);
|
49 |
+
cm = ConfusionMatrixDisplay(cm, display_labels=classes);
|
50 |
+
cm.plot(values_format='d', cmap='Blues', ax=ax)
|
51 |
+
fig.delaxes(fig.axes[1]) # delete colorbar
|
52 |
+
plt.xticks(rotation=90, fontsize=50)
|
53 |
+
plt.yticks(fontsize=50)
|
54 |
+
plt.rcParams.update({'font.size': 50})
|
55 |
+
plt.xlabel('Predicted Label', fontsize=50)
|
56 |
+
plt.ylabel('True Label', fontsize=50)
|
57 |
+
plt.savefig(fig_name)
|
58 |
+
|
59 |
+
def plot_confusion_matrix_SVM(fig_name, true_labels, predicted_labels, classes):
|
60 |
+
fig = plt.figure(figsize=(100, 100))
|
61 |
+
ax = fig.add_subplot(1, 1, 1)
|
62 |
+
|
63 |
+
cm = confusion_matrix(true_labels, predicted_labels)
|
64 |
+
cm_display = ConfusionMatrixDisplay(cm, display_labels=classes)
|
65 |
+
|
66 |
+
cm_display.plot(values_format='d', cmap='Blues', ax=ax)
|
67 |
+
|
68 |
+
fig.delaxes(fig.axes[1]) # delete colorbar
|
69 |
+
plt.xticks(rotation=90, fontsize=50)
|
70 |
+
plt.yticks(fontsize=50)
|
71 |
+
plt.rcParams.update({'font.size': 50})
|
72 |
+
plt.xlabel('Predicted Label', fontsize=50)
|
73 |
+
plt.ylabel('True Label', fontsize=50)
|
74 |
+
plt.savefig(fig_name)
|
75 |
+
|
76 |
+
|
77 |
+
def plot_most_incorrect(fig_name, incorrect, classes, n_images, normalize=True):
|
78 |
+
rows = int(np.sqrt(n_images))
|
79 |
+
cols = int(np.sqrt(n_images))
|
80 |
+
|
81 |
+
fig = plt.figure(figsize=(25, 20))
|
82 |
+
|
83 |
+
for i in range(rows * cols):
|
84 |
+
|
85 |
+
ax = fig.add_subplot(rows, cols, i + 1)
|
86 |
+
|
87 |
+
image, true_label, probs = incorrect[i]
|
88 |
+
image = image.permute(1, 2, 0)
|
89 |
+
true_prob = probs[true_label]
|
90 |
+
incorrect_prob, incorrect_label = torch.max(probs, dim=0)
|
91 |
+
true_class = classes[true_label]
|
92 |
+
incorrect_class = classes[incorrect_label]
|
93 |
+
|
94 |
+
if normalize:
|
95 |
+
image = normalize_image(image)
|
96 |
+
|
97 |
+
ax.imshow(image.cpu().numpy())
|
98 |
+
ax.set_title(f'true label: {true_class} ({true_prob:.3f})\n' \
|
99 |
+
f'pred label: {incorrect_class} ({incorrect_prob:.3f})')
|
100 |
+
ax.axis('off')
|
101 |
+
|
102 |
+
fig.subplots_adjust(hspace=0.4)
|
103 |
+
plt.savefig(fig_name)
|
104 |
+
|
105 |
+
def get_pca(data, n_components = 2):
|
106 |
+
pca = decomposition.PCA()
|
107 |
+
pca.n_components = n_components
|
108 |
+
pca_data = pca.fit_transform(data)
|
109 |
+
return pca_data
|
110 |
+
|
111 |
+
|
112 |
+
def plot_representations(fig_name, data, labels, classes, n_images=None):
|
113 |
+
if n_images is not None:
|
114 |
+
data = data[:n_images]
|
115 |
+
labels = labels[:n_images]
|
116 |
+
|
117 |
+
fig = plt.figure(figsize=(15, 15))
|
118 |
+
ax = fig.add_subplot(111)
|
119 |
+
scatter = ax.scatter(data[:, 0], data[:, 1], c=labels, cmap='hsv')
|
120 |
+
# handles, _ = scatter.legend_elements(num = None)
|
121 |
+
# legend = plt.legend(handles = handles, labels = classes)
|
122 |
+
plt.savefig(fig_name)
|
123 |
+
|
124 |
+
|
125 |
+
def plot_filtered_images(fig_name, images, filters, n_filters = None, normalize = True):
|
126 |
+
|
127 |
+
images = torch.cat([i.unsqueeze(0) for i in images], dim = 0).cpu()
|
128 |
+
filters = filters.cpu()
|
129 |
+
|
130 |
+
if n_filters is not None:
|
131 |
+
filters = filters[:n_filters]
|
132 |
+
|
133 |
+
n_images = images.shape[0]
|
134 |
+
n_filters = filters.shape[0]
|
135 |
+
|
136 |
+
filtered_images = F.conv2d(images, filters)
|
137 |
+
|
138 |
+
fig = plt.figure(figsize = (30, 30))
|
139 |
+
|
140 |
+
for i in range(n_images):
|
141 |
+
|
142 |
+
image = images[i]
|
143 |
+
|
144 |
+
if normalize:
|
145 |
+
image = normalize_image(image)
|
146 |
+
|
147 |
+
ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters))
|
148 |
+
ax.imshow(image.permute(1,2,0).numpy())
|
149 |
+
ax.set_title('Original')
|
150 |
+
ax.axis('off')
|
151 |
+
|
152 |
+
for j in range(n_filters):
|
153 |
+
image = filtered_images[i][j]
|
154 |
+
|
155 |
+
if normalize:
|
156 |
+
image = normalize_image(image)
|
157 |
+
|
158 |
+
ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters)+j+1)
|
159 |
+
ax.imshow(image.numpy(), cmap = 'bone')
|
160 |
+
ax.set_title(f'Filter {j+1}')
|
161 |
+
ax.axis('off');
|
162 |
+
|
163 |
+
fig.subplots_adjust(hspace = -0.7)
|
164 |
+
plt.savefig(fig_name)
|
165 |
+
|
166 |
+
|
167 |
+
def plot_filters(fig_name, filters, normalize=True):
|
168 |
+
filters = filters.cpu()
|
169 |
+
|
170 |
+
n_filters = filters.shape[0]
|
171 |
+
|
172 |
+
rows = int(np.sqrt(n_filters))
|
173 |
+
cols = int(np.sqrt(n_filters))
|
174 |
+
|
175 |
+
fig = plt.figure(figsize=(30, 15))
|
176 |
+
|
177 |
+
for i in range(rows * cols):
|
178 |
+
|
179 |
+
image = filters[i]
|
180 |
+
|
181 |
+
if normalize:
|
182 |
+
image = normalize_image(image)
|
183 |
+
|
184 |
+
ax = fig.add_subplot(rows, cols, i + 1)
|
185 |
+
ax.imshow(image.permute(1, 2, 0))
|
186 |
+
ax.axis('off')
|
187 |
+
|
188 |
+
fig.subplots_adjust(wspace=-0.9)
|
189 |
+
plt.savefig(fig_name)
|
190 |
+
|
191 |
+
def plot_tsne(fig_name, all_features, all_labels):
|
192 |
+
tsne = TSNE(n_components=2, random_state=42)
|
193 |
+
tsne_results = tsne.fit_transform(all_features)
|
194 |
+
plt.figure(figsize=(10, 7))
|
195 |
+
scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=all_labels, cmap='viridis', s=5)
|
196 |
+
plt.colorbar(scatter)
|
197 |
+
plt.title('t-SNE Visualization')
|
198 |
+
plt.show()
|
199 |
+
plt.savefig(fig_name)
|
200 |
+
|
201 |
+
|
202 |
+
def plot_grad_cam(images, cams, predicted_labels, true_labels, classes, path):
|
203 |
+
fig, axs = plt.subplots(nrows=2, ncols=len(images), figsize=(20, 10))
|
204 |
+
|
205 |
+
for i, (img, cam, pred_label, true_label) in enumerate(zip(images, cams, predicted_labels, true_labels)):
|
206 |
+
# Display the original image on the top row
|
207 |
+
axs[0, i].imshow(img.permute(1,2,0).cpu().numpy())
|
208 |
+
pred_class_name = classes[pred_label]
|
209 |
+
true_class_name = classes[true_label]
|
210 |
+
axs[0, i].set_title(f"Predicted: {pred_class_name}\nTrue: {true_class_name}", fontsize=12)
|
211 |
+
axs[0, i].axis('off')
|
212 |
+
|
213 |
+
# Add label to the leftmost plot
|
214 |
+
if i == 0:
|
215 |
+
axs[0, i].set_ylabel("Original Image", fontsize=14, rotation=90, labelpad=10)
|
216 |
+
|
217 |
+
# Convert the original image to grayscale
|
218 |
+
grayscale_img = cv2.cvtColor(img.permute(1,2,0).cpu().numpy(), cv2.COLOR_RGB2GRAY)
|
219 |
+
grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
|
220 |
+
|
221 |
+
# Overlay the Grad-CAM heatmap on the grayscale image
|
222 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
|
223 |
+
heatmap = np.float32(heatmap) / 255
|
224 |
+
cam_img = heatmap + np.float32(grayscale_img)
|
225 |
+
cam_img = cam_img / np.max(cam_img)
|
226 |
+
|
227 |
+
# Display the Grad-CAM image on the bottom row
|
228 |
+
axs[1, i].imshow(cam_img)
|
229 |
+
axs[1, i].axis('off')
|
230 |
+
|
231 |
+
# Add label to the leftmost plot
|
232 |
+
if i == 0:
|
233 |
+
axs[1, i].set_ylabel("Grad-CAM", fontsize=14, rotation=90, labelpad=10)
|
234 |
+
|
235 |
+
plt.tight_layout()
|
236 |
+
plt.savefig(path)
|
237 |
+
plt.close()
|
238 |
+
|
vis_confusion_mtx.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from data_utils.data_tribology import TribologyDataset
|
7 |
+
from utils.experiment_utils import get_model, get_prediction
|
8 |
+
from utils.arg_utils import get_args
|
9 |
+
from utils.visualization_utils import plot_confusion_matrix
|
10 |
+
|
11 |
+
def generate_confusion_matrix(image_name, model, iterator, device):
|
12 |
+
labels, predictions = get_prediction(model, iterator, device)
|
13 |
+
plot_confusion_matrix('visualization_results/'+image_name+'_confusion_mtx.png', labels, predictions, classes=["ANTLER", "BEECHWOOD", "BEFOREUSE", "BONE", "IVORY","SPRUCEWOOD"])
|
14 |
+
|
15 |
+
|
16 |
+
def main(args):
|
17 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
+
|
19 |
+
model = get_model(args)
|
20 |
+
|
21 |
+
basepath=os.getcwd()
|
22 |
+
experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
|
23 |
+
if args.model == 'ViT':
|
24 |
+
experiment_dir = Path(os.path.join(basepath,'experiments','ViT_H',args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
|
25 |
+
checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
|
26 |
+
checkpoint_path = checkpoint_dir / f'epoch{str(args.epochs)}.pth'
|
27 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
28 |
+
model = model.to(device)
|
29 |
+
|
30 |
+
train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
|
31 |
+
test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
|
32 |
+
img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
|
33 |
+
BATCHSIZE = args.batch_size
|
34 |
+
train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
|
35 |
+
test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
|
36 |
+
|
37 |
+
means, stds = train_dataset.get_statistics()
|
38 |
+
train_dataset.prepare_transform(means, stds, mode='train')
|
39 |
+
test_dataset.prepare_transform(means, stds, mode='test')
|
40 |
+
|
41 |
+
test_iterator = torch.utils.data.DataLoader(test_dataset,
|
42 |
+
batch_size=BATCHSIZE,
|
43 |
+
num_workers=4,
|
44 |
+
shuffle=False,
|
45 |
+
pin_memory=True,
|
46 |
+
drop_last=False)
|
47 |
+
|
48 |
+
|
49 |
+
generate_confusion_matrix(args.model, model, test_iterator, device)
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
args = get_args()
|
53 |
+
main(args)
|
54 |
+
|
vote_analysis.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import torch.optim as optim
|
6 |
+
import torch.optim.lr_scheduler as lr_scheduler
|
7 |
+
|
8 |
+
import torch.utils.data as data
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import random
|
12 |
+
import tqdm
|
13 |
+
import os
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
from data_utils.data_tribology import TribologyDataset
|
17 |
+
from utils.experiment_utils import get_model, get_name, get_logger, train, evaluate, evaluate_vote, evaluate_vote_analysis
|
18 |
+
from utils.arg_utils import get_args
|
19 |
+
|
20 |
+
def main(args):
|
21 |
+
'''Reproducibility'''
|
22 |
+
SEED = args.seed
|
23 |
+
random.seed(SEED)
|
24 |
+
np.random.seed(SEED)
|
25 |
+
torch.manual_seed(SEED)
|
26 |
+
torch.cuda.manual_seed(SEED)
|
27 |
+
torch.backends.cudnn.deterministic = True
|
28 |
+
torch.backends.cudnn.benchmark = False
|
29 |
+
|
30 |
+
'''Folder Creation'''
|
31 |
+
basepath=os.getcwd()
|
32 |
+
experiment_dir = Path(os.path.join(basepath,'experiments',args.model,args.resolution,args.magnification,args.modality,args.pretrained,args.frozen,args.vote))
|
33 |
+
experiment_dir.mkdir(parents=True, exist_ok=True)
|
34 |
+
checkpoint_dir = Path(os.path.join(experiment_dir,'checkpoints'))
|
35 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
36 |
+
|
37 |
+
'''Logging'''
|
38 |
+
model_name = get_name(args)
|
39 |
+
print(model_name, 'STARTED')
|
40 |
+
|
41 |
+
logger = get_logger(experiment_dir, 'vote_analysis')
|
42 |
+
|
43 |
+
'''Data Loading'''
|
44 |
+
train_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_train.csv"
|
45 |
+
test_csv_path = f"./LUA_Dataset/CSV/{args.resolution}_{args.magnification}_6w_test.csv"
|
46 |
+
img_path = f"./LUA_Dataset/{args.resolution}/{args.magnification}/{args.modality}"
|
47 |
+
|
48 |
+
# results_acc_1 = {}
|
49 |
+
# results_acc_3 = {}
|
50 |
+
# classes_num = 6
|
51 |
+
BATCHSIZE = args.batch_size
|
52 |
+
train_dataset = TribologyDataset(csv_path = train_csv_path, img_path = img_path)
|
53 |
+
test_dataset = TribologyDataset(csv_path = test_csv_path, img_path = img_path)
|
54 |
+
|
55 |
+
# prepare the data augmentation
|
56 |
+
means, stds = train_dataset.get_statistics()
|
57 |
+
train_dataset.prepare_transform(means, stds, mode='train')
|
58 |
+
test_dataset.prepare_transform(means, stds, mode='test')
|
59 |
+
|
60 |
+
VALID_RATIO = 0.1
|
61 |
+
|
62 |
+
num_train = len(train_dataset)
|
63 |
+
num_valid = int(VALID_RATIO * num_train)
|
64 |
+
train_dataset, valid_dataset = data.random_split(train_dataset, [num_train - num_valid, num_valid])
|
65 |
+
logger.info(f'Number of training samples: {len(train_dataset)}')
|
66 |
+
logger.info(f'Number of validation samples: {len(valid_dataset)}')
|
67 |
+
|
68 |
+
test_iterator = torch.utils.data.DataLoader(test_dataset,
|
69 |
+
batch_size=BATCHSIZE,
|
70 |
+
num_workers=4,
|
71 |
+
shuffle=False,
|
72 |
+
pin_memory=True,
|
73 |
+
drop_last=False)
|
74 |
+
print('DATA LOADED')
|
75 |
+
|
76 |
+
# Define model
|
77 |
+
model = get_model(args)
|
78 |
+
print('MODEL LOADED')
|
79 |
+
|
80 |
+
# Define device
|
81 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
82 |
+
model = model.to(device)
|
83 |
+
|
84 |
+
|
85 |
+
print('SETUP DONE')
|
86 |
+
# train our model
|
87 |
+
|
88 |
+
print('TRAINING STARTED')
|
89 |
+
|
90 |
+
model.load_state_dict(torch.load(checkpoint_dir / f'epoch{args.epochs}.pth'))
|
91 |
+
logger.info('-------------------Beginning of Testing-------------------')
|
92 |
+
print('TESTING STARTED')
|
93 |
+
|
94 |
+
vote_accuracy, correct_case_accuracy, incorrect_case_accuracy, incorrect_most_common, novote_accuracy = evaluate_vote_analysis(model, test_iterator, device)
|
95 |
+
logger.info(f'Test Acc @1: {vote_accuracy * 100:6.2f}%')
|
96 |
+
logger.info(f'No Vote Accuracy @1: {novote_accuracy * 100:6.2f}%')
|
97 |
+
logger.info(f'Correct Case Consistency @1: {correct_case_accuracy * 100:6.2f}%')
|
98 |
+
logger.info(f'Incorrect Case Consistency @1: {incorrect_case_accuracy * 100:6.2f}%')
|
99 |
+
logger.info(f'Incorrect Most Common: {incorrect_most_common* 100:6.2f}%')
|
100 |
+
|
101 |
+
logger.info('-------------------End of Testing-------------------')
|
102 |
+
print('TESTING DONE')
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
args = get_args()
|
107 |
+
main(args)
|