rosyvs commited on
Commit
25aaf66
1 Parent(s): dc03094

levi default wer norm

Browse files
Files changed (1) hide show
  1. benchmark_utils.py +8 -50
benchmark_utils.py CHANGED
@@ -66,8 +66,9 @@ def ASRmanifest(
66
  with torch.no_grad():
67
  with autocast():
68
  try:
69
- result = asr_pipeline(audiofile)
70
  asrtext = result['text']
 
71
  except (FileNotFoundError, ValueError) as e:
72
  print(f'SKIPPED: {audiofile}')
73
  continue
@@ -77,49 +78,6 @@ def ASRmanifest(
77
  compute_time = (et-st)
78
  print(f'...transcription complete in {compute_time:.1f} sec')
79
 
80
- def load_model(
81
- model_path:str,
82
- language='english',
83
- use_int8 = False,
84
- device_map='auto'):
85
-
86
- warnings.filterwarnings("ignore")
87
- transformers.utils.logging.set_verbosity_error()
88
-
89
- try:
90
- model = WhisperForConditionalGeneration.from_pretrained(
91
- model_path,
92
- load_in_8bit=use_int8,
93
- device_map=device_map,
94
- use_cache=False,
95
- )
96
- try:
97
- processor=WhisperProcessor.from_pretrained(model_path, language=language, task="transcribe")
98
- except OSError:
99
- print('missing tokenizer and preprocessor config files in save dir, checking directory above...')
100
- processor=WhisperProcessor.from_pretrained(os.path.join(model_path,'..'), language=language, task="transcribe")
101
-
102
- except OSError as e:
103
- print(f'{e}: possibly missing model or config file in model path. Will check for adapter...')
104
- # check if PEFT
105
- if os.path.isdir(os.path.join(model_path , "adapter_model")):
106
- print('found adapter...loading PEFT model')
107
- # checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json
108
- peft_config = PeftConfig.from_pretrained(os.path.join(model_path , "adapter_model"))
109
- print(f'...loading and merging LORA weights to base model {peft_config.base_model_name_or_path}')
110
- model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path,
111
- load_in_8bit=use_int8,
112
- device_map=device_map,
113
- use_cache=False,
114
- )
115
- model = PeftModel.from_pretrained(model, os.path.join(model_path,"adapter_model"))
116
- model = model.merge_and_unload()
117
- processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task="transcribe")
118
- else:
119
- raise e
120
- model.eval()
121
- return(model, processor)
122
-
123
  def prepare_pipeline(model_path, generate_opts):
124
  """Prepare a pipeline for ASR inference
125
  Args:
@@ -128,16 +86,16 @@ def prepare_pipeline(model_path, generate_opts):
128
  Returns:
129
  pipeline: ASR pipeline
130
  """
131
- model, processor = load_model(
132
- model_path=model_path)
133
 
134
  asr_pipeline = pipeline(
135
  "automatic-speech-recognition",
136
- model=model,
137
  tokenizer=processor.tokenizer,
138
  feature_extractor=processor.feature_extractor,
139
  generate_kwargs=generate_opts,
140
- )
 
141
  return asr_pipeline
142
 
143
  #%% WER evaluation functions
@@ -285,7 +243,7 @@ def wer_from_df(
285
  hypcol='hyp',
286
  return_alignments=False,
287
  normalise = True,
288
- text_norm_method='isat',
289
  printout=True):
290
  """Compute WER from a dataframe containing a ref col and a hyp col
291
  WER is computed on the edit operation counts over the whole df,
@@ -338,7 +296,7 @@ def wer_from_csv(
338
  hypcol='hyp',
339
  return_alignments=False,
340
  normalise = True,
341
- text_norm_method='isat' ,
342
  printout=True):
343
 
344
  res = pd.read_csv(csv_path).astype(str)
 
66
  with torch.no_grad():
67
  with autocast():
68
  try:
69
+ result = asr_pipeline(audiofile )
70
  asrtext = result['text']
71
+ asr_pipeline.call_count = 0
72
  except (FileNotFoundError, ValueError) as e:
73
  print(f'SKIPPED: {audiofile}')
74
  continue
 
78
  compute_time = (et-st)
79
  print(f'...transcription complete in {compute_time:.1f} sec')
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def prepare_pipeline(model_path, generate_opts):
82
  """Prepare a pipeline for ASR inference
83
  Args:
 
86
  Returns:
87
  pipeline: ASR pipeline
88
  """
89
+ processor = WhisperProcessor.from_pretrained(model_path)
 
90
 
91
  asr_pipeline = pipeline(
92
  "automatic-speech-recognition",
93
+ model=model_path,
94
  tokenizer=processor.tokenizer,
95
  feature_extractor=processor.feature_extractor,
96
  generate_kwargs=generate_opts,
97
+ model_kwargs={"load_in_8bit": False},
98
+ device_map='auto')
99
  return asr_pipeline
100
 
101
  #%% WER evaluation functions
 
243
  hypcol='hyp',
244
  return_alignments=False,
245
  normalise = True,
246
+ text_norm_method='levi',
247
  printout=True):
248
  """Compute WER from a dataframe containing a ref col and a hyp col
249
  WER is computed on the edit operation counts over the whole df,
 
296
  hypcol='hyp',
297
  return_alignments=False,
298
  normalise = True,
299
+ text_norm_method='levi' ,
300
  printout=True):
301
 
302
  res = pd.read_csv(csv_path).astype(str)