|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.scoring import BaseScorer, register_scorer |
|
from fairseq.scoring.tokenizer import EvaluationTokenizer |
|
|
|
|
|
@dataclass |
|
class WerScorerConfig(FairseqDataclass): |
|
wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( |
|
default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} |
|
) |
|
wer_remove_punct: bool = field( |
|
default=False, metadata={"help": "remove punctuation"} |
|
) |
|
wer_char_level: bool = field( |
|
default=False, metadata={"help": "evaluate at character level"} |
|
) |
|
wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) |
|
|
|
|
|
@register_scorer("wer", dataclass=WerScorerConfig) |
|
class WerScorer(BaseScorer): |
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
self.reset() |
|
try: |
|
import editdistance as ed |
|
except ImportError: |
|
raise ImportError("Please install editdistance to use WER scorer") |
|
self.ed = ed |
|
self.tokenizer = EvaluationTokenizer( |
|
tokenizer_type=self.cfg.wer_tokenizer, |
|
lowercase=self.cfg.wer_lowercase, |
|
punctuation_removal=self.cfg.wer_remove_punct, |
|
character_tokenization=self.cfg.wer_char_level, |
|
) |
|
|
|
def reset(self): |
|
self.distance = 0 |
|
self.ref_length = 0 |
|
|
|
def add_string(self, ref, pred): |
|
ref_items = self.tokenizer.tokenize(ref).split() |
|
pred_items = self.tokenizer.tokenize(pred).split() |
|
self.distance += self.ed.eval(ref_items, pred_items) |
|
self.ref_length += len(ref_items) |
|
|
|
def result_string(self): |
|
return f"WER: {self.score():.2f}" |
|
|
|
def score(self): |
|
return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 |
|
|