File size: 3,696 Bytes
38524c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: utf-8 -*-
"""
@author:cb
@contact:[email protected]
@time:2023/5/30 14:21
@filename:tokenization.py
@software:PyCharm
@description:
"""
import re
from transformers import FSMTTokenizer as fsmt


class FSMTTokenizer(fsmt):
    def __init__(self, *args, **kwargs):
        super(FSMTTokenizer, self).__init__(*args, **kwargs)
        self.space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*')
        self.reversal = False

    def moses_tokenize(self, text, lang):
        if lang not in self.cache_moses_tokenizer:
            moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
            self.cache_moses_tokenizer[lang] = moses_tokenizer
        return self.cache_moses_tokenizer[lang].tokenize(
            text, aggressive_dash_splits=True, return_str=False, escape=False
        )

    def _switch_to_input_mode(self):
        if self.reversal:
            self.lang_prefix, self.lang_prefix_id = 'zh', 64870
        else:
            self.lang_prefix, self.lang_prefix_id = 'en', 64812

    def _switch_to_target_mode(self):
        if self.reversal:
            self.lang_prefix, self.lang_prefix_id = 'en', 64812
        else:
            self.lang_prefix, self.lang_prefix_id = 'zh', 64870

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A FAIRSEQ Transformer sequence has the following format:

        - single sequence: `<s> X </s>`
        - pair of sequences: `<s> A </s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        sep = [self.sep_token_id]
        token_ids_0 = [self.lang_prefix_id] + token_ids_0
        # no bos used in fairseq
        if token_ids_1 is None:
            return token_ids_0 + sep
        return token_ids_0 + sep + token_ids_1 + sep

    def moses_pipeline(self, text, lang):
        text = self.moses_punct_norm(text, lang)
        return text

    def _tokenize(self, text, lang="en", bypass_tokenizer=False):
        """
        原版FSMTTokenizer会把中文标点英文化,故重写
        :param text:
        :param lang:
        :param bypass_tokenizer:
        :return:
        """
        if self.do_lower_case:
            text = text.lower()
        if bypass_tokenizer:
            text = text.split()
        else:
            text = self.moses_pipeline(text, lang=self.lang_prefix)
            text = self.moses_tokenize(text, lang=self.lang_prefix)

        split_tokens = []
        for token in text:
            if token:
                split_tokens.extend(list(self.bpe(token).split(" ")))

        return split_tokens

    def convert_tokens_to_string(self, tokens):
        """
        删除非英文字母前后的空格,业务上处理更合适
        :param tokens:
        :return:
        """
        tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens)
        tokens = self.space_re.sub('', tokens)
        return tokens


if __name__ == '__main__':
    tokenizer = FSMTTokenizer.from_pretrained(r'./')
    r = tokenizer(['hello'], text_target=['你好朋友'])
    print(r)
    tokenizer.reversal = True
    r = tokenizer(['你好朋友'], text_target=['hello'])
    # # r['input_ids'] += r['labels']
    # # r['labels'] += r['input_ids']
    print(r)