File size: 5,403 Bytes
bf53f45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import csv
import json
import time
from collections import OrderedDict, defaultdict

import torch
from mivolo.data.misc import cumulative_error, cumulative_score
from timm.utils import AverageMeter, accuracy


def time_sync():
    # pytorch-accurate time
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return time.time()


def write_results(results_file, results, format="csv"):
    with open(results_file, mode="w") as cf:
        if format == "json":
            json.dump(results, cf, indent=4)
        else:
            if not isinstance(results, (list, tuple)):
                results = [results]
            if not results:
                return
            dw = csv.DictWriter(cf, fieldnames=results[0].keys())
            dw.writeheader()
            for r in results:
                dw.writerow(r)
            cf.flush()


class Metrics:
    def __init__(self, l_for_cs, draw_hist, age_classes=None):
        self.batch_time = AverageMeter()
        self.preproc_batch_time = AverageMeter()
        self.seen = 0

        self.losses = AverageMeter()
        self.top1_m_gender = AverageMeter()
        self.top1_m_age = AverageMeter()

        if age_classes is None:
            self.is_regression = True
            self.av_csl_age = AverageMeter()
            self.max_error = AverageMeter()
            self.per_age_error = defaultdict(list)
            self.l_for_cs = l_for_cs
        else:
            self.is_regression = False

        self.draw_hist = draw_hist

    def update_regression_age_metrics(self, age_out, age_target):
        batch_size = age_out.size(0)

        age_abs_err = torch.abs(age_out - age_target)
        age_acc1 = torch.sum(age_abs_err) / age_out.shape[0]
        age_csl = cumulative_score(age_out, age_target, self.l_for_cs)
        me = cumulative_error(age_out, age_target, 20)

        self.top1_m_age.update(age_acc1.item(), batch_size)
        self.av_csl_age.update(age_csl.item(), batch_size)
        self.max_error.update(me.item(), batch_size)

        if self.draw_hist:
            for i in range(age_out.shape[0]):
                self.per_age_error[int(age_target[i].item())].append(age_abs_err[i].item())

    def update_age_accuracy(self, age_out, age_target):
        batch_size = age_out.size(0)
        if batch_size == 0:
            return
        correct = torch.sum(age_out == age_target)
        age_acc1 = correct * 100.0 / batch_size
        self.top1_m_age.update(age_acc1.item(), batch_size)

    def update_gender_accuracy(self, gender_out, gender_target):
        if gender_out is None or gender_out.size(0) == 0:
            return
        batch_size = gender_out.size(0)
        gender_acc1 = accuracy(gender_out, gender_target, topk=(1,))[0]
        if gender_acc1 is not None:
            self.top1_m_gender.update(gender_acc1.item(), batch_size)

    def update_loss(self, loss, batch_size):
        self.losses.update(loss.item(), batch_size)

    def update_time(self, process_time, preprocess_time, batch_size):
        self.seen += batch_size
        self.batch_time.update(process_time)
        self.preproc_batch_time.update(preprocess_time)

    def get_info_str(self, batch_size):
        avg_time = (self.preproc_batch_time.sum + self.batch_time.sum) / self.batch_time.count
        cur_time = self.batch_time.val + self.preproc_batch_time.val
        middle_info = (
            "Time: {cur_time:.3f}s ({avg_time:.3f}s, {rate_avg:>7.2f}/s)  "
            "Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  "
            "Gender Acc: {top1gender.val:>7.2f} ({top1gender.avg:>7.2f}) ".format(
                cur_time=cur_time,
                avg_time=avg_time,
                rate_avg=batch_size / avg_time,
                loss=self.losses,
                top1gender=self.top1_m_gender,
            )
        )

        if self.is_regression:
            age_info = (
                "Age CS@{l_for_cs}: {csl.val:>7.4f} ({csl.avg:>7.4f})  "
                "Age CE@20: {max_error.val:>7.4f} ({max_error.avg:>7.4f})  "
                "Age ME: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format(
                    top1age=self.top1_m_age, csl=self.av_csl_age, max_error=self.max_error, l_for_cs=self.l_for_cs
                )
            )
        else:
            age_info = "Age Acc: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format(top1age=self.top1_m_age)

        return middle_info + age_info

    def get_result(self):
        age_top1a = self.top1_m_age.avg
        gender_top1 = self.top1_m_gender.avg if self.top1_m_gender.count > 0 else None

        mean_per_image_time = self.batch_time.sum / self.seen
        mean_preprocessing_time = self.preproc_batch_time.sum / self.seen

        results = OrderedDict(
            mean_inference_time=mean_per_image_time * 1e3,
            mean_preprocessing_time=mean_preprocessing_time * 1e3,
            agetop1=round(age_top1a, 4),
            agetop1_err=round(100 - age_top1a, 4),
        )

        if self.is_regression:
            results.update(
                dict(
                    max_error=self.max_error.avg,
                    csl=self.av_csl_age.avg,
                    per_age_error=self.per_age_error,
                )
            )

        if gender_top1 is not None:
            results.update(dict(gendertop1=round(gender_top1, 4), gendertop1_err=round(100 - gender_top1, 4)))

        return results