Automatic Speech Recognition
Transformers
Safetensors
Japanese
whisper
audio
hf-asr-leaderboard
Inference Endpoints
asahi417 commited on
Commit
8ad1a53
1 Parent(s): 7dd6751

Create run_short_form_eval.py

Browse files
Files changed (1) hide show
  1. run_short_form_eval.py +125 -0
run_short_form_eval.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute CER/WER for Japanese ASR models."""
2
+ import json
3
+ import os
4
+ import argparse
5
+ from pprint import pprint
6
+
7
+ import torch
8
+ import pandas as pd
9
+ from transformers import pipeline
10
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
11
+ from datasets import load_dataset
12
+ from evaluate import load
13
+
14
+ parser = argparse.ArgumentParser(description='Compute CER/WER for Japanese ASR model.')
15
+ parser.add_argument('-m', '--model', default="kotoba-tech/kotoba-whisper-v1.1", type=str)
16
+ parser.add_argument('-d', '--dataset', default="japanese-asr/ja_asr.jsut_basic5000", type=str)
17
+ parser.add_argument('-a', '--attn', default="sdpa", type=str)
18
+ parser.add_argument('-b', '--batch', default=16, type=int)
19
+ parser.add_argument('-c', '--chunk-length', default=15, type=int)
20
+ parser.add_argument('-o', '--output-dir', default="eval_pipeline", type=str)
21
+ parser.add_argument('-p', '--punctuator', action="store_true")
22
+ parser.add_argument('-s', '--stable-ts', action="store_true")
23
+ parser.add_argument('--pretty-table', action="store_true")
24
+ arg = parser.parse_args()
25
+
26
+ os.makedirs(arg.output_dir, exist_ok=True)
27
+ output_metric_file = f"{arg.output_dir}/metric.jsonl"
28
+
29
+ # display mode
30
+ if arg.pretty_table:
31
+ with open(output_metric_file) as f:
32
+ metrics = [json.loads(s) for s in f.read().split("\n") if len(s) > 0]
33
+ df_metric = pd.DataFrame(metrics).round(1).sort_values(["dataset", "model"])
34
+ df_metric["cer/wer (norm)"] = [f"{c}/{w}" for c, w in zip(df_metric["cer_norm"], df_metric["wer_norm"])]
35
+ df_metric["cer/wer (raw)"] = [f"{c}/{w}" for c, w in zip(df_metric["cer_raw"], df_metric["wer_raw"])]
36
+
37
+ def pretty(m, p, s):
38
+ if p and s:
39
+ return f"{m} (punctuator + stable-ts)"
40
+ if s:
41
+ return f"{m} (stable-ts)"
42
+ if p:
43
+ return f"{m} (punctuator)"
44
+ return m
45
+
46
+ df_metric["model"] = [pretty(m, p, s) for m, p, s in zip(df_metric["model"], df_metric["punctuator"], df_metric["stable_ts"])]
47
+ df_metric = df_metric[["model", "dataset", "punctuator", "stable_ts", "cer/wer (raw)", "cer/wer (norm)"]]
48
+ print(df_metric)
49
+ df_metric = df_metric.drop_duplicates()
50
+ print("\nNORM")
51
+ print(df_metric.pivot(values="cer/wer (norm)", columns="dataset", index="model").to_markdown())
52
+ print("\nRAW")
53
+ print(df_metric.pivot(values="cer/wer (raw)", columns="dataset", index="model").to_markdown())
54
+ exit()
55
+
56
+ # model config
57
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
58
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
59
+ model_kwargs = {"attn_implementation": arg.attn} if torch.cuda.is_available() and arg.attn else {}
60
+ generate_kwargs = {"language": "japanese", "task": "transcribe"}
61
+ pipeline_config = dict(
62
+ model=arg.model,
63
+ torch_dtype=torch_dtype,
64
+ device=device,
65
+ model_kwargs=model_kwargs,
66
+ chunk_length_s=arg.chunk_length,
67
+ batch_size=arg.batch
68
+ )
69
+
70
+ # instantiate pipeline
71
+ metric = {"model": arg.model, "dataset": arg.dataset, "chunk_length_s": arg.chunk_length}
72
+ if arg.model in ["kotoba-tech/kotoba-whisper-v1.1"]:
73
+ pipe = pipeline(trust_remote_code=True, punctuator=arg.punctuator, stable_ts=arg.stable_ts, **pipeline_config)
74
+ stable_ts, punctuator = arg.stable_ts, arg.punctuator
75
+ else:
76
+ pipe = pipeline("automatic-speech-recognition", **pipeline_config)
77
+ stable_ts, punctuator = None, None
78
+ metric.update({"punctuator": punctuator, "stable_ts": stable_ts})
79
+
80
+ # load the dataset and get prediction
81
+ dataset = load_dataset(arg.dataset, split="test")
82
+ output = pipe(dataset['audio'], generate_kwargs=generate_kwargs)
83
+ normalizer = BasicTextNormalizer()
84
+ prediction_norm = [normalizer(i['text']).replace(" ", "") for i in output]
85
+ references_norm = [normalizer(i).replace(" ", "") for i in dataset['transcription']]
86
+ prediction_raw = [i['text'].replace(" ", "") for i in output]
87
+ references_raw = [i.replace(" ", "") for i in dataset['transcription']]
88
+
89
+ # compute metrics
90
+ cer_metric = load("cer")
91
+ cer_norm = 100 * cer_metric.compute(predictions=prediction_norm, references=references_norm)
92
+ cer_raw = 100 * cer_metric.compute(predictions=prediction_raw, references=references_raw)
93
+ wer_metric = load("wer")
94
+ wer_norm = 100 * wer_metric.compute(predictions=prediction_norm, references=references_norm)
95
+ wer_raw = 100 * wer_metric.compute(predictions=prediction_raw, references=references_raw)
96
+ metric.update({"cer_raw": cer_raw, "wer_raw": wer_raw, "cer_norm": cer_norm, "wer_norm": wer_norm})
97
+
98
+ # save the results
99
+ metrics = []
100
+ if os.path.exists(output_metric_file):
101
+ with open(output_metric_file) as f:
102
+ metrics += [json.loads(s) for s in f.read().split("\n") if len(s) > 0]
103
+ output_prediction_file = f"{arg.output_dir}/prediction.csv"
104
+ dfs = None
105
+ if os.path.exists(output_prediction_file):
106
+ dfs = pd.read_csv(output_prediction_file, index_col=0)
107
+ metrics.append(metric)
108
+ pprint(metrics)
109
+ with open(output_metric_file, "w") as f:
110
+ f.write("\n".join([json.dumps(s) for s in metrics]))
111
+
112
+ # save prediction
113
+ audio_id = [i["path"] for i in dataset['audio']]
114
+ df = pd.DataFrame(
115
+ [audio_id, references_norm, prediction_norm, references_raw, prediction_raw],
116
+ index=["id", "reference_norm", "prediction_norm", "reference_raw", "prediction_raw"]
117
+ ).T
118
+ df["model"] = arg.model
119
+ df["dataset"] = arg.dataset
120
+ df["stable_ts"] = stable_ts
121
+ df["punctuator"] = punctuator
122
+ df["chunk_length_s"] = arg.chunk_length
123
+ dfs = df if dfs is None else pd.concat([dfs, df])
124
+ dfs.to_csv(output_prediction_file, index=False)
125
+