import torch from .base import Tokenizer from .helper import get_stats, merge_batch_get_stats from heapq import nlargest import time MANA_SPECIAL_TOKENS = { '<|end|>': 265712, '<|user|>': 265713, '<|assistant|>': 265714, '<|system|>': 265715 } class ManaTokenizer(Tokenizer): def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1): """ - pattern: optional string to override the default (GPT-4 split pattern) - special_tokens: str -> int dictionary of special tokens example: {'<|endoftext|>': 100257} """ super().__init__(pattern, multiprocess, store_dict, stop_list_size, freq_cutoff) self.register_special_tokens(MANA_SPECIAL_TOKENS) self.load("mana_tokenizer/mana.model") self.padding_side = "right" self.pad_token_id = self.special_tokens.get('<|end|>') @property def tokens(self): """Property to retrieve token IDs for a given text.""" return self._tokens @property def attention_masks(self): """Property to retrieve attention masks for a given text.""" return self._attention_masks def encode(self, text, allowed_special="none_raise"): """Override encode to include attention masks.""" encoded_ids = super().encode(text, allowed_special=allowed_special) self._tokens = encoded_ids self._attention_masks = torch.ones(len(encoded_ids), dtype=torch.int32) return self def batch_encode(self, texts, padding=True): """ Encode a list of texts with dynamic padding and attention masks. Handles left padding and attention masking. Parameters: texts (list of str): List of texts to encode. padding (bool): If True, pad sequences to the max length in the batch. Returns: dict: A dictionary containing input_ids and attention_mask tensors. """ # Ensure encode method returns a dict with 'input_ids' and 'attention_mask' encoded_texts = [{"input_ids": self.encode(text).tokens, "attention_mask": [1] * len(self.encode(text).tokens)} for text in texts] max_len = max(len(t["input_ids"]) for t in encoded_texts) if padding else None # Apply padding with left alignment input_ids = [] attention_masks = [] for encoding in encoded_texts: ids = encoding["input_ids"] attn_mask = encoding["attention_mask"] if padding and len(ids) < max_len: pad_len = max_len - len(ids) if self.padding_side == "left": ids = [self.pad_token_id] * pad_len + ids attn_mask = [0] * pad_len + attn_mask else: ids = ids + [self.pad_token_id] * pad_len attn_mask = attn_mask + [0] * pad_len input_ids.append(ids) attention_masks.append(attn_mask) # Convert to tensors input_ids = torch.tensor(input_ids, dtype=torch.long) attention_masks = torch.tensor(attention_masks, dtype=torch.long) return {"input_ids": input_ids, "attention_mask": attention_masks} def get_vocab(self): """Function to return the vocabulary dictionary.""" return self.vocab @property def vocab_size(self): """Property to return the vocabulary size.""" return len(self.vocab) def train(self, data, vocab_size, cap_divisor=2, max_batch_size=0, verbose=False): t0 = time.time() ids = self._import_data(data) # [(bytes, int)] -> text chunks and their counts t1 = time.time() print(f'Time spent loading data: {t1-t0:.2f}') merges = self.merges # {(int, int): int} -> token pair to new token vocab = self.vocab # {int: bytes} -> token to its bytes representation batch_count = 0 curr_vocab_size = len(vocab) num_merges = vocab_size - curr_vocab_size merges_remaining = num_merges if max_batch_size < 1: max_batch_size = num_merges stats = get_stats(ids) # stats are later updated by merge_batch_get_stats start_time = time.time() while merges_remaining > 0: seen_first = set() # tokens seen in the first position in pairs seen_last = set() # tokens seen in the last position in pairs pairs_to_merge = {} num_pairs_to_search = min(merges_remaining//cap_divisor, len(vocab), max_batch_size) or 1 top_pairs = nlargest(num_pairs_to_search, stats, key=stats.get) for first, last in top_pairs: # pairs are (first, last) tuples if first in seen_last or last in seen_first: # unsafe merge seen_first.add(first) seen_last.add(last) continue # skip this pair but keep looking for safe merges in top_pairs seen_first.add(first) seen_last.add(last) pairs_to_merge[(first, last)] = curr_vocab_size vocab[curr_vocab_size] = vocab[first] + vocab[last] curr_vocab_size += 1 merges_remaining -= len(pairs_to_merge) merges.update(pairs_to_merge) # save the merges batch_count += 1 if merges_remaining: # no need to merge last batch stats = merge_batch_get_stats(ids, pairs_to_merge) # replace pairs_to_merge keys in ids with their values if verbose: t2 = time.time() time_taken = t2 - start_time avg_time_per_batch = time_taken / batch_count estimated_remaining_time = avg_time_per_batch * (num_merges - merges_remaining) estimated_end_time = time.strftime("%H:%M:%S", time.localtime(time.time() + estimated_remaining_time)) print(f"Batch {batch_count} merged {len(pairs_to_merge)} pairs in {t2-t1:.2f} sec. " f"Merges remaining: {merges_remaining}. Estimated end time: {estimated_end_time}") t1 = t2 self.merges = merges # used in encode() self.vocab = vocab # used in decode()