asr / models.py
HoneyTian's picture
update
2267fac
raw
history blame
3.21 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from enum import Enum
from functools import lru_cache
import os
import huggingface_hub
import sherpa
class EnumDecodingMethod(Enum):
greedy_search = "greedy_search"
modified_beam_search = "modified_beam_search"
class EnumRecognizerType(Enum):
sherpa_offline_recognizer = "sherpa.OfflineRecognizer"
sherpa_online_recognizer = "sherpa.OnlineRecognizer"
sherpa_onnx_offline_recognizer = "sherpa_onnx.OfflineRecognizer"
sherpa_onnx_online_recognizer = "sherpa_onnx.OnlineRecognizer"
model_map = {
"Chinese": [
{
"repo_id": "csukuangfj/wenet-chinese-model",
"model_file": "final.zip",
"tokens_file": "units.txt",
"subfolder": ".",
}
]
}
def download_model(repo_id: str,
nn_model_filename: str,
tokens_filename: str,
sub_folder: str,
local_model_dir: str,
):
nn_model_filename = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=nn_model_filename,
subfolder=sub_folder,
local_dir=local_model_dir,
)
tokens_filename = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=tokens_filename,
subfolder=sub_folder,
local_dir=local_model_dir,
)
return nn_model_filename, tokens_filename
@lru_cache(maxsize=10)
def load_sherpa_offline_recognizer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
num_active_paths: int = 2,
decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search,
num_mel_bins: int = 80,
frame_dither: int = 0,
):
feat_config = sherpa.FeatureConfig()
feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
feat_config.fbank_opts.frame_opts.dither = frame_dither
config = sherpa.OfflineRecognizerConfig(
nn_model=nn_model_file,
tokens=tokens_file,
use_gpu=False,
feat_config=feat_config,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
recognizer = sherpa.OfflineRecognizer(config)
return recognizer
def load_recognizer(
repo_id: str,
nn_model_filename: str,
tokens_filename: str,
sub_folder: str,
local_model_dir: str,
recognizer_type: EnumRecognizerType,
decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search,
):
if not os.path.exists(local_model_dir):
download_model(
repo_id=repo_id,
nn_model_filename=nn_model_filename,
tokens_filename=tokens_filename,
sub_folder=sub_folder,
local_model_dir=local_model_dir,
)
return
if __name__ == "__main__":
pass