File size: 2,524 Bytes
1c8f03b ce448ae 1c8f03b 0df470a 1c8f03b 0df470a 1c8f03b 0df470a 1c8f03b 0df470a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import logging
from typing import List
from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder
logger = logging.getLogger(__name__)
# fmt: off
# https://huggingface.co/docs/transformers/main/en/chat_templating
PROMPT_TEMPLATE = (
"{{ '<|bos|>' }}"
"{{ '<rating>' }}"
"{% if 'rating' not in messages or messages['rating'] is none %}"
"{{ 'rating:sfw, rating:general' }}"
"{% else %}"
"{{ messages['rating'] }}"
"{% endif %}"
"{{ '</rating>' }}"
"{{ '<copyright>' }}"
"{% if 'copyright' not in messages or messages['copyright'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['copyright'] }}"
"{% endif %}"
"{{ '</copyright>' }}"
"{{ '<character>' }}"
"{% if 'character' not in messages or messages['character'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['character'] }}"
"{% endif %}"
"{{ '</character>' }}"
"{{ '<general>' }}"
# length token
"{% if 'length' not in messages or messages['length'] is none %}"
"{{ '<|long|>' }}"
"{% else %}"
"{{ messages['length'] }}"
"{% endif %}"
# general token
"{% if 'general' not in messages or messages['general'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['general'] }}"
"{% endif %}"
"{{ '<|input_end|>' }}"
).strip()
# fmt: on
class DartDecoder:
def __init__(self, special_tokens: List[str]):
self.special_tokens = list(special_tokens)
def decode_chain(self, tokens: List[str]) -> List[str]:
new_tokens = []
is_specials = []
for i, token in enumerate(tokens):
is_specials.append(token in self.special_tokens)
if i == 0:
new_tokens.append(token)
continue
# this token or previous token is special
if is_specials[i] or is_specials[i - 1]:
new_tokens.append(token)
continue
new_tokens.append(f", {token}")
return new_tokens
class DartTokenizer(PreTrainedTokenizerFast):
"""Dart tokenizer"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tokenizer.decoder = Decoder.custom( # type: ignore
DartDecoder(list(self.get_added_vocab().keys()))
)
@property
def default_chat_template(self):
"""
Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
"""
return PROMPT_TEMPLATE
|