File size: 8,377 Bytes
bf53f45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import argparse
import json
import logging
from typing import Tuple

import matplotlib.pyplot as plt
import seaborn as sns
import torch
from eval_tools import Metrics, time_sync, write_results
from mivolo.data.dataset import build as build_data
from mivolo.model.mi_volo import MiVOLO
from timm.utils import setup_default_logging

_logger = logging.getLogger("inference")
LOG_FREQUENCY = 10


def get_parser():
    parser = argparse.ArgumentParser(description="PyTorch MiVOLO Validation")
    parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images")
    parser.add_argument("--dataset_annotations", default="", type=str, required=True, help="path to annotations")
    parser.add_argument(
        "--dataset_name",
        default=None,
        type=str,
        required=True,
        choices=["utk", "imdb", "lagenda", "fairface", "adience", "agedb", "cacd"],
        help="dataset name",
    )
    parser.add_argument("--split", default="validation", help="dataset splits separated by comma (default: validation)")
    parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")

    parser.add_argument("--batch-size", default=64, type=int, help="batch size")
    parser.add_argument(
        "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
    )
    parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.")
    parser.add_argument("--l-for-cs", type=int, default=5, help="L for CS (cumulative score)")

    parser.add_argument("--half", action="store_true", default=False, help="use half-precision model")
    parser.add_argument(
        "--with-persons", action="store_true", default=False, help="If the model will run with persons, if available"
    )
    parser.add_argument(
        "--disable-faces", action="store_true", default=False, help="If the model will use only persons if available"
    )

    parser.add_argument("--draw-hist", action="store_true", help="Draws the hist of error by age")
    parser.add_argument(
        "--results-file",
        default="",
        type=str,
        metavar="FILENAME",
        help="Output csv file for validation results (summary)",
    )
    parser.add_argument(
        "--results-format", default="csv", type=str, help="Format for results file one of (csv, json) (default: csv)."
    )

    return parser


def process_batch(
    mivolo_model: MiVOLO,
    input: torch.tensor,
    target: torch.tensor,
    num_classes_gender: int = 2,
):

    start = time_sync()
    output = mivolo_model.inference(input)
    # target with age == -1 and gender == -1 marks that sample is not valid
    assert not (all(target[:, 0] == -1) and all(target[:, 1] == -1))

    if not mivolo_model.meta.only_age:
        gender_out = output[:, :num_classes_gender]
        gender_target = target[:, 1]
        age_out = output[:, num_classes_gender:]
    else:
        age_out = output
        gender_out, gender_target = None, None

    # measure elapsed time
    process_time = time_sync() - start

    age_target = target[:, 0].unsqueeze(1)

    return age_out, age_target, gender_out, gender_target, process_time


def _filter_invalid_target(out: torch.tensor, target: torch.tensor):
    # exclude samples where target gt == -1, that marks sample is not valid
    mask = target != -1
    return out[mask], target[mask]


def postprocess_gender(gender_out: torch.tensor, gender_target: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
    if gender_target is None:
        return gender_out, gender_target
    return _filter_invalid_target(gender_out, gender_target)


def postprocess_age(age_out: torch.tensor, age_target: torch.tensor, dataset) -> Tuple[torch.tensor, torch.tensor]:
    # Revert _norm_age() operation. Output is 2 float tensors

    age_out, age_target = _filter_invalid_target(age_out, age_target)

    age_out = age_out * (dataset.max_age - dataset.min_age) + dataset.avg_age
    # clamp to 0 because age can be below zero
    age_out = torch.clamp(age_out, min=0)

    if dataset.age_classes is not None:
        # classification case
        age_out = torch.round(age_out)
        if dataset._intervals.device != age_out.device:
            dataset._intervals = dataset._intervals.to(age_out.device)
        age_inds = torch.searchsorted(dataset._intervals, age_out, side="right") - 1
        age_out = age_inds
    else:
        age_target = age_target * (dataset.max_age - dataset.min_age) + dataset.avg_age
    return age_out, age_target


def validate(args):

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True

    mivolo_model = MiVOLO(
        args.checkpoint,
        args.device,
        half=args.half,
        use_persons=args.with_persons,
        disable_faces=args.disable_faces,
        verbose=True,
    )

    dataset, loader = build_data(
        name=args.dataset_name,
        images_path=args.dataset_images,
        annotations_path=args.dataset_annotations,
        split=args.split,
        mivolo_model=mivolo_model,  # to get meta information from model
        workers=args.workers,
        batch_size=args.batch_size,
    )

    d_stat = Metrics(args.l_for_cs, args.draw_hist, dataset.age_classes)

    # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
    mivolo_model.warmup(args.batch_size)

    preproc_end = time_sync()
    for batch_idx, (input, target) in enumerate(loader):

        preprocess_time = time_sync() - preproc_end
        # get output and calculate loss
        age_out, age_target, gender_out, gender_target, process_time = process_batch(
            mivolo_model, input, target, dataset.num_classes_gender
        )

        gender_out, gender_target = postprocess_gender(gender_out, gender_target)
        age_out, age_target = postprocess_age(age_out, age_target, dataset)

        d_stat.update_gender_accuracy(gender_out, gender_target)
        if d_stat.is_regression:
            d_stat.update_regression_age_metrics(age_out, age_target)
        else:
            d_stat.update_age_accuracy(age_out, age_target)
        d_stat.update_time(process_time, preprocess_time, input.shape[0])

        if batch_idx % LOG_FREQUENCY == 0:
            _logger.info(
                "Test: [{0:>4d}/{1}]  " "{2}".format(batch_idx, len(loader), d_stat.get_info_str(input.size(0)))
            )

        preproc_end = time_sync()

    # model info
    results = dict(
        model=args.checkpoint,
        dataset_name=args.dataset_name,
        param_count=round(mivolo_model.param_count / 1e6, 2),
        img_size=mivolo_model.input_size,
        use_faces=mivolo_model.meta.use_face_crops,
        use_persons=mivolo_model.meta.use_persons,
        in_chans=mivolo_model.meta.in_chans,
        batch=args.batch_size,
    )
    # metrics info
    results.update(d_stat.get_result())
    return results


def main():
    parser = get_parser()
    setup_default_logging()
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True

    results = validate(args)

    result_str = " * Age Acc@1 {:.3f} ({:.3f})".format(results["agetop1"], results["agetop1_err"])
    if "gendertop1" in results:
        result_str += " Gender Acc@1 1 {:.3f} ({:.3f})".format(results["gendertop1"], results["gendertop1_err"])
    result_str += " Mean inference time {:.3f} ms Mean preprocessing time {:.3f}".format(
        results["mean_inference_time"], results["mean_preprocessing_time"]
    )
    _logger.info(result_str)

    if args.draw_hist and "per_age_error" in results:
        err = [sum(v) / len(v) for k, v in results["per_age_error"].items()]
        ages = list(results["per_age_error"].keys())
        sns.scatterplot(x=ages, y=err, hue=err)
        plt.legend([], [], frameon=False)
        plt.xlabel("Age")
        plt.ylabel("MAE")
        plt.savefig("age_error.png", dpi=300)

    if args.results_file:
        write_results(args.results_file, results, format=args.results_format)

    # output results in JSON to stdout w/ delimiter for runner script
    print(f"--result\n{json.dumps(results, indent=4)}")


if __name__ == "__main__":
    main()