# -*- coding: utf-8 -*- """ @author:cb @contact:chenbo@bat100.net @time:2023/5/30 14:21 @filename:tokenization.py @software:PyCharm @description: """ import re from transformers import FSMTTokenizer as fsmt class FSMTTokenizer(fsmt): space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*') def moses_tokenize(self, text, lang): if lang not in self.cache_moses_tokenizer: moses_tokenizer = self.sm.MosesTokenizer(lang=lang) self.cache_moses_tokenizer[lang] = moses_tokenizer return self.cache_moses_tokenizer[lang].tokenize( text, aggressive_dash_splits=True, return_str=False, escape=False ) def _switch_to_input_mode(self): self.lang_prefix, self.lang_prefix_id = 'en', 64812 def _switch_to_target_mode(self): self.lang_prefix, self.lang_prefix_id = 'zh', 64870 def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A FAIRSEQ Transformer sequence has the following format: - single sequence: ` X ` - pair of sequences: ` A B ` Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ sep = [self.sep_token_id] token_ids_0 = [self.lang_prefix_id] + token_ids_0 # no bos used in fairseq if token_ids_1 is None: return token_ids_0 + sep return token_ids_0 + sep + token_ids_1 + sep def moses_pipeline(self, text, lang): text = self.moses_punct_norm(text, lang) return text def _tokenize(self, text, lang="en", bypass_tokenizer=False): """ 原版FSMTTokenizer会把中文标点英文化,故重写 :param text: :param lang: :param bypass_tokenizer: :return: """ if self.do_lower_case: text = text.lower() if bypass_tokenizer: text = text.split() else: text = self.moses_pipeline(text, lang=self.lang_prefix) text = self.moses_tokenize(text, lang=self.lang_prefix) split_tokens = [] for token in text: if token: split_tokens.extend(list(self.bpe(token).split(" "))) return split_tokens def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): """ :param text: :param is_split_into_words: :param kwargs: :return: """ if kwargs.get('src', True): self._switch_to_input_mode() else: self._switch_to_target_mode() return super(FSMTTokenizer, self).prepare_for_tokenization(text, is_split_into_words=False, **kwargs) def convert_tokens_to_string(self, tokens): """ 删除非英文字母前后的空格,业务上处理更合适 :param tokens: :return: """ tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens) tokens = FSMTTokenizer.space_re.sub('', tokens) return tokens if __name__ == '__main__': tokenizer = FSMTTokenizer.from_pretrained(r'./') r = tokenizer.tokenize(['hello', 'hi']) print(r)