Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
import logging | |
from typing import Tuple | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import torch | |
from eval_tools import Metrics, time_sync, write_results | |
from mivolo.data.dataset import build as build_data | |
from mivolo.model.mi_volo import MiVOLO | |
from timm.utils import setup_default_logging | |
_logger = logging.getLogger("inference") | |
LOG_FREQUENCY = 10 | |
def get_parser(): | |
parser = argparse.ArgumentParser(description="PyTorch MiVOLO Validation") | |
parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images") | |
parser.add_argument("--dataset_annotations", default="", type=str, required=True, help="path to annotations") | |
parser.add_argument( | |
"--dataset_name", | |
default=None, | |
type=str, | |
required=True, | |
choices=["utk", "imdb", "lagenda", "fairface", "adience", "agedb", "cacd"], | |
help="dataset name", | |
) | |
parser.add_argument("--split", default="validation", help="dataset splits separated by comma (default: validation)") | |
parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint") | |
parser.add_argument("--batch-size", default=64, type=int, help="batch size") | |
parser.add_argument( | |
"--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" | |
) | |
parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.") | |
parser.add_argument("--l-for-cs", type=int, default=5, help="L for CS (cumulative score)") | |
parser.add_argument("--half", action="store_true", default=False, help="use half-precision model") | |
parser.add_argument( | |
"--with-persons", action="store_true", default=False, help="If the model will run with persons, if available" | |
) | |
parser.add_argument( | |
"--disable-faces", action="store_true", default=False, help="If the model will use only persons if available" | |
) | |
parser.add_argument("--draw-hist", action="store_true", help="Draws the hist of error by age") | |
parser.add_argument( | |
"--results-file", | |
default="", | |
type=str, | |
metavar="FILENAME", | |
help="Output csv file for validation results (summary)", | |
) | |
parser.add_argument( | |
"--results-format", default="csv", type=str, help="Format for results file one of (csv, json) (default: csv)." | |
) | |
return parser | |
def process_batch( | |
mivolo_model: MiVOLO, | |
input: torch.tensor, | |
target: torch.tensor, | |
num_classes_gender: int = 2, | |
): | |
start = time_sync() | |
output = mivolo_model.inference(input) | |
# target with age == -1 and gender == -1 marks that sample is not valid | |
assert not (all(target[:, 0] == -1) and all(target[:, 1] == -1)) | |
if not mivolo_model.meta.only_age: | |
gender_out = output[:, :num_classes_gender] | |
gender_target = target[:, 1] | |
age_out = output[:, num_classes_gender:] | |
else: | |
age_out = output | |
gender_out, gender_target = None, None | |
# measure elapsed time | |
process_time = time_sync() - start | |
age_target = target[:, 0].unsqueeze(1) | |
return age_out, age_target, gender_out, gender_target, process_time | |
def _filter_invalid_target(out: torch.tensor, target: torch.tensor): | |
# exclude samples where target gt == -1, that marks sample is not valid | |
mask = target != -1 | |
return out[mask], target[mask] | |
def postprocess_gender(gender_out: torch.tensor, gender_target: torch.tensor) -> Tuple[torch.tensor, torch.tensor]: | |
if gender_target is None: | |
return gender_out, gender_target | |
return _filter_invalid_target(gender_out, gender_target) | |
def postprocess_age(age_out: torch.tensor, age_target: torch.tensor, dataset) -> Tuple[torch.tensor, torch.tensor]: | |
# Revert _norm_age() operation. Output is 2 float tensors | |
age_out, age_target = _filter_invalid_target(age_out, age_target) | |
age_out = age_out * (dataset.max_age - dataset.min_age) + dataset.avg_age | |
# clamp to 0 because age can be below zero | |
age_out = torch.clamp(age_out, min=0) | |
if dataset.age_classes is not None: | |
# classification case | |
age_out = torch.round(age_out) | |
if dataset._intervals.device != age_out.device: | |
dataset._intervals = dataset._intervals.to(age_out.device) | |
age_inds = torch.searchsorted(dataset._intervals, age_out, side="right") - 1 | |
age_out = age_inds | |
else: | |
age_target = age_target * (dataset.max_age - dataset.min_age) + dataset.avg_age | |
return age_out, age_target | |
def validate(args): | |
if torch.cuda.is_available(): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.benchmark = True | |
mivolo_model = MiVOLO( | |
args.checkpoint, | |
args.device, | |
half=args.half, | |
use_persons=args.with_persons, | |
disable_faces=args.disable_faces, | |
verbose=True, | |
) | |
dataset, loader = build_data( | |
name=args.dataset_name, | |
images_path=args.dataset_images, | |
annotations_path=args.dataset_annotations, | |
split=args.split, | |
mivolo_model=mivolo_model, # to get meta information from model | |
workers=args.workers, | |
batch_size=args.batch_size, | |
) | |
d_stat = Metrics(args.l_for_cs, args.draw_hist, dataset.age_classes) | |
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non | |
mivolo_model.warmup(args.batch_size) | |
preproc_end = time_sync() | |
for batch_idx, (input, target) in enumerate(loader): | |
preprocess_time = time_sync() - preproc_end | |
# get output and calculate loss | |
age_out, age_target, gender_out, gender_target, process_time = process_batch( | |
mivolo_model, input, target, dataset.num_classes_gender | |
) | |
gender_out, gender_target = postprocess_gender(gender_out, gender_target) | |
age_out, age_target = postprocess_age(age_out, age_target, dataset) | |
d_stat.update_gender_accuracy(gender_out, gender_target) | |
if d_stat.is_regression: | |
d_stat.update_regression_age_metrics(age_out, age_target) | |
else: | |
d_stat.update_age_accuracy(age_out, age_target) | |
d_stat.update_time(process_time, preprocess_time, input.shape[0]) | |
if batch_idx % LOG_FREQUENCY == 0: | |
_logger.info( | |
"Test: [{0:>4d}/{1}] " "{2}".format(batch_idx, len(loader), d_stat.get_info_str(input.size(0))) | |
) | |
preproc_end = time_sync() | |
# model info | |
results = dict( | |
model=args.checkpoint, | |
dataset_name=args.dataset_name, | |
param_count=round(mivolo_model.param_count / 1e6, 2), | |
img_size=mivolo_model.input_size, | |
use_faces=mivolo_model.meta.use_face_crops, | |
use_persons=mivolo_model.meta.use_persons, | |
in_chans=mivolo_model.meta.in_chans, | |
batch=args.batch_size, | |
) | |
# metrics info | |
results.update(d_stat.get_result()) | |
return results | |
def main(): | |
parser = get_parser() | |
setup_default_logging() | |
args = parser.parse_args() | |
if torch.cuda.is_available(): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.benchmark = True | |
results = validate(args) | |
result_str = " * Age Acc@1 {:.3f} ({:.3f})".format(results["agetop1"], results["agetop1_err"]) | |
if "gendertop1" in results: | |
result_str += " Gender Acc@1 1 {:.3f} ({:.3f})".format(results["gendertop1"], results["gendertop1_err"]) | |
result_str += " Mean inference time {:.3f} ms Mean preprocessing time {:.3f}".format( | |
results["mean_inference_time"], results["mean_preprocessing_time"] | |
) | |
_logger.info(result_str) | |
if args.draw_hist and "per_age_error" in results: | |
err = [sum(v) / len(v) for k, v in results["per_age_error"].items()] | |
ages = list(results["per_age_error"].keys()) | |
sns.scatterplot(x=ages, y=err, hue=err) | |
plt.legend([], [], frameon=False) | |
plt.xlabel("Age") | |
plt.ylabel("MAE") | |
plt.savefig("age_error.png", dpi=300) | |
if args.results_file: | |
write_results(args.results_file, results, format=args.results_format) | |
# output results in JSON to stdout w/ delimiter for runner script | |
print(f"--result\n{json.dumps(results, indent=4)}") | |
if __name__ == "__main__": | |
main() | |