mana_tokenizer / helper.py
tspersian's picture
scripts
4128ba5
from collections import Counter, defaultdict
import unicodedata
def get_stats(ids):
"""
Given `ids`, a list of 2-tuples of iterables of ints and int values,
returns a defaultdict with the counts of occurrences of all the consecutive
pairs of integers within each bytes object, multiplied by the integer value
associated with each key. This function does not count pairs between the last
element of one key the first element of the next key. The integer value
associated with each key serves as a multiplier for the count of each pair
within that object. Consecutive identical pairs within the same bytes object
are counted only once to avoid overcounting repeat characters.
Example:
get_stats({b'abc': 2, b'bcd': 1, b'eee': 1})
-> defaultdict(<class 'int'>, {(97, 98): 1, (98, 99): 2, (99, 100): 1, (101, 101): 1})
"""
counts = defaultdict(int)
for chunk, num in ids:
last_index = len(chunk) - 1
i = 0
while i < last_index:
j = i + 1
counts[(chunk[i], chunk[j])] += num
i = j
return counts
def merge_batch_get_stats(ids, pairs):
counts = defaultdict(int)
for chunk, num in ids:
last_index = len(chunk) - 1
i = 0
while i < last_index:
j = i + 1
token = pairs.get((chunk[i], chunk[j]))
if token is not None:
chunk[i] = token
del chunk[j]
last_index -= 1
if i:
counts[(chunk[i-1], chunk[i])] += num
i = j
if i and i == last_index:
counts[(chunk[-2], chunk[i])] += num
return counts
def merge(ids, pair, idx, len_ids):
"""
In the list of integers (ids), replace all consecutive occurrences
of pair with the new integer token idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
i = 0
while i + 1 < len_ids:
j = i + 1
if ids[i] == pair[0] and ids[j] == pair[1]:
ids[i] = idx
del ids[j]
len_ids -= 1
i = j
return len_ids
def replace_control_characters(s: str) -> str:
# we don't want to print control characters
# which distort the output (e.g. \n or much worse)
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
# http://www.unicode.org/reports/tr44/#GC_Values_Table
chars = []
for ch in s:
if unicodedata.category(ch)[0] != "C":
chars.append(ch) # this character is ok
else:
chars.append(f"\\u{ord(ch):04x}") # escape
return "".join(chars)
def render_token(t: bytes) -> str:
# pretty print a token, escaping control characters
s = t.decode('utf-8', errors='replace')
s = replace_control_characters(s)
return s
def _process_dicts(batch, compiled_pattern): # for raw datasets.Dataset
counter = Counter()
for item in batch:
counter.update(re.findall(compiled_pattern, item))
return counter
def _process_string_scalar(batch, compiled_pattern):
counter = Counter()
for item in batch:
counter.update(re.findall(compiled_pattern, item.as_py()))
return counter