# -*- coding: utf-8 -*- """ @author:cb @contact:chenbo@bat100.net @time:2023/5/30 14:21 @filename:tokenization.py @software:PyCharm @description: """ import re from transformers import FSMTTokenizer as fsmt class FSMTTokenizer(fsmt): def __init__(self, *args, **kwargs): super(FSMTTokenizer, self).__init__(*args, **kwargs) self.space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*') self.reversal = False def moses_tokenize(self, text, lang): if lang not in self.cache_moses_tokenizer: moses_tokenizer = self.sm.MosesTokenizer(lang=lang) self.cache_moses_tokenizer[lang] = moses_tokenizer return self.cache_moses_tokenizer[lang].tokenize( text, aggressive_dash_splits=True, return_str=False, escape=False ) def _switch_to_input_mode(self): if self.reversal: self.lang_prefix, self.lang_prefix_id = 'zh', 64870 else: self.lang_prefix, self.lang_prefix_id = 'en', 64812 def _switch_to_target_mode(self): if self.reversal: self.lang_prefix, self.lang_prefix_id = 'en', 64812 else: self.lang_prefix, self.lang_prefix_id = 'zh', 64870 def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A FAIRSEQ Transformer sequence has the following format: - single sequence: ` X ` - pair of sequences: ` A B ` Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ sep = [self.sep_token_id] token_ids_0 = [self.lang_prefix_id] + token_ids_0 # no bos used in fairseq if token_ids_1 is None: return token_ids_0 + sep return token_ids_0 + sep + token_ids_1 + sep def moses_pipeline(self, text, lang): text = self.moses_punct_norm(text, lang) return text def _tokenize(self, text, lang="en", bypass_tokenizer=False): """ 原版FSMTTokenizer会把中文标点英文化,故重写 :param text: :param lang: :param bypass_tokenizer: :return: """ if self.do_lower_case: text = text.lower() if bypass_tokenizer: text = text.split() else: text = self.moses_pipeline(text, lang=self.lang_prefix) text = self.moses_tokenize(text, lang=self.lang_prefix) split_tokens = [] for token in text: if token: split_tokens.extend(list(self.bpe(token).split(" "))) return split_tokens def convert_tokens_to_string(self, tokens): """ 删除非英文字母前后的空格,业务上处理更合适 :param tokens: :return: """ tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens) tokens = self.space_re.sub('', tokens) return tokens if __name__ == '__main__': tokenizer = FSMTTokenizer.from_pretrained(r'./') r = tokenizer(['hello'], text_target=['你好朋友']) print(r) tokenizer.reversal = True r = tokenizer(['你好朋友'], text_target=['hello']) # # r['input_ids'] += r['labels'] # # r['labels'] += r['input_ids'] print(r)