asr / decode.py
HoneyTian's picture
update
2267fac
raw
history blame
4.12 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Union, Tuple
import numpy as np
import sherpa
import sherpa_onnx
import torch
import torchaudio
import wave
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
:param wave_filename: Path to a wave file. It should be single channel and each sample should be 16-bit.
Its sample rate does not need to be 16kHz.
:return: Return a tuple containing:
signal: A 1-D array of dtype np.float32 containing the samples, which are normalized to the range [-1, 1].
sample_rate: sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth()
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def decode_offline_recognizer(recognizer: sherpa.OfflineRecognizer,
filename: str,
) -> str:
s = recognizer.create_stream()
s.accept_wave_file(filename)
recognizer.decode_stream(s)
text = s.result.text.strip()
return text.lower()
def decode_online_recognizer(recognizer: sherpa.OnlineRecognizer,
filename: str,
expected_sample_rate: int = 16000,
) -> str:
samples, actual_sample_rate = torchaudio.load(filename)
if expected_sample_rate != actual_sample_rate:
raise AssertionError(
"expected sample rate: {}, but: actually: {}".format(expected_sample_rate, actual_sample_rate)
)
samples = samples[0].contiguous()
s = recognizer.create_stream()
tail_padding = torch.zeros(int(expected_sample_rate * 0.3), dtype=torch.float32)
s.accept_waveform(expected_sample_rate, samples)
s.accept_waveform(expected_sample_rate, tail_padding)
s.input_finished()
while recognizer.is_ready(s):
recognizer.decode_stream(s)
text = recognizer.get_result(s).text
return text.strip().lower()
def decode_offline_recognizer_sherpa_onnx(recognizer: sherpa_onnx.OfflineRecognizer,
filename: str,
) -> str:
s = recognizer.create_stream()
samples, sample_rate = read_wave(filename)
s.accept_waveform(sample_rate, samples)
recognizer.decode_stream(s)
return s.result.text.lower()
def decode_online_recognizer_sherpa_onnx(recognizer: sherpa_onnx.OnlineRecognizer,
filename: str,
) -> str:
s = recognizer.create_stream()
samples, sample_rate = read_wave(filename)
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.3 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while recognizer.is_ready(s):
recognizer.decode_stream(s)
return recognizer.get_result(s).lower()
def decode_by_recognizer(
recognizer: Union[
sherpa.OfflineRecognizer,
sherpa.OnlineRecognizer,
sherpa_onnx.OfflineRecognizer,
sherpa_onnx.OnlineRecognizer,
],
filename: str,
) -> str:
if isinstance(recognizer, sherpa.OfflineRecognizer):
return decode_offline_recognizer(recognizer, filename)
elif isinstance(recognizer, sherpa.OnlineRecognizer):
return decode_online_recognizer(recognizer, filename)
elif isinstance(recognizer, sherpa_onnx.OfflineRecognizer):
return decode_offline_recognizer_sherpa_onnx(recognizer, filename)
elif isinstance(recognizer, sherpa_onnx.OnlineRecognizer):
return decode_online_recognizer_sherpa_onnx(recognizer, filename)
else:
raise ValueError(f"Unknown recognizer type {type(recognizer)}")
if __name__ == "__main__":
pass