#!/usr/bin/python3 # -*- coding: utf-8 -*- from enum import Enum from functools import lru_cache import os import huggingface_hub import sherpa class EnumDecodingMethod(Enum): greedy_search = "greedy_search" modified_beam_search = "modified_beam_search" class EnumRecognizerType(Enum): sherpa_offline_recognizer = "sherpa.OfflineRecognizer" sherpa_online_recognizer = "sherpa.OnlineRecognizer" sherpa_onnx_offline_recognizer = "sherpa_onnx.OfflineRecognizer" sherpa_onnx_online_recognizer = "sherpa_onnx.OnlineRecognizer" model_map = { "Chinese": [ { "repo_id": "csukuangfj/wenet-chinese-model", "model_file": "final.zip", "tokens_file": "units.txt", "subfolder": ".", } ] } def download_model(repo_id: str, nn_model_filename: str, tokens_filename: str, sub_folder: str, local_model_dir: str, ): nn_model_filename = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=nn_model_filename, subfolder=sub_folder, local_dir=local_model_dir, ) tokens_filename = huggingface_hub.hf_hub_download( repo_id=repo_id, filename=tokens_filename, subfolder=sub_folder, local_dir=local_model_dir, ) return nn_model_filename, tokens_filename @lru_cache(maxsize=10) def load_sherpa_offline_recognizer(nn_model_file: str, tokens_file: str, sample_rate: int = 16000, num_active_paths: int = 2, decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search, num_mel_bins: int = 80, frame_dither: int = 0, ): feat_config = sherpa.FeatureConfig() feat_config.fbank_opts.frame_opts.samp_freq = sample_rate feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins feat_config.fbank_opts.frame_opts.dither = frame_dither config = sherpa.OfflineRecognizerConfig( nn_model=nn_model_file, tokens=tokens_file, use_gpu=False, feat_config=feat_config, decoding_method=decoding_method, num_active_paths=num_active_paths, ) recognizer = sherpa.OfflineRecognizer(config) return recognizer def load_recognizer( repo_id: str, nn_model_filename: str, tokens_filename: str, sub_folder: str, local_model_dir: str, recognizer_type: EnumRecognizerType, decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search, ): if not os.path.exists(local_model_dir): download_model( repo_id=repo_id, nn_model_filename=nn_model_filename, tokens_filename=tokens_filename, sub_folder=sub_folder, local_model_dir=local_model_dir, ) return if __name__ == "__main__": pass