|
from transformers import Wav2Vec2Processor, AutoConfig |
|
import onnxruntime as rt |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import os |
|
import torchaudio |
|
import soundfile as sf |
|
|
|
|
|
class EndOfSpeechDetection: |
|
processor: Wav2Vec2Processor |
|
config: AutoConfig |
|
session: rt.InferenceSession |
|
|
|
def load_model(self, path, use_gpu=False): |
|
processor = Wav2Vec2Processor.from_pretrained(path) |
|
config = AutoConfig.from_pretrained(path) |
|
|
|
sess_options = rt.SessionOptions() |
|
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
|
providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"] |
|
session = rt.InferenceSession( |
|
os.path.join(path, "model.onnx"), sess_options, providers=providers |
|
) |
|
return processor, config, session |
|
|
|
def predict(self, segment, file_type="pcm"): |
|
if file_type == "pcm": |
|
|
|
speech_array = np.memmap(segment, dtype="float32", mode="r").astype( |
|
np.float32 |
|
) |
|
else: |
|
|
|
speech_array, _ = torchaudio.load(segment) |
|
speech_array = speech_array[0].numpy() |
|
|
|
features = self.processor( |
|
speech_array, sampling_rate=16000, return_tensors="pt", padding=True |
|
) |
|
input_values = features.input_values |
|
outputs = self.session.run( |
|
[self.session.get_outputs()[-1].name], |
|
{self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()}, |
|
)[0] |
|
softmax_output = F.softmax(torch.tensor(outputs), dim=1) |
|
|
|
both_classes_with_prob = { |
|
self.config.id2label[i]: softmax_output[0][i].item() |
|
for i in range(len(softmax_output[0])) |
|
} |
|
|
|
return both_classes_with_prob |
|
|
|
|
|
if __name__ == "__main__": |
|
eos = EndOfSpeechDetection() |
|
eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx") |
|
|
|
audio_file = "5sec_audio.wav" |
|
audio, sr = torchaudio.load(audio_file) |
|
audio = audio[0].numpy() |
|
audio_len = len(audio) |
|
segment_len = 700 * sr // 1000 |
|
segments = [] |
|
for i in range(0, audio_len, segment_len): |
|
if i + segment_len < audio_len: |
|
segment = audio[i : i + segment_len] |
|
else: |
|
segment = audio[i:] |
|
|
|
segments.append(segment) |
|
|
|
if not os.path.exists("segments"): |
|
os.makedirs("segments") |
|
for i, segment in enumerate(segments): |
|
sf.write(f"segments/segment_{i}.wav", segment, sr) |
|
print(eos.predict(f"segments/segment_{i}.wav", file_type="wav")) |
|
|