import csv import json import time from collections import OrderedDict, defaultdict import torch from mivolo.data.misc import cumulative_error, cumulative_score from timm.utils import AverageMeter, accuracy def time_sync(): # pytorch-accurate time if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() def write_results(results_file, results, format="csv"): with open(results_file, mode="w") as cf: if format == "json": json.dump(results, cf, indent=4) else: if not isinstance(results, (list, tuple)): results = [results] if not results: return dw = csv.DictWriter(cf, fieldnames=results[0].keys()) dw.writeheader() for r in results: dw.writerow(r) cf.flush() class Metrics: def __init__(self, l_for_cs, draw_hist, age_classes=None): self.batch_time = AverageMeter() self.preproc_batch_time = AverageMeter() self.seen = 0 self.losses = AverageMeter() self.top1_m_gender = AverageMeter() self.top1_m_age = AverageMeter() if age_classes is None: self.is_regression = True self.av_csl_age = AverageMeter() self.max_error = AverageMeter() self.per_age_error = defaultdict(list) self.l_for_cs = l_for_cs else: self.is_regression = False self.draw_hist = draw_hist def update_regression_age_metrics(self, age_out, age_target): batch_size = age_out.size(0) age_abs_err = torch.abs(age_out - age_target) age_acc1 = torch.sum(age_abs_err) / age_out.shape[0] age_csl = cumulative_score(age_out, age_target, self.l_for_cs) me = cumulative_error(age_out, age_target, 20) self.top1_m_age.update(age_acc1.item(), batch_size) self.av_csl_age.update(age_csl.item(), batch_size) self.max_error.update(me.item(), batch_size) if self.draw_hist: for i in range(age_out.shape[0]): self.per_age_error[int(age_target[i].item())].append(age_abs_err[i].item()) def update_age_accuracy(self, age_out, age_target): batch_size = age_out.size(0) if batch_size == 0: return correct = torch.sum(age_out == age_target) age_acc1 = correct * 100.0 / batch_size self.top1_m_age.update(age_acc1.item(), batch_size) def update_gender_accuracy(self, gender_out, gender_target): if gender_out is None or gender_out.size(0) == 0: return batch_size = gender_out.size(0) gender_acc1 = accuracy(gender_out, gender_target, topk=(1,))[0] if gender_acc1 is not None: self.top1_m_gender.update(gender_acc1.item(), batch_size) def update_loss(self, loss, batch_size): self.losses.update(loss.item(), batch_size) def update_time(self, process_time, preprocess_time, batch_size): self.seen += batch_size self.batch_time.update(process_time) self.preproc_batch_time.update(preprocess_time) def get_info_str(self, batch_size): avg_time = (self.preproc_batch_time.sum + self.batch_time.sum) / self.batch_time.count cur_time = self.batch_time.val + self.preproc_batch_time.val middle_info = ( "Time: {cur_time:.3f}s ({avg_time:.3f}s, {rate_avg:>7.2f}/s) " "Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) " "Gender Acc: {top1gender.val:>7.2f} ({top1gender.avg:>7.2f}) ".format( cur_time=cur_time, avg_time=avg_time, rate_avg=batch_size / avg_time, loss=self.losses, top1gender=self.top1_m_gender, ) ) if self.is_regression: age_info = ( "Age CS@{l_for_cs}: {csl.val:>7.4f} ({csl.avg:>7.4f}) " "Age CE@20: {max_error.val:>7.4f} ({max_error.avg:>7.4f}) " "Age ME: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format( top1age=self.top1_m_age, csl=self.av_csl_age, max_error=self.max_error, l_for_cs=self.l_for_cs ) ) else: age_info = "Age Acc: {top1age.val:>7.2f} ({top1age.avg:>7.2f})".format(top1age=self.top1_m_age) return middle_info + age_info def get_result(self): age_top1a = self.top1_m_age.avg gender_top1 = self.top1_m_gender.avg if self.top1_m_gender.count > 0 else None mean_per_image_time = self.batch_time.sum / self.seen mean_preprocessing_time = self.preproc_batch_time.sum / self.seen results = OrderedDict( mean_inference_time=mean_per_image_time * 1e3, mean_preprocessing_time=mean_preprocessing_time * 1e3, agetop1=round(age_top1a, 4), agetop1_err=round(100 - age_top1a, 4), ) if self.is_regression: results.update( dict( max_error=self.max_error.avg, csl=self.av_csl_age.avg, per_age_error=self.per_age_error, ) ) if gender_top1 is not None: results.update(dict(gendertop1=round(gender_top1, 4), gendertop1_err=round(100 - gender_top1, 4))) return results