import os |
from peft import PeftModel, PeftConfig |
import torch |
from torch.cuda.amp import autocast |
from torch.utils.data import DataLoader |
from tqdm import tqdm |
import transformers |
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig |
from transformers import pipeline, AutomaticSpeechRecognitionPipeline |
import argparse |
import time |
from pathlib import Path |
import json |
import pandas as pd |
import csv |
def prepare_pipeline(model_type='large-v2', |
model_dir="../models/whisat-1.2/", |
use_stock_model=False, |
generate_opts={'max_new_tokens':112, |
'num_beams':1, |
'repetition_penalty':1, |
'do_sample':False} |
): |
lang='english' |
USE_INT8 = False |
import warnings |
warnings.filterwarnings("ignore") |
transformers.utils.logging.set_verbosity_error() |
init_from_hub_path = f"openai/whisper-{model_type}" |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
print(device) |
feature_extractor = WhisperFeatureExtractor.from_pretrained(init_from_hub_path) |
tokenizer = WhisperTokenizer.from_pretrained(init_from_hub_path, language=lang, task="transcribe") |
processor = WhisperProcessor.from_pretrained(init_from_hub_path, language=lang, task="transcribe") |
if use_stock_model: |
model =WhisperForConditionalGeneration.from_pretrained(init_from_hub_path) |
else: |
checkpoint_dir = os.path.expanduser(model_dir) |
if os.path.isdir(os.path.join(checkpoint_dir , "adapter_model")): |
print('...it looks like this model was tuned using PEFT, because adapter_model/ is present in ckpt dir') |
peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir , "adapter_model")) |
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, |
load_in_8bit=USE_INT8, |
device_map='auto', |
use_cache=False, |
) |
model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir,"adapter_model")) |
else: |
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_dir, |
load_in_8bit=USE_INT8, |
device_map='auto', |
use_cache=False, |
) |
model.eval() |
pipe = AutomaticSpeechRecognitionPipeline( |
model=model, |
tokenizer=tokenizer, |
feature_extractor=feature_extractor, |
chunk_length_s=30, |
device=device, |
return_timestamps=False, |
generate_kwargs=generate_opts, |
) |
return(pipe) |
def load_model(model_type='large-v2', |
model_dir="../models/whisat-1.2/"): |
lang='english' |
USE_INT8 = False |
import warnings |
warnings.filterwarnings("ignore") |
transformers.utils.logging.set_verbosity_error() |
init_from_hub_path = f"openai/whisper-{model_type}" |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
print(device) |
feature_extractor = WhisperFeatureExtractor.from_pretrained(init_from_hub_path) |
tokenizer = WhisperTokenizer.from_pretrained(init_from_hub_path, language=lang, task="transcribe") |
processor = WhisperProcessor.from_pretrained(init_from_hub_path, language=lang, task="transcribe") |
checkpoint_dir = os.path.expanduser(model_dir) |
peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir , "adapter_model")) |
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, |
load_in_8bit=USE_INT8, |
device_map='auto', |
use_cache=False, |
) |
model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir,"adapter_model")) |
model.eval() |
return(model, tokenizer, processor) |
def ASRdirWhisat( |
audio_dir, |
files_to_include=None, |
out_dir = '../whisat_results/', |
model_type='large-v2', |
model_name='whisat-1.2', |
model_dir="../models/whisat-1.2", |
use_stock_model=False, |
max_new_tokens=112, |
num_beams=1, |
do_sample=False, |
repetition_penalty=1, |
): |
asr_model=prepare_pipeline( |
model_type=model_type, |
model_dir=model_dir, |
use_stock_model=use_stock_model, |
generate_opts={'max_new_tokens':max_new_tokens, |
'num_beams':num_beams, |
'repetition_penalty':repetition_penalty, |
'do_sample':do_sample |
} |
) |
if use_stock_model: |
model_name='whisper_' + model_type + '_stock' |
if files_to_include: |
assert isinstance(files_to_include,list) ,'files_to_include should be a list of paths relative to audio_dir to transcribe' |
audio_files=files_to_include |
else: |
audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )] |
asrDir = os.path.join(out_dir,f'ASR_{model_name}') |
jsonDir = os.path.join(out_dir,f'JSON_{model_name}') |
os.makedirs(asrDir, exist_ok=True) |
os.makedirs(jsonDir, exist_ok=True) |
message = "This may take a while on CPU. Go make a cuppa" if asr_model.device.type=="cpu" else "Running on GPU" |
print(f'Running ASR for {len(audio_files)} files. {message} ...') |
compute_time=0 |
total_audio_dur=0 |
st = time.time() |
for audiofile in tqdm(audio_files): |
sessname=Path(audiofile).stem |
sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve()) |
asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") |
jsonFile = os.path.join(jsonDir,sesspath, f"{sessname}.json") |
os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True) |
os.makedirs(os.path.join(jsonDir,sesspath),exist_ok=True) |
with torch.no_grad(): |
with autocast(): |
try: |
result = asr_model(audiofile) |
except ValueError as e: |
print(f'{e}: {audiofile}') |
continue |
with open(jsonFile, "w") as jf: |
json.dump(result, jf, indent=4) |
asrtext = result['text'] |
with open(asrFullFile,'w') as outfile: |
outfile.write(asrtext) |
et = time.time() |
compute_time = (et-st) |
print(f'...transcription complete in {compute_time:.1f} sec') |
def ASRmanifestWhisat( |
manifest_csv, |
out_csv, |
corpora_root, |
model_type='large-v2', |
model_dir="../models/whisat-1.2", |
use_stock_model=False, |
max_new_tokens=112, |
num_beams=1, |
do_sample=False, |
repetition_penalty=1, |
): |
df = pd.read_csv(manifest_csv,keep_default_na=False) |
fieldnames = list(df.columns) + ['asr'] |
asr_model=prepare_pipeline( |
model_type=model_type, |
model_dir=model_dir, |
use_stock_model=use_stock_model, |
generate_opts={'max_new_tokens':max_new_tokens, |
'num_beams':num_beams, |
'repetition_penalty':repetition_penalty, |
'do_sample':do_sample |
} |
) |
message = "This may take a while on CPU. Go make a cuppa " if asr_model.device.type=="cpu" else "Running on GPU" |
print(f'Running ASR for {len(df)} files. {message} ...') |
compute_time=0 |
total_audio_dur=0 |
st = time.time() |
with open(out_csv, 'w', newline='') as csvfile: |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,delimiter=',') |
writer.writeheader() |
for i,row in tqdm(df.iterrows(), total=df.shape[0]): |
audiofile=row['wav'].replace('$DATAROOT',corpora_root) |
with torch.no_grad(): |
with autocast(): |
try: |
result = asr_model(audiofile) |
asrtext = result['text'] |
except ValueError as e: |
print(f'{e}: {audiofile}') |
asrtext='' |
row['asr']=asrtext |
writer.writerow( row.to_dict()) |
et = time.time() |
compute_time = (et-st) |
print(f'...transcription complete in {compute_time:.1f} sec') |