File size: 478 Bytes
e06b649 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from datasets import load_metric
cer_metric = load_metric("cer")
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer} |