```python |
"""the interface to interact with wakeword model""" |
import pyaudio |
import threading |
import time |
import torchaudio |
import torch |
import numpy as np |
import queue |
from transformers import WavLMForSequenceClassification |
from transformers import AutoFeatureExtractor |
def int2float(sound): |
abs_max = np.abs(sound).max() |
sound = sound.astype('float32') |
if abs_max > 0: |
sound *= 1/abs_max |
sound = sound.squeeze() # depends on the use case |
return sound |
class RealtimeDecoder(): |
def __init__(self, |
model, |
) -> None: |
self.model = model |
self.vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', |
model='silero_vad', |
force_reload=False, |
onnx=False) |
(self.get_speech_timestamps, _, _, _, _) = utils |
self.SAMPLE_RATE = 16000 |
self.cache_output = { |
"cache" : torch.zeros(0, 0, 0, dtype=torch.float), |
"wavchunks": [], |
} |
self.continue_recording = threading.Event() |
self.frame_duration_ms = 1000 |
self.audio_queue = queue.SimpleQueue() |
self.speech_queue = queue.SimpleQueue() |
def start_recording(self, wait_enter_to_stop=True): |
def stop(): |
input("Press Enter to stop the recording:\n\n") |
self.continue_recording.set() |
def record(): |
audio = pyaudio.PyAudio() |
stream = audio.open(format=pyaudio.paInt16, |
channels=1, |
rate=self.SAMPLE_RATE, |
input=True, |
frames_per_buffer=int(self.SAMPLE_RATE / 10)) |
while not self.continue_recording.is_set(): |
audio_chunk = stream.read(int(self.SAMPLE_RATE * self.frame_duration_ms / 1000.0), exception_on_overflow = False) |
audio_int16 = np.frombuffer(audio_chunk, np.int16) |
audio_float32 = int2float(audio_int16) |
waveform = torch.from_numpy(audio_float32) |
self.audio_queue.put(waveform) |
print("Finish record") |
stream.close() |
if wait_enter_to_stop: |
stop_listener_thread = threading.Thread(target=stop, daemon=False) |
else: |
stop_listener_thread = None |
recording_thread = threading.Thread(target=record, daemon=False) |
return stop_listener_thread, recording_thread |
def finish_realtime_decode(self): |
self.cache_output = { |
"cache" : torch.zeros(0, 0, 0, dtype=torch.float), |
"wavchunks": [], |
} |
def start_decoding(self): |
def decode(): |
while not self.continue_recording.is_set(): |
if self.audio_queue.qsize() > 0: |
currunt_wavform = self.audio_queue.get() |
if currunt_wavform is not None: |
self.cache_output['wavchunks'].append(currunt_wavform) |
self.cache_output['wavchunks'] = self.cache_output['wavchunks'][-4:] |
if len(self.cache_output['wavchunks']) > 1: |
wavform = torch.cat(self.cache_output['wavchunks'][-2:], dim=-1) |
speech_timestamps = self.get_speech_timestamps(wavform, self.vad_model, sampling_rate=self.SAMPLE_RATE) |
logits = [1, 0] |
if len(speech_timestamps) > 0: |
input_features = feature_extractor.pad([{"input_values": wavform}], padding=True, return_tensors="pt") |
logits = self.model(**input_features).logits.softmax(dim=-1).squeeze() |
if logits[1] > 0.6: |
print("hey armar", logits, wavform.size(-1) / self.SAMPLE_RATE) |
self.cache_output['wavchunks'] = [] |
else: |
print('.'+'.'*self.audio_queue.qsize()) |
else: |
time.sleep(0.01) |
print("KWS thread finish") |
kws_decode_thread = threading.Thread(target=decode, daemon=False) |
return kws_decode_thread |
if __name__ == "__main__": |
print("Model loading....") |
kws_model = WavLMForSequenceClassification.from_pretrained('nguyenvulebinh/heyarmar') |
feature_extractor = AutoFeatureExtractor.from_pretrained('nguyenvulebinh/heyarmar') |
print("Model loaded....") |
# file_wave = './99.wav' |
# wav, rate = torchaudio.load(file_wave) |
# input_features = feature_extractor.pad([{"input_values": item} for item in wav], padding=True, return_tensors="pt") |
# output = kws_model(**input_features) |
# print(output.logits.softmax(dim=-1)) |
obj_decode = RealtimeDecoder(kws_model) |
recording_threads = obj_decode.start_recording() |
kws_decode_thread = obj_decode.start_decoding() |
for thread in recording_threads: |
if thread is not None: |
thread.start() |
kws_decode_thread.start() |
for thread in recording_threads: |
if thread is not None: |
thread.join() |
kws_decode_thread.join() |
``` |