|
|
|
from collections import Counter, defaultdict |
|
import unicodedata |
|
|
|
def get_stats(ids): |
|
""" |
|
Given `ids`, a list of 2-tuples of iterables of ints and int values, |
|
returns a defaultdict with the counts of occurrences of all the consecutive |
|
pairs of integers within each bytes object, multiplied by the integer value |
|
associated with each key. This function does not count pairs between the last |
|
element of one key the first element of the next key. The integer value |
|
associated with each key serves as a multiplier for the count of each pair |
|
within that object. Consecutive identical pairs within the same bytes object |
|
are counted only once to avoid overcounting repeat characters. |
|
|
|
Example: |
|
get_stats({b'abc': 2, b'bcd': 1, b'eee': 1}) |
|
-> defaultdict(<class 'int'>, {(97, 98): 1, (98, 99): 2, (99, 100): 1, (101, 101): 1}) |
|
""" |
|
counts = defaultdict(int) |
|
for chunk, num in ids: |
|
last_index = len(chunk) - 1 |
|
i = 0 |
|
while i < last_index: |
|
j = i + 1 |
|
counts[(chunk[i], chunk[j])] += num |
|
i = j |
|
return counts |
|
|
|
def merge_batch_get_stats(ids, pairs): |
|
counts = defaultdict(int) |
|
for chunk, num in ids: |
|
last_index = len(chunk) - 1 |
|
i = 0 |
|
while i < last_index: |
|
j = i + 1 |
|
token = pairs.get((chunk[i], chunk[j])) |
|
if token is not None: |
|
chunk[i] = token |
|
del chunk[j] |
|
last_index -= 1 |
|
if i: |
|
counts[(chunk[i-1], chunk[i])] += num |
|
i = j |
|
if i and i == last_index: |
|
counts[(chunk[-2], chunk[i])] += num |
|
return counts |
|
|
|
def merge(ids, pair, idx, len_ids): |
|
""" |
|
In the list of integers (ids), replace all consecutive occurrences |
|
of pair with the new integer token idx |
|
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4] |
|
""" |
|
i = 0 |
|
while i + 1 < len_ids: |
|
j = i + 1 |
|
if ids[i] == pair[0] and ids[j] == pair[1]: |
|
ids[i] = idx |
|
del ids[j] |
|
len_ids -= 1 |
|
i = j |
|
return len_ids |
|
|
|
def replace_control_characters(s: str) -> str: |
|
|
|
|
|
|
|
|
|
chars = [] |
|
for ch in s: |
|
if unicodedata.category(ch)[0] != "C": |
|
chars.append(ch) |
|
else: |
|
chars.append(f"\\u{ord(ch):04x}") |
|
return "".join(chars) |
|
|
|
def render_token(t: bytes) -> str: |
|
|
|
s = t.decode('utf-8', errors='replace') |
|
s = replace_control_characters(s) |
|
return s |
|
|
|
def _process_dicts(batch, compiled_pattern): |
|
counter = Counter() |
|
for item in batch: |
|
counter.update(re.findall(compiled_pattern, item)) |
|
return counter |
|
|
|
def _process_string_scalar(batch, compiled_pattern): |
|
counter = Counter() |
|
for item in batch: |
|
counter.update(re.findall(compiled_pattern, item.as_py())) |
|
return counter |