NiranjanC commited on
Commit
dc03094
1 Parent(s): 286bce6

Create benchmark_utils.py

Browse files
Files changed (1) hide show
  1. benchmark_utils.py +353 -0
benchmark_utils.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%% imports
2
+ import pandas as pd
3
+ import time
4
+ from tqdm import tqdm
5
+ import torch
6
+ from torch.cuda.amp import autocast
7
+ import transformers
8
+ from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig
9
+ from transformers import pipeline, AutomaticSpeechRecognitionPipeline
10
+ from peft import PeftModel, PeftConfig
11
+ import warnings
12
+ import jiwer
13
+ from jiwer.process import WordOutput
14
+ import pandas as pd
15
+ import numpy as np
16
+ from pathlib import Path
17
+ import os
18
+ import math
19
+ from decimal import InvalidOperation
20
+ import contractions
21
+ from whisper.normalizers.english import EnglishTextNormalizer
22
+ from num2words import num2words
23
+ import csv
24
+ import re
25
+ import string
26
+
27
+ #%% define functions
28
+ def ASRmanifest(
29
+ manifest_csv: str,
30
+ out_csv: str,
31
+ corpora_root: str,
32
+ model_path:str,
33
+ ):
34
+ """Run Whisper ASR on a dataset specified in a manifest
35
+ Args:
36
+ manifest_csv (str): path to manifest csv listing files to transcribe
37
+ out_csv (str):path to write output csv
38
+ corpora_root (str): root path where audio files are, inserted in place of $DATAROOT in manifest
39
+ model_path (str): path to model directory / huggingface model name
40
+ """
41
+
42
+ df = pd.read_csv(manifest_csv,keep_default_na=False)
43
+ fieldnames = list(df.columns) + ['asr']
44
+
45
+ asr_pipeline=prepare_pipeline(
46
+ model_path=model_path,
47
+ generate_opts={'max_new_tokens':448,
48
+ 'num_beams':1,#greedy
49
+ 'repetition_penalty':1,
50
+ 'do_sample':False
51
+ }
52
+ )
53
+
54
+ message = "This may take a while on CPU." if asr_pipeline.device.type=="cpu" else "Using GPU"
55
+ print(f'Running ASR for {len(df)} files. {message} ...')
56
+ compute_time=0
57
+ total_audio_dur=0
58
+ # get the start time
59
+ st = time.time()
60
+
61
+ with open(out_csv, 'w', newline='') as csvfile:
62
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames,delimiter=',')
63
+ writer.writeheader()
64
+ for i,row in tqdm(df.iterrows(), total=df.shape[0]):
65
+ audiofile=row['wav'].replace('$DATAROOT',corpora_root)
66
+ with torch.no_grad():
67
+ with autocast():
68
+ try:
69
+ result = asr_pipeline(audiofile)
70
+ asrtext = result['text']
71
+ except (FileNotFoundError, ValueError) as e:
72
+ print(f'SKIPPED: {audiofile}')
73
+ continue
74
+ row['asr']=asrtext
75
+ writer.writerow( row.to_dict())
76
+ et = time.time()
77
+ compute_time = (et-st)
78
+ print(f'...transcription complete in {compute_time:.1f} sec')
79
+
80
+ def load_model(
81
+ model_path:str,
82
+ language='english',
83
+ use_int8 = False,
84
+ device_map='auto'):
85
+
86
+ warnings.filterwarnings("ignore")
87
+ transformers.utils.logging.set_verbosity_error()
88
+
89
+ try:
90
+ model = WhisperForConditionalGeneration.from_pretrained(
91
+ model_path,
92
+ load_in_8bit=use_int8,
93
+ device_map=device_map,
94
+ use_cache=False,
95
+ )
96
+ try:
97
+ processor=WhisperProcessor.from_pretrained(model_path, language=language, task="transcribe")
98
+ except OSError:
99
+ print('missing tokenizer and preprocessor config files in save dir, checking directory above...')
100
+ processor=WhisperProcessor.from_pretrained(os.path.join(model_path,'..'), language=language, task="transcribe")
101
+
102
+ except OSError as e:
103
+ print(f'{e}: possibly missing model or config file in model path. Will check for adapter...')
104
+ # check if PEFT
105
+ if os.path.isdir(os.path.join(model_path , "adapter_model")):
106
+ print('found adapter...loading PEFT model')
107
+ # checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
108
+ peft_config = PeftConfig.from_pretrained(os.path.join(model_path , "adapter_model"))
109
+ print(f'...loading and merging LORA weights to base model {peft_config.base_model_name_or_path}')
110
+ model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
111
+ load_in_8bit=use_int8,
112
+ device_map=device_map,
113
+ use_cache=False,
114
+ )
115
+ model = PeftModel.from_pretrained(model, os.path.join(model_path,"adapter_model"))
116
+ model = model.merge_and_unload()
117
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task="transcribe")
118
+ else:
119
+ raise e
120
+ model.eval()
121
+ return(model, processor)
122
+
123
+ def prepare_pipeline(model_path, generate_opts):
124
+ """Prepare a pipeline for ASR inference
125
+ Args:
126
+ model_path (str): path to model directory / huggingface model name
127
+ generate_opts (dict): options to pass to pipeline
128
+ Returns:
129
+ pipeline: ASR pipeline
130
+ """
131
+ model, processor = load_model(
132
+ model_path=model_path)
133
+
134
+ asr_pipeline = pipeline(
135
+ "automatic-speech-recognition",
136
+ model=model,
137
+ tokenizer=processor.tokenizer,
138
+ feature_extractor=processor.feature_extractor,
139
+ generate_kwargs=generate_opts,
140
+ )
141
+ return asr_pipeline
142
+
143
+ #%% WER evaluation functions
144
+ def get_normalizer(text_norm_method='isat'):
145
+ if text_norm_method=='whisper':
146
+ normalizer=whisper_norm_text_for_wer
147
+ elif text_norm_method=='whisper_keep_tags':
148
+ normalizer=EnglishTextNormalizer()
149
+ elif text_norm_method=='isat':
150
+ normalizer = norm_text_for_wer
151
+ elif text_norm_method=='levi':
152
+ normalizer = levi_norm_text_for_wer
153
+ else:
154
+ raise NotImplementedError(f'unrecognized normalizer method: {text_norm_method}')
155
+ return normalizer
156
+
157
+ def strip_punct(instr, keep_math=False):
158
+ newstr = ''
159
+ for word in instr.split():
160
+ if keep_math:
161
+ word=word.strip('!"#$&\',.:;<=>?@[\\]^_`{|}~')
162
+ else:
163
+ # delete punct from start and end of word
164
+ word = word.strip(string.punctuation)
165
+ # delete commas inside numbers
166
+ m = re.match(r'(\d*),(\d)', word)
167
+ if m != None:
168
+ word = word.replace(',', '')
169
+ # commas inside words become space
170
+ word = re.sub(",", " ", word)
171
+ # hyphens inside words become space
172
+ if keep_math:
173
+ pass
174
+ else:
175
+ word = re.sub("-", " ", word)
176
+ word = word.strip()
177
+ newstr += ' ' + word
178
+ newstr = newstr.strip()
179
+ return newstr
180
+
181
+ def remove_in_brackets(text):
182
+ # removes any clause in brackets or parens, and the brackets themselves
183
+ return re.sub("[\(\[\<].*?[\)\]\>]+", " ", text)
184
+
185
+ def caught_num2words(text):
186
+ # first do currency replacements #TODO: plurals vs singular
187
+ if '$' in text:
188
+ text = re.sub('\$([0-9]+)', '\g<1> dollars', text)
189
+ if '€' in text:
190
+ text = re.sub('\$([0-9]+)', '\g<1> euro', text)
191
+ if '£' in text:
192
+ text = re.sub('\$([0-9]+)', '\g<1> pounds', text)
193
+ if '%' in text:
194
+ text = re.sub('([0-9]+)\%', '\g<1> percent', text)
195
+
196
+ # strip punctuation
197
+ text=strip_punct(text, keep_math=True)
198
+ text=text.strip('*=/')
199
+ # catch strings that might be converted to infinity or NaN and return as is...
200
+ naughty_words = ['INF','Inf','inf','NAN','NaN', 'nan', 'NONE','None','none','Infinity','infinity']
201
+ if text in naughty_words:
202
+ return text
203
+ try:
204
+ if len(text.split()) > 1:
205
+ return ' '.join([caught_num2words(word) for word in text.split()])
206
+ else:
207
+ return num2words(text)
208
+ except (InvalidOperation, ValueError) as error:
209
+ return text
210
+
211
+ def spell_math(text):
212
+ # spell out mathematical expressions
213
+ # numerals preceded by hyphen become negative
214
+ text = re.sub('\-(\d+)', 'minus \g<1>', text)
215
+ text = re.sub('(\d+\s?)\-(\s?\d?)', '\g<1> minus \g<2>', text)
216
+ text = re.sub('(\w+\s+)\-(\s?\w+)', '\g<1> minus \g<2>', text) # need to be more careful with - as this could be a hyphenated word not minus
217
+ text = re.sub('(\w+\s?)\+(\s?\w+)', '\g<1> plus \g<2>', text)
218
+ text = re.sub('(\w+\s?)\*(\s?\w+)', '\g<1> times \g<2>', text)
219
+ text = re.sub('(\d+\s?)x(\s?\d)', '\g<1> times \g<2>', text) # need to be more careful with x as this could be a variable not times
220
+ text = re.sub('(\w+\s?)\/(\s?\w+)', '\g<1> divided by \g<2>', text)
221
+ text = re.sub('(\w+\s?)\=(\s?\w+)', '\g<1> equals \g<2>', text)
222
+ return text
223
+
224
+ def expand_contractions(str):
225
+ expanded_words = []
226
+ for wrd in str.split():
227
+ expanded_words.append(contractions.fix(wrd))
228
+ str = ' '.join(expanded_words)
229
+ return str
230
+
231
+ def norm_text_for_wer(text):
232
+ # function to format text or lists of text (e.g. asr, transcript) for wer computation.
233
+ # Converts from list to a single string and apply some text normalization operations
234
+ # note that the clean_REV_transcript function should be applied first to remove REV-specific keywords
235
+ # and extract text from docx format tables
236
+
237
+ if isinstance(text,list):
238
+ text = ' '.join(text)
239
+ text=str(text)
240
+ text = text.replace('\n',' ') # replace newline with space
241
+ text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible]
242
+ text = re.sub('%\w+','', text) # remove %HESITATION etc
243
+ text = ' '.join([caught_num2words(str) for str in text.split(' ')]) # spell out numbers
244
+ text = expand_contractions(text)
245
+ text = strip_punct(text)
246
+ text = text.lower()
247
+ text = re.sub('\s+',' ',text) # replace multiple space with single
248
+ return text
249
+
250
+ def levi_norm_text_for_wer(text):
251
+ # function to format text or lists of text (e.g. asr, transcript) for wer computation.
252
+ # specialized for math language
253
+
254
+ if isinstance(text,list):
255
+ text = ' '.join(text)
256
+ text=str(text)
257
+ text = text.replace('\n',' ') # replace newline with space
258
+ text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible]
259
+ text = re.sub('%\w+','', text) # remove %HESITATION etc
260
+ text = spell_math(text)
261
+ text = ' '.join([caught_num2words(str) for str in text.split(' ')]) # spell out numbers
262
+ text = expand_contractions(text)
263
+ text = strip_punct(text, keep_math=True)
264
+ text = text.lower()
265
+ text = re.sub('\s+',' ',text) # replace multiple space with single
266
+ return text
267
+
268
+ def whisper_norm_text_for_wer(text):
269
+ # function to format text for wer computation.
270
+ # uses Whisper normalizer after stripping corpus-specific special tags
271
+
272
+ if isinstance(text,list):
273
+ text = ' '.join(text)
274
+ text=str(text)
275
+ text = text.replace('\n',' ') # replace newline with space
276
+ text = re.sub('%\w+','', text) # remove %HESITATION etc
277
+ text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible]
278
+ normalizer = EnglishTextNormalizer()
279
+ text = normalizer(text)
280
+ return text
281
+
282
+ def wer_from_df(
283
+ df,
284
+ refcol='ref',
285
+ hypcol='hyp',
286
+ return_alignments=False,
287
+ normalise = True,
288
+ text_norm_method='isat',
289
+ printout=True):
290
+ """Compute WER from a dataframe containing a ref col and a hyp col
291
+ WER is computed on the edit operation counts over the whole df,
292
+ not averaged over single utterances.
293
+
294
+ Args:
295
+ df (pandas DataFrame): containing rows per utterance
296
+ refcol (str, optional): column name containing reference transcript. Defaults to 'ref'.
297
+ hypcol (str, optional): column name containing hypothesis transcript. Defaults to 'hyp'.
298
+ return_alignments (bool, optional): Return full word-level alignments. Defaults to False.
299
+ normalise (bool, optional): Apply text normalisatin to ref and hyp (see norm_text_for_wer). Defaults to True.
300
+ printout (bool, optional): Print WER metrics. Defaults to True.
301
+ """
302
+ normalizer=get_normalizer(text_norm_method)
303
+
304
+ refs=df[refcol].astype(str)
305
+ hyps = df[hypcol].astype(str)
306
+ if normalise:
307
+ refs=refs.apply(normalizer)
308
+ hyps=hyps.apply(normalizer)
309
+
310
+ #ID,ref,hyp,ref_norm,hyp_norm
311
+ if any(s == '' for s in list(refs)):
312
+ nonempty=refs.str.len()>0
313
+ refs=refs[nonempty]
314
+ hyps=hyps[nonempty]
315
+ # print(f'{sum(~nonempty)} empty references removed (after normalisation if applied)')
316
+ wer_meas = jiwer.compute_measures(list(refs), list(hyps))
317
+
318
+ if not return_alignments:
319
+ # remove alignments
320
+ del wer_meas['ops']
321
+ del wer_meas['truth']
322
+ del wer_meas['hypothesis']
323
+ wer_meas['word_count'] = wer_meas['substitutions']+wer_meas['deletions']+wer_meas['hits']
324
+ wer_meas['sub_rate'] = wer_meas['substitutions']/wer_meas['word_count']
325
+ wer_meas['del_rate'] = wer_meas['deletions']/wer_meas['word_count']
326
+ wer_meas['ins_rate'] = wer_meas['insertions']/wer_meas['word_count']
327
+
328
+ if printout:
329
+ for key in ['wer','sub_rate','del_rate','ins_rate']:
330
+ print((f"{key}={100*wer_meas[key]:.1f}" ))
331
+ print(f"word_count={int(wer_meas['word_count'])}")
332
+ return wer_meas
333
+
334
+
335
+ def wer_from_csv(
336
+ csv_path,
337
+ refcol='ref',
338
+ hypcol='hyp',
339
+ return_alignments=False,
340
+ normalise = True,
341
+ text_norm_method='isat' ,
342
+ printout=True):
343
+
344
+ res = pd.read_csv(csv_path).astype(str)
345
+
346
+ wer_meas=wer_from_df(res,
347
+ refcol=refcol,
348
+ hypcol=hypcol,
349
+ return_alignments=return_alignments,
350
+ normalise = normalise,
351
+ text_norm_method=text_norm_method,
352
+ printout=printout)
353
+ return wer_meas