|
|
|
""" |
|
@author:cb |
|
@contact:[email protected] |
|
@time:2023/5/30 14:21 |
|
@filename:tokenization.py |
|
@software:PyCharm |
|
@description: |
|
""" |
|
import re |
|
from transformers import FSMTTokenizer as fsmt |
|
|
|
|
|
class FSMTTokenizer(fsmt): |
|
def __init__(self, *args, **kwargs): |
|
super(FSMTTokenizer, self).__init__(*args, **kwargs) |
|
self.space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*') |
|
self.reversal = False |
|
|
|
def moses_tokenize(self, text, lang): |
|
if lang not in self.cache_moses_tokenizer: |
|
moses_tokenizer = self.sm.MosesTokenizer(lang=lang) |
|
self.cache_moses_tokenizer[lang] = moses_tokenizer |
|
return self.cache_moses_tokenizer[lang].tokenize( |
|
text, aggressive_dash_splits=True, return_str=False, escape=False |
|
) |
|
|
|
def _switch_to_input_mode(self): |
|
if self.reversal: |
|
self.lang_prefix, self.lang_prefix_id = 'zh', 64870 |
|
else: |
|
self.lang_prefix, self.lang_prefix_id = 'en', 64812 |
|
|
|
def _switch_to_target_mode(self): |
|
if self.reversal: |
|
self.lang_prefix, self.lang_prefix_id = 'en', 64812 |
|
else: |
|
self.lang_prefix, self.lang_prefix_id = 'zh', 64870 |
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
|
""" |
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and |
|
adding special tokens. A FAIRSEQ Transformer sequence has the following format: |
|
|
|
- single sequence: `<s> X </s>` |
|
- pair of sequences: `<s> A </s> B </s>` |
|
|
|
Args: |
|
token_ids_0 (`List[int]`): |
|
List of IDs to which the special tokens will be added. |
|
token_ids_1 (`List[int]`, *optional*): |
|
Optional second list of IDs for sequence pairs. |
|
|
|
Returns: |
|
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. |
|
""" |
|
sep = [self.sep_token_id] |
|
token_ids_0 = [self.lang_prefix_id] + token_ids_0 |
|
|
|
if token_ids_1 is None: |
|
return token_ids_0 + sep |
|
return token_ids_0 + sep + token_ids_1 + sep |
|
|
|
def moses_pipeline(self, text, lang): |
|
text = self.moses_punct_norm(text, lang) |
|
return text |
|
|
|
def _tokenize(self, text, lang="en", bypass_tokenizer=False): |
|
""" |
|
原版FSMTTokenizer会把中文标点英文化,故重写 |
|
:param text: |
|
:param lang: |
|
:param bypass_tokenizer: |
|
:return: |
|
""" |
|
if self.do_lower_case: |
|
text = text.lower() |
|
if bypass_tokenizer: |
|
text = text.split() |
|
else: |
|
text = self.moses_pipeline(text, lang=self.lang_prefix) |
|
text = self.moses_tokenize(text, lang=self.lang_prefix) |
|
|
|
split_tokens = [] |
|
for token in text: |
|
if token: |
|
split_tokens.extend(list(self.bpe(token).split(" "))) |
|
|
|
return split_tokens |
|
|
|
def convert_tokens_to_string(self, tokens): |
|
""" |
|
删除非英文字母前后的空格,业务上处理更合适 |
|
:param tokens: |
|
:return: |
|
""" |
|
tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens) |
|
tokens = self.space_re.sub('', tokens) |
|
return tokens |
|
|
|
|
|
if __name__ == '__main__': |
|
tokenizer = FSMTTokenizer.from_pretrained(r'./') |
|
r = tokenizer(['hello'], text_target=['你好朋友']) |
|
print(r) |
|
tokenizer.reversal = True |
|
r = tokenizer(['你好朋友'], text_target=['hello']) |
|
|
|
|
|
print(r) |