|
import os |
|
from peft import PeftModel, PeftConfig |
|
import torch |
|
from torch.cuda.amp import autocast |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
import transformers |
|
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig |
|
from transformers import pipeline, AutomaticSpeechRecognitionPipeline |
|
import argparse |
|
import time |
|
from pathlib import Path |
|
import json |
|
import pandas as pd |
|
import csv |
|
|
|
def prepare_pipeline(model_type='large-v2', |
|
model_dir="../models/whisat-1.2/", |
|
use_stock_model=False, |
|
generate_opts={'max_new_tokens':112, |
|
'num_beams':1, |
|
'repetition_penalty':1, |
|
'do_sample':False} |
|
): |
|
|
|
lang='english' |
|
USE_INT8 = False |
|
|
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
init_from_hub_path = f"openai/whisper-{model_type}" |
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
print(device) |
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(init_from_hub_path) |
|
|
|
tokenizer = WhisperTokenizer.from_pretrained(init_from_hub_path, language=lang, task="transcribe") |
|
processor = WhisperProcessor.from_pretrained(init_from_hub_path, language=lang, task="transcribe") |
|
|
|
if use_stock_model: |
|
model =WhisperForConditionalGeneration.from_pretrained(init_from_hub_path) |
|
else: |
|
checkpoint_dir = os.path.expanduser(model_dir) |
|
|
|
if os.path.isdir(os.path.join(checkpoint_dir , "adapter_model")): |
|
print('...it looks like this model was tuned using PEFT, because adapter_model/ is present in ckpt dir') |
|
|
|
|
|
peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir , "adapter_model")) |
|
|
|
|
|
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, |
|
load_in_8bit=USE_INT8, |
|
device_map='auto', |
|
use_cache=False, |
|
) |
|
model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir,"adapter_model")) |
|
else: |
|
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_dir, |
|
load_in_8bit=USE_INT8, |
|
device_map='auto', |
|
use_cache=False, |
|
) |
|
model.eval() |
|
|
|
pipe = AutomaticSpeechRecognitionPipeline( |
|
|
|
model=model, |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
chunk_length_s=30, |
|
device=device, |
|
return_timestamps=False, |
|
generate_kwargs=generate_opts, |
|
) |
|
|
|
return(pipe) |
|
|
|
def load_model(model_type='large-v2', |
|
model_dir="../models/whisat-1.2/"): |
|
|
|
lang='english' |
|
USE_INT8 = False |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
transformers.utils.logging.set_verbosity_error() |
|
|
|
init_from_hub_path = f"openai/whisper-{model_type}" |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(device) |
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(init_from_hub_path) |
|
|
|
tokenizer = WhisperTokenizer.from_pretrained(init_from_hub_path, language=lang, task="transcribe") |
|
processor = WhisperProcessor.from_pretrained(init_from_hub_path, language=lang, task="transcribe") |
|
|
|
checkpoint_dir = os.path.expanduser(model_dir) |
|
|
|
peft_config = PeftConfig.from_pretrained(os.path.join(checkpoint_dir , "adapter_model")) |
|
|
|
|
|
model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, |
|
load_in_8bit=USE_INT8, |
|
device_map='auto', |
|
use_cache=False, |
|
) |
|
model = PeftModel.from_pretrained(model, os.path.join(checkpoint_dir,"adapter_model")) |
|
model.eval() |
|
return(model, tokenizer, processor) |
|
|
|
def ASRdirWhisat( |
|
audio_dir, |
|
files_to_include=None, |
|
out_dir = '../whisat_results/', |
|
model_type='large-v2', |
|
model_name='whisat-1.2', |
|
model_dir="../models/whisat-1.2", |
|
use_stock_model=False, |
|
max_new_tokens=112, |
|
num_beams=1, |
|
do_sample=False, |
|
repetition_penalty=1, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
asr_model=prepare_pipeline( |
|
model_type=model_type, |
|
model_dir=model_dir, |
|
use_stock_model=use_stock_model, |
|
generate_opts={'max_new_tokens':max_new_tokens, |
|
'num_beams':num_beams, |
|
'repetition_penalty':repetition_penalty, |
|
'do_sample':do_sample |
|
} |
|
) |
|
|
|
if use_stock_model: |
|
model_name='whisper_' + model_type + '_stock' |
|
|
|
if files_to_include: |
|
assert isinstance(files_to_include,list) ,'files_to_include should be a list of paths relative to audio_dir to transcribe' |
|
audio_files=files_to_include |
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )] |
|
|
|
|
|
asrDir = os.path.join(out_dir,f'ASR_{model_name}') |
|
jsonDir = os.path.join(out_dir,f'JSON_{model_name}') |
|
os.makedirs(asrDir, exist_ok=True) |
|
os.makedirs(jsonDir, exist_ok=True) |
|
|
|
message = "This may take a while on CPU. Go make a cuppa" if asr_model.device.type=="cpu" else "Running on GPU" |
|
print(f'Running ASR for {len(audio_files)} files. {message} ...') |
|
compute_time=0 |
|
total_audio_dur=0 |
|
|
|
st = time.time() |
|
|
|
for audiofile in tqdm(audio_files): |
|
sessname=Path(audiofile).stem |
|
sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve()) |
|
asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") |
|
jsonFile = os.path.join(jsonDir,sesspath, f"{sessname}.json") |
|
os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True) |
|
os.makedirs(os.path.join(jsonDir,sesspath),exist_ok=True) |
|
|
|
with torch.no_grad(): |
|
with autocast(): |
|
try: |
|
result = asr_model(audiofile) |
|
except ValueError as e: |
|
print(f'{e}: {audiofile}') |
|
continue |
|
|
|
|
|
with open(jsonFile, "w") as jf: |
|
json.dump(result, jf, indent=4) |
|
|
|
|
|
|
|
|
|
asrtext = result['text'] |
|
|
|
with open(asrFullFile,'w') as outfile: |
|
outfile.write(asrtext) |
|
|
|
et = time.time() |
|
compute_time = (et-st) |
|
print(f'...transcription complete in {compute_time:.1f} sec') |
|
|
|
|
|
def ASRmanifestWhisat( |
|
manifest_csv, |
|
out_csv, |
|
corpora_root, |
|
model_type='large-v2', |
|
model_dir="../models/whisat-1.2", |
|
use_stock_model=False, |
|
max_new_tokens=112, |
|
num_beams=1, |
|
do_sample=False, |
|
repetition_penalty=1, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df = pd.read_csv(manifest_csv,keep_default_na=False) |
|
fieldnames = list(df.columns) + ['asr'] |
|
|
|
asr_model=prepare_pipeline( |
|
model_type=model_type, |
|
model_dir=model_dir, |
|
use_stock_model=use_stock_model, |
|
generate_opts={'max_new_tokens':max_new_tokens, |
|
'num_beams':num_beams, |
|
'repetition_penalty':repetition_penalty, |
|
'do_sample':do_sample |
|
} |
|
) |
|
|
|
message = "This may take a while on CPU. Go make a cuppa " if asr_model.device.type=="cpu" else "Running on GPU" |
|
print(f'Running ASR for {len(df)} files. {message} ...') |
|
compute_time=0 |
|
total_audio_dur=0 |
|
|
|
st = time.time() |
|
|
|
with open(out_csv, 'w', newline='') as csvfile: |
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames,delimiter=',') |
|
writer.writeheader() |
|
|
|
for i,row in tqdm(df.iterrows(), total=df.shape[0]): |
|
|
|
audiofile=row['wav'].replace('$DATAROOT',corpora_root) |
|
with torch.no_grad(): |
|
with autocast(): |
|
try: |
|
result = asr_model(audiofile) |
|
asrtext = result['text'] |
|
except ValueError as e: |
|
print(f'{e}: {audiofile}') |
|
asrtext='' |
|
|
|
row['asr']=asrtext |
|
writer.writerow( row.to_dict()) |
|
|
|
et = time.time() |
|
compute_time = (et-st) |
|
print(f'...transcription complete in {compute_time:.1f} sec') |
|
|
|
|