from transformers import Wav2Vec2Processor, AutoConfig import onnxruntime as rt import torch import torch.nn.functional as F import numpy as np import os import torchaudio import soundfile as sf class EndOfSpeechDetection: processor: Wav2Vec2Processor config: AutoConfig session: rt.InferenceSession def load_model(self, path, use_gpu=False): processor = Wav2Vec2Processor.from_pretrained(path) config = AutoConfig.from_pretrained(path) sess_options = rt.SessionOptions() sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"] session = rt.InferenceSession( os.path.join(path, "model.onnx"), sess_options, providers=providers ) return processor, config, session def predict(self, segment, file_type="pcm"): if file_type == "pcm": # pcm files speech_array = np.memmap(segment, dtype="float32", mode="r").astype( np.float32 ) else: # wave files speech_array, _ = torchaudio.load(segment) speech_array = speech_array[0].numpy() features = self.processor( speech_array, sampling_rate=16000, return_tensors="pt", padding=True ) input_values = features.input_values outputs = self.session.run( [self.session.get_outputs()[-1].name], {self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()}, )[0] softmax_output = F.softmax(torch.tensor(outputs), dim=1) both_classes_with_prob = { self.config.id2label[i]: softmax_output[0][i].item() for i in range(len(softmax_output[0])) } return both_classes_with_prob if __name__ == "__main__": eos = EndOfSpeechDetection() eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx") audio_file = "5sec_audio.wav" audio, sr = torchaudio.load(audio_file) audio = audio[0].numpy() audio_len = len(audio) segment_len = 700 * sr // 1000 segments = [] for i in range(0, audio_len, segment_len): if i + segment_len < audio_len: segment = audio[i : i + segment_len] else: segment = audio[i:] segments.append(segment) if not os.path.exists("segments"): os.makedirs("segments") for i, segment in enumerate(segments): sf.write(f"segments/segment_{i}.wav", segment, sr) print(eos.predict(f"segments/segment_{i}.wav", file_type="wav"))