File size: 3,209 Bytes
2267fac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from enum import Enum
from functools import lru_cache
import os

import huggingface_hub
import sherpa


class EnumDecodingMethod(Enum):
    greedy_search = "greedy_search"
    modified_beam_search = "modified_beam_search"


class EnumRecognizerType(Enum):
    sherpa_offline_recognizer = "sherpa.OfflineRecognizer"
    sherpa_online_recognizer = "sherpa.OnlineRecognizer"
    sherpa_onnx_offline_recognizer = "sherpa_onnx.OfflineRecognizer"
    sherpa_onnx_online_recognizer = "sherpa_onnx.OnlineRecognizer"


model_map = {
    "Chinese": [
        {
            "repo_id": "csukuangfj/wenet-chinese-model",
            "model_file": "final.zip",
            "tokens_file": "units.txt",
            "subfolder": ".",
        }
    ]
}


def download_model(repo_id: str,
                   nn_model_filename: str,
                   tokens_filename: str,
                   sub_folder: str,
                   local_model_dir: str,
                   ):

    nn_model_filename = huggingface_hub.hf_hub_download(
        repo_id=repo_id,
        filename=nn_model_filename,
        subfolder=sub_folder,
        local_dir=local_model_dir,
    )

    tokens_filename = huggingface_hub.hf_hub_download(
        repo_id=repo_id,
        filename=tokens_filename,
        subfolder=sub_folder,
        local_dir=local_model_dir,
    )
    return nn_model_filename, tokens_filename


@lru_cache(maxsize=10)
def load_sherpa_offline_recognizer(nn_model_file: str,
                                   tokens_file: str,
                                   sample_rate: int = 16000,
                                   num_active_paths: int = 2,
                                   decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search,
                                   num_mel_bins: int = 80,
                                   frame_dither: int = 0,
                                   ):
    feat_config = sherpa.FeatureConfig()
    feat_config.fbank_opts.frame_opts.samp_freq = sample_rate
    feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins
    feat_config.fbank_opts.frame_opts.dither = frame_dither

    config = sherpa.OfflineRecognizerConfig(
        nn_model=nn_model_file,
        tokens=tokens_file,
        use_gpu=False,
        feat_config=feat_config,
        decoding_method=decoding_method,
        num_active_paths=num_active_paths,
    )

    recognizer = sherpa.OfflineRecognizer(config)
    return recognizer


def load_recognizer(
                    repo_id: str,
                    nn_model_filename: str,
                    tokens_filename: str,
                    sub_folder: str,
                    local_model_dir: str,
                    recognizer_type: EnumRecognizerType,
                    decoding_method: EnumDecodingMethod = EnumDecodingMethod.greedy_search,
                    ):
    if not os.path.exists(local_model_dir):
        download_model(
            repo_id=repo_id,
            nn_model_filename=nn_model_filename,
            tokens_filename=tokens_filename,
            sub_folder=sub_folder,
            local_model_dir=local_model_dir,
        )

    return


if __name__ == "__main__":
    pass