import regex as re import base64 import tiktoken import os import json from transformers import PreTrainedTokenizer class BaseTokenizer(PreTrainedTokenizer): """Abstract class for tokenizer.""" def __init__(self, **kwargs): super().__init__() @property def add_prefix_space(self): return False @property def vocab_size(self): raise NotImplemented def tokenize(self, text): raise NotImplemented def detokenize(self, token_ids, ignore_special_tokens=True): raise NotImplemented def build_single_message(self, role, metadata, message): assert role in ["system", "user", "assistant", "observation"], role role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") message_tokens = self.tokenizer.encode(message, disallowed_special=()) tokens = role_tokens + message_tokens return tokens def build_chat_input(self, query, history=None, role="user", metadata=""): if history is None: history = [] input_ids = [] for item in history: content = item["content"] if item["role"] == "system" and "tools" in item: content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) input_ids.extend(self.build_single_message(role, metadata, query)) input_ids.extend([self.get_command("<|assistant|>")]) return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) @property def eos_id(self): raise NotImplemented def get_command(self, token): return NotImplemented class TikTokenizer(BaseTokenizer): vocab_files_names = {"vocab_file": "tokenizer.tiktoken"} def __init__(self, vocab_file, **kwargs): pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" self.pat_str = re.compile(pat_str) self.b64_vocab = {} mergeable_ranks = {} with open(vocab_file) as f: for line in f: token, rank = line.strip().split() rank = int(rank) token = base64.b64decode(token) mergeable_ranks[token] = rank self.b64_vocab['%s' % token] = rank self.special_tokens = ["<|endoftext|>", "[MASK]", "[gMASK]", "[sMASK]", "", "", "<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] self.special_tokens = { token: idx for idx, token in enumerate(self.special_tokens, start=len(mergeable_ranks)) } self.special_token_ids = {idx: token for token, idx in self.special_tokens.items()} self.tokenizer = tiktoken.Encoding( name="my_tokenizer", pat_str=pat_str, mergeable_ranks=mergeable_ranks, special_tokens=self.special_tokens ) self.decoder = {rank: token for token, rank in mergeable_ranks.items()} self.n_words = len(self.decoder) + len(self.special_tokens) super().__init__() @property def add_prefix_space(self): return False def tokenize(self, text, add_special_tokens=True): ids = self.encode(text, add_special_tokens=add_special_tokens) return [self.convert_id_to_token(_id) for _id in ids] def detokenize(self, ids, ignore_special_tokens=True): if ignore_special_tokens: ids = [idx for idx in ids if idx not in self.special_token_ids] return self.tokenizer.decode(ids) def encode(self, text, add_special_tokens=True): ids = self.tokenizer.encode(text, disallowed_special=(), allowed_special="all") if add_special_tokens: ids = [self.special_tokens["[gMASK]"], self.special_tokens[""]] + ids return ids def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False): if type(ids) is int: ids = [ids] return self.detokenize(ids, ignore_special_tokens=skip_special_tokens) def encode_pieces(self, text): ids = self.tokenizer.encode(text, disallowed_special=()) return list(map(lambda x: self.decoder[x].detokenize('utf-8', errors='replace'), ids)) @property def vocab_size(self): return self.n_words @property def eos_token_id(self): return self.special_tokens["<|endoftext|>"] def convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ if token in self.special_tokens: return self.special_tokens[token] # assert type(token) == str, "type of token (%s) is %s" % (token, type(token)) # ids = self.tokenizer.encode(token, disallowed_special=()) if token in self.b64_vocab: return self.b64_vocab[token] # if len(ids) == 1: # return ids[0] else: raise RuntimeError(f"{token} is not a single token") def _convert_token_to_id(self, token): return self.convert_token_to_id(token) def convert_id_to_token(self, index): if index in self.special_token_ids: return self.special_token_ids[index] return '%s' % self.decoder[index] # try: # return self.decoder[index].decode('utf-8') # except Exception as e: # print("Exception: %s for (%d)%s" % (e, index, self.decoder[index])) # return "" #return self.decoder[index].detokenize('utf-8', errors='replace') def _convert_id_to_token(self, index): return self.convert_id_to_token(index) def get_command(self, token): return self.special_tokens[token] def get_vocab(self): vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} return vocab