added usage part to README.md
Browse files
README.md
CHANGED
@@ -37,6 +37,71 @@ The model is trained at 700 and 704ms (11x64ms) inputs of raw audio. The sample
|
|
37 |
|
38 |
The model classifies each audio input into 2 classes - eos (id: 0) and not_eos (id: 1).
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Latency (& Memory) Optimization
|
42 |
- Knowledge Distillation
|
|
|
37 |
|
38 |
The model classifies each audio input into 2 classes - eos (id: 0) and not_eos (id: 1).
|
39 |
|
40 |
+
# Usage
|
41 |
+
|
42 |
+
```python
|
43 |
+
from transformers import Wav2Vec2Processor, AutoConfig
|
44 |
+
import onnxruntime as rt
|
45 |
+
import torch
|
46 |
+
import torch.nn.functional as F
|
47 |
+
import numpy as np
|
48 |
+
import os
|
49 |
+
import torchaudio
|
50 |
+
|
51 |
+
|
52 |
+
class EndOfSpeechDetection:
|
53 |
+
processor: Wav2Vec2Processor
|
54 |
+
config: AutoConfig
|
55 |
+
session: rt.InferenceSession
|
56 |
+
|
57 |
+
def load_model(self, path, use_gpu=False):
|
58 |
+
processor = Wav2Vec2Processor.from_pretrained(path)
|
59 |
+
config = AutoConfig.from_pretrained(path)
|
60 |
+
|
61 |
+
sess_options = rt.SessionOptions()
|
62 |
+
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
|
63 |
+
|
64 |
+
providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
|
65 |
+
session = rt.InferenceSession(
|
66 |
+
os.path.join(path, "model.onnx"), sess_options, providers=providers
|
67 |
+
)
|
68 |
+
return processor, config, session
|
69 |
+
|
70 |
+
def predict(self, segment, file_type="pcm"):
|
71 |
+
if file_type == "pcm":
|
72 |
+
# pcm files
|
73 |
+
speech_array = np.memmap(segment, dtype="float32", mode="r").astype(
|
74 |
+
np.float32
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
# wave files
|
78 |
+
speech_array, _ = torchaudio.load(segment)
|
79 |
+
speech_array = speech_array[0].numpy()
|
80 |
+
|
81 |
+
features = self.processor(
|
82 |
+
speech_array, sampling_rate=16000, return_tensors="pt", padding=True
|
83 |
+
)
|
84 |
+
input_values = features.input_values
|
85 |
+
outputs = self.session.run(
|
86 |
+
[self.session.get_outputs()[-1].name],
|
87 |
+
{self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()},
|
88 |
+
)[0]
|
89 |
+
softmax_output = F.softmax(torch.tensor(outputs), dim=1)
|
90 |
+
|
91 |
+
both_classes_with_prob = {
|
92 |
+
self.config.id2label[i]: softmax_output[0][i].item()
|
93 |
+
for i in range(len(softmax_output[0]))
|
94 |
+
}
|
95 |
+
|
96 |
+
return both_classes_with_prob
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
eos = EndOfSpeechDetection()
|
101 |
+
eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx")
|
102 |
+
print(eos.predict("some.wav"))
|
103 |
+
|
104 |
+
```
|
105 |
|
106 |
# Latency (& Memory) Optimization
|
107 |
- Knowledge Distillation
|