jaimin's picture
Upload 78 files
bf53f45 verified
raw
history blame
8.38 kB
import argparse
import json
import logging
from typing import Tuple
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from eval_tools import Metrics, time_sync, write_results
from mivolo.data.dataset import build as build_data
from mivolo.model.mi_volo import MiVOLO
from timm.utils import setup_default_logging
_logger = logging.getLogger("inference")
LOG_FREQUENCY = 10
def get_parser():
parser = argparse.ArgumentParser(description="PyTorch MiVOLO Validation")
parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images")
parser.add_argument("--dataset_annotations", default="", type=str, required=True, help="path to annotations")
parser.add_argument(
"--dataset_name",
default=None,
type=str,
required=True,
choices=["utk", "imdb", "lagenda", "fairface", "adience", "agedb", "cacd"],
help="dataset name",
)
parser.add_argument("--split", default="validation", help="dataset splits separated by comma (default: validation)")
parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")
parser.add_argument("--batch-size", default=64, type=int, help="batch size")
parser.add_argument(
"--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
)
parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.")
parser.add_argument("--l-for-cs", type=int, default=5, help="L for CS (cumulative score)")
parser.add_argument("--half", action="store_true", default=False, help="use half-precision model")
parser.add_argument(
"--with-persons", action="store_true", default=False, help="If the model will run with persons, if available"
)
parser.add_argument(
"--disable-faces", action="store_true", default=False, help="If the model will use only persons if available"
)
parser.add_argument("--draw-hist", action="store_true", help="Draws the hist of error by age")
parser.add_argument(
"--results-file",
default="",
type=str,
metavar="FILENAME",
help="Output csv file for validation results (summary)",
)
parser.add_argument(
"--results-format", default="csv", type=str, help="Format for results file one of (csv, json) (default: csv)."
)
return parser
def process_batch(
mivolo_model: MiVOLO,
input: torch.tensor,
target: torch.tensor,
num_classes_gender: int = 2,
):
start = time_sync()
output = mivolo_model.inference(input)
# target with age == -1 and gender == -1 marks that sample is not valid
assert not (all(target[:, 0] == -1) and all(target[:, 1] == -1))
if not mivolo_model.meta.only_age:
gender_out = output[:, :num_classes_gender]
gender_target = target[:, 1]
age_out = output[:, num_classes_gender:]
else:
age_out = output
gender_out, gender_target = None, None
# measure elapsed time
process_time = time_sync() - start
age_target = target[:, 0].unsqueeze(1)
return age_out, age_target, gender_out, gender_target, process_time
def _filter_invalid_target(out: torch.tensor, target: torch.tensor):
# exclude samples where target gt == -1, that marks sample is not valid
mask = target != -1
return out[mask], target[mask]
def postprocess_gender(gender_out: torch.tensor, gender_target: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
if gender_target is None:
return gender_out, gender_target
return _filter_invalid_target(gender_out, gender_target)
def postprocess_age(age_out: torch.tensor, age_target: torch.tensor, dataset) -> Tuple[torch.tensor, torch.tensor]:
# Revert _norm_age() operation. Output is 2 float tensors
age_out, age_target = _filter_invalid_target(age_out, age_target)
age_out = age_out * (dataset.max_age - dataset.min_age) + dataset.avg_age
# clamp to 0 because age can be below zero
age_out = torch.clamp(age_out, min=0)
if dataset.age_classes is not None:
# classification case
age_out = torch.round(age_out)
if dataset._intervals.device != age_out.device:
dataset._intervals = dataset._intervals.to(age_out.device)
age_inds = torch.searchsorted(dataset._intervals, age_out, side="right") - 1
age_out = age_inds
else:
age_target = age_target * (dataset.max_age - dataset.min_age) + dataset.avg_age
return age_out, age_target
def validate(args):
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
mivolo_model = MiVOLO(
args.checkpoint,
args.device,
half=args.half,
use_persons=args.with_persons,
disable_faces=args.disable_faces,
verbose=True,
)
dataset, loader = build_data(
name=args.dataset_name,
images_path=args.dataset_images,
annotations_path=args.dataset_annotations,
split=args.split,
mivolo_model=mivolo_model, # to get meta information from model
workers=args.workers,
batch_size=args.batch_size,
)
d_stat = Metrics(args.l_for_cs, args.draw_hist, dataset.age_classes)
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
mivolo_model.warmup(args.batch_size)
preproc_end = time_sync()
for batch_idx, (input, target) in enumerate(loader):
preprocess_time = time_sync() - preproc_end
# get output and calculate loss
age_out, age_target, gender_out, gender_target, process_time = process_batch(
mivolo_model, input, target, dataset.num_classes_gender
)
gender_out, gender_target = postprocess_gender(gender_out, gender_target)
age_out, age_target = postprocess_age(age_out, age_target, dataset)
d_stat.update_gender_accuracy(gender_out, gender_target)
if d_stat.is_regression:
d_stat.update_regression_age_metrics(age_out, age_target)
else:
d_stat.update_age_accuracy(age_out, age_target)
d_stat.update_time(process_time, preprocess_time, input.shape[0])
if batch_idx % LOG_FREQUENCY == 0:
_logger.info(
"Test: [{0:>4d}/{1}] " "{2}".format(batch_idx, len(loader), d_stat.get_info_str(input.size(0)))
)
preproc_end = time_sync()
# model info
results = dict(
model=args.checkpoint,
dataset_name=args.dataset_name,
param_count=round(mivolo_model.param_count / 1e6, 2),
img_size=mivolo_model.input_size,
use_faces=mivolo_model.meta.use_face_crops,
use_persons=mivolo_model.meta.use_persons,
in_chans=mivolo_model.meta.in_chans,
batch=args.batch_size,
)
# metrics info
results.update(d_stat.get_result())
return results
def main():
parser = get_parser()
setup_default_logging()
args = parser.parse_args()
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
results = validate(args)
result_str = " * Age Acc@1 {:.3f} ({:.3f})".format(results["agetop1"], results["agetop1_err"])
if "gendertop1" in results:
result_str += " Gender Acc@1 1 {:.3f} ({:.3f})".format(results["gendertop1"], results["gendertop1_err"])
result_str += " Mean inference time {:.3f} ms Mean preprocessing time {:.3f}".format(
results["mean_inference_time"], results["mean_preprocessing_time"]
)
_logger.info(result_str)
if args.draw_hist and "per_age_error" in results:
err = [sum(v) / len(v) for k, v in results["per_age_error"].items()]
ages = list(results["per_age_error"].keys())
sns.scatterplot(x=ages, y=err, hue=err)
plt.legend([], [], frameon=False)
plt.xlabel("Age")
plt.ylabel("MAE")
plt.savefig("age_error.png", dpi=300)
if args.results_file:
write_results(args.results_file, results, format=args.results_format)
# output results in JSON to stdout w/ delimiter for runner script
print(f"--result\n{json.dumps(results, indent=4)}")
if __name__ == "__main__":
main()