|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC |
|
from optimum.bettertransformer import BetterTransformer |
|
import torch |
|
import librosa |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
torch.random.manual_seed(0); |
|
|
|
|
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
|
model = BetterTransformer.transform(model) |
|
|
|
|
|
def load_audio(audio_path, processor): |
|
audio, sr = librosa.load(audio_path, sr=None) |
|
|
|
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_features |
|
return input_values |
|
|
|
@torch.inference_mode() |
|
def get_emissions(input_values, model): |
|
results = model(input_values,).logits |
|
return results |
|
|
|
def score_audio(audio_path, true_result): |
|
true_result = true_result.split('/') |
|
input_values = load_audio(audio_path, processor) |
|
logits = get_emissions(input_values, model).cpu() |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(predicted_ids)[0].lower() |
|
|
|
|
|
result = {'transcription': transcription, |
|
'score': int(any([x in transcription for x in true_result])), |
|
} |
|
return result |
|
|
|
|
|
|