Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import julius | |
import pesq | |
import torch | |
import torchmetrics | |
class PesqMetric(torchmetrics.Metric): | |
"""Metric for Perceptual Evaluation of Speech Quality. | |
(https://doi.org/10.5281/zenodo.6549559) | |
""" | |
sum_pesq: torch.Tensor | |
total: torch.Tensor | |
def __init__(self, sample_rate: int): | |
super().__init__() | |
self.sr = sample_rate | |
self.add_state("sum_pesq", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") | |
def update(self, preds: torch.Tensor, targets: torch.Tensor): | |
if self.sr != 16000: | |
preds = julius.resample_frac(preds, self.sr, 16000) | |
targets = julius.resample_frac(targets, self.sr, 16000) | |
for ii in range(preds.size(0)): | |
try: | |
self.sum_pesq += pesq.pesq( | |
16000, targets[ii, 0].detach().cpu().numpy(), preds[ii, 0].detach().cpu().numpy() | |
) | |
self.total += 1 | |
except ( | |
pesq.NoUtterancesError | |
): # this error can append when the sample don't contain speech | |
pass | |
def compute(self) -> torch.Tensor: | |
return ( | |
self.sum_pesq / self.total | |
if (self.total != 0).item() | |
else torch.tensor(0.0) | |
) | |