Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from typing import Union, Tuple | |
import numpy as np | |
import sherpa | |
import sherpa_onnx | |
import torch | |
import torchaudio | |
import wave | |
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: | |
""" | |
:param wave_filename: Path to a wave file. It should be single channel and each sample should be 16-bit. | |
Its sample rate does not need to be 16kHz. | |
:return: Return a tuple containing: | |
signal: A 1-D array of dtype np.float32 containing the samples, which are normalized to the range [-1, 1]. | |
sample_rate: sample rate of the wave file | |
""" | |
with wave.open(wave_filename) as f: | |
assert f.getnchannels() == 1, f.getnchannels() | |
assert f.getsampwidth() == 2, f.getsampwidth() | |
num_samples = f.getnframes() | |
samples = f.readframes(num_samples) | |
samples_int16 = np.frombuffer(samples, dtype=np.int16) | |
samples_float32 = samples_int16.astype(np.float32) | |
samples_float32 = samples_float32 / 32768 | |
return samples_float32, f.getframerate() | |
def decode_offline_recognizer(recognizer: sherpa.OfflineRecognizer, | |
filename: str, | |
) -> str: | |
s = recognizer.create_stream() | |
s.accept_wave_file(filename) | |
recognizer.decode_stream(s) | |
text = s.result.text.strip() | |
return text.lower() | |
def decode_online_recognizer(recognizer: sherpa.OnlineRecognizer, | |
filename: str, | |
expected_sample_rate: int = 16000, | |
) -> str: | |
samples, actual_sample_rate = torchaudio.load(filename) | |
if expected_sample_rate != actual_sample_rate: | |
raise AssertionError( | |
"expected sample rate: {}, but: actually: {}".format(expected_sample_rate, actual_sample_rate) | |
) | |
samples = samples[0].contiguous() | |
s = recognizer.create_stream() | |
tail_padding = torch.zeros(int(expected_sample_rate * 0.3), dtype=torch.float32) | |
s.accept_waveform(expected_sample_rate, samples) | |
s.accept_waveform(expected_sample_rate, tail_padding) | |
s.input_finished() | |
while recognizer.is_ready(s): | |
recognizer.decode_stream(s) | |
text = recognizer.get_result(s).text | |
return text.strip().lower() | |
def decode_offline_recognizer_sherpa_onnx(recognizer: sherpa_onnx.OfflineRecognizer, | |
filename: str, | |
) -> str: | |
s = recognizer.create_stream() | |
samples, sample_rate = read_wave(filename) | |
s.accept_waveform(sample_rate, samples) | |
recognizer.decode_stream(s) | |
return s.result.text.lower() | |
def decode_online_recognizer_sherpa_onnx(recognizer: sherpa_onnx.OnlineRecognizer, | |
filename: str, | |
) -> str: | |
s = recognizer.create_stream() | |
samples, sample_rate = read_wave(filename) | |
s.accept_waveform(sample_rate, samples) | |
tail_paddings = np.zeros(int(0.3 * sample_rate), dtype=np.float32) | |
s.accept_waveform(sample_rate, tail_paddings) | |
s.input_finished() | |
while recognizer.is_ready(s): | |
recognizer.decode_stream(s) | |
return recognizer.get_result(s).lower() | |
def decode_by_recognizer( | |
recognizer: Union[ | |
sherpa.OfflineRecognizer, | |
sherpa.OnlineRecognizer, | |
sherpa_onnx.OfflineRecognizer, | |
sherpa_onnx.OnlineRecognizer, | |
], | |
filename: str, | |
) -> str: | |
if isinstance(recognizer, sherpa.OfflineRecognizer): | |
return decode_offline_recognizer(recognizer, filename) | |
elif isinstance(recognizer, sherpa.OnlineRecognizer): | |
return decode_online_recognizer(recognizer, filename) | |
elif isinstance(recognizer, sherpa_onnx.OfflineRecognizer): | |
return decode_offline_recognizer_sherpa_onnx(recognizer, filename) | |
elif isinstance(recognizer, sherpa_onnx.OnlineRecognizer): | |
return decode_online_recognizer_sherpa_onnx(recognizer, filename) | |
else: | |
raise ValueError(f"Unknown recognizer type {type(recognizer)}") | |
if __name__ == "__main__": | |
pass | |