|
import shutil |
|
import warnings |
|
from sklearn import metrics |
|
from sklearn.metrics import confusion_matrix |
|
from PIL import Image |
|
|
|
warnings.filterwarnings("ignore") |
|
import torch.utils.data as data |
|
import os |
|
import argparse |
|
from sklearn.metrics import f1_score, confusion_matrix |
|
from data_preprocessing.sam import SAM |
|
import torch.nn.parallel |
|
import torch.backends.cudnn as cudnn |
|
import torch.optim |
|
import torch.utils.data |
|
import torch.utils.data.distributed |
|
import matplotlib.pyplot as plt |
|
import torchvision.datasets as datasets |
|
import torchvision.transforms as transforms |
|
import numpy as np |
|
import datetime |
|
from torchsampler import ImbalancedDatasetSampler |
|
from models.PosterV2_7cls import pyramid_trans_expr2 |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
now = datetime.datetime.now() |
|
time_str = now.strftime("[%m-%d]-[%H-%M]-") |
|
if torch.backends.mps.is_available(): |
|
device = "mps" |
|
elif torch.cuda.is_available(): |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
print(f"Using device: {device}") |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data", type=str, default=r"raf-db/DATASET") |
|
parser.add_argument( |
|
"--data_type", |
|
default="RAF-DB", |
|
choices=["RAF-DB", "AffectNet-7", "CAER-S"], |
|
type=str, |
|
help="dataset option", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_path", type=str, default="./checkpoint/" + time_str + "model.pth" |
|
) |
|
parser.add_argument( |
|
"--best_checkpoint_path", |
|
type=str, |
|
default="./checkpoint/" + time_str + "model_best.pth", |
|
) |
|
parser.add_argument( |
|
"-j", |
|
"--workers", |
|
default=4, |
|
type=int, |
|
metavar="N", |
|
help="number of data loading workers", |
|
) |
|
parser.add_argument( |
|
"--epochs", default=200, type=int, metavar="N", help="number of total epochs to run" |
|
) |
|
parser.add_argument( |
|
"--start-epoch", |
|
default=0, |
|
type=int, |
|
metavar="N", |
|
help="manual epoch number (useful on restarts)", |
|
) |
|
parser.add_argument("-b", "--batch-size", default=2, type=int, metavar="N") |
|
parser.add_argument( |
|
"--optimizer", type=str, default="adam", help="Optimizer, adam or sgd." |
|
) |
|
|
|
parser.add_argument( |
|
"--lr", "--learning-rate", default=0.000035, type=float, metavar="LR", dest="lr" |
|
) |
|
parser.add_argument("--momentum", default=0.9, type=float, metavar="M") |
|
parser.add_argument( |
|
"--wd", "--weight-decay", default=1e-4, type=float, metavar="W", dest="weight_decay" |
|
) |
|
parser.add_argument( |
|
"-p", "--print-freq", default=30, type=int, metavar="N", help="print frequency" |
|
) |
|
parser.add_argument( |
|
"--resume", default=None, type=str, metavar="PATH", help="path to checkpoint" |
|
) |
|
parser.add_argument( |
|
"-e", "--evaluate", default=None, type=str, help="evaluate model on test set" |
|
) |
|
parser.add_argument("--beta", type=float, default=0.6) |
|
parser.add_argument("--gpu", type=str, default="0") |
|
|
|
parser.add_argument( |
|
"-i", "--image", type=str, help="upload a single image to test the prediction" |
|
) |
|
parser.add_argument("-t", "--test", type=str, help="test model on single image") |
|
args = parser.parse_args() |
|
|
|
|
|
def main(): |
|
|
|
best_acc = 0 |
|
|
|
|
|
|
|
model = pyramid_trans_expr2(img_size=224, num_classes=7) |
|
|
|
model = torch.nn.DataParallel(model) |
|
model = model.to(device) |
|
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
if args.optimizer == "adamw": |
|
base_optimizer = torch.optim.AdamW |
|
elif args.optimizer == "adam": |
|
base_optimizer = torch.optim.Adam |
|
elif args.optimizer == "sgd": |
|
base_optimizer = torch.optim.SGD |
|
else: |
|
raise ValueError("Optimizer not supported.") |
|
|
|
optimizer = SAM( |
|
model.parameters(), |
|
base_optimizer, |
|
lr=args.lr, |
|
rho=0.05, |
|
adaptive=False, |
|
) |
|
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98) |
|
recorder = RecorderMeter(args.epochs) |
|
recorder1 = RecorderMeter1(args.epochs) |
|
|
|
if args.resume: |
|
if os.path.isfile(args.resume): |
|
print("=> loading checkpoint '{}'".format(args.resume)) |
|
checkpoint = torch.load(args.resume) |
|
args.start_epoch = checkpoint["epoch"] |
|
best_acc = checkpoint["best_acc"] |
|
recorder = checkpoint["recorder"] |
|
recorder1 = checkpoint["recorder1"] |
|
best_acc = best_acc.to() |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
print( |
|
"=> loaded checkpoint '{}' (epoch {})".format( |
|
args.resume, checkpoint["epoch"] |
|
) |
|
) |
|
else: |
|
print("=> no checkpoint found at '{}'".format(args.resume)) |
|
cudnn.benchmark = True |
|
|
|
|
|
traindir = os.path.join(args.data, "train") |
|
|
|
valdir = os.path.join(args.data, "test") |
|
|
|
if args.evaluate is None: |
|
if args.data_type == "RAF-DB": |
|
train_dataset = datasets.ImageFolder( |
|
traindir, |
|
transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
transforms.RandomErasing(scale=(0.02, 0.1)), |
|
] |
|
), |
|
) |
|
else: |
|
train_dataset = datasets.ImageFolder( |
|
traindir, |
|
transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
transforms.RandomErasing(p=1, scale=(0.05, 0.05)), |
|
] |
|
), |
|
) |
|
|
|
if args.data_type == "AffectNet-7": |
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
sampler=ImbalancedDatasetSampler(train_dataset), |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
) |
|
|
|
else: |
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
) |
|
|
|
test_dataset = datasets.ImageFolder( |
|
valdir, |
|
transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
), |
|
] |
|
), |
|
) |
|
|
|
val_loader = torch.utils.data.DataLoader( |
|
test_dataset, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=args.workers, |
|
pin_memory=True, |
|
) |
|
|
|
if args.evaluate is not None: |
|
from validation import validate |
|
|
|
if os.path.isfile(args.evaluate): |
|
print("=> loading checkpoint '{}'".format(args.evaluate)) |
|
checkpoint = torch.load(args.evaluate, map_location=device) |
|
best_acc = checkpoint["best_acc"] |
|
best_acc = best_acc.to() |
|
print(f"best_acc:{best_acc}") |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
print( |
|
"=> loaded checkpoint '{}' (epoch {})".format( |
|
args.evaluate, checkpoint["epoch"] |
|
) |
|
) |
|
else: |
|
print("=> no checkpoint found at '{}'".format(args.evaluate)) |
|
validate(val_loader, model, criterion, args) |
|
return |
|
|
|
if args.test is not None: |
|
from prediction import predict |
|
|
|
if os.path.isfile(args.test): |
|
print("=> loading checkpoint '{}'".format(args.test)) |
|
checkpoint = torch.load(args.test, map_location=device) |
|
best_acc = checkpoint["best_acc"] |
|
best_acc = best_acc.to() |
|
print(f"best_acc:{best_acc}") |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
print( |
|
"=> loaded checkpoint '{}' (epoch {})".format( |
|
args.test, checkpoint["epoch"] |
|
) |
|
) |
|
else: |
|
print("=> no checkpoint found at '{}'".format(args.test)) |
|
predict(model, image_path=args.image) |
|
|
|
return |
|
matrix = None |
|
|
|
for epoch in range(args.start_epoch, args.epochs): |
|
current_learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] |
|
print("Current learning rate: ", current_learning_rate) |
|
txt_name = "./log/" + time_str + "log.txt" |
|
with open(txt_name, "a") as f: |
|
f.write("Current learning rate: " + str(current_learning_rate) + "\n") |
|
|
|
|
|
train_acc, train_los = train( |
|
train_loader, model, criterion, optimizer, epoch, args |
|
) |
|
|
|
|
|
val_acc, val_los, output, target, D = validate( |
|
val_loader, model, criterion, args |
|
) |
|
|
|
scheduler.step() |
|
|
|
recorder.update(epoch, train_los, train_acc, val_los, val_acc) |
|
recorder1.update(output, target) |
|
|
|
curve_name = time_str + "cnn.png" |
|
recorder.plot_curve(os.path.join("./log/", curve_name)) |
|
|
|
|
|
is_best = val_acc > best_acc |
|
best_acc = max(val_acc, best_acc) |
|
|
|
print("Current best accuracy: ", best_acc.item()) |
|
|
|
if is_best: |
|
matrix = D |
|
|
|
print("Current best matrix: ", matrix) |
|
|
|
txt_name = "./log/" + time_str + "log.txt" |
|
with open(txt_name, "a") as f: |
|
f.write("Current best accuracy: " + str(best_acc.item()) + "\n") |
|
|
|
save_checkpoint( |
|
{ |
|
"epoch": epoch + 1, |
|
"state_dict": model.state_dict(), |
|
"best_acc": best_acc, |
|
"optimizer": optimizer.state_dict(), |
|
"recorder1": recorder1, |
|
"recorder": recorder, |
|
}, |
|
is_best, |
|
args, |
|
) |
|
|
|
|
|
def train(train_loader, model, criterion, optimizer, epoch, args): |
|
losses = AverageMeter("Loss", ":.4f") |
|
top1 = AverageMeter("Accuracy", ":6.3f") |
|
progress = ProgressMeter( |
|
len(train_loader), [losses, top1], prefix="Epoch: [{}]".format(epoch) |
|
) |
|
|
|
|
|
model.train() |
|
|
|
for i, (images, target) in enumerate(train_loader): |
|
images = images.to(device) |
|
target = target.to(device) |
|
|
|
|
|
output = model(images) |
|
loss = criterion(output, target) |
|
|
|
|
|
acc1, _ = accuracy(output, target, topk=(1, 5)) |
|
losses.update(loss.item(), images.size(0)) |
|
top1.update(acc1[0], images.size(0)) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
|
|
optimizer.first_step(zero_grad=True) |
|
images = images.to(device) |
|
target = target.to(device) |
|
|
|
|
|
output = model(images) |
|
loss = criterion(output, target) |
|
|
|
|
|
acc1, _ = accuracy(output, target, topk=(1, 5)) |
|
losses.update(loss.item(), images.size(0)) |
|
top1.update(acc1[0], images.size(0)) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.second_step(zero_grad=True) |
|
|
|
|
|
if i % args.print_freq == 0: |
|
progress.display(i) |
|
|
|
return top1.avg, losses.avg |
|
|
|
|
|
def save_checkpoint(state, is_best, args): |
|
torch.save(state, args.checkpoint_path) |
|
if is_best: |
|
best_state = state.pop("optimizer") |
|
torch.save(best_state, args.best_checkpoint_path) |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, name, fmt=":f"): |
|
self.name = name |
|
self.fmt = fmt |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" |
|
return fmtstr.format(**self.__dict__) |
|
|
|
|
|
class ProgressMeter(object): |
|
def __init__(self, num_batches, meters, prefix=""): |
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
|
self.meters = meters |
|
self.prefix = prefix |
|
|
|
def display(self, batch): |
|
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
|
entries += [str(meter) for meter in self.meters] |
|
print_txt = "\t".join(entries) |
|
print(print_txt) |
|
txt_name = "./log/" + time_str + "log.txt" |
|
with open(txt_name, "a") as f: |
|
f.write(print_txt + "\n") |
|
|
|
def _get_batch_fmtstr(self, num_batches): |
|
num_digits = len(str(num_batches // 1)) |
|
fmt = "{:" + str(num_digits) + "d}" |
|
return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
"""Computes the accuracy over the k top predictions for the specified values of k""" |
|
with torch.no_grad(): |
|
maxk = max(topk) |
|
batch_size = target.size(0) |
|
_, pred = output.topk(maxk, 1, True, True) |
|
pred = pred.t() |
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
res = [] |
|
for k in topk: |
|
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) |
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
return res |
|
|
|
|
|
labels = ["A", "B", "C", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O"] |
|
|
|
|
|
class RecorderMeter1(object): |
|
"""Computes and stores the minimum loss value and its epoch index""" |
|
|
|
def __init__(self, total_epoch): |
|
self.reset(total_epoch) |
|
|
|
def reset(self, total_epoch): |
|
self.total_epoch = total_epoch |
|
self.current_epoch = 0 |
|
self.epoch_losses = np.zeros( |
|
(self.total_epoch, 2), dtype=np.float32 |
|
) |
|
self.epoch_accuracy = np.zeros( |
|
(self.total_epoch, 2), dtype=np.float32 |
|
) |
|
|
|
def update(self, output, target): |
|
self.y_pred = output |
|
self.y_true = target |
|
|
|
def plot_confusion_matrix(self, cm, title="Confusion Matrix", cmap=plt.cm.binary): |
|
plt.imshow(cm, interpolation="nearest", cmap=cmap) |
|
y_true = self.y_true |
|
y_pred = self.y_pred |
|
|
|
plt.title(title) |
|
plt.colorbar() |
|
xlocations = np.array(range(len(labels))) |
|
plt.xticks(xlocations, labels, rotation=90) |
|
plt.yticks(xlocations, labels) |
|
plt.ylabel("True label") |
|
plt.xlabel("Predicted label") |
|
|
|
cm = confusion_matrix(y_true, y_pred) |
|
np.set_printoptions(precision=2) |
|
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] |
|
plt.figure(figsize=(12, 8), dpi=120) |
|
|
|
ind_array = np.arange(len(labels)) |
|
x, y = np.meshgrid(ind_array, ind_array) |
|
for x_val, y_val in zip(x.flatten(), y.flatten()): |
|
c = cm_normalized[y_val][x_val] |
|
if c > 0.01: |
|
plt.text( |
|
x_val, |
|
y_val, |
|
"%0.2f" % (c,), |
|
color="red", |
|
fontsize=7, |
|
va="center", |
|
ha="center", |
|
) |
|
|
|
tick_marks = np.arange(len(7)) |
|
plt.gca().set_xticks(tick_marks, minor=True) |
|
plt.gca().set_yticks(tick_marks, minor=True) |
|
plt.gca().xaxis.set_ticks_position("none") |
|
plt.gca().yaxis.set_ticks_position("none") |
|
plt.grid(True, which="minor", linestyle="-") |
|
plt.gcf().subplots_adjust(bottom=0.15) |
|
|
|
plot_confusion_matrix(cm_normalized, title="Normalized confusion matrix") |
|
|
|
plt.savefig("./log/confusion_matrix.png", format="png") |
|
|
|
print("Saved figure") |
|
plt.show() |
|
|
|
def matrix(self): |
|
target = self.y_true |
|
output = self.y_pred |
|
im_re_label = np.array(target) |
|
im_pre_label = np.array(output) |
|
y_ture = im_re_label.flatten() |
|
|
|
y_pred = im_pre_label.flatten() |
|
im_pre_label.transpose() |
|
|
|
|
|
class RecorderMeter(object): |
|
"""Computes and stores the minimum loss value and its epoch index""" |
|
|
|
def __init__(self, total_epoch): |
|
self.reset(total_epoch) |
|
|
|
def reset(self, total_epoch): |
|
self.total_epoch = total_epoch |
|
self.current_epoch = 0 |
|
self.epoch_losses = np.zeros( |
|
(self.total_epoch, 2), dtype=np.float32 |
|
) |
|
self.epoch_accuracy = np.zeros( |
|
(self.total_epoch, 2), dtype=np.float32 |
|
) |
|
|
|
def update(self, idx, train_loss, train_acc, val_loss, val_acc): |
|
self.epoch_losses[idx, 0] = train_loss * 30 |
|
self.epoch_losses[idx, 1] = val_loss * 30 |
|
self.epoch_accuracy[idx, 0] = train_acc |
|
self.epoch_accuracy[idx, 1] = val_acc |
|
self.current_epoch = idx + 1 |
|
|
|
def plot_curve(self, save_path): |
|
title = "the accuracy/loss curve of train/val" |
|
dpi = 80 |
|
width, height = 1800, 800 |
|
legend_fontsize = 10 |
|
figsize = width / float(dpi), height / float(dpi) |
|
|
|
fig = plt.figure(figsize=figsize) |
|
x_axis = np.array([i for i in range(self.total_epoch)]) |
|
y_axis = np.zeros(self.total_epoch) |
|
|
|
plt.xlim(0, self.total_epoch) |
|
plt.ylim(0, 100) |
|
interval_y = 5 |
|
interval_x = 5 |
|
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) |
|
plt.yticks(np.arange(0, 100 + interval_y, interval_y)) |
|
plt.grid() |
|
plt.title(title, fontsize=20) |
|
plt.xlabel("the training epoch", fontsize=16) |
|
plt.ylabel("accuracy", fontsize=16) |
|
|
|
y_axis[:] = self.epoch_accuracy[:, 0] |
|
plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2) |
|
plt.legend(loc=4, fontsize=legend_fontsize) |
|
|
|
y_axis[:] = self.epoch_accuracy[:, 1] |
|
plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2) |
|
plt.legend(loc=4, fontsize=legend_fontsize) |
|
|
|
y_axis[:] = self.epoch_losses[:, 0] |
|
plt.plot(x_axis, y_axis, color="g", linestyle=":", label="train-loss-x30", lw=2) |
|
plt.legend(loc=4, fontsize=legend_fontsize) |
|
|
|
y_axis[:] = self.epoch_losses[:, 1] |
|
plt.plot(x_axis, y_axis, color="y", linestyle=":", label="valid-loss-x30", lw=2) |
|
plt.legend(loc=4, fontsize=legend_fontsize) |
|
|
|
if save_path is not None: |
|
fig.savefig(save_path, dpi=dpi, bbox_inches="tight") |
|
print("Saved figure") |
|
plt.close(fig) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|