from datasets import load_metric | |
cer_metric = load_metric("cer") | |
def compute_metrics(pred): | |
labels_ids = pred.label_ids | |
pred_ids = pred.predictions | |
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) | |
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id | |
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) | |
cer = cer_metric.compute(predictions=pred_str, references=label_str) | |
return {"cer": cer} |