Spaces:
Runtime error
Runtime error
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
import torch | |
from modules.wenet_extractor.transducer.joint import TransducerJoint | |
from modules.wenet_extractor.transducer.predictor import ( | |
ConvPredictor, | |
EmbeddingPredictor, | |
RNNPredictor, | |
) | |
from modules.wenet_extractor.transducer.transducer import Transducer | |
from modules.wenet_extractor.transformer.asr_model import ASRModel | |
from modules.wenet_extractor.transformer.cmvn import GlobalCMVN | |
from modules.wenet_extractor.transformer.ctc import CTC | |
from modules.wenet_extractor.transformer.decoder import ( | |
BiTransformerDecoder, | |
TransformerDecoder, | |
) | |
from modules.wenet_extractor.transformer.encoder import ( | |
ConformerEncoder, | |
TransformerEncoder, | |
) | |
from modules.wenet_extractor.squeezeformer.encoder import SqueezeformerEncoder | |
from modules.wenet_extractor.efficient_conformer.encoder import ( | |
EfficientConformerEncoder, | |
) | |
from modules.wenet_extractor.paraformer.paraformer import Paraformer | |
from modules.wenet_extractor.cif.predictor import Predictor | |
from modules.wenet_extractor.utils.cmvn import load_cmvn | |
def init_model(configs): | |
if configs["cmvn_file"] is not None: | |
mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"]) | |
global_cmvn = GlobalCMVN( | |
torch.from_numpy(mean).float(), torch.from_numpy(istd).float() | |
) | |
else: | |
global_cmvn = None | |
input_dim = configs["input_dim"] | |
vocab_size = configs["output_dim"] | |
encoder_type = configs.get("encoder", "conformer") | |
decoder_type = configs.get("decoder", "bitransformer") | |
if encoder_type == "conformer": | |
encoder = ConformerEncoder( | |
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] | |
) | |
elif encoder_type == "squeezeformer": | |
encoder = SqueezeformerEncoder( | |
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] | |
) | |
elif encoder_type == "efficientConformer": | |
encoder = EfficientConformerEncoder( | |
input_dim, | |
global_cmvn=global_cmvn, | |
**configs["encoder_conf"], | |
**( | |
configs["encoder_conf"]["efficient_conf"] | |
if "efficient_conf" in configs["encoder_conf"] | |
else {} | |
), | |
) | |
else: | |
encoder = TransformerEncoder( | |
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] | |
) | |
if decoder_type == "transformer": | |
decoder = TransformerDecoder( | |
vocab_size, encoder.output_size(), **configs["decoder_conf"] | |
) | |
else: | |
assert 0.0 < configs["model_conf"]["reverse_weight"] < 1.0 | |
assert configs["decoder_conf"]["r_num_blocks"] > 0 | |
decoder = BiTransformerDecoder( | |
vocab_size, encoder.output_size(), **configs["decoder_conf"] | |
) | |
ctc = CTC(vocab_size, encoder.output_size()) | |
# Init joint CTC/Attention or Transducer model | |
if "predictor" in configs: | |
predictor_type = configs.get("predictor", "rnn") | |
if predictor_type == "rnn": | |
predictor = RNNPredictor(vocab_size, **configs["predictor_conf"]) | |
elif predictor_type == "embedding": | |
predictor = EmbeddingPredictor(vocab_size, **configs["predictor_conf"]) | |
configs["predictor_conf"]["output_size"] = configs["predictor_conf"][ | |
"embed_size" | |
] | |
elif predictor_type == "conv": | |
predictor = ConvPredictor(vocab_size, **configs["predictor_conf"]) | |
configs["predictor_conf"]["output_size"] = configs["predictor_conf"][ | |
"embed_size" | |
] | |
else: | |
raise NotImplementedError("only rnn, embedding and conv type support now") | |
configs["joint_conf"]["enc_output_size"] = configs["encoder_conf"][ | |
"output_size" | |
] | |
configs["joint_conf"]["pred_output_size"] = configs["predictor_conf"][ | |
"output_size" | |
] | |
joint = TransducerJoint(vocab_size, **configs["joint_conf"]) | |
model = Transducer( | |
vocab_size=vocab_size, | |
blank=0, | |
predictor=predictor, | |
encoder=encoder, | |
attention_decoder=decoder, | |
joint=joint, | |
ctc=ctc, | |
**configs["model_conf"], | |
) | |
elif "paraformer" in configs: | |
predictor = Predictor(**configs["cif_predictor_conf"]) | |
model = Paraformer( | |
vocab_size=vocab_size, | |
encoder=encoder, | |
decoder=decoder, | |
ctc=ctc, | |
predictor=predictor, | |
**configs["model_conf"], | |
) | |
else: | |
model = ASRModel( | |
vocab_size=vocab_size, | |
encoder=encoder, | |
decoder=decoder, | |
ctc=ctc, | |
lfmmi_dir=configs.get("lfmmi_dir", ""), | |
**configs["model_conf"], | |
) | |
return model | |