Spaces:
Sleeping
Sleeping
File size: 8,377 Bytes
bf53f45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import argparse
import json
import logging
from typing import Tuple
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from eval_tools import Metrics, time_sync, write_results
from mivolo.data.dataset import build as build_data
from mivolo.model.mi_volo import MiVOLO
from timm.utils import setup_default_logging
_logger = logging.getLogger("inference")
LOG_FREQUENCY = 10
def get_parser():
parser = argparse.ArgumentParser(description="PyTorch MiVOLO Validation")
parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images")
parser.add_argument("--dataset_annotations", default="", type=str, required=True, help="path to annotations")
parser.add_argument(
"--dataset_name",
default=None,
type=str,
required=True,
choices=["utk", "imdb", "lagenda", "fairface", "adience", "agedb", "cacd"],
help="dataset name",
)
parser.add_argument("--split", default="validation", help="dataset splits separated by comma (default: validation)")
parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")
parser.add_argument("--batch-size", default=64, type=int, help="batch size")
parser.add_argument(
"--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
)
parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.")
parser.add_argument("--l-for-cs", type=int, default=5, help="L for CS (cumulative score)")
parser.add_argument("--half", action="store_true", default=False, help="use half-precision model")
parser.add_argument(
"--with-persons", action="store_true", default=False, help="If the model will run with persons, if available"
)
parser.add_argument(
"--disable-faces", action="store_true", default=False, help="If the model will use only persons if available"
)
parser.add_argument("--draw-hist", action="store_true", help="Draws the hist of error by age")
parser.add_argument(
"--results-file",
default="",
type=str,
metavar="FILENAME",
help="Output csv file for validation results (summary)",
)
parser.add_argument(
"--results-format", default="csv", type=str, help="Format for results file one of (csv, json) (default: csv)."
)
return parser
def process_batch(
mivolo_model: MiVOLO,
input: torch.tensor,
target: torch.tensor,
num_classes_gender: int = 2,
):
start = time_sync()
output = mivolo_model.inference(input)
# target with age == -1 and gender == -1 marks that sample is not valid
assert not (all(target[:, 0] == -1) and all(target[:, 1] == -1))
if not mivolo_model.meta.only_age:
gender_out = output[:, :num_classes_gender]
gender_target = target[:, 1]
age_out = output[:, num_classes_gender:]
else:
age_out = output
gender_out, gender_target = None, None
# measure elapsed time
process_time = time_sync() - start
age_target = target[:, 0].unsqueeze(1)
return age_out, age_target, gender_out, gender_target, process_time
def _filter_invalid_target(out: torch.tensor, target: torch.tensor):
# exclude samples where target gt == -1, that marks sample is not valid
mask = target != -1
return out[mask], target[mask]
def postprocess_gender(gender_out: torch.tensor, gender_target: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
if gender_target is None:
return gender_out, gender_target
return _filter_invalid_target(gender_out, gender_target)
def postprocess_age(age_out: torch.tensor, age_target: torch.tensor, dataset) -> Tuple[torch.tensor, torch.tensor]:
# Revert _norm_age() operation. Output is 2 float tensors
age_out, age_target = _filter_invalid_target(age_out, age_target)
age_out = age_out * (dataset.max_age - dataset.min_age) + dataset.avg_age
# clamp to 0 because age can be below zero
age_out = torch.clamp(age_out, min=0)
if dataset.age_classes is not None:
# classification case
age_out = torch.round(age_out)
if dataset._intervals.device != age_out.device:
dataset._intervals = dataset._intervals.to(age_out.device)
age_inds = torch.searchsorted(dataset._intervals, age_out, side="right") - 1
age_out = age_inds
else:
age_target = age_target * (dataset.max_age - dataset.min_age) + dataset.avg_age
return age_out, age_target
def validate(args):
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
mivolo_model = MiVOLO(
args.checkpoint,
args.device,
half=args.half,
use_persons=args.with_persons,
disable_faces=args.disable_faces,
verbose=True,
)
dataset, loader = build_data(
name=args.dataset_name,
images_path=args.dataset_images,
annotations_path=args.dataset_annotations,
split=args.split,
mivolo_model=mivolo_model, # to get meta information from model
workers=args.workers,
batch_size=args.batch_size,
)
d_stat = Metrics(args.l_for_cs, args.draw_hist, dataset.age_classes)
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
mivolo_model.warmup(args.batch_size)
preproc_end = time_sync()
for batch_idx, (input, target) in enumerate(loader):
preprocess_time = time_sync() - preproc_end
# get output and calculate loss
age_out, age_target, gender_out, gender_target, process_time = process_batch(
mivolo_model, input, target, dataset.num_classes_gender
)
gender_out, gender_target = postprocess_gender(gender_out, gender_target)
age_out, age_target = postprocess_age(age_out, age_target, dataset)
d_stat.update_gender_accuracy(gender_out, gender_target)
if d_stat.is_regression:
d_stat.update_regression_age_metrics(age_out, age_target)
else:
d_stat.update_age_accuracy(age_out, age_target)
d_stat.update_time(process_time, preprocess_time, input.shape[0])
if batch_idx % LOG_FREQUENCY == 0:
_logger.info(
"Test: [{0:>4d}/{1}] " "{2}".format(batch_idx, len(loader), d_stat.get_info_str(input.size(0)))
)
preproc_end = time_sync()
# model info
results = dict(
model=args.checkpoint,
dataset_name=args.dataset_name,
param_count=round(mivolo_model.param_count / 1e6, 2),
img_size=mivolo_model.input_size,
use_faces=mivolo_model.meta.use_face_crops,
use_persons=mivolo_model.meta.use_persons,
in_chans=mivolo_model.meta.in_chans,
batch=args.batch_size,
)
# metrics info
results.update(d_stat.get_result())
return results
def main():
parser = get_parser()
setup_default_logging()
args = parser.parse_args()
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
results = validate(args)
result_str = " * Age Acc@1 {:.3f} ({:.3f})".format(results["agetop1"], results["agetop1_err"])
if "gendertop1" in results:
result_str += " Gender Acc@1 1 {:.3f} ({:.3f})".format(results["gendertop1"], results["gendertop1_err"])
result_str += " Mean inference time {:.3f} ms Mean preprocessing time {:.3f}".format(
results["mean_inference_time"], results["mean_preprocessing_time"]
)
_logger.info(result_str)
if args.draw_hist and "per_age_error" in results:
err = [sum(v) / len(v) for k, v in results["per_age_error"].items()]
ages = list(results["per_age_error"].keys())
sns.scatterplot(x=ages, y=err, hue=err)
plt.legend([], [], frameon=False)
plt.xlabel("Age")
plt.ylabel("MAE")
plt.savefig("age_error.png", dpi=300)
if args.results_file:
write_results(args.results_file, results, format=args.results_format)
# output results in JSON to stdout w/ delimiter for runner script
print(f"--result\n{json.dumps(results, indent=4)}")
if __name__ == "__main__":
main()
|