|
```python |
|
"""the interface to interact with wakeword model""" |
|
import pyaudio |
|
import threading |
|
import time |
|
import torchaudio |
|
import torch |
|
import numpy as np |
|
import queue |
|
from transformers import WavLMForSequenceClassification |
|
from transformers import AutoFeatureExtractor |
|
|
|
|
|
def int2float(sound): |
|
abs_max = np.abs(sound).max() |
|
sound = sound.astype('float32') |
|
if abs_max > 0: |
|
sound *= 1/abs_max |
|
sound = sound.squeeze() # depends on the use case |
|
return sound |
|
|
|
class RealtimeDecoder(): |
|
|
|
def __init__(self, |
|
model, |
|
) -> None: |
|
self.model = model |
|
self.vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', |
|
model='silero_vad', |
|
force_reload=False, |
|
onnx=False) |
|
|
|
(self.get_speech_timestamps, _, _, _, _) = utils |
|
self.SAMPLE_RATE = 16000 |
|
self.cache_output = { |
|
"cache" : torch.zeros(0, 0, 0, dtype=torch.float), |
|
"wavchunks": [], |
|
} |
|
self.continue_recording = threading.Event() |
|
self.frame_duration_ms = 1000 |
|
self.audio_queue = queue.SimpleQueue() |
|
self.speech_queue = queue.SimpleQueue() |
|
|
|
def start_recording(self, wait_enter_to_stop=True): |
|
def stop(): |
|
input("Press Enter to stop the recording:\n\n") |
|
self.continue_recording.set() |
|
def record(): |
|
audio = pyaudio.PyAudio() |
|
stream = audio.open(format=pyaudio.paInt16, |
|
channels=1, |
|
rate=self.SAMPLE_RATE, |
|
input=True, |
|
frames_per_buffer=int(self.SAMPLE_RATE / 10)) |
|
while not self.continue_recording.is_set(): |
|
audio_chunk = stream.read(int(self.SAMPLE_RATE * self.frame_duration_ms / 1000.0), exception_on_overflow = False) |
|
audio_int16 = np.frombuffer(audio_chunk, np.int16) |
|
audio_float32 = int2float(audio_int16) |
|
waveform = torch.from_numpy(audio_float32) |
|
self.audio_queue.put(waveform) |
|
print("Finish record") |
|
stream.close() |
|
if wait_enter_to_stop: |
|
stop_listener_thread = threading.Thread(target=stop, daemon=False) |
|
else: |
|
stop_listener_thread = None |
|
recording_thread = threading.Thread(target=record, daemon=False) |
|
return stop_listener_thread, recording_thread |
|
|
|
def finish_realtime_decode(self): |
|
self.cache_output = { |
|
"cache" : torch.zeros(0, 0, 0, dtype=torch.float), |
|
"wavchunks": [], |
|
} |
|
|
|
def start_decoding(self): |
|
def decode(): |
|
while not self.continue_recording.is_set(): |
|
if self.audio_queue.qsize() > 0: |
|
currunt_wavform = self.audio_queue.get() |
|
if currunt_wavform is not None: |
|
self.cache_output['wavchunks'].append(currunt_wavform) |
|
self.cache_output['wavchunks'] = self.cache_output['wavchunks'][-4:] |
|
|
|
if len(self.cache_output['wavchunks']) > 1: |
|
wavform = torch.cat(self.cache_output['wavchunks'][-2:], dim=-1) |
|
speech_timestamps = self.get_speech_timestamps(wavform, self.vad_model, sampling_rate=self.SAMPLE_RATE) |
|
logits = [1, 0] |
|
if len(speech_timestamps) > 0: |
|
input_features = feature_extractor.pad([{"input_values": wavform}], padding=True, return_tensors="pt") |
|
logits = self.model(**input_features).logits.softmax(dim=-1).squeeze() |
|
if logits[1] > 0.6: |
|
print("hey armar", logits, wavform.size(-1) / self.SAMPLE_RATE) |
|
self.cache_output['wavchunks'] = [] |
|
else: |
|
print('.'+'.'*self.audio_queue.qsize()) |
|
else: |
|
time.sleep(0.01) |
|
print("KWS thread finish") |
|
kws_decode_thread = threading.Thread(target=decode, daemon=False) |
|
return kws_decode_thread |
|
|
|
if __name__ == "__main__": |
|
print("Model loading....") |
|
|
|
kws_model = WavLMForSequenceClassification.from_pretrained('nguyenvulebinh/heyarmar') |
|
feature_extractor = AutoFeatureExtractor.from_pretrained('nguyenvulebinh/heyarmar') |
|
|
|
print("Model loaded....") |
|
|
|
# file_wave = './99.wav' |
|
# wav, rate = torchaudio.load(file_wave) |
|
# input_features = feature_extractor.pad([{"input_values": item} for item in wav], padding=True, return_tensors="pt") |
|
# output = kws_model(**input_features) |
|
# print(output.logits.softmax(dim=-1)) |
|
|
|
|
|
obj_decode = RealtimeDecoder(kws_model) |
|
recording_threads = obj_decode.start_recording() |
|
kws_decode_thread = obj_decode.start_decoding() |
|
for thread in recording_threads: |
|
if thread is not None: |
|
thread.start() |
|
kws_decode_thread.start() |
|
for thread in recording_threads: |
|
if thread is not None: |
|
thread.join() |
|
kws_decode_thread.join() |
|
|
|
``` |