|
import os |
|
from peft import PeftModel, PeftConfig |
|
import torch |
|
from torch.cuda.amp import autocast |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
import transformers |
|
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig |
|
from transformers import pipeline, AutomaticSpeechRecognitionPipeline |
|
import argparse |
|
import time |
|
from pathlib import Path |
|
import json |
|
import pandas as pd |
|
import csv |
|
|
|
def prepare_pipeline(model_path, generate_kwargs): |
|
"""Prepare a pipeline for ASR inference |
|
Args: |
|
model_path (str): path to model directory / huggingface model name |
|
generate_kwargs (dict): options to pass to pipeline |
|
Returns: |
|
pipeline: ASR pipeline |
|
""" |
|
processor = WhisperProcessor.from_pretrained(model_path) |
|
|
|
asr_pipeline = pipeline( |
|
"automatic-speech-recognition", |
|
model=model_path, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
generate_kwargs=generate_kwargs, |
|
model_kwargs={"load_in_8bit": False}, |
|
device_map='auto') |
|
return asr_pipeline |
|
|
|
def ASRdirWhisat( |
|
audio_dir, |
|
out_dir = '../whisat_results/', |
|
model_dir=".", |
|
max_new_tokens=112, |
|
num_beams=1, |
|
do_sample=False, |
|
repetition_penalty=1, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
asr_model=prepare_pipeline( |
|
model_type=model_type, |
|
model_dir=model_dir, |
|
use_stock_model=use_stock_model, |
|
generate_kwargs={'max_new_tokens':max_new_tokens, |
|
'num_beams':num_beams, |
|
'repetition_penalty':repetition_penalty, |
|
'do_sample':do_sample |
|
} |
|
) |
|
|
|
|
|
audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )] |
|
|
|
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
message = "This may take a while on CPU." if asr_model.device.type=="cpu" else "Running on GPU" |
|
print(f'Running ASR for {len(audio_files)} files. {message} ...') |
|
compute_time=0 |
|
total_audio_dur=0 |
|
|
|
st = time.time() |
|
asrDir = out_dir |
|
for audiofile in tqdm(audio_files): |
|
sessname=Path(audiofile).stem |
|
sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve()) |
|
asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") |
|
os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True) |
|
|
|
with torch.no_grad(): |
|
with autocast(): |
|
try: |
|
result = asr_model(audiofile) |
|
except ValueError as e: |
|
print(f'{e}: {audiofile}') |
|
continue |
|
|
|
asrtext = result['text'] |
|
|
|
with open(asrFullFile,'w') as outfile: |
|
outfile.write(asrtext) |
|
|
|
et = time.time() |
|
compute_time = (et-st) |
|
print(f'...transcription complete in {compute_time:.1f} sec') |
|
|
|
|