|
import regex as re |
|
import base64 |
|
import tiktoken |
|
import os |
|
import json |
|
from transformers import PreTrainedTokenizer |
|
|
|
class BaseTokenizer(PreTrainedTokenizer): |
|
"""Abstract class for tokenizer.""" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__() |
|
|
|
@property |
|
def add_prefix_space(self): |
|
return False |
|
|
|
@property |
|
def vocab_size(self): |
|
raise NotImplemented |
|
|
|
def tokenize(self, text): |
|
raise NotImplemented |
|
|
|
def detokenize(self, token_ids, ignore_special_tokens=True): |
|
raise NotImplemented |
|
|
|
def build_single_message(self, role, metadata, message): |
|
assert role in ["system", "user", "assistant", "observation"], role |
|
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") |
|
message_tokens = self.tokenizer.encode(message, disallowed_special=()) |
|
tokens = role_tokens + message_tokens |
|
return tokens |
|
|
|
def build_chat_input(self, query, history=None, role="user", metadata=""): |
|
if history is None: |
|
history = [] |
|
input_ids = [] |
|
for item in history: |
|
content = item["content"] |
|
if item["role"] == "system" and "tools" in item: |
|
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) |
|
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) |
|
input_ids.extend(self.build_single_message(role, metadata, query)) |
|
input_ids.extend([self.get_command("<|assistant|>")]) |
|
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) |
|
|
|
@property |
|
def eos_id(self): |
|
raise NotImplemented |
|
|
|
def get_command(self, token): |
|
return NotImplemented |
|
|
|
class TikTokenizer(BaseTokenizer): |
|
vocab_files_names = {"vocab_file": "tokenizer.tiktoken"} |
|
|
|
def __init__(self, vocab_file, **kwargs): |
|
pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" |
|
self.pat_str = re.compile(pat_str) |
|
|
|
self.b64_vocab = {} |
|
mergeable_ranks = {} |
|
with open(vocab_file) as f: |
|
for line in f: |
|
token, rank = line.strip().split() |
|
rank = int(rank) |
|
token = base64.b64decode(token) |
|
mergeable_ranks[token] = rank |
|
self.b64_vocab['%s' % token] = rank |
|
|
|
self.special_tokens = ["<|endoftext|>", "[MASK]", "[gMASK]", "[sMASK]", "<sop>", "<eop>", "<|system|>", |
|
"<|user|>", "<|assistant|>", "<|observation|>"] |
|
self.special_tokens = { |
|
token: idx for idx, token in enumerate(self.special_tokens, start=len(mergeable_ranks)) |
|
} |
|
self.special_token_ids = {idx: token for token, idx in self.special_tokens.items()} |
|
|
|
self.tokenizer = tiktoken.Encoding( |
|
name="my_tokenizer", |
|
pat_str=pat_str, |
|
mergeable_ranks=mergeable_ranks, |
|
special_tokens=self.special_tokens |
|
) |
|
self.decoder = {rank: token for token, rank in mergeable_ranks.items()} |
|
self.n_words = len(self.decoder) + len(self.special_tokens) |
|
super().__init__() |
|
|
|
@property |
|
def add_prefix_space(self): |
|
return False |
|
|
|
def tokenize(self, text, add_special_tokens=True): |
|
ids = self.encode(text, add_special_tokens=add_special_tokens) |
|
return [self.convert_id_to_token(_id) for _id in ids] |
|
|
|
def detokenize(self, ids, ignore_special_tokens=True): |
|
if ignore_special_tokens: |
|
ids = [idx for idx in ids if idx not in self.special_token_ids] |
|
return self.tokenizer.decode(ids) |
|
|
|
def encode(self, text, add_special_tokens=True): |
|
ids = self.tokenizer.encode(text, disallowed_special=(), allowed_special="all") |
|
if add_special_tokens: |
|
ids = [self.special_tokens["[gMASK]"], self.special_tokens["<sop>"]] + ids |
|
return ids |
|
|
|
def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False): |
|
if type(ids) is int: |
|
ids = [ids] |
|
return self.detokenize(ids, ignore_special_tokens=skip_special_tokens) |
|
|
|
def encode_pieces(self, text): |
|
ids = self.tokenizer.encode(text, disallowed_special=()) |
|
return list(map(lambda x: self.decoder[x].detokenize('utf-8', errors='replace'), ids)) |
|
|
|
@property |
|
def vocab_size(self): |
|
return self.n_words |
|
|
|
@property |
|
def eos_token_id(self): |
|
return self.special_tokens["<|endoftext|>"] |
|
|
|
def convert_token_to_id(self, token): |
|
""" Converts a token (str) in an id using the vocab. """ |
|
if token in self.special_tokens: |
|
return self.special_tokens[token] |
|
|
|
|
|
if token in self.b64_vocab: |
|
return self.b64_vocab[token] |
|
|
|
|
|
else: |
|
raise RuntimeError(f"{token} is not a single token") |
|
|
|
def _convert_token_to_id(self, token): |
|
return self.convert_token_to_id(token) |
|
|
|
def convert_id_to_token(self, index): |
|
if index in self.special_token_ids: |
|
return self.special_token_ids[index] |
|
return '%s' % self.decoder[index] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_id_to_token(self, index): |
|
return self.convert_id_to_token(index) |
|
|
|
def get_command(self, token): |
|
return self.special_tokens[token] |
|
|
|
def get_vocab(self): |
|
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} |
|
return vocab |
|
|