import torch | |
from torchmetrics.text import CharErrorRate | |
class CERSimilarity: | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.cer = CharErrorRate() | |
def __call__(self, input_ids, target_ids): | |
# print(input_ids.shape) | |
input_str = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) | |
target_ids[target_ids == -100] = self.tokenizer.pad_token_id | |
target_str = self.tokenizer.batch_decode(target_ids, skip_special_tokens=True) | |
return self.cer(input_str, target_str) | |