set 'levi' as default WER normalizer
Browse files- benchmark_utils.py +8 -50
benchmark_utils.py
CHANGED
@@ -66,8 +66,9 @@ def ASRmanifest(
|
|
66 |
with torch.no_grad():
|
67 |
with autocast():
|
68 |
try:
|
69 |
-
result = asr_pipeline(audiofile)
|
70 |
asrtext = result['text']
|
|
|
71 |
except (FileNotFoundError, ValueError) as e:
|
72 |
print(f'SKIPPED: {audiofile}')
|
73 |
continue
|
@@ -77,49 +78,6 @@ def ASRmanifest(
|
|
77 |
compute_time = (et-st)
|
78 |
print(f'...transcription complete in {compute_time:.1f} sec')
|
79 |
|
80 |
-
def load_model(
|
81 |
-
model_path:str,
|
82 |
-
language='english',
|
83 |
-
use_int8 = False,
|
84 |
-
device_map='auto'):
|
85 |
-
|
86 |
-
warnings.filterwarnings("ignore")
|
87 |
-
transformers.utils.logging.set_verbosity_error()
|
88 |
-
|
89 |
-
try:
|
90 |
-
model = WhisperForConditionalGeneration.from_pretrained(
|
91 |
-
model_path,
|
92 |
-
load_in_8bit=use_int8,
|
93 |
-
device_map=device_map,
|
94 |
-
use_cache=False,
|
95 |
-
)
|
96 |
-
try:
|
97 |
-
processor=WhisperProcessor.from_pretrained(model_path, language=language, task="transcribe")
|
98 |
-
except OSError:
|
99 |
-
print('missing tokenizer and preprocessor config files in save dir, checking directory above...')
|
100 |
-
processor=WhisperProcessor.from_pretrained(os.path.join(model_path,'..'), language=language, task="transcribe")
|
101 |
-
|
102 |
-
except OSError as e:
|
103 |
-
print(f'{e}: possibly missing model or config file in model path. Will check for adapter...')
|
104 |
-
# check if PEFT
|
105 |
-
if os.path.isdir(os.path.join(model_path , "adapter_model")):
|
106 |
-
print('found adapter...loading PEFT model')
|
107 |
-
# checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
|
108 |
-
peft_config = PeftConfig.from_pretrained(os.path.join(model_path , "adapter_model"))
|
109 |
-
print(f'...loading and merging LORA weights to base model {peft_config.base_model_name_or_path}')
|
110 |
-
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
|
111 |
-
load_in_8bit=use_int8,
|
112 |
-
device_map=device_map,
|
113 |
-
use_cache=False,
|
114 |
-
)
|
115 |
-
model = PeftModel.from_pretrained(model, os.path.join(model_path,"adapter_model"))
|
116 |
-
model = model.merge_and_unload()
|
117 |
-
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task="transcribe")
|
118 |
-
else:
|
119 |
-
raise e
|
120 |
-
model.eval()
|
121 |
-
return(model, processor)
|
122 |
-
|
123 |
def prepare_pipeline(model_path, generate_opts):
|
124 |
"""Prepare a pipeline for ASR inference
|
125 |
Args:
|
@@ -128,16 +86,16 @@ def prepare_pipeline(model_path, generate_opts):
|
|
128 |
Returns:
|
129 |
pipeline: ASR pipeline
|
130 |
"""
|
131 |
-
|
132 |
-
model_path=model_path)
|
133 |
|
134 |
asr_pipeline = pipeline(
|
135 |
"automatic-speech-recognition",
|
136 |
-
model=
|
137 |
tokenizer=processor.tokenizer,
|
138 |
feature_extractor=processor.feature_extractor,
|
139 |
generate_kwargs=generate_opts,
|
140 |
-
|
|
|
141 |
return asr_pipeline
|
142 |
|
143 |
#%% WER evaluation functions
|
@@ -285,7 +243,7 @@ def wer_from_df(
|
|
285 |
hypcol='hyp',
|
286 |
return_alignments=False,
|
287 |
normalise = True,
|
288 |
-
text_norm_method='
|
289 |
printout=True):
|
290 |
"""Compute WER from a dataframe containing a ref col and a hyp col
|
291 |
WER is computed on the edit operation counts over the whole df,
|
@@ -338,7 +296,7 @@ def wer_from_csv(
|
|
338 |
hypcol='hyp',
|
339 |
return_alignments=False,
|
340 |
normalise = True,
|
341 |
-
text_norm_method='
|
342 |
printout=True):
|
343 |
|
344 |
res = pd.read_csv(csv_path).astype(str)
|
|
|
66 |
with torch.no_grad():
|
67 |
with autocast():
|
68 |
try:
|
69 |
+
result = asr_pipeline(audiofile )
|
70 |
asrtext = result['text']
|
71 |
+
asr_pipeline.call_count = 0
|
72 |
except (FileNotFoundError, ValueError) as e:
|
73 |
print(f'SKIPPED: {audiofile}')
|
74 |
continue
|
|
|
78 |
compute_time = (et-st)
|
79 |
print(f'...transcription complete in {compute_time:.1f} sec')
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
def prepare_pipeline(model_path, generate_opts):
|
82 |
"""Prepare a pipeline for ASR inference
|
83 |
Args:
|
|
|
86 |
Returns:
|
87 |
pipeline: ASR pipeline
|
88 |
"""
|
89 |
+
processor = WhisperProcessor.from_pretrained(model_path)
|
|
|
90 |
|
91 |
asr_pipeline = pipeline(
|
92 |
"automatic-speech-recognition",
|
93 |
+
model=model_path,
|
94 |
tokenizer=processor.tokenizer,
|
95 |
feature_extractor=processor.feature_extractor,
|
96 |
generate_kwargs=generate_opts,
|
97 |
+
model_kwargs={"load_in_8bit": False},
|
98 |
+
device_map='auto')
|
99 |
return asr_pipeline
|
100 |
|
101 |
#%% WER evaluation functions
|
|
|
243 |
hypcol='hyp',
|
244 |
return_alignments=False,
|
245 |
normalise = True,
|
246 |
+
text_norm_method='levi',
|
247 |
printout=True):
|
248 |
"""Compute WER from a dataframe containing a ref col and a hyp col
|
249 |
WER is computed on the edit operation counts over the whole df,
|
|
|
296 |
hypcol='hyp',
|
297 |
return_alignments=False,
|
298 |
normalise = True,
|
299 |
+
text_norm_method='levi' ,
|
300 |
printout=True):
|
301 |
|
302 |
res = pd.read_csv(csv_path).astype(str)
|