Spaces:
Running
Running
File size: 3,209 Bytes
2267fac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from enum import Enum
from functools import lru_cache
import os
import huggingface_hub
import sherpa
class EnumDecodingMethod(Enum):
greedy_search = "greedy_search"
modified_beam_search = "modified_beam_search"
class EnumRecognizerType(Enum):
sherpa_offline_recognizer = "sherpa.OfflineRecognizer"
sherpa_online_recognizer = "sherpa.OnlineRecognizer"
sherpa_onnx_offline_recognizer = "sherpa_onnx.OfflineRecognizer"
sherpa_onnx_online_recognizer = "sherpa_onnx.OnlineRecognizer"
model_map = {
"Chinese": [
{
"repo_id": "csukuangfj/wenet-chinese-model",
"model_file": "final.zip",
"tokens_file": "units.txt",
"subfolder": ".",
}
]
}
def download_model(repo_id: str,
nn_model_filename: str,
tokens_filename: str,
sub_folder: str,
local_model_dir: str,
):
nn_model_filename = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=nn_model_filename,
subfolder=sub_folder,
local_dir=local_model_dir,
)
tokens_filename = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=tokens_filename,
subfolder=sub_folder,
local_dir=local_model_dir,
)
return nn_model_filename, tokens_filename
@lru_cache(maxsize=10)
def load_sherpa_offline_recognizer(nn_model_file: str,
tokens_file: str,
sample_rate: int = 16000,
num_active_paths: int = 2,
decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search,
num_mel_bins: int = 80,
frame_dither: int = 0,
):
feat_config = sherpa.FeatureConfig()
feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
feat_config.fbank_opts.frame_opts.dither = frame_dither
config = sherpa.OfflineRecognizerConfig(
nn_model=nn_model_file,
tokens=tokens_file,
use_gpu=False,
feat_config=feat_config,
decoding_method=decoding_method,
num_active_paths=num_active_paths,
)
recognizer = sherpa.OfflineRecognizer(config)
return recognizer
def load_recognizer(
repo_id: str,
nn_model_filename: str,
tokens_filename: str,
sub_folder: str,
local_model_dir: str,
recognizer_type: EnumRecognizerType,
decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search,
):
if not os.path.exists(local_model_dir):
download_model(
repo_id=repo_id,
nn_model_filename=nn_model_filename,
tokens_filename=tokens_filename,
sub_folder=sub_folder,
local_model_dir=local_model_dir,
)
return
if __name__ == "__main__":
pass
|