File size: 5,700 Bytes
c968fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).

# ## Citations

# ```bibtex
# @inproceedings{yao2021wenet,
#   title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
#   author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
#   booktitle={Proc. Interspeech},
#   year={2021},
#   address={Brno, Czech Republic },
#   organization={IEEE}
# }

# @article{zhang2022wenet,
#   title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
#   author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
#   journal={arXiv preprint arXiv:2203.15455},
#   year={2022}
# }
#

import torch
from modules.wenet_extractor.transducer.joint import TransducerJoint
from modules.wenet_extractor.transducer.predictor import (
    ConvPredictor,
    EmbeddingPredictor,
    RNNPredictor,
)
from modules.wenet_extractor.transducer.transducer import Transducer
from modules.wenet_extractor.transformer.asr_model import ASRModel
from modules.wenet_extractor.transformer.cmvn import GlobalCMVN
from modules.wenet_extractor.transformer.ctc import CTC
from modules.wenet_extractor.transformer.decoder import (
    BiTransformerDecoder,
    TransformerDecoder,
)
from modules.wenet_extractor.transformer.encoder import (
    ConformerEncoder,
    TransformerEncoder,
)
from modules.wenet_extractor.squeezeformer.encoder import SqueezeformerEncoder
from modules.wenet_extractor.efficient_conformer.encoder import (
    EfficientConformerEncoder,
)
from modules.wenet_extractor.paraformer.paraformer import Paraformer
from modules.wenet_extractor.cif.predictor import Predictor
from modules.wenet_extractor.utils.cmvn import load_cmvn


def init_model(configs):
    if configs["cmvn_file"] is not None:
        mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"])
        global_cmvn = GlobalCMVN(
            torch.from_numpy(mean).float(), torch.from_numpy(istd).float()
        )
    else:
        global_cmvn = None

    input_dim = configs["input_dim"]
    vocab_size = configs["output_dim"]

    encoder_type = configs.get("encoder", "conformer")
    decoder_type = configs.get("decoder", "bitransformer")

    if encoder_type == "conformer":
        encoder = ConformerEncoder(
            input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
        )
    elif encoder_type == "squeezeformer":
        encoder = SqueezeformerEncoder(
            input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
        )
    elif encoder_type == "efficientConformer":
        encoder = EfficientConformerEncoder(
            input_dim,
            global_cmvn=global_cmvn,
            **configs["encoder_conf"],
            **(
                configs["encoder_conf"]["efficient_conf"]
                if "efficient_conf" in configs["encoder_conf"]
                else {}
            ),
        )
    else:
        encoder = TransformerEncoder(
            input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
        )
    if decoder_type == "transformer":
        decoder = TransformerDecoder(
            vocab_size, encoder.output_size(), **configs["decoder_conf"]
        )
    else:
        assert 0.0 < configs["model_conf"]["reverse_weight"] < 1.0
        assert configs["decoder_conf"]["r_num_blocks"] > 0
        decoder = BiTransformerDecoder(
            vocab_size, encoder.output_size(), **configs["decoder_conf"]
        )
    ctc = CTC(vocab_size, encoder.output_size())

    # Init joint CTC/Attention or Transducer model
    if "predictor" in configs:
        predictor_type = configs.get("predictor", "rnn")
        if predictor_type == "rnn":
            predictor = RNNPredictor(vocab_size, **configs["predictor_conf"])
        elif predictor_type == "embedding":
            predictor = EmbeddingPredictor(vocab_size, **configs["predictor_conf"])
            configs["predictor_conf"]["output_size"] = configs["predictor_conf"][
                "embed_size"
            ]
        elif predictor_type == "conv":
            predictor = ConvPredictor(vocab_size, **configs["predictor_conf"])
            configs["predictor_conf"]["output_size"] = configs["predictor_conf"][
                "embed_size"
            ]
        else:
            raise NotImplementedError("only rnn, embedding and conv type support now")
        configs["joint_conf"]["enc_output_size"] = configs["encoder_conf"][
            "output_size"
        ]
        configs["joint_conf"]["pred_output_size"] = configs["predictor_conf"][
            "output_size"
        ]
        joint = TransducerJoint(vocab_size, **configs["joint_conf"])
        model = Transducer(
            vocab_size=vocab_size,
            blank=0,
            predictor=predictor,
            encoder=encoder,
            attention_decoder=decoder,
            joint=joint,
            ctc=ctc,
            **configs["model_conf"],
        )
    elif "paraformer" in configs:
        predictor = Predictor(**configs["cif_predictor_conf"])
        model = Paraformer(
            vocab_size=vocab_size,
            encoder=encoder,
            decoder=decoder,
            ctc=ctc,
            predictor=predictor,
            **configs["model_conf"],
        )
    else:
        model = ASRModel(
            vocab_size=vocab_size,
            encoder=encoder,
            decoder=decoder,
            ctc=ctc,
            lfmmi_dir=configs.get("lfmmi_dir", ""),
            **configs["model_conf"],
        )
    return model