|
|
|
import pandas as pd |
|
import time |
|
from tqdm import tqdm |
|
import torch |
|
from torch.cuda.amp import autocast |
|
import transformers |
|
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig |
|
from transformers import pipeline, AutomaticSpeechRecognitionPipeline |
|
from peft import PeftModel, PeftConfig |
|
import warnings |
|
import jiwer |
|
from jiwer.process import WordOutput |
|
import pandas as pd |
|
import numpy as np |
|
from pathlib import Path |
|
import os |
|
import math |
|
from decimal import InvalidOperation |
|
import contractions |
|
from whisper.normalizers.english import EnglishTextNormalizer |
|
from num2words import num2words |
|
import csv |
|
import re |
|
import string |
|
|
|
|
|
def ASRmanifest( |
|
manifest_csv: str, |
|
out_csv: str, |
|
corpora_root: str, |
|
model_path:str, |
|
): |
|
"""Run Whisper ASR on a dataset specified in a manifest |
|
Args: |
|
manifest_csv (str): path to manifest csv listing files to transcribe |
|
out_csv (str):path to write output csv |
|
corpora_root (str): root path where audio files are, inserted in place of $DATAROOT in manifest |
|
model_path (str): path to model directory / huggingface model name |
|
""" |
|
|
|
df = pd.read_csv(manifest_csv,keep_default_na=False) |
|
fieldnames = list(df.columns) + ['asr'] |
|
|
|
asr_pipeline=prepare_pipeline( |
|
model_path=model_path, |
|
generate_opts={'max_new_tokens':448, |
|
'num_beams':1, |
|
'repetition_penalty':1, |
|
'do_sample':False |
|
} |
|
) |
|
|
|
message = "This may take a while on CPU." if asr_pipeline.device.type=="cpu" else "Using GPU" |
|
print(f'Running ASR for {len(df)} files. {message} ...') |
|
compute_time=0 |
|
total_audio_dur=0 |
|
|
|
st = time.time() |
|
|
|
with open(out_csv, 'w', newline='') as csvfile: |
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,delimiter=',') |
|
writer.writeheader() |
|
for i,row in tqdm(df.iterrows(), total=df.shape[0]): |
|
audiofile=row['wav'].replace('$DATAROOT',corpora_root) |
|
with torch.no_grad(): |
|
with autocast(): |
|
try: |
|
result = asr_pipeline(audiofile ) |
|
asrtext = result['text'] |
|
asr_pipeline.call_count = 0 |
|
except (FileNotFoundError, ValueError) as e: |
|
print(f'SKIPPED: {audiofile}') |
|
continue |
|
row['asr']=asrtext |
|
writer.writerow( row.to_dict()) |
|
et = time.time() |
|
compute_time = (et-st) |
|
print(f'...transcription complete in {compute_time:.1f} sec') |
|
|
|
def prepare_pipeline(model_path, generate_opts): |
|
"""Prepare a pipeline for ASR inference |
|
Args: |
|
model_path (str): path to model directory / huggingface model name |
|
generate_opts (dict): options to pass to pipeline |
|
Returns: |
|
pipeline: ASR pipeline |
|
""" |
|
processor = WhisperProcessor.from_pretrained(model_path) |
|
|
|
asr_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model=model_path, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
generate_kwargs=generate_opts, |
|
model_kwargs={"load_in_8bit": False}, |
|
device_map='auto') |
|
return asr_pipeline |
|
|
|
|
|
def get_normalizer(text_norm_method='isat'): |
|
if text_norm_method=='whisper': |
|
normalizer=whisper_norm_text_for_wer |
|
elif text_norm_method=='whisper_keep_tags': |
|
normalizer=EnglishTextNormalizer() |
|
elif text_norm_method=='isat': |
|
normalizer = norm_text_for_wer |
|
elif text_norm_method=='levi': |
|
normalizer = levi_norm_text_for_wer |
|
else: |
|
raise NotImplementedError(f'unrecognized normalizer method: {text_norm_method}') |
|
return normalizer |
|
|
|
def strip_punct(instr, keep_math=False): |
|
newstr = '' |
|
for word in instr.split(): |
|
if keep_math: |
|
word=word.strip('!"#$&\',.:;<=>?@[\\]^_`{|}~') |
|
else: |
|
|
|
word = word.strip(string.punctuation) |
|
|
|
m = re.match(r'(\d*),(\d)', word) |
|
if m != None: |
|
word = word.replace(',', '') |
|
|
|
word = re.sub(",", " ", word) |
|
|
|
if keep_math: |
|
pass |
|
else: |
|
word = re.sub("-", " ", word) |
|
word = word.strip() |
|
newstr += ' ' + word |
|
newstr = newstr.strip() |
|
return newstr |
|
|
|
def remove_in_brackets(text): |
|
|
|
return re.sub("[\(\[\<].*?[\)\]\>]+", " ", text) |
|
|
|
def caught_num2words(text): |
|
|
|
if '$' in text: |
|
text = re.sub('\$([0-9]+)', '\g<1> dollars', text) |
|
if '€' in text: |
|
text = re.sub('\$([0-9]+)', '\g<1> euro', text) |
|
if '£' in text: |
|
text = re.sub('\$([0-9]+)', '\g<1> pounds', text) |
|
if '%' in text: |
|
text = re.sub('([0-9]+)\%', '\g<1> percent', text) |
|
|
|
|
|
text=strip_punct(text, keep_math=True) |
|
text=text.strip('*=/') |
|
|
|
naughty_words = ['INF','Inf','inf','NAN','NaN', 'nan', 'NONE','None','none','Infinity','infinity'] |
|
if text in naughty_words: |
|
return text |
|
try: |
|
if len(text.split()) > 1: |
|
return ' '.join([caught_num2words(word) for word in text.split()]) |
|
else: |
|
return num2words(text) |
|
except (InvalidOperation, ValueError) as error: |
|
return text |
|
|
|
def spell_math(text): |
|
|
|
|
|
text = re.sub('\-(\d+)', 'minus \g<1>', text) |
|
text = re.sub('(\d+\s?)\-(\s?\d?)', '\g<1> minus \g<2>', text) |
|
text = re.sub('(\w+\s+)\-(\s?\w+)', '\g<1> minus \g<2>', text) |
|
text = re.sub('(\w+\s?)\+(\s?\w+)', '\g<1> plus \g<2>', text) |
|
text = re.sub('(\w+\s?)\*(\s?\w+)', '\g<1> times \g<2>', text) |
|
text = re.sub('(\d+\s?)x(\s?\d)', '\g<1> times \g<2>', text) |
|
text = re.sub('(\w+\s?)\/(\s?\w+)', '\g<1> divided by \g<2>', text) |
|
text = re.sub('(\w+\s?)\=(\s?\w+)', '\g<1> equals \g<2>', text) |
|
return text |
|
|
|
def expand_contractions(str): |
|
expanded_words = [] |
|
for wrd in str.split(): |
|
expanded_words.append(contractions.fix(wrd)) |
|
str = ' '.join(expanded_words) |
|
return str |
|
|
|
def norm_text_for_wer(text): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(text,list): |
|
text = ' '.join(text) |
|
text=str(text) |
|
text = text.replace('\n',' ') |
|
text = remove_in_brackets(text) |
|
text = re.sub('%\w+','', text) |
|
text = ' '.join([caught_num2words(str) for str in text.split(' ')]) |
|
text = expand_contractions(text) |
|
text = strip_punct(text) |
|
text = text.lower() |
|
text = re.sub('\s+',' ',text) |
|
return text |
|
|
|
def levi_norm_text_for_wer(text): |
|
|
|
|
|
|
|
if isinstance(text,list): |
|
text = ' '.join(text) |
|
text=str(text) |
|
text = text.replace('\n',' ') |
|
text = remove_in_brackets(text) |
|
text = re.sub('%\w+','', text) |
|
text = spell_math(text) |
|
text = ' '.join([caught_num2words(str) for str in text.split(' ')]) |
|
text = expand_contractions(text) |
|
text = strip_punct(text, keep_math=True) |
|
text = text.lower() |
|
text = re.sub('\s+',' ',text) |
|
return text |
|
|
|
def whisper_norm_text_for_wer(text): |
|
|
|
|
|
|
|
if isinstance(text,list): |
|
text = ' '.join(text) |
|
text=str(text) |
|
text = text.replace('\n',' ') |
|
text = re.sub('%\w+','', text) |
|
text = remove_in_brackets(text) |
|
normalizer = EnglishTextNormalizer() |
|
text = normalizer(text) |
|
return text |
|
|
|
def wer_from_df( |
|
df, |
|
refcol='ref', |
|
hypcol='hyp', |
|
return_alignments=False, |
|
normalise = True, |
|
text_norm_method='levi', |
|
printout=True): |
|
"""Compute WER from a dataframe containing a ref col and a hyp col |
|
WER is computed on the edit operation counts over the whole df, |
|
not averaged over single utterances. |
|
|
|
Args: |
|
df (pandas DataFrame): containing rows per utterance |
|
refcol (str, optional): column name containing reference transcript. Defaults to 'ref'. |
|
hypcol (str, optional): column name containing hypothesis transcript. Defaults to 'hyp'. |
|
return_alignments (bool, optional): Return full word-level alignments. Defaults to False. |
|
normalise (bool, optional): Apply text normalisatin to ref and hyp (see norm_text_for_wer). Defaults to True. |
|
printout (bool, optional): Print WER metrics. Defaults to True. |
|
""" |
|
normalizer=get_normalizer(text_norm_method) |
|
|
|
refs=df[refcol].astype(str) |
|
hyps = df[hypcol].astype(str) |
|
if normalise: |
|
refs=refs.apply(normalizer) |
|
hyps=hyps.apply(normalizer) |
|
|
|
|
|
if any(s == '' for s in list(refs)): |
|
nonempty=refs.str.len()>0 |
|
refs=refs[nonempty] |
|
hyps=hyps[nonempty] |
|
|
|
wer_meas = jiwer.compute_measures(list(refs), list(hyps)) |
|
|
|
if not return_alignments: |
|
|
|
del wer_meas['ops'] |
|
del wer_meas['truth'] |
|
del wer_meas['hypothesis'] |
|
wer_meas['word_count'] = wer_meas['substitutions']+wer_meas['deletions']+wer_meas['hits'] |
|
wer_meas['sub_rate'] = wer_meas['substitutions']/wer_meas['word_count'] |
|
wer_meas['del_rate'] = wer_meas['deletions']/wer_meas['word_count'] |
|
wer_meas['ins_rate'] = wer_meas['insertions']/wer_meas['word_count'] |
|
|
|
if printout: |
|
for key in ['wer','sub_rate','del_rate','ins_rate']: |
|
print((f"{key}={100*wer_meas[key]:.1f}" )) |
|
print(f"word_count={int(wer_meas['word_count'])}") |
|
return wer_meas |
|
|
|
|
|
def wer_from_csv( |
|
csv_path, |
|
refcol='ref', |
|
hypcol='hyp', |
|
return_alignments=False, |
|
normalise = True, |
|
text_norm_method='levi' , |
|
printout=True): |
|
|
|
res = pd.read_csv(csv_path).astype(str) |
|
|
|
wer_meas=wer_from_df(res, |
|
refcol=refcol, |
|
hypcol=hypcol, |
|
return_alignments=return_alignments, |
|
normalise = normalise, |
|
text_norm_method=text_norm_method, |
|
printout=printout) |
|
return wer_meas |