File size: 5,145 Bytes
7a8d5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
```python
"""the interface to interact with wakeword model"""
import pyaudio
import threading
import time
import torchaudio
import torch
import numpy as np
import queue
from transformers import WavLMForSequenceClassification
from transformers import AutoFeatureExtractor


def int2float(sound):
    abs_max = np.abs(sound).max()
    sound = sound.astype('float32')
    if abs_max > 0:
        sound *= 1/abs_max
    sound = sound.squeeze()  # depends on the use case
    return sound

class RealtimeDecoder():

    def __init__(self,
        model,
    ) -> None:
        self.model = model
        self.vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                              model='silero_vad',
                              force_reload=False,
                              onnx=False)

        (self.get_speech_timestamps, _, _, _, _) = utils
        self.SAMPLE_RATE = 16000
        self.cache_output = {
            "cache" : torch.zeros(0, 0, 0, dtype=torch.float),
            "wavchunks": [],
        }
        self.continue_recording = threading.Event()
        self.frame_duration_ms = 1000
        self.audio_queue = queue.SimpleQueue()
        self.speech_queue = queue.SimpleQueue()

    def start_recording(self, wait_enter_to_stop=True):
        def stop():
            input("Press Enter to stop the recording:\n\n")
            self.continue_recording.set()
        def record():
            audio = pyaudio.PyAudio()
            stream = audio.open(format=pyaudio.paInt16,
                    channels=1,
                    rate=self.SAMPLE_RATE,
                    input=True,
                    frames_per_buffer=int(self.SAMPLE_RATE / 10))
            while not self.continue_recording.is_set():
                audio_chunk = stream.read(int(self.SAMPLE_RATE * self.frame_duration_ms / 1000.0), exception_on_overflow = False)
                audio_int16 = np.frombuffer(audio_chunk, np.int16)
                audio_float32 = int2float(audio_int16)
                waveform = torch.from_numpy(audio_float32)
                self.audio_queue.put(waveform)
            print("Finish record")
            stream.close()
        if wait_enter_to_stop:
            stop_listener_thread = threading.Thread(target=stop, daemon=False)
        else:
            stop_listener_thread = None
        recording_thread = threading.Thread(target=record, daemon=False)
        return stop_listener_thread, recording_thread

    def finish_realtime_decode(self):        
        self.cache_output = {
            "cache" : torch.zeros(0, 0, 0, dtype=torch.float),
            "wavchunks": [],
        }

    def start_decoding(self):
        def decode():
            while not self.continue_recording.is_set():
                if self.audio_queue.qsize() > 0:
                    currunt_wavform = self.audio_queue.get()
                    if currunt_wavform is not None:
                        self.cache_output['wavchunks'].append(currunt_wavform)
                        self.cache_output['wavchunks'] = self.cache_output['wavchunks'][-4:]

                    if len(self.cache_output['wavchunks']) > 1:
                        wavform = torch.cat(self.cache_output['wavchunks'][-2:], dim=-1)
                        speech_timestamps = self.get_speech_timestamps(wavform, self.vad_model, sampling_rate=self.SAMPLE_RATE)
                        logits = [1, 0]
                        if len(speech_timestamps) > 0:
                            input_features = feature_extractor.pad([{"input_values": wavform}], padding=True, return_tensors="pt")
                            logits = self.model(**input_features).logits.softmax(dim=-1).squeeze()
                        if logits[1] > 0.6:
                            print("hey armar", logits, wavform.size(-1) / self.SAMPLE_RATE)
                            self.cache_output['wavchunks'] = []
                        else:
                            print('.'+'.'*self.audio_queue.qsize())
                else:
                    time.sleep(0.01)
            print("KWS thread finish")
        kws_decode_thread = threading.Thread(target=decode, daemon=False)
        return kws_decode_thread

if __name__ == "__main__":
    print("Model loading....")

    kws_model = WavLMForSequenceClassification.from_pretrained('nguyenvulebinh/heyarmar')
    feature_extractor = AutoFeatureExtractor.from_pretrained('nguyenvulebinh/heyarmar')

    print("Model loaded....")

    # file_wave = './99.wav'
    # wav, rate = torchaudio.load(file_wave)
    # input_features = feature_extractor.pad([{"input_values": item} for item in wav], padding=True, return_tensors="pt")
    # output = kws_model(**input_features)
    # print(output.logits.softmax(dim=-1))


    obj_decode = RealtimeDecoder(kws_model)
    recording_threads = obj_decode.start_recording()
    kws_decode_thread = obj_decode.start_decoding()
    for thread in recording_threads:
        if thread is not None:
            thread.start()
    kws_decode_thread.start()
    for thread in recording_threads:
        if thread is not None:
            thread.join()
    kws_decode_thread.join()

```