Spaces:
Sleeping
Sleeping
# Copyright 2021 The HuggingFace Evaluate Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" Word Error Ratio (WER) metric. """ | |
import datasets | |
from jiwer import compute_measures | |
import evaluate | |
_KWARGS_DESCRIPTION = """ | |
Compute WER score of transcribed segments against references. | |
Args: | |
references: List of references for each speech input. | |
predictions: List of transcriptions to score. | |
concatenate_texts (bool, default=False): Whether to concatenate all input texts or compute WER iteratively. | |
Returns: | |
(float): the word error rate | |
Examples: | |
>>> predictions = ["this is the prediction", "there is an other sample"] | |
>>> references = ["this is the reference", "there is another one"] | |
>>> wer = evaluate.load("wer") | |
>>> wer_score = wer.compute(predictions=predictions, references=references) | |
>>> print(wer_score) | |
0.5 | |
""" | |
class WER(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Value("string", id="sequence"), | |
} | |
), | |
codebase_urls=["https://github.com/jitsi/jiwer/"], | |
reference_urls=[ | |
"https://en.wikipedia.org/wiki/Word_error_rate", | |
], | |
) | |
def _compute(self, predictions=None, references=None, concatenate_texts=False): | |
if concatenate_texts: | |
return compute_measures(references, predictions)["wer"] | |
else: | |
incorrect = 0 | |
total = 0 | |
for prediction, reference in zip(predictions, references): | |
measures = compute_measures(reference, prediction) | |
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"] | |
total += measures["substitutions"] + measures["deletions"] + measures["hits"] | |
return incorrect / total | |