Runtime error
Runtime error
#! /usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Copyright 2023 Imperial College London (Pingchuan Ma) | |
# Apache 2.0 ( | |
import os | |
import torch | |
import hydra | |
from pipelines.metrics.measures import get_wer | |
from pipelines.metrics.measures import get_cer | |
from pipelines.pipeline import InferencePipeline | |
class AverageMeter: | |
"""Computes and stores the average and current value.""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | | = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | | += val * n | |
self.count += n | |
self.avg = / self.count | |
def benchmark_inference(inference_pipeline, data_dir, landmarks_dir, lines, data_ext=".mp4", landmarks_ext=".pkl"): | |
wer, cer = AverageMeter(), AverageMeter() | |
for idx, line in enumerate(lines): | |
basename, groundtruth = line.split()[0], " ".join(line.split()[1:]) | |
data_filename = os.path.join(data_dir, f"{basename}{data_ext}") | |
landmarks_filename = os.path.join(landmarks_dir, f"{basename}{landmarks_ext}") if landmarks_dir else None | |
output = inference_pipeline(data_filename, landmarks_filename) | |
print(f"hyp: {output}\nref: {groundtruth}" if groundtruth is not None else "") | |
if groundtruth is not None: | |
wer.update(get_wer(output, groundtruth), len(groundtruth.split())) | |
cer.update(get_cer(output, groundtruth), len(groundtruth)) | |
print(f"progress: {idx+1}/{len(lines)}\tcur WER: {wer.val*100:.2f}\tcur CER: {cer.val*100:.2f}\tavg WER: {wer.avg*100:.2f}\tavg CER: {cer.avg*100:.2f}") | |
def main(cfg): | |
device = torch.device(f"cuda:{cfg.gpu_idx}") if torch.cuda.is_available() and cfg.gpu_idx >= 0 else "cpu" | |
inference_pipeline = InferencePipeline(config_filename=cfg.config_filename, detector=cfg.detector, face_track=not cfg.landmarks_filename and not cfg.landmarks_dir, device=device) | |
assert os.path.isdir(cfg.data_dir), f"{cfg.data_dir} is not a directory." | |
assert os.path.isfile(cfg.labels_filename), f"{cfg.labels_filename} does not exist." | |
benchmark_inference(inference_pipeline, cfg.data_dir, cfg.landmarks_dir, open(cfg.labels_filename).read().splitlines(), cfg.data_ext, cfg.landmarks_ext) | |
if __name__ == '__main__': | |
main() | |