File size: 556 Bytes
e06b649
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from torchmetrics.text import CharErrorRate

class CERSimilarity:

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

        self.cer = CharErrorRate()

    def __call__(self, input_ids, target_ids):
        # print(input_ids.shape)
        input_str = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        target_ids[target_ids == -100] = self.tokenizer.pad_token_id
        target_str = self.tokenizer.batch_decode(target_ids, skip_special_tokens=True)

        return self.cer(input_str, target_str)