auto_avsr / eval.py
mpc001's picture
Upload 125 files
09481f3
raw
history blame
2.47 kB
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import os
import torch
import hydra
from pipelines.metrics.measures import get_wer
from pipelines.metrics.measures import get_cer
from pipelines.pipeline import InferencePipeline
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.total = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.total += val * n
self.count += n
self.avg = self.total / self.count
def benchmark_inference(inference_pipeline, data_dir, landmarks_dir, lines, data_ext=".mp4", landmarks_ext=".pkl"):
wer, cer = AverageMeter(), AverageMeter()
for idx, line in enumerate(lines):
basename, groundtruth = line.split()[0], " ".join(line.split()[1:])
data_filename = os.path.join(data_dir, f"{basename}{data_ext}")
landmarks_filename = os.path.join(landmarks_dir, f"{basename}{landmarks_ext}") if landmarks_dir else None
output = inference_pipeline(data_filename, landmarks_filename)
print(f"hyp: {output}\nref: {groundtruth}" if groundtruth is not None else "")
if groundtruth is not None:
wer.update(get_wer(output, groundtruth), len(groundtruth.split()))
cer.update(get_cer(output, groundtruth), len(groundtruth))
print(f"progress: {idx+1}/{len(lines)}\tcur WER: {wer.val*100:.2f}\tcur CER: {cer.val*100:.2f}\tavg WER: {wer.avg*100:.2f}\tavg CER: {cer.avg*100:.2f}")
@hydra.main(version_base=None, config_path="hydra_configs", config_name="default")
def main(cfg):
device = torch.device(f"cuda:{cfg.gpu_idx}") if torch.cuda.is_available() and cfg.gpu_idx >= 0 else "cpu"
inference_pipeline = InferencePipeline(config_filename=cfg.config_filename, detector=cfg.detector, face_track=not cfg.landmarks_filename and not cfg.landmarks_dir, device=device)
assert os.path.isdir(cfg.data_dir), f"{cfg.data_dir} is not a directory."
assert os.path.isfile(cfg.labels_filename), f"{cfg.labels_filename} does not exist."
benchmark_inference(inference_pipeline, cfg.data_dir, cfg.landmarks_dir, open(cfg.labels_filename).read().splitlines(), cfg.data_ext, cfg.landmarks_ext)
if __name__ == '__main__':
main()