Create run_short_form_eval.py
Browse files- run_short_form_eval.py +125 -0
run_short_form_eval.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Compute CER/WER for Japanese ASR models."""
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
from pprint import pprint
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import pandas as pd
|
9 |
+
from transformers import pipeline
|
10 |
+
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
11 |
+
from datasets import load_dataset
|
12 |
+
from evaluate import load
|
13 |
+
|
14 |
+
parser = argparse.ArgumentParser(description='Compute CER/WER for Japanese ASR model.')
|
15 |
+
parser.add_argument('-m', '--model', default="kotoba-tech/kotoba-whisper-v1.1", type=str)
|
16 |
+
parser.add_argument('-d', '--dataset', default="japanese-asr/ja_asr.jsut_basic5000", type=str)
|
17 |
+
parser.add_argument('-a', '--attn', default="sdpa", type=str)
|
18 |
+
parser.add_argument('-b', '--batch', default=16, type=int)
|
19 |
+
parser.add_argument('-c', '--chunk-length', default=15, type=int)
|
20 |
+
parser.add_argument('-o', '--output-dir', default="eval_pipeline", type=str)
|
21 |
+
parser.add_argument('-p', '--punctuator', action="store_true")
|
22 |
+
parser.add_argument('-s', '--stable-ts', action="store_true")
|
23 |
+
parser.add_argument('--pretty-table', action="store_true")
|
24 |
+
arg = parser.parse_args()
|
25 |
+
|
26 |
+
os.makedirs(arg.output_dir, exist_ok=True)
|
27 |
+
output_metric_file = f"{arg.output_dir}/metric.jsonl"
|
28 |
+
|
29 |
+
# display mode
|
30 |
+
if arg.pretty_table:
|
31 |
+
with open(output_metric_file) as f:
|
32 |
+
metrics = [json.loads(s) for s in f.read().split("\n") if len(s) > 0]
|
33 |
+
df_metric = pd.DataFrame(metrics).round(1).sort_values(["dataset", "model"])
|
34 |
+
df_metric["cer/wer (norm)"] = [f"{c}/{w}" for c, w in zip(df_metric["cer_norm"], df_metric["wer_norm"])]
|
35 |
+
df_metric["cer/wer (raw)"] = [f"{c}/{w}" for c, w in zip(df_metric["cer_raw"], df_metric["wer_raw"])]
|
36 |
+
|
37 |
+
def pretty(m, p, s):
|
38 |
+
if p and s:
|
39 |
+
return f"{m} (punctuator + stable-ts)"
|
40 |
+
if s:
|
41 |
+
return f"{m} (stable-ts)"
|
42 |
+
if p:
|
43 |
+
return f"{m} (punctuator)"
|
44 |
+
return m
|
45 |
+
|
46 |
+
df_metric["model"] = [pretty(m, p, s) for m, p, s in zip(df_metric["model"], df_metric["punctuator"], df_metric["stable_ts"])]
|
47 |
+
df_metric = df_metric[["model", "dataset", "punctuator", "stable_ts", "cer/wer (raw)", "cer/wer (norm)"]]
|
48 |
+
print(df_metric)
|
49 |
+
df_metric = df_metric.drop_duplicates()
|
50 |
+
print("\nNORM")
|
51 |
+
print(df_metric.pivot(values="cer/wer (norm)", columns="dataset", index="model").to_markdown())
|
52 |
+
print("\nRAW")
|
53 |
+
print(df_metric.pivot(values="cer/wer (raw)", columns="dataset", index="model").to_markdown())
|
54 |
+
exit()
|
55 |
+
|
56 |
+
# model config
|
57 |
+
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
58 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
59 |
+
model_kwargs = {"attn_implementation": arg.attn} if torch.cuda.is_available() and arg.attn else {}
|
60 |
+
generate_kwargs = {"language": "japanese", "task": "transcribe"}
|
61 |
+
pipeline_config = dict(
|
62 |
+
model=arg.model,
|
63 |
+
torch_dtype=torch_dtype,
|
64 |
+
device=device,
|
65 |
+
model_kwargs=model_kwargs,
|
66 |
+
chunk_length_s=arg.chunk_length,
|
67 |
+
batch_size=arg.batch
|
68 |
+
)
|
69 |
+
|
70 |
+
# instantiate pipeline
|
71 |
+
metric = {"model": arg.model, "dataset": arg.dataset, "chunk_length_s": arg.chunk_length}
|
72 |
+
if arg.model in ["kotoba-tech/kotoba-whisper-v1.1"]:
|
73 |
+
pipe = pipeline(trust_remote_code=True, punctuator=arg.punctuator, stable_ts=arg.stable_ts, **pipeline_config)
|
74 |
+
stable_ts, punctuator = arg.stable_ts, arg.punctuator
|
75 |
+
else:
|
76 |
+
pipe = pipeline("automatic-speech-recognition", **pipeline_config)
|
77 |
+
stable_ts, punctuator = None, None
|
78 |
+
metric.update({"punctuator": punctuator, "stable_ts": stable_ts})
|
79 |
+
|
80 |
+
# load the dataset and get prediction
|
81 |
+
dataset = load_dataset(arg.dataset, split="test")
|
82 |
+
output = pipe(dataset['audio'], generate_kwargs=generate_kwargs)
|
83 |
+
normalizer = BasicTextNormalizer()
|
84 |
+
prediction_norm = [normalizer(i['text']).replace(" ", "") for i in output]
|
85 |
+
references_norm = [normalizer(i).replace(" ", "") for i in dataset['transcription']]
|
86 |
+
prediction_raw = [i['text'].replace(" ", "") for i in output]
|
87 |
+
references_raw = [i.replace(" ", "") for i in dataset['transcription']]
|
88 |
+
|
89 |
+
# compute metrics
|
90 |
+
cer_metric = load("cer")
|
91 |
+
cer_norm = 100 * cer_metric.compute(predictions=prediction_norm, references=references_norm)
|
92 |
+
cer_raw = 100 * cer_metric.compute(predictions=prediction_raw, references=references_raw)
|
93 |
+
wer_metric = load("wer")
|
94 |
+
wer_norm = 100 * wer_metric.compute(predictions=prediction_norm, references=references_norm)
|
95 |
+
wer_raw = 100 * wer_metric.compute(predictions=prediction_raw, references=references_raw)
|
96 |
+
metric.update({"cer_raw": cer_raw, "wer_raw": wer_raw, "cer_norm": cer_norm, "wer_norm": wer_norm})
|
97 |
+
|
98 |
+
# save the results
|
99 |
+
metrics = []
|
100 |
+
if os.path.exists(output_metric_file):
|
101 |
+
with open(output_metric_file) as f:
|
102 |
+
metrics += [json.loads(s) for s in f.read().split("\n") if len(s) > 0]
|
103 |
+
output_prediction_file = f"{arg.output_dir}/prediction.csv"
|
104 |
+
dfs = None
|
105 |
+
if os.path.exists(output_prediction_file):
|
106 |
+
dfs = pd.read_csv(output_prediction_file, index_col=0)
|
107 |
+
metrics.append(metric)
|
108 |
+
pprint(metrics)
|
109 |
+
with open(output_metric_file, "w") as f:
|
110 |
+
f.write("\n".join([json.dumps(s) for s in metrics]))
|
111 |
+
|
112 |
+
# save prediction
|
113 |
+
audio_id = [i["path"] for i in dataset['audio']]
|
114 |
+
df = pd.DataFrame(
|
115 |
+
[audio_id, references_norm, prediction_norm, references_raw, prediction_raw],
|
116 |
+
index=["id", "reference_norm", "prediction_norm", "reference_raw", "prediction_raw"]
|
117 |
+
).T
|
118 |
+
df["model"] = arg.model
|
119 |
+
df["dataset"] = arg.dataset
|
120 |
+
df["stable_ts"] = stable_ts
|
121 |
+
df["punctuator"] = punctuator
|
122 |
+
df["chunk_length_s"] = arg.chunk_length
|
123 |
+
dfs = df if dfs is None else pd.concat([dfs, df])
|
124 |
+
dfs.to_csv(output_prediction_file, index=False)
|
125 |
+
|