import torch from torchmetrics.text import CharErrorRate class CERSimilarity: def __init__(self, tokenizer): self.tokenizer = tokenizer self.cer = CharErrorRate() def __call__(self, input_ids, target_ids): # print(input_ids.shape) input_str = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True) target_ids[target_ids == -100] = self.tokenizer.pad_token_id target_str = self.tokenizer.batch_decode(target_ids, skip_special_tokens=True) return self.cer(input_str, target_str)