Audio Classification
ONNX
stefanpp commited on
Commit
dc2d03d
1 Parent(s): d5d7d09

added usage part to README.md

Browse files
Files changed (1) hide show
  1. README.md +65 -0
README.md CHANGED
@@ -37,6 +37,71 @@ The model is trained at 700 and 704ms (11x64ms) inputs of raw audio. The sample
37
 
38
  The model classifies each audio input into 2 classes - eos (id: 0) and not_eos (id: 1).
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Latency (& Memory) Optimization
42
  - Knowledge Distillation
 
37
 
38
  The model classifies each audio input into 2 classes - eos (id: 0) and not_eos (id: 1).
39
 
40
+ # Usage
41
+
42
+ ```python
43
+ from transformers import Wav2Vec2Processor, AutoConfig
44
+ import onnxruntime as rt
45
+ import torch
46
+ import torch.nn.functional as F
47
+ import numpy as np
48
+ import os
49
+ import torchaudio
50
+
51
+
52
+ class EndOfSpeechDetection:
53
+ processor: Wav2Vec2Processor
54
+ config: AutoConfig
55
+ session: rt.InferenceSession
56
+
57
+ def load_model(self, path, use_gpu=False):
58
+ processor = Wav2Vec2Processor.from_pretrained(path)
59
+ config = AutoConfig.from_pretrained(path)
60
+
61
+ sess_options = rt.SessionOptions()
62
+ sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
63
+
64
+ providers = ["ROCMExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
65
+ session = rt.InferenceSession(
66
+ os.path.join(path, "model.onnx"), sess_options, providers=providers
67
+ )
68
+ return processor, config, session
69
+
70
+ def predict(self, segment, file_type="pcm"):
71
+ if file_type == "pcm":
72
+ # pcm files
73
+ speech_array = np.memmap(segment, dtype="float32", mode="r").astype(
74
+ np.float32
75
+ )
76
+ else:
77
+ # wave files
78
+ speech_array, _ = torchaudio.load(segment)
79
+ speech_array = speech_array[0].numpy()
80
+
81
+ features = self.processor(
82
+ speech_array, sampling_rate=16000, return_tensors="pt", padding=True
83
+ )
84
+ input_values = features.input_values
85
+ outputs = self.session.run(
86
+ [self.session.get_outputs()[-1].name],
87
+ {self.session.get_inputs()[-1].name: input_values.detach().cpu().numpy()},
88
+ )[0]
89
+ softmax_output = F.softmax(torch.tensor(outputs), dim=1)
90
+
91
+ both_classes_with_prob = {
92
+ self.config.id2label[i]: softmax_output[0][i].item()
93
+ for i in range(len(softmax_output[0]))
94
+ }
95
+
96
+ return both_classes_with_prob
97
+
98
+
99
+ if __name__ == "__main__":
100
+ eos = EndOfSpeechDetection()
101
+ eos.processor, eos.config, eos.session = eos.load_model("eos-model-onnx")
102
+ print(eos.predict("some.wav"))
103
+
104
+ ```
105
 
106
  # Latency (& Memory) Optimization
107
  - Knowledge Distillation