import logging from typing import List from transformers import PreTrainedTokenizerFast from tokenizers.decoders import Decoder logger = logging.getLogger(__name__) # fmt: off # https://huggingface.co/docs/transformers/main/en/chat_templating PROMPT_TEMPLATE = ( "{{ '<|bos|>' }}" "{{ '' }}" "{% if 'rating' not in messages or messages['rating'] is none %}" "{{ 'rating:sfw, rating:general' }}" "{% else %}" "{{ messages['rating'] }}" "{% endif %}" "{{ '' }}" "{{ '' }}" "{% if 'copyright' not in messages or messages['copyright'] is none %}" "{{ '' }}" "{% else %}" "{{ messages['copyright'] }}" "{% endif %}" "{{ '' }}" "{{ '' }}" "{% if 'character' not in messages or messages['character'] is none %}" "{{ '' }}" "{% else %}" "{{ messages['character'] }}" "{% endif %}" "{{ '' }}" "{{ '' }}" # length token "{% if 'length' not in messages or messages['length'] is none %}" "{{ '<|long|>' }}" "{% else %}" "{{ messages['length'] }}" "{% endif %}" # general token "{% if 'general' not in messages or messages['general'] is none %}" "{{ '' }}" "{% else %}" "{{ messages['general'] }}" "{% endif %}" "{{ '<|input_end|>' }}" ).strip() # fmt: on class DartDecoder: def __init__(self, special_tokens: List[str]): self.special_tokens = list(special_tokens) def decode_chain(self, tokens: List[str]) -> List[str]: new_tokens = [] is_specials = [] for i, token in enumerate(tokens): is_specials.append(token in self.special_tokens) if i == 0: new_tokens.append(token) continue # this token or previous token is special if is_specials[i] or is_specials[i - 1]: new_tokens.append(token) continue new_tokens.append(f", {token}") return new_tokens class DartTokenizer(PreTrainedTokenizerFast): """Dart tokenizer""" def __init__(self, **kwargs): super().__init__(**kwargs) self._tokenizer.decoder = Decoder.custom( # type: ignore DartDecoder(list(self.get_added_vocab().keys())) ) @property def default_chat_template(self): """ Danbooru Tags Transformer uses special format prompt to generate danbooru tags. """ return PROMPT_TEMPLATE