Spaces:
Running
Running
"""Functions to help with searching codes using regex.""" | |
import pickle | |
import re | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
def load_dataset_cache(cache_base_path): | |
"""Load cache files required for dataset from `cache_base_path`.""" | |
tokens_str = np.load(cache_base_path + "tokens_str.npy") | |
tokens_text = np.load(cache_base_path + "tokens_text.npy") | |
token_byte_pos = np.load(cache_base_path + "token_byte_pos.npy") | |
return tokens_str, tokens_text, token_byte_pos | |
def load_code_search_cache(cache_base_path): | |
"""Load cache files required for code search from `cache_base_path`.""" | |
metrics = np.load(cache_base_path + "metrics.npy", allow_pickle=True).item() | |
with open(cache_base_path + "cb_acts.pkl", "rb") as f: | |
cb_acts = pickle.load(f) | |
with open(cache_base_path + "act_count_ft_tkns.pkl", "rb") as f: | |
act_count_ft_tkns = pickle.load(f) | |
return cb_acts, act_count_ft_tkns, metrics | |
def search_re(re_pattern, tokens_text, at_odd_even=-1): | |
"""Get list of (example_id, token_pos) where re_pattern matches in tokens_text. | |
Args: | |
re_pattern: regex pattern to search for. | |
tokens_text: list of example texts. | |
at_odd_even: to limit matches to odd or even positions only. | |
-1 (default): to not limit matches. | |
0: to limit matches to odd positions only. | |
1: to limit matches to even positions only. | |
This is useful for the TokFSM dataset when searching for states | |
since the first token of states are always at even positions. | |
""" | |
# TODO: ensure that parentheses are not escaped | |
assert at_odd_even in [-1, 0, 1], f"Invalid at_odd_even: {at_odd_even}" | |
if re_pattern.find("(") == -1: | |
re_pattern = f"({re_pattern})" | |
res = [ | |
(i, finditer.span(1)[0]) | |
for i, text in enumerate(tokens_text) | |
for finditer in re.finditer(re_pattern, text) | |
if finditer.span(1)[0] != finditer.span(1)[1] | |
] | |
if at_odd_even != -1: | |
res = [r for r in res if r[1] % 2 == at_odd_even] | |
return res | |
def byte_id_to_token_pos_id(example_byte_id, token_byte_pos): | |
"""Convert byte position (or character position in a text) to its token position. | |
Used to convert the searched regex span to its token position. | |
Args: | |
example_byte_id: tuple of (example_id, byte_id) where byte_id is a | |
character's position in the text. | |
token_byte_pos: numpy array of shape (num_examples, seq_len) where | |
`token_byte_pos[example_id][token_pos]` is the byte position of | |
the token at `token_pos` in the example with `example_id`. | |
Returns: | |
(example_id, token_pos_id) tuple. | |
""" | |
example_id, byte_id = example_byte_id | |
index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right") | |
return (example_id, index) | |
def get_code_precision_and_recall(token_pos_ids, codebook_acts, cb_act_counts=None): | |
"""Search for the codes that activate on the given `token_pos_ids`. | |
Args: | |
token_pos_ids: list of (example_id, token_pos_id) tuples. | |
codebook_acts: numpy array of activations of a codebook on a dataset with | |
shape (num_examples, seq_len, k_codebook). | |
cb_act_counts: array of shape (num_codes,) where `cb_act_counts[cb_name][code]` | |
is the number of times the code `code` is activated in the dataset. | |
Returns: | |
codes: numpy array of code ids sorted by their precision on the given `token_pos_ids`. | |
prec: numpy array where `prec[i]` is the precision of the code | |
`codes[i]` for the given `token_pos_ids`. | |
recall: numpy array where `recall[i]` is the recall of the code | |
`codes[i]` for the given `token_pos_ids`. | |
code_acts: numpy array where `code_acts[i]` is the number of times | |
the code `codes[i]` is activated in the dataset. | |
""" | |
codes = np.array( | |
[ | |
codebook_acts[example_id][token_pos_id] | |
for example_id, token_pos_id in token_pos_ids | |
] | |
) | |
codes, counts = np.unique(codes, return_counts=True) | |
recall = counts / len(token_pos_ids) | |
idx = recall > 0.01 | |
codes, counts, recall = codes[idx], counts[idx], recall[idx] | |
if cb_act_counts is not None: | |
code_acts = np.array([cb_act_counts[code] for code in codes]) | |
prec = counts / code_acts | |
sort_idx = np.argsort(prec)[::-1] | |
else: | |
code_acts = np.zeros_like(codes) | |
prec = np.zeros_like(codes) | |
sort_idx = np.argsort(recall)[::-1] | |
codes, prec, recall = codes[sort_idx], prec[sort_idx], recall[sort_idx] | |
code_acts = code_acts[sort_idx] | |
return codes, prec, recall, code_acts | |
def get_neuron_precision_and_recall( | |
token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts | |
): | |
"""Get the neurons with the highest precision and recall for the given `token_pos_ids`. | |
Args: | |
token_pos_ids: list of token (example_id, token_pos_id) tuples from a dataset over which | |
the neurons with the highest precision and recall are to be found. | |
recall: recall threshold for the neurons (this determines their activation threshold). | |
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons | |
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size). | |
The third dimension is 2 because we consider neurons from both: attention and mlp. | |
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons | |
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len). | |
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two | |
dimensions to the last dimensions and then sorting the last dimension. | |
Returns: | |
best_prec: highest precision amongst all the neurons for the given `token_pos_ids`. | |
best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids` | |
based on the threshold determined by the `recall` argument. | |
best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number, | |
`is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp, | |
and `neuron_id` is the neuron's index in the layer. | |
""" | |
if isinstance(neuron_acts_by_ex, torch.Tensor): | |
neuron_acts_on_pattern = torch.stack( | |
[ | |
neuron_acts_by_ex[example_id, token_pos_id] | |
for example_id, token_pos_id in token_pos_ids | |
], | |
dim=-1, | |
) # (layers, 2, dim_size, matches) | |
neuron_acts_on_pattern = torch.sort(neuron_acts_on_pattern, dim=-1).values | |
else: | |
neuron_acts_on_pattern = np.stack( | |
[ | |
neuron_acts_by_ex[example_id, token_pos_id] | |
for example_id, token_pos_id in token_pos_ids | |
], | |
axis=-1, | |
) # (layers, 2, dim_size, matches) | |
neuron_acts_on_pattern.sort(axis=-1) | |
neuron_acts_on_pattern = torch.from_numpy(neuron_acts_on_pattern) | |
act_thresh = neuron_acts_on_pattern[ | |
:, :, :, -int(recall * neuron_acts_on_pattern.shape[-1]) | |
] | |
assert neuron_sorted_acts.shape[:-1] == act_thresh.shape | |
prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1)) | |
prec_den = prec_den.squeeze(-1) | |
prec_den = neuron_sorted_acts.shape[-1] - prec_den | |
prec = int(recall * neuron_acts_on_pattern.shape[-1]) / prec_den | |
assert ( | |
prec.shape == neuron_acts_on_pattern.shape[:-1] | |
), f"{prec.shape} != {neuron_acts_on_pattern.shape[:-1]}" | |
best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape) | |
best_prec = prec[best_neuron_idx] | |
best_neuron_act_thresh = act_thresh[best_neuron_idx].item() | |
best_neuron_acts = neuron_acts_by_ex[ | |
:, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2] | |
] | |
best_neuron_acts = best_neuron_acts >= best_neuron_act_thresh | |
best_neuron_acts = np.stack(np.where(best_neuron_acts), axis=-1) | |
return best_prec, best_neuron_acts, best_neuron_idx | |
def convert_to_adv_name(name, cb_at, gcb=""): | |
"""Convert layer0_head0 to layer0_attn_preproj_gcb0.""" | |
if gcb: | |
layer, head = name.split("_") | |
return layer + f"_{cb_at}_gcb" + head[4:] | |
else: | |
return layer + "_" + cb_at | |
def convert_to_base_name(name, gcb=""): | |
"""Convert layer0_attn_preproj_gcb0 to layer0_head0.""" | |
split_name = name.split("_") | |
layer, head = split_name[0], split_name[-1][3:] | |
if "gcb" in name: | |
return layer + "_head" + head | |
else: | |
return layer | |
def get_layer_head_from_base_name(name): | |
"""Convert layer0_head0 to 0, 0.""" | |
split_name = name.split("_") | |
layer = int(split_name[0][5:]) | |
head = None | |
if len(split_name) > 1: | |
head = int(split_name[-1][4:]) | |
return layer, head | |
def get_layer_head_from_adv_name(name): | |
"""Convert layer0_attn_preproj_gcb0 to 0, 0.""" | |
base_name = convert_to_base_name(name) | |
layer, head = get_layer_head_from_base_name(base_name) | |
return layer, head | |
def get_codes_from_pattern( | |
re_pattern, | |
tokens_text, | |
token_byte_pos, | |
cb_acts, | |
act_count_ft_tkns, | |
gcb="", | |
topk=5, | |
prec_threshold=0.5, | |
at_odd_even=-1, | |
): | |
"""Fetch codes that activate on a given regex pattern. | |
Retrieves at most `top_k` codes that activate with precision above `prec_threshold`. | |
Args: | |
re_pattern: regex pattern to search for. | |
tokens_text: list of example texts of a dataset. | |
token_byte_pos: numpy array of shape (num_examples, seq_len) where | |
`token_byte_pos[example_id][token_pos]` is the byte position of | |
the token at `token_pos` in the example with `example_id`. | |
cb_acts: dict of codebook activations. | |
act_count_ft_tkns: dict over all codebooks of number of token activations on the dataset | |
gcb: "_gcb" for grouped codebooks and "" for non-grouped codebooks. | |
topk: maximum number of codes to return per codebook. | |
prec_threshold: minimum precision required for a code to be returned. | |
at_odd_even: to limit matches to odd or even positions only. | |
-1 (default): to not limit matches. | |
0: to limit matches to odd positions only. | |
1: to limit matches to even positions only. | |
This is useful for the TokFSM dataset when searching for states | |
since the first token of states are always at even positions. | |
Returns: | |
codebook_wise_codes: dict of codebook name to list of | |
(code, prec, recall, code_acts) tuples. | |
re_token_matches: number of tokens that match the regex pattern. | |
""" | |
byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even) | |
token_pos_ids = [ | |
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids | |
] | |
token_pos_ids = np.unique(token_pos_ids, axis=0) | |
re_token_matches = len(token_pos_ids) | |
codebook_wise_codes = {} | |
for cb_name, cb in tqdm(cb_acts.items()): | |
base_cb_name = convert_to_base_name(cb_name, gcb=gcb) | |
codes, prec, recall, code_acts = get_code_precision_and_recall( | |
token_pos_ids, | |
cb, | |
cb_act_counts=act_count_ft_tkns[base_cb_name], | |
) | |
idx = np.arange(min(topk, len(codes))) | |
idx = idx[prec[:topk] > prec_threshold] | |
codes, prec, recall = codes[idx], prec[idx], recall[idx] | |
code_acts = code_acts[idx] | |
codes_pr = list(zip(codes, prec, recall, code_acts)) | |
codebook_wise_codes[base_cb_name] = codes_pr | |
return codebook_wise_codes, re_token_matches | |
def get_neurons_from_pattern( | |
re_pattern, | |
tokens_text, | |
token_byte_pos, | |
neuron_acts_by_ex, | |
neuron_sorted_acts, | |
recall_threshold, | |
at_odd_even=-1, | |
): | |
"""Fetch the highest precision neurons that activate on a given regex pattern. | |
The activation threshold for the neurons is determined by the `recall_threshold`. | |
Args: | |
re_pattern: regex pattern to search for. | |
tokens_text: list of example texts of a dataset. | |
token_byte_pos: numpy array of shape (num_examples, seq_len) where | |
`token_byte_pos[example_id][token_pos]` is the byte position of | |
the token at `token_pos` in the example with `example_id`. | |
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons | |
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size). | |
The third dimension is 2 because we consider neurons from both: attention and mlp. | |
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons | |
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len). | |
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two | |
dimensions to the last dimensions and then sorting the last dimension. | |
recall_threshold: recall threshold for the neurons (this determines their activation threshold). | |
at_odd_even: to limit matches to odd or even positions only. | |
-1 (default): to not limit matches. | |
0: to limit matches to odd positions only. | |
1: to limit matches to even positions only. | |
This is useful for the TokFSM dataset when searching for states | |
since the first token of states are always at even positions. | |
Returns: | |
best_prec: highest precision amongst all the neurons for the given `token_pos_ids`. | |
best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids` | |
based on the threshold determined by the `recall` argument. | |
best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number, | |
`is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp, | |
and `neuron_id` is the neuron's index in the layer. | |
re_token_matches: number of tokens that match the regex pattern. | |
""" | |
byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even) | |
token_pos_ids = [ | |
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids | |
] | |
token_pos_ids = np.unique(token_pos_ids, axis=0) | |
re_token_matches = len(token_pos_ids) | |
best_prec, best_neuron_acts, best_neuron_idx = get_neuron_precision_and_recall( | |
token_pos_ids, | |
recall_threshold, | |
neuron_acts_by_ex, | |
neuron_sorted_acts, | |
) | |
return best_prec, best_neuron_acts, best_neuron_idx, re_token_matches | |
def compare_codes_with_neurons( | |
best_codes_info, | |
tokens_text, | |
token_byte_pos, | |
neuron_acts_by_ex, | |
neuron_sorted_acts, | |
at_odd_even=-1, | |
): | |
"""Compare codes with the highest precision neurons on the regex pattern of the code. | |
Args: | |
best_codes_info: list of CodeInfo objects. | |
tokens_text: list of example texts of a dataset. | |
token_byte_pos: numpy array of shape (num_examples, seq_len) where | |
`token_byte_pos[example_id][token_pos]` is the byte position of | |
the token at `token_pos` in the example with `example_id`. | |
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons | |
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size). | |
The third dimension is 2 because we consider neurons from both: attention and mlp. | |
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons | |
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len). | |
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two | |
dimensions to the last dimensions and then sorting the last dimension. | |
at_odd_even: to limit matches to odd or even positions only. | |
-1 (default): to not limit matches. | |
0: to limit matches to odd positions only. | |
1: to limit matches to even positions only. | |
This is useful for the TokFSM dataset when searching for states | |
since the first token of states are always at even positions. | |
Returns: | |
codes_better_than_neurons: fraction of codes that have higher precision than the highest | |
precision neuron on the regex pattern of the code. | |
code_best_precs: is an array of the precision of each code in `best_codes_info`. | |
all_best_prec: is an array of the highest precision neurons on the regex pattern. | |
""" | |
assert isinstance(neuron_acts_by_ex, np.ndarray) | |
( | |
neuron_best_prec, | |
all_best_neuron_acts, | |
all_best_neuron_idxs, | |
all_re_token_matches, | |
) = zip( | |
*[ | |
get_neurons_from_pattern( | |
code_info.regex, | |
tokens_text, | |
token_byte_pos, | |
neuron_acts_by_ex, | |
neuron_sorted_acts, | |
code_info.recall, | |
at_odd_even=at_odd_even, | |
) | |
for code_info in tqdm(best_codes_info) | |
], | |
strict=True, | |
) | |
neuron_best_prec = np.array(neuron_best_prec) | |
code_best_precs = np.array([code_info.prec for code_info in best_codes_info]) | |
codes_better_than_neurons = code_best_precs > neuron_best_prec | |
return codes_better_than_neurons.mean(), code_best_precs, neuron_best_prec | |