File size: 556 Bytes
e06b649 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
from torchmetrics.text import CharErrorRate
class CERSimilarity:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.cer = CharErrorRate()
def __call__(self, input_ids, target_ids):
# print(input_ids.shape)
input_str = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
target_ids[target_ids == -100] = self.tokenizer.pad_token_id
target_str = self.tokenizer.batch_decode(target_ids, skip_special_tokens=True)
return self.cer(input_str, target_str)
|