Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
5.7 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
import torch
from modules.wenet_extractor.transducer.joint import TransducerJoint
from modules.wenet_extractor.transducer.predictor import (
ConvPredictor,
EmbeddingPredictor,
RNNPredictor,
)
from modules.wenet_extractor.transducer.transducer import Transducer
from modules.wenet_extractor.transformer.asr_model import ASRModel
from modules.wenet_extractor.transformer.cmvn import GlobalCMVN
from modules.wenet_extractor.transformer.ctc import CTC
from modules.wenet_extractor.transformer.decoder import (
BiTransformerDecoder,
TransformerDecoder,
)
from modules.wenet_extractor.transformer.encoder import (
ConformerEncoder,
TransformerEncoder,
)
from modules.wenet_extractor.squeezeformer.encoder import SqueezeformerEncoder
from modules.wenet_extractor.efficient_conformer.encoder import (
EfficientConformerEncoder,
)
from modules.wenet_extractor.paraformer.paraformer import Paraformer
from modules.wenet_extractor.cif.predictor import Predictor
from modules.wenet_extractor.utils.cmvn import load_cmvn
def init_model(configs):
if configs["cmvn_file"] is not None:
mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(), torch.from_numpy(istd).float()
)
else:
global_cmvn = None
input_dim = configs["input_dim"]
vocab_size = configs["output_dim"]
encoder_type = configs.get("encoder", "conformer")
decoder_type = configs.get("decoder", "bitransformer")
if encoder_type == "conformer":
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
)
elif encoder_type == "squeezeformer":
encoder = SqueezeformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
)
elif encoder_type == "efficientConformer":
encoder = EfficientConformerEncoder(
input_dim,
global_cmvn=global_cmvn,
**configs["encoder_conf"],
**(
configs["encoder_conf"]["efficient_conf"]
if "efficient_conf" in configs["encoder_conf"]
else {}
),
)
else:
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"]
)
if decoder_type == "transformer":
decoder = TransformerDecoder(
vocab_size, encoder.output_size(), **configs["decoder_conf"]
)
else:
assert 0.0 < configs["model_conf"]["reverse_weight"] < 1.0
assert configs["decoder_conf"]["r_num_blocks"] > 0
decoder = BiTransformerDecoder(
vocab_size, encoder.output_size(), **configs["decoder_conf"]
)
ctc = CTC(vocab_size, encoder.output_size())
# Init joint CTC/Attention or Transducer model
if "predictor" in configs:
predictor_type = configs.get("predictor", "rnn")
if predictor_type == "rnn":
predictor = RNNPredictor(vocab_size, **configs["predictor_conf"])
elif predictor_type == "embedding":
predictor = EmbeddingPredictor(vocab_size, **configs["predictor_conf"])
configs["predictor_conf"]["output_size"] = configs["predictor_conf"][
"embed_size"
]
elif predictor_type == "conv":
predictor = ConvPredictor(vocab_size, **configs["predictor_conf"])
configs["predictor_conf"]["output_size"] = configs["predictor_conf"][
"embed_size"
]
else:
raise NotImplementedError("only rnn, embedding and conv type support now")
configs["joint_conf"]["enc_output_size"] = configs["encoder_conf"][
"output_size"
]
configs["joint_conf"]["pred_output_size"] = configs["predictor_conf"][
"output_size"
]
joint = TransducerJoint(vocab_size, **configs["joint_conf"])
model = Transducer(
vocab_size=vocab_size,
blank=0,
predictor=predictor,
encoder=encoder,
attention_decoder=decoder,
joint=joint,
ctc=ctc,
**configs["model_conf"],
)
elif "paraformer" in configs:
predictor = Predictor(**configs["cif_predictor_conf"])
model = Paraformer(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs["model_conf"],
)
else:
model = ASRModel(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
lfmmi_dir=configs.get("lfmmi_dir", ""),
**configs["model_conf"],
)
return model