"""the interface to interact with wakeword model"""
import pyaudio
import threading
import time
import torchaudio
import torch
import numpy as np
import queue
from transformers import WavLMForSequenceClassification
from transformers import AutoFeatureExtractor
def int2float(sound):
abs_max = np.abs(sound).max()
sound = sound.astype('float32')
if abs_max > 0:
sound *= 1/abs_max
sound = sound.squeeze()
return sound
class RealtimeDecoder():
def __init__(self,
model,
) -> None:
self.model = model
self.vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False)
(self.get_speech_timestamps, _, _, _, _) = utils
self.SAMPLE_RATE = 16000
self.cache_output = {
"cache" : torch.zeros(0, 0, 0, dtype=torch.float),
"wavchunks": [],
}
self.continue_recording = threading.Event()
self.frame_duration_ms = 1000
self.audio_queue = queue.SimpleQueue()
self.speech_queue = queue.SimpleQueue()
def start_recording(self, wait_enter_to_stop=True):
def stop():
input("Press Enter to stop the recording:\n\n")
self.continue_recording.set()
def record():
audio = pyaudio.PyAudio()
stream = audio.open(format=pyaudio.paInt16,
channels=1,
rate=self.SAMPLE_RATE,
input=True,
frames_per_buffer=int(self.SAMPLE_RATE / 10))
while not self.continue_recording.is_set():
audio_chunk = stream.read(int(self.SAMPLE_RATE * self.frame_duration_ms / 1000.0), exception_on_overflow = False)
audio_int16 = np.frombuffer(audio_chunk, np.int16)
audio_float32 = int2float(audio_int16)
waveform = torch.from_numpy(audio_float32)
self.audio_queue.put(waveform)
print("Finish record")
stream.close()
if wait_enter_to_stop:
stop_listener_thread = threading.Thread(target=stop, daemon=False)
else:
stop_listener_thread = None
recording_thread = threading.Thread(target=record, daemon=False)
return stop_listener_thread, recording_thread
def finish_realtime_decode(self):
self.cache_output = {
"cache" : torch.zeros(0, 0, 0, dtype=torch.float),
"wavchunks": [],
}
def start_decoding(self):
def decode():
while not self.continue_recording.is_set():
if self.audio_queue.qsize() > 0:
currunt_wavform = self.audio_queue.get()
if currunt_wavform is not None:
self.cache_output['wavchunks'].append(currunt_wavform)
self.cache_output['wavchunks'] = self.cache_output['wavchunks'][-4:]
if len(self.cache_output['wavchunks']) > 1:
wavform = torch.cat(self.cache_output['wavchunks'][-2:], dim=-1)
speech_timestamps = self.get_speech_timestamps(wavform, self.vad_model, sampling_rate=self.SAMPLE_RATE)
logits = [1, 0]
if len(speech_timestamps) > 0:
input_features = feature_extractor.pad([{"input_values": wavform}], padding=True, return_tensors="pt")
logits = self.model(**input_features).logits.softmax(dim=-1).squeeze()
if logits[1] > 0.6:
print("hey armar", logits, wavform.size(-1) / self.SAMPLE_RATE)
self.cache_output['wavchunks'] = []
else:
print('.'+'.'*self.audio_queue.qsize())
else:
time.sleep(0.01)
print("KWS thread finish")
kws_decode_thread = threading.Thread(target=decode, daemon=False)
return kws_decode_thread
if __name__ == "__main__":
print("Model loading....")
kws_model = WavLMForSequenceClassification.from_pretrained('nguyenvulebinh/heyarmar')
feature_extractor = AutoFeatureExtractor.from_pretrained('nguyenvulebinh/heyarmar')
print("Model loaded....")
obj_decode = RealtimeDecoder(kws_model)
recording_threads = obj_decode.start_recording()
kws_decode_thread = obj_decode.start_decoding()
for thread in recording_threads:
if thread is not None:
thread.start()
kws_decode_thread.start()
for thread in recording_threads:
if thread is not None:
thread.join()
kws_decode_thread.join()