hadrakey's picture
Training in progress, step 1000
e06b649 verified
raw
history blame
556 Bytes
import torch
from torchmetrics.text import CharErrorRate
class CERSimilarity:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.cer = CharErrorRate()
def __call__(self, input_ids, target_ids):
# print(input_ids.shape)
input_str = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
target_ids[target_ids == -100] = self.tokenizer.pad_token_id
target_str = self.tokenizer.batch_decode(target_ids, skip_special_tokens=True)
return self.cer(input_str, target_str)