import shutil |
import warnings |
from sklearn import metrics |
from sklearn.metrics import confusion_matrix |
from PIL import Image |
warnings.filterwarnings("ignore") |
import torch.utils.data as data |
import os |
import argparse |
from sklearn.metrics import f1_score, confusion_matrix |
from data_preprocessing.sam import SAM |
import torch.nn.parallel |
import torch.backends.cudnn as cudnn |
import torch.optim |
import torch.utils.data |
import torch.utils.data.distributed |
import matplotlib.pyplot as plt |
import torchvision.datasets as datasets |
import torchvision.transforms as transforms |
import numpy as np |
import datetime |
from torchsampler import ImbalancedDatasetSampler |
from models.PosterV2_7cls import pyramid_trans_expr2 |
warnings.filterwarnings("ignore", category=UserWarning) |
now = datetime.datetime.now() |
time_str = now.strftime("[%m-%d]-[%H-%M]-") |
if torch.backends.mps.is_available(): |
device = "mps" |
elif torch.cuda.is_available(): |
device = "cuda" |
else: |
device = "cpu" |
print(f"Using device: {device}") |
parser = argparse.ArgumentParser() |
parser.add_argument("--data", type=str, default=r"raf-db/DATASET") |
parser.add_argument( |
"--data_type", |
default="RAF-DB", |
choices=["RAF-DB", "AffectNet-7", "CAER-S"], |
type=str, |
help="dataset option", |
) |
parser.add_argument( |
"--checkpoint_path", type=str, default="./checkpoint/" + time_str + "model.pth" |
) |
parser.add_argument( |
"--best_checkpoint_path", |
type=str, |
default="./checkpoint/" + time_str + "model_best.pth", |
) |
parser.add_argument( |
"-j", |
"--workers", |
default=4, |
type=int, |
metavar="N", |
help="number of data loading workers", |
) |
parser.add_argument( |
"--epochs", default=200, type=int, metavar="N", help="number of total epochs to run" |
) |
parser.add_argument( |
"--start-epoch", |
default=0, |
type=int, |
metavar="N", |
help="manual epoch number (useful on restarts)", |
) |
parser.add_argument("-b", "--batch-size", default=2, type=int, metavar="N") |
parser.add_argument( |
"--optimizer", type=str, default="adam", help="Optimizer, adam or sgd." |
) |
parser.add_argument( |
"--lr", "--learning-rate", default=0.000035, type=float, metavar="LR", dest="lr" |
) |
parser.add_argument("--momentum", default=0.9, type=float, metavar="M") |
parser.add_argument( |
"--wd", "--weight-decay", default=1e-4, type=float, metavar="W", dest="weight_decay" |
) |
parser.add_argument( |
"-p", "--print-freq", default=30, type=int, metavar="N", help="print frequency" |
) |
parser.add_argument( |
"--resume", default=None, type=str, metavar="PATH", help="path to checkpoint" |
) |
parser.add_argument( |
"-e", "--evaluate", default=None, type=str, help="evaluate model on test set" |
) |
parser.add_argument("--beta", type=float, default=0.6) |
parser.add_argument("--gpu", type=str, default="0") |
parser.add_argument( |
"-i", "--image", type=str, help="upload a single image to test the prediction" |
) |
parser.add_argument("-t", "--test", type=str, help="test model on single image") |
args = parser.parse_args() |
def main(): |
best_acc = 0 |
model = pyramid_trans_expr2(img_size=224, num_classes=7) |
model = torch.nn.DataParallel(model) |
model = model.to(device) |
criterion = torch.nn.CrossEntropyLoss() |
if args.optimizer == "adamw": |
base_optimizer = torch.optim.AdamW |
elif args.optimizer == "adam": |
base_optimizer = torch.optim.Adam |
elif args.optimizer == "sgd": |
base_optimizer = torch.optim.SGD |
else: |
raise ValueError("Optimizer not supported.") |
optimizer = SAM( |
model.parameters(), |
base_optimizer, |
lr=args.lr, |
rho=0.05, |
adaptive=False, |
) |
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98) |
recorder = RecorderMeter(args.epochs) |
recorder1 = RecorderMeter1(args.epochs) |
if args.resume: |
if os.path.isfile(args.resume): |
print("=> loading checkpoint '{}'".format(args.resume)) |
checkpoint = torch.load(args.resume) |
args.start_epoch = checkpoint["epoch"] |
best_acc = checkpoint["best_acc"] |
recorder = checkpoint["recorder"] |
recorder1 = checkpoint["recorder1"] |
best_acc = best_acc.to() |
model.load_state_dict(checkpoint["state_dict"]) |
optimizer.load_state_dict(checkpoint["optimizer"]) |
print( |
"=> loaded checkpoint '{}' (epoch {})".format( |
args.resume, checkpoint["epoch"] |
) |
) |
else: |
print("=> no checkpoint found at '{}'".format(args.resume)) |
cudnn.benchmark = True |
traindir = os.path.join(args.data, "train") |
valdir = os.path.join(args.data, "test") |
if args.evaluate is None: |
if args.data_type == "RAF-DB": |
train_dataset = datasets.ImageFolder( |
traindir, |
transforms.Compose( |
[ |
transforms.Resize((224, 224)), |
transforms.RandomHorizontalFlip(), |
transforms.ToTensor(), |
transforms.Normalize( |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
), |
transforms.RandomErasing(scale=(0.02, 0.1)), |
] |
), |
) |
else: |
train_dataset = datasets.ImageFolder( |
traindir, |
transforms.Compose( |
[ |
transforms.Resize((224, 224)), |
transforms.RandomHorizontalFlip(), |
transforms.ToTensor(), |
transforms.Normalize( |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
), |
transforms.RandomErasing(p=1, scale=(0.05, 0.05)), |
] |
), |
) |
if args.data_type == "AffectNet-7": |
train_loader = torch.utils.data.DataLoader( |
train_dataset, |
sampler=ImbalancedDatasetSampler(train_dataset), |
batch_size=args.batch_size, |
shuffle=False, |
num_workers=args.workers, |
pin_memory=True, |
) |
else: |
train_loader = torch.utils.data.DataLoader( |
train_dataset, |
batch_size=args.batch_size, |
shuffle=True, |
num_workers=args.workers, |
pin_memory=True, |
) |
test_dataset = datasets.ImageFolder( |
valdir, |
transforms.Compose( |
[ |
transforms.Resize((224, 224)), |
transforms.ToTensor(), |
transforms.Normalize( |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
), |
] |
), |
) |
val_loader = torch.utils.data.DataLoader( |
test_dataset, |
batch_size=args.batch_size, |
shuffle=False, |
num_workers=args.workers, |
pin_memory=True, |
) |
if args.evaluate is not None: |
from validation import validate |
if os.path.isfile(args.evaluate): |
print("=> loading checkpoint '{}'".format(args.evaluate)) |
checkpoint = torch.load(args.evaluate, map_location=device) |
best_acc = checkpoint["best_acc"] |
best_acc = best_acc.to() |
print(f"best_acc:{best_acc}") |
model.load_state_dict(checkpoint["state_dict"]) |
print( |
"=> loaded checkpoint '{}' (epoch {})".format( |
args.evaluate, checkpoint["epoch"] |
) |
) |
else: |
print("=> no checkpoint found at '{}'".format(args.evaluate)) |
validate(val_loader, model, criterion, args) |
return |
if args.test is not None: |
from prediction import predict |
if os.path.isfile(args.test): |
print("=> loading checkpoint '{}'".format(args.test)) |
checkpoint = torch.load(args.test, map_location=device) |
best_acc = checkpoint["best_acc"] |
best_acc = best_acc.to() |
print(f"best_acc:{best_acc}") |
model.load_state_dict(checkpoint["state_dict"]) |
print( |
"=> loaded checkpoint '{}' (epoch {})".format( |
args.test, checkpoint["epoch"] |
) |
) |
else: |
print("=> no checkpoint found at '{}'".format(args.test)) |
predict(model, image_path=args.image) |
return |
matrix = None |
for epoch in range(args.start_epoch, args.epochs): |
current_learning_rate = optimizer.state_dict()["param_groups"][0]["lr"] |
print("Current learning rate: ", current_learning_rate) |
txt_name = "./log/" + time_str + "log.txt" |
with open(txt_name, "a") as f: |
f.write("Current learning rate: " + str(current_learning_rate) + "\n") |
train_acc, train_los = train( |
train_loader, model, criterion, optimizer, epoch, args |
) |
val_acc, val_los, output, target, D = validate( |
val_loader, model, criterion, args |
) |
scheduler.step() |
recorder.update(epoch, train_los, train_acc, val_los, val_acc) |
recorder1.update(output, target) |
curve_name = time_str + "cnn.png" |
recorder.plot_curve(os.path.join("./log/", curve_name)) |
is_best = val_acc > best_acc |
best_acc = max(val_acc, best_acc) |
print("Current best accuracy: ", best_acc.item()) |
if is_best: |
matrix = D |
print("Current best matrix: ", matrix) |
txt_name = "./log/" + time_str + "log.txt" |
with open(txt_name, "a") as f: |
f.write("Current best accuracy: " + str(best_acc.item()) + "\n") |
save_checkpoint( |
{ |
"epoch": epoch + 1, |
"state_dict": model.state_dict(), |
"best_acc": best_acc, |
"optimizer": optimizer.state_dict(), |
"recorder1": recorder1, |
"recorder": recorder, |
}, |
is_best, |
args, |
) |
def train(train_loader, model, criterion, optimizer, epoch, args): |
losses = AverageMeter("Loss", ":.4f") |
top1 = AverageMeter("Accuracy", ":6.3f") |
progress = ProgressMeter( |
len(train_loader), [losses, top1], prefix="Epoch: [{}]".format(epoch) |
) |
model.train() |
for i, (images, target) in enumerate(train_loader): |
images = images.to(device) |
target = target.to(device) |
output = model(images) |
loss = criterion(output, target) |
acc1, _ = accuracy(output, target, topk=(1, 5)) |
losses.update(loss.item(), images.size(0)) |
top1.update(acc1[0], images.size(0)) |
optimizer.zero_grad() |
loss.backward() |
optimizer.first_step(zero_grad=True) |
images = images.to(device) |
target = target.to(device) |
output = model(images) |
loss = criterion(output, target) |
acc1, _ = accuracy(output, target, topk=(1, 5)) |
losses.update(loss.item(), images.size(0)) |
top1.update(acc1[0], images.size(0)) |
optimizer.zero_grad() |
loss.backward() |
optimizer.second_step(zero_grad=True) |
if i % args.print_freq == 0: |
progress.display(i) |
return top1.avg, losses.avg |
def save_checkpoint(state, is_best, args): |
torch.save(state, args.checkpoint_path) |
if is_best: |
best_state = state.pop("optimizer") |
torch.save(best_state, args.best_checkpoint_path) |
class AverageMeter(object): |
"""Computes and stores the average and current value""" |
def __init__(self, name, fmt=":f"): |
self.name = name |
self.fmt = fmt |
self.reset() |
def reset(self): |
self.val = 0 |
self.avg = 0 |
self.sum = 0 |
self.count = 0 |
def update(self, val, n=1): |
self.val = val |
self.sum += val * n |
self.count += n |
self.avg = self.sum / self.count |
def __str__(self): |
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" |
return fmtstr.format(**self.__dict__) |
class ProgressMeter(object): |
def __init__(self, num_batches, meters, prefix=""): |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
self.meters = meters |
self.prefix = prefix |
def display(self, batch): |
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
entries += [str(meter) for meter in self.meters] |
print_txt = "\t".join(entries) |
print(print_txt) |
txt_name = "./log/" + time_str + "log.txt" |
with open(txt_name, "a") as f: |
f.write(print_txt + "\n") |
def _get_batch_fmtstr(self, num_batches): |
num_digits = len(str(num_batches // 1)) |
fmt = "{:" + str(num_digits) + "d}" |
return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
def accuracy(output, target, topk=(1,)): |
"""Computes the accuracy over the k top predictions for the specified values of k""" |
with torch.no_grad(): |
maxk = max(topk) |
batch_size = target.size(0) |
_, pred = output.topk(maxk, 1, True, True) |
pred = pred.t() |
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
res = [] |
for k in topk: |
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) |
res.append(correct_k.mul_(100.0 / batch_size)) |
return res |
labels = ["A", "B", "C", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O"] |
class RecorderMeter1(object): |
"""Computes and stores the minimum loss value and its epoch index""" |
def __init__(self, total_epoch): |
self.reset(total_epoch) |
def reset(self, total_epoch): |
self.total_epoch = total_epoch |
self.current_epoch = 0 |
self.epoch_losses = np.zeros( |
(self.total_epoch, 2), dtype=np.float32 |
) |
self.epoch_accuracy = np.zeros( |
(self.total_epoch, 2), dtype=np.float32 |
) |
def update(self, output, target): |
self.y_pred = output |
self.y_true = target |
def plot_confusion_matrix(self, cm, title="Confusion Matrix", cmap=plt.cm.binary): |
plt.imshow(cm, interpolation="nearest", cmap=cmap) |
y_true = self.y_true |
y_pred = self.y_pred |
plt.title(title) |
plt.colorbar() |
xlocations = np.array(range(len(labels))) |
plt.xticks(xlocations, labels, rotation=90) |
plt.yticks(xlocations, labels) |
plt.ylabel("True label") |
plt.xlabel("Predicted label") |
cm = confusion_matrix(y_true, y_pred) |
np.set_printoptions(precision=2) |
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] |
plt.figure(figsize=(12, 8), dpi=120) |
ind_array = np.arange(len(labels)) |
x, y = np.meshgrid(ind_array, ind_array) |
for x_val, y_val in zip(x.flatten(), y.flatten()): |
c = cm_normalized[y_val][x_val] |
if c > 0.01: |
plt.text( |
x_val, |
y_val, |
"%0.2f" % (c,), |
color="red", |
fontsize=7, |
va="center", |
ha="center", |
) |
tick_marks = np.arange(len(7)) |
plt.gca().set_xticks(tick_marks, minor=True) |
plt.gca().set_yticks(tick_marks, minor=True) |
plt.gca().xaxis.set_ticks_position("none") |
plt.gca().yaxis.set_ticks_position("none") |
plt.grid(True, which="minor", linestyle="-") |
plt.gcf().subplots_adjust(bottom=0.15) |
plot_confusion_matrix(cm_normalized, title="Normalized confusion matrix") |
plt.savefig("./log/confusion_matrix.png", format="png") |
print("Saved figure") |
plt.show() |
def matrix(self): |
target = self.y_true |
output = self.y_pred |
im_re_label = np.array(target) |
im_pre_label = np.array(output) |
y_ture = im_re_label.flatten() |
y_pred = im_pre_label.flatten() |
im_pre_label.transpose() |
class RecorderMeter(object): |
"""Computes and stores the minimum loss value and its epoch index""" |
def __init__(self, total_epoch): |
self.reset(total_epoch) |
def reset(self, total_epoch): |
self.total_epoch = total_epoch |
self.current_epoch = 0 |
self.epoch_losses = np.zeros( |
(self.total_epoch, 2), dtype=np.float32 |
) |
self.epoch_accuracy = np.zeros( |
(self.total_epoch, 2), dtype=np.float32 |
) |
def update(self, idx, train_loss, train_acc, val_loss, val_acc): |
self.epoch_losses[idx, 0] = train_loss * 30 |
self.epoch_losses[idx, 1] = val_loss * 30 |
self.epoch_accuracy[idx, 0] = train_acc |
self.epoch_accuracy[idx, 1] = val_acc |
self.current_epoch = idx + 1 |
def plot_curve(self, save_path): |
title = "the accuracy/loss curve of train/val" |
dpi = 80 |
width, height = 1800, 800 |
legend_fontsize = 10 |
figsize = width / float(dpi), height / float(dpi) |
fig = plt.figure(figsize=figsize) |
x_axis = np.array([i for i in range(self.total_epoch)]) |
y_axis = np.zeros(self.total_epoch) |
plt.xlim(0, self.total_epoch) |
plt.ylim(0, 100) |
interval_y = 5 |
interval_x = 5 |
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) |
plt.yticks(np.arange(0, 100 + interval_y, interval_y)) |
plt.grid() |
plt.title(title, fontsize=20) |
plt.xlabel("the training epoch", fontsize=16) |
plt.ylabel("accuracy", fontsize=16) |
y_axis[:] = self.epoch_accuracy[:, 0] |
plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2) |
plt.legend(loc=4, fontsize=legend_fontsize) |
y_axis[:] = self.epoch_accuracy[:, 1] |
plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2) |
plt.legend(loc=4, fontsize=legend_fontsize) |
y_axis[:] = self.epoch_losses[:, 0] |
plt.plot(x_axis, y_axis, color="g", linestyle=":", label="train-loss-x30", lw=2) |
plt.legend(loc=4, fontsize=legend_fontsize) |
y_axis[:] = self.epoch_losses[:, 1] |
plt.plot(x_axis, y_axis, color="y", linestyle=":", label="valid-loss-x30", lw=2) |
plt.legend(loc=4, fontsize=legend_fontsize) |
if save_path is not None: |
fig.savefig(save_path, dpi=dpi, bbox_inches="tight") |
print("Saved figure") |
plt.close(fig) |
if __name__ == "__main__": |
main() |