#%% imports import pandas as pd import time from tqdm import tqdm import torch from torch.cuda.amp import autocast import transformers from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig from transformers import pipeline, AutomaticSpeechRecognitionPipeline from peft import PeftModel, PeftConfig import warnings import jiwer from jiwer.process import WordOutput import pandas as pd import numpy as np from pathlib import Path import os import math from decimal import InvalidOperation import contractions from whisper.normalizers.english import EnglishTextNormalizer from num2words import num2words import csv import re import string #%% define functions def ASRmanifest( manifest_csv: str, out_csv: str, corpora_root: str, model_path:str, ): """Run Whisper ASR on a dataset specified in a manifest Args: manifest_csv (str): path to manifest csv listing files to transcribe out_csv (str):path to write output csv corpora_root (str): root path where audio files are, inserted in place of $DATAROOT in manifest model_path (str): path to model directory / huggingface model name """ df = pd.read_csv(manifest_csv,keep_default_na=False) fieldnames = list(df.columns) + ['asr'] asr_pipeline=prepare_pipeline( model_path=model_path, generate_opts={'max_new_tokens':448, 'num_beams':1,#greedy 'repetition_penalty':1, 'do_sample':False } ) message = "This may take a while on CPU." if asr_pipeline.device.type=="cpu" else "Using GPU" print(f'Running ASR for {len(df)} files. {message} ...') compute_time=0 total_audio_dur=0 # get the start time st = time.time() with open(out_csv, 'w', newline='') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames,delimiter=',') writer.writeheader() for i,row in tqdm(df.iterrows(), total=df.shape[0]): audiofile=row['wav'].replace('$DATAROOT',corpora_root) with torch.no_grad(): with autocast(): try: result = asr_pipeline(audiofile) asrtext = result['text'] except (FileNotFoundError, ValueError) as e: print(f'SKIPPED: {audiofile}') continue row['asr']=asrtext writer.writerow( row.to_dict()) et = time.time() compute_time = (et-st) print(f'...transcription complete in {compute_time:.1f} sec') def load_model( model_path:str, language='english', use_int8 = False, device_map='auto'): warnings.filterwarnings("ignore") transformers.utils.logging.set_verbosity_error() try: model = WhisperForConditionalGeneration.from_pretrained( model_path, load_in_8bit=use_int8, device_map=device_map, use_cache=False, ) try: processor=WhisperProcessor.from_pretrained(model_path, language=language, task="transcribe") except OSError: print('missing tokenizer and preprocessor config files in save dir, checking directory above...') processor=WhisperProcessor.from_pretrained(os.path.join(model_path,'..'), language=language, task="transcribe") except OSError as e: print(f'{e}: possibly missing model or config file in model path. Will check for adapter...') # check if PEFT if os.path.isdir(os.path.join(model_path , "adapter_model")): print('found adapter...loading PEFT model') # checkpoint dir needs adapter model subdir with adapter_model.bin and adapter_confg.json peft_config = PeftConfig.from_pretrained(os.path.join(model_path , "adapter_model")) print(f'...loading and merging LORA weights to base model {peft_config.base_model_name_or_path}') model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, load_in_8bit=use_int8, device_map=device_map, use_cache=False, ) model = PeftModel.from_pretrained(model, os.path.join(model_path,"adapter_model")) model = model.merge_and_unload() processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task="transcribe") else: raise e model.eval() return(model, processor) def prepare_pipeline(model_path, generate_opts): """Prepare a pipeline for ASR inference Args: model_path (str): path to model directory / huggingface model name generate_opts (dict): options to pass to pipeline Returns: pipeline: ASR pipeline """ model, processor = load_model( model_path=model_path) asr_pipeline = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, generate_kwargs=generate_opts, ) return asr_pipeline #%% WER evaluation functions def get_normalizer(text_norm_method='isat'): if text_norm_method=='whisper': normalizer=whisper_norm_text_for_wer elif text_norm_method=='whisper_keep_tags': normalizer=EnglishTextNormalizer() elif text_norm_method=='isat': normalizer = norm_text_for_wer elif text_norm_method=='levi': normalizer = levi_norm_text_for_wer else: raise NotImplementedError(f'unrecognized normalizer method: {text_norm_method}') return normalizer def strip_punct(instr, keep_math=False): newstr = '' for word in instr.split(): if keep_math: word=word.strip('!"#$&\',.:;<=>?@[\\]^_`{|}~') else: # delete punct from start and end of word word = word.strip(string.punctuation) # delete commas inside numbers m = re.match(r'(\d*),(\d)', word) if m != None: word = word.replace(',', '') # commas inside words become space word = re.sub(",", " ", word) # hyphens inside words become space if keep_math: pass else: word = re.sub("-", " ", word) word = word.strip() newstr += ' ' + word newstr = newstr.strip() return newstr def remove_in_brackets(text): # removes any clause in brackets or parens, and the brackets themselves return re.sub("[\(\[\<].*?[\)\]\>]+", " ", text) def caught_num2words(text): # first do currency replacements #TODO: plurals vs singular if '$' in text: text = re.sub('\$([0-9]+)', '\g<1> dollars', text) if '€' in text: text = re.sub('\$([0-9]+)', '\g<1> euro', text) if '£' in text: text = re.sub('\$([0-9]+)', '\g<1> pounds', text) if '%' in text: text = re.sub('([0-9]+)\%', '\g<1> percent', text) # strip punctuation text=strip_punct(text, keep_math=True) text=text.strip('*=/') # catch strings that might be converted to infinity or NaN and return as is... naughty_words = ['INF','Inf','inf','NAN','NaN', 'nan', 'NONE','None','none','Infinity','infinity'] if text in naughty_words: return text try: if len(text.split()) > 1: return ' '.join([caught_num2words(word) for word in text.split()]) else: return num2words(text) except (InvalidOperation, ValueError) as error: return text def spell_math(text): # spell out mathematical expressions # numerals preceded by hyphen become negative text = re.sub('\-(\d+)', 'minus \g<1>', text) text = re.sub('(\d+\s?)\-(\s?\d?)', '\g<1> minus \g<2>', text) text = re.sub('(\w+\s+)\-(\s?\w+)', '\g<1> minus \g<2>', text) # need to be more careful with - as this could be a hyphenated word not minus text = re.sub('(\w+\s?)\+(\s?\w+)', '\g<1> plus \g<2>', text) text = re.sub('(\w+\s?)\*(\s?\w+)', '\g<1> times \g<2>', text) text = re.sub('(\d+\s?)x(\s?\d)', '\g<1> times \g<2>', text) # need to be more careful with x as this could be a variable not times text = re.sub('(\w+\s?)\/(\s?\w+)', '\g<1> divided by \g<2>', text) text = re.sub('(\w+\s?)\=(\s?\w+)', '\g<1> equals \g<2>', text) return text def expand_contractions(str): expanded_words = [] for wrd in str.split(): expanded_words.append(contractions.fix(wrd)) str = ' '.join(expanded_words) return str def norm_text_for_wer(text): # function to format text or lists of text (e.g. asr, transcript) for wer computation. # Converts from list to a single string and apply some text normalization operations # note that the clean_REV_transcript function should be applied first to remove REV-specific keywords # and extract text from docx format tables if isinstance(text,list): text = ' '.join(text) text=str(text) text = text.replace('\n',' ') # replace newline with space text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible] text = re.sub('%\w+','', text) # remove %HESITATION etc text = ' '.join([caught_num2words(str) for str in text.split(' ')]) # spell out numbers text = expand_contractions(text) text = strip_punct(text) text = text.lower() text = re.sub('\s+',' ',text) # replace multiple space with single return text def levi_norm_text_for_wer(text): # function to format text or lists of text (e.g. asr, transcript) for wer computation. # specialized for math language if isinstance(text,list): text = ' '.join(text) text=str(text) text = text.replace('\n',' ') # replace newline with space text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible] text = re.sub('%\w+','', text) # remove %HESITATION etc text = spell_math(text) text = ' '.join([caught_num2words(str) for str in text.split(' ')]) # spell out numbers text = expand_contractions(text) text = strip_punct(text, keep_math=True) text = text.lower() text = re.sub('\s+',' ',text) # replace multiple space with single return text def whisper_norm_text_for_wer(text): # function to format text for wer computation. # uses Whisper normalizer after stripping corpus-specific special tags if isinstance(text,list): text = ' '.join(text) text=str(text) text = text.replace('\n',' ') # replace newline with space text = re.sub('%\w+','', text) # remove %HESITATION etc text = remove_in_brackets(text) # removes non-spoken annotations such as [inaudible] normalizer = EnglishTextNormalizer() text = normalizer(text) return text def wer_from_df( df, refcol='ref', hypcol='hyp', return_alignments=False, normalise = True, text_norm_method='isat', printout=True): """Compute WER from a dataframe containing a ref col and a hyp col WER is computed on the edit operation counts over the whole df, not averaged over single utterances. Args: df (pandas DataFrame): containing rows per utterance refcol (str, optional): column name containing reference transcript. Defaults to 'ref'. hypcol (str, optional): column name containing hypothesis transcript. Defaults to 'hyp'. return_alignments (bool, optional): Return full word-level alignments. Defaults to False. normalise (bool, optional): Apply text normalisatin to ref and hyp (see norm_text_for_wer). Defaults to True. printout (bool, optional): Print WER metrics. Defaults to True. """ normalizer=get_normalizer(text_norm_method) refs=df[refcol].astype(str) hyps = df[hypcol].astype(str) if normalise: refs=refs.apply(normalizer) hyps=hyps.apply(normalizer) #ID,ref,hyp,ref_norm,hyp_norm if any(s == '' for s in list(refs)): nonempty=refs.str.len()>0 refs=refs[nonempty] hyps=hyps[nonempty] # print(f'{sum(~nonempty)} empty references removed (after normalisation if applied)') wer_meas = jiwer.compute_measures(list(refs), list(hyps)) if not return_alignments: # remove alignments del wer_meas['ops'] del wer_meas['truth'] del wer_meas['hypothesis'] wer_meas['word_count'] = wer_meas['substitutions']+wer_meas['deletions']+wer_meas['hits'] wer_meas['sub_rate'] = wer_meas['substitutions']/wer_meas['word_count'] wer_meas['del_rate'] = wer_meas['deletions']/wer_meas['word_count'] wer_meas['ins_rate'] = wer_meas['insertions']/wer_meas['word_count'] if printout: for key in ['wer','sub_rate','del_rate','ins_rate']: print((f"{key}={100*wer_meas[key]:.1f}" )) print(f"word_count={int(wer_meas['word_count'])}") return wer_meas def wer_from_csv( csv_path, refcol='ref', hypcol='hyp', return_alignments=False, normalise = True, text_norm_method='isat' , printout=True): res = pd.read_csv(csv_path).astype(str) wer_meas=wer_from_df(res, refcol=refcol, hypcol=hypcol, return_alignments=return_alignments, normalise = normalise, text_norm_method=text_norm_method, printout=printout) return wer_meas