File size: 2,652 Bytes
d5d7d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from transformers import Wav2Vec2Processor, AutoConfig
import onnxruntime as rt
import torch
import torch.nn.functional as F
import numpy as np
import os
import torchaudio
import soundfile as sf


class EndOfSpeechDetection:
    processor: Wav2Vec2Processor
    config: AutoConfig
    session: rt.InferenceSession

    def load_model(self, path, use_gpu=False):
        processor = Wav2Vec2Processor.from_pretrained(path)
        config = AutoConfig.from_pretrained(path)

        sess_options = rt.SessionOptions()
        sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL

        providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
        session = rt.InferenceSession(
            os.path.join(path, "model.onnx"), sess_options, providers=providers
        )
        return processor, config, session

    def predict(self, segment, file_type="pcm"):
        if file_type == "pcm":
            # pcm files
            speech_array = np.memmap(segment, dtype="float32", mode="r").astype(
                np.float32
            )
        else:
            # wave files
            speech_array, _ = torchaudio.load(segment)
            speech_array = speech_array[0].numpy()

        features = self.processor(
            speech_array, sampling_rate=16000, return_tensors="pt", padding=True
        )
        input_values = features.input_values
        outputs = self.session.run(
            [self.session.get_outputs()[-1].name],
            {self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()},
        )[0]
        softmax_output = F.softmax(torch.tensor(outputs), dim=1)

        both_classes_with_prob = {
            self.config.id2label[i]: softmax_output[0][i].item()
            for i in range(len(softmax_output[0]))
        }

        return both_classes_with_prob


if __name__ == "__main__":
    eos = EndOfSpeechDetection()
    eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx")

    audio_file = "5sec_audio.wav"
    audio, sr = torchaudio.load(audio_file)
    audio = audio[0].numpy()
    audio_len = len(audio)
    segment_len = 700 * sr // 1000
    segments = []
    for i in range(0, audio_len, segment_len):
        if i + segment_len < audio_len:
            segment = audio[i : i + segment_len]
        else:
            segment = audio[i:]

        segments.append(segment)

    if not os.path.exists("segments"):
        os.makedirs("segments")
    for i, segment in enumerate(segments):
        sf.write(f"segments/segment_{i}.wav", segment, sr)
        print(eos.predict(f"segments/segment_{i}.wav", file_type="wav"))