import os from peft import PeftModel, PeftConfig import torch from torch.cuda.amp import autocast from torch.utils.data import DataLoader from tqdm import tqdm import transformers from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, GenerationConfig from transformers import pipeline, AutomaticSpeechRecognitionPipeline import argparse import time from pathlib import Path import json import pandas as pd import csv def prepare_pipeline(model_path, generate_kwargs): """Prepare a pipeline for ASR inference Args: model_path (str): path to model directory / huggingface model name generate_kwargs (dict): options to pass to pipeline Returns: pipeline: ASR pipeline """ processor = WhisperProcessor.from_pretrained(model_path) asr_pipeline = pipeline( "automatic-speech-recognition", model=model_path, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, generate_kwargs=generate_kwargs, model_kwargs={"load_in_8bit": False}, device_map='auto') return asr_pipeline def ASRdirWhisat( audio_dir, out_dir = '../whisat_results/', model_dir=".", max_new_tokens=112, num_beams=1, do_sample=False, repetition_penalty=1, ): ## ASR using fine-tuned Transformers Whisper # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Simply trancsribe each file in the specified folder separately # Whisper takes 30-second input. Anything shorter than this will be 0 padded. Longer will be concatenated. # Save output in same directory structure as input in specified top-level folder # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ asr_model=prepare_pipeline( model_type=model_type, model_dir=model_dir, use_stock_model=use_stock_model, generate_kwargs={'max_new_tokens':max_new_tokens, 'num_beams':num_beams, 'repetition_penalty':repetition_penalty, 'do_sample':do_sample } ) audio_files = [str(f) for f in Path(audio_dir).rglob("*") if (str(f).rsplit('.',maxsplit=1)[-1] in ['MOV', 'mov', 'WAV', 'wav', 'mp4', 'mp3', 'm4a', 'aac', 'flac', 'alac', 'ogg'] and f.is_file() )] # audio_identifier = os.path.basename(audio_dir) os.makedirs(out_dir, exist_ok=True) message = "This may take a while on CPU." if asr_model.device.type=="cpu" else "Running on GPU" print(f'Running ASR for {len(audio_files)} files. {message} ...') compute_time=0 total_audio_dur=0 # get the start time st = time.time() asrDir = out_dir for audiofile in tqdm(audio_files): sessname=Path(audiofile).stem sesspath=os.path.relpath(os.path.dirname(Path(audiofile).resolve()),Path(audio_dir).resolve()) asrFullFile = os.path.join(asrDir,sesspath,f"{sessname}.asr.txt") # full session ASR results file os.makedirs(os.path.join(asrDir,sesspath),exist_ok=True) with torch.no_grad(): with autocast(): try: result = asr_model(audiofile) except ValueError as e: print(f'{e}: {audiofile}') continue asrtext = result['text'] with open(asrFullFile,'w') as outfile: outfile.write(asrtext) # print(asrtext) et = time.time() compute_time = (et-st) print(f'...transcription complete in {compute_time:.1f} sec')