Upload tokenization_dart.py
Browse files- tokenization_dart.py +54 -60
tokenization_dart.py
CHANGED
@@ -1,26 +1,60 @@
|
|
1 |
import logging
|
2 |
-
import os
|
3 |
import json
|
4 |
-
from typing import
|
5 |
from pydantic.dataclasses import dataclass
|
6 |
|
7 |
-
import numpy as np
|
8 |
-
from numpy.typing import NDArray
|
9 |
-
|
10 |
from transformers import PreTrainedTokenizerFast
|
11 |
from tokenizers.decoders import Decoder
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
-
VOCAB_FILES_NAMES = {
|
16 |
-
"category_config": "category_config.json",
|
17 |
-
}
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
}
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
@dataclass
|
@@ -71,57 +105,17 @@ class DartDecoder:
|
|
71 |
class DartTokenizer(PreTrainedTokenizerFast):
|
72 |
"""Dart tokenizer"""
|
73 |
|
74 |
-
|
75 |
-
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
76 |
-
|
77 |
-
def __init__(self, category_config, **kwargs):
|
78 |
super().__init__(**kwargs)
|
79 |
|
80 |
self._tokenizer.decoder = Decoder.custom( # type: ignore
|
81 |
DartDecoder(list(self.get_added_vocab().keys()))
|
82 |
)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
tokens,
|
90 |
-
) in self.category_config.category_to_token_ids.items():
|
91 |
-
self._id_to_category_map[tokens] = int(category_id)
|
92 |
-
|
93 |
-
def create_vocab_mask(self, value: int = 1):
|
94 |
-
"""Create an array of vocab size filled with specified value"""
|
95 |
-
return np.full(self.vocab_size, value).astype("uint8")
|
96 |
-
|
97 |
-
def get_token_ids_in_category(self, category_id: Union[int, str]):
|
98 |
-
"""Get token ids in the specified category"""
|
99 |
-
return self.category_config.category_to_token_ids[str(category_id)]
|
100 |
-
|
101 |
-
def get_category(self, category_id: Union[int, str]):
|
102 |
-
"""Get the specified category config"""
|
103 |
-
return self.category_config.categories[str(category_id)]
|
104 |
-
|
105 |
-
def convert_ids_to_category_ids(self, token_ids: Union[int, List[int]]):
|
106 |
-
"""Get the category ids of specified tokens"""
|
107 |
-
return self._id_to_category_map[token_ids]
|
108 |
-
|
109 |
-
def get_banned_tokens_mask(self, tokens: Union[str, List[str], int, List[int]]):
|
110 |
-
if isinstance(tokens, str):
|
111 |
-
tokens = [tokens]
|
112 |
-
elif isinstance(tokens, int):
|
113 |
-
tokens = [tokens]
|
114 |
-
elif isinstance(tokens, list):
|
115 |
-
tokens = [ # type: ignore
|
116 |
-
self.convert_tokens_to_ids(token) if isinstance(token, str) else token
|
117 |
-
for token in tokens
|
118 |
-
]
|
119 |
-
|
120 |
-
assert isinstance(tokens, list) and all(
|
121 |
-
[isinstance(token, int) for token in tokens]
|
122 |
-
)
|
123 |
-
|
124 |
-
mask = self.create_vocab_mask(value=1)
|
125 |
-
mask[tokens] = 0
|
126 |
|
127 |
-
return
|
|
|
1 |
import logging
|
|
|
2 |
import json
|
3 |
+
from typing import Dict, List
|
4 |
from pydantic.dataclasses import dataclass
|
5 |
|
|
|
|
|
|
|
6 |
from transformers import PreTrainedTokenizerFast
|
7 |
from tokenizers.decoders import Decoder
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
|
|
|
|
|
|
11 |
|
12 |
+
# fmt: off
|
13 |
+
# https://huggingface.co/docs/transformers/main/en/chat_templating
|
14 |
+
PROMPT_TEMPLATE = (
|
15 |
+
"{{ '<|bos|>' }}"
|
16 |
+
|
17 |
+
"{{ '<rating>' }}"
|
18 |
+
"{% if 'rating' not in messages or messages['rating'] is none %}"
|
19 |
+
"{{ 'rating:sfw, rating:general' }}"
|
20 |
+
"{% else %}"
|
21 |
+
"{{ messages['rating'] }}"
|
22 |
+
"{% endif %}"
|
23 |
+
"{{ '</rating>' }}"
|
24 |
+
|
25 |
+
"{{ '<copyright>' }}"
|
26 |
+
"{% if 'copyright' not in messages or messages['copyright'] is none %}"
|
27 |
+
"{{ '' }}"
|
28 |
+
"{% else %}"
|
29 |
+
"{{ messages['copyright'] }}"
|
30 |
+
"{% endif %}"
|
31 |
+
"{{ '</copyright>' }}"
|
32 |
+
|
33 |
+
"{{ '<character>' }}"
|
34 |
+
"{% if 'character' not in messages or messages['character'] is none %}"
|
35 |
+
"{{ '' }}"
|
36 |
+
"{% else %}"
|
37 |
+
"{{ messages['character'] }}"
|
38 |
+
"{% endif %}"
|
39 |
+
"{{ '</character>' }}"
|
40 |
+
|
41 |
+
"{{ '<general>' }}"
|
42 |
+
# length token
|
43 |
+
"{% if 'length' not in messages or messages['length'] is none %}"
|
44 |
+
"{{ '<|long|>' }}"
|
45 |
+
"{% else %}"
|
46 |
+
"{{ messages['length'] }}"
|
47 |
+
"{% endif %}"
|
48 |
+
|
49 |
+
# general token
|
50 |
+
"{% if 'general' not in messages or messages['general'] is none %}"
|
51 |
+
"{{ '' }}"
|
52 |
+
"{% else %}"
|
53 |
+
"{{ messages['general'] }}"
|
54 |
+
"{% endif %}"
|
55 |
+
"{{ '<|input_end|>' }}"
|
56 |
+
).strip()
|
57 |
+
# fmt: on
|
58 |
|
59 |
|
60 |
@dataclass
|
|
|
105 |
class DartTokenizer(PreTrainedTokenizerFast):
|
106 |
"""Dart tokenizer"""
|
107 |
|
108 |
+
def __init__(self, **kwargs):
|
|
|
|
|
|
|
109 |
super().__init__(**kwargs)
|
110 |
|
111 |
self._tokenizer.decoder = Decoder.custom( # type: ignore
|
112 |
DartDecoder(list(self.get_added_vocab().keys()))
|
113 |
)
|
114 |
|
115 |
+
@property
|
116 |
+
def default_chat_template(self):
|
117 |
+
"""
|
118 |
+
Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
|
119 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
return PROMPT_TEMPLATE
|