Audio Classification
ONNX
stefanpp commited on
Commit
d5d7d09
1 Parent(s): 2085273

added test inference script

Browse files
Files changed (1) hide show
  1. inference.py +80 -0
inference.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Processor, AutoConfig
2
+ import onnxruntime as rt
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import os
7
+ import torchaudio
8
+ import soundfile as sf
9
+
10
+
11
+ class EndOfSpeechDetection:
12
+ processor: Wav2Vec2Processor
13
+ config: AutoConfig
14
+ session: rt.InferenceSession
15
+
16
+ def load_model(self, path, use_gpu=False):
17
+ processor = Wav2Vec2Processor.from_pretrained(path)
18
+ config = AutoConfig.from_pretrained(path)
19
+
20
+ sess_options = rt.SessionOptions()
21
+ sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
22
+
23
+ providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
24
+ session = rt.InferenceSession(
25
+ os.path.join(path, "model.onnx"), sess_options, providers=providers
26
+ )
27
+ return processor, config, session
28
+
29
+ def predict(self, segment, file_type="pcm"):
30
+ if file_type == "pcm":
31
+ # pcm files
32
+ speech_array = np.memmap(segment, dtype="float32", mode="r").astype(
33
+ np.float32
34
+ )
35
+ else:
36
+ # wave files
37
+ speech_array, _ = torchaudio.load(segment)
38
+ speech_array = speech_array[0].numpy()
39
+
40
+ features = self.processor(
41
+ speech_array, sampling_rate=16000, return_tensors="pt", padding=True
42
+ )
43
+ input_values = features.input_values
44
+ outputs = self.session.run(
45
+ [self.session.get_outputs()[-1].name],
46
+ {self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()},
47
+ )[0]
48
+ softmax_output = F.softmax(torch.tensor(outputs), dim=1)
49
+
50
+ both_classes_with_prob = {
51
+ self.config.id2label[i]: softmax_output[0][i].item()
52
+ for i in range(len(softmax_output[0]))
53
+ }
54
+
55
+ return both_classes_with_prob
56
+
57
+
58
+ if __name__ == "__main__":
59
+ eos = EndOfSpeechDetection()
60
+ eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx")
61
+
62
+ audio_file = "5sec_audio.wav"
63
+ audio, sr = torchaudio.load(audio_file)
64
+ audio = audio[0].numpy()
65
+ audio_len = len(audio)
66
+ segment_len = 700 * sr // 1000
67
+ segments = []
68
+ for i in range(0, audio_len, segment_len):
69
+ if i + segment_len < audio_len:
70
+ segment = audio[i : i + segment_len]
71
+ else:
72
+ segment = audio[i:]
73
+
74
+ segments.append(segment)
75
+
76
+ if not os.path.exists("segments"):
77
+ os.makedirs("segments")
78
+ for i, segment in enumerate(segments):
79
+ sf.write(f"segments/segment_{i}.wav", segment, sr)
80
+ print(eos.predict(f"segments/segment_{i}.wav", file_type="wav"))