File size: 18,041 Bytes
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63b5bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9376c
 
63b5bc1
7f9376c
 
 
 
 
63b5bc1
 
 
7f9376c
 
 
63b5bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9376c
 
 
 
 
63b5bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63b5bc1
 
7f9376c
63b5bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9376c
63b5bc1
7f9376c
 
 
 
 
 
63b5bc1
7f9376c
63b5bc1
7f9376c
 
 
 
 
 
63b5bc1
 
 
 
 
7f9376c
 
 
 
63b5bc1
7f9376c
63b5bc1
 
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
63b5bc1
 
 
7f9376c
63b5bc1
7f9376c
 
 
 
63b5bc1
 
7f9376c
 
63b5bc1
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63b5bc1
7f9376c
 
 
 
 
 
 
 
 
 
 
63b5bc1
7f9376c
 
63b5bc1
7f9376c
63b5bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9376c
 
 
 
 
 
 
63b5bc1
 
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63b5bc1
7f9376c
63b5bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9376c
 
 
 
 
63b5bc1
7f9376c
 
 
 
 
 
 
 
 
 
 
 
 
 
63b5bc1
7f9376c
63b5bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f9376c
 
63b5bc1
7f9376c
 
 
 
 
 
63b5bc1
7f9376c
 
 
 
 
63b5bc1
7f9376c
63b5bc1
7f9376c
 
 
63b5bc1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
"""Functions to help with searching codes using regex."""

import pickle
import re

import numpy as np
import torch
from tqdm import tqdm


def load_dataset_cache(cache_base_path):
    """Load cache files required for dataset from `cache_base_path`."""
    tokens_str = np.load(cache_base_path + "tokens_str.npy")
    tokens_text = np.load(cache_base_path + "tokens_text.npy")
    token_byte_pos = np.load(cache_base_path + "token_byte_pos.npy")
    return tokens_str, tokens_text, token_byte_pos


def load_code_search_cache(cache_base_path):
    """Load cache files required for code search from `cache_base_path`."""
    metrics = np.load(cache_base_path + "metrics.npy", allow_pickle=True).item()
    with open(cache_base_path + "cb_acts.pkl", "rb") as f:
        cb_acts = pickle.load(f)
    with open(cache_base_path + "act_count_ft_tkns.pkl", "rb") as f:
        act_count_ft_tkns = pickle.load(f)

    return cb_acts, act_count_ft_tkns, metrics


def search_re(re_pattern, tokens_text, at_odd_even=-1):
    """Get list of (example_id, token_pos) where re_pattern matches in tokens_text.



    Args:

        re_pattern: regex pattern to search for.

        tokens_text: list of example texts.

        at_odd_even: to limit matches to odd or even positions only.

            -1 (default): to not limit matches.

            0: to limit matches to odd positions only.

            1: to limit matches to even positions only.

            This is useful for the TokFSM dataset when searching for states

            since the first token of states are always at even positions.

    """
    # TODO: ensure that parentheses are not escaped
    assert at_odd_even in [-1, 0, 1], f"Invalid at_odd_even: {at_odd_even}"
    if re_pattern.find("(") == -1:
        re_pattern = f"({re_pattern})"
    res = [
        (i, finditer.span(1)[0])
        for i, text in enumerate(tokens_text)
        for finditer in re.finditer(re_pattern, text)
        if finditer.span(1)[0] != finditer.span(1)[1]
    ]
    if at_odd_even != -1:
        res = [r for r in res if r[1] % 2 == at_odd_even]
    return res


def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
    """Convert byte position (or character position in a text) to its token position.



    Used to convert the searched regex span to its token position.



    Args:

        example_byte_id: tuple of (example_id, byte_id) where byte_id is a

            character's position in the text.

        token_byte_pos: numpy array of shape (num_examples, seq_len) where

            `token_byte_pos[example_id][token_pos]` is the byte position of

            the token at `token_pos` in the example with `example_id`.



    Returns:

        (example_id, token_pos_id) tuple.

    """
    example_id, byte_id = example_byte_id
    index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
    return (example_id, index)


def get_code_precision_and_recall(token_pos_ids, codebook_acts, cb_act_counts=None):
    """Search for the codes that activate on the given `token_pos_ids`.



    Args:

        token_pos_ids: list of (example_id, token_pos_id) tuples.

        codebook_acts: numpy array of activations of a codebook on a dataset with

            shape (num_examples, seq_len, k_codebook).

        cb_act_counts: array of shape (num_codes,) where `cb_act_counts[cb_name][code]`

            is the number of times the code `code` is activated in the dataset.



    Returns:

        codes: numpy array of code ids sorted by their precision on the given `token_pos_ids`.

        prec: numpy array where `prec[i]` is the precision of the code

            `codes[i]` for the given `token_pos_ids`.

        recall: numpy array where `recall[i]` is the recall of the code

            `codes[i]` for the given `token_pos_ids`.

        code_acts: numpy array where `code_acts[i]` is the number of times

            the code `codes[i]` is activated in the dataset.

    """
    codes = np.array(
        [
            codebook_acts[example_id][token_pos_id]
            for example_id, token_pos_id in token_pos_ids
        ]
    )
    codes, counts = np.unique(codes, return_counts=True)
    recall = counts / len(token_pos_ids)
    idx = recall > 0.01
    codes, counts, recall = codes[idx], counts[idx], recall[idx]
    if cb_act_counts is not None:
        code_acts = np.array([cb_act_counts[code] for code in codes])
        prec = counts / code_acts
        sort_idx = np.argsort(prec)[::-1]
    else:
        code_acts = np.zeros_like(codes)
        prec = np.zeros_like(codes)
        sort_idx = np.argsort(recall)[::-1]
    codes, prec, recall = codes[sort_idx], prec[sort_idx], recall[sort_idx]
    code_acts = code_acts[sort_idx]
    return codes, prec, recall, code_acts


def get_neuron_precision_and_recall(

    token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts

):
    """Get the neurons with the highest precision and recall for the given `token_pos_ids`.



    Args:

        token_pos_ids: list of token (example_id, token_pos_id) tuples from a dataset over which

            the neurons with the highest precision and recall are to be found.

        recall: recall threshold for the neurons (this determines their activation threshold).

        neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons

            on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).

            The third dimension is 2 because we consider neurons from both: attention and mlp.

        neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons

            on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).

            This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two

            dimensions to the last dimensions and then sorting the last dimension.



    Returns:

        best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.

        best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`

            based on the threshold determined by the `recall` argument.

        best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,

            `is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,

            and `neuron_id` is the neuron's index in the layer.

    """
    if isinstance(neuron_acts_by_ex, torch.Tensor):
        neuron_acts_on_pattern = torch.stack(
            [
                neuron_acts_by_ex[example_id, token_pos_id]
                for example_id, token_pos_id in token_pos_ids
            ],
            dim=-1,
        )  # (layers, 2, dim_size, matches)
        neuron_acts_on_pattern = torch.sort(neuron_acts_on_pattern, dim=-1).values
    else:
        neuron_acts_on_pattern = np.stack(
            [
                neuron_acts_by_ex[example_id, token_pos_id]
                for example_id, token_pos_id in token_pos_ids
            ],
            axis=-1,
        )  # (layers, 2, dim_size, matches)
        neuron_acts_on_pattern.sort(axis=-1)
        neuron_acts_on_pattern = torch.from_numpy(neuron_acts_on_pattern)
    act_thresh = neuron_acts_on_pattern[
        :, :, :, -int(recall * neuron_acts_on_pattern.shape[-1])
    ]
    assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
    prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
    prec_den = prec_den.squeeze(-1)
    prec_den = neuron_sorted_acts.shape[-1] - prec_den
    prec = int(recall * neuron_acts_on_pattern.shape[-1]) / prec_den
    assert (
        prec.shape == neuron_acts_on_pattern.shape[:-1]
    ), f"{prec.shape} != {neuron_acts_on_pattern.shape[:-1]}"

    best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
    best_prec = prec[best_neuron_idx]
    best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
    best_neuron_acts = neuron_acts_by_ex[
        :, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
    ]
    best_neuron_acts = best_neuron_acts >= best_neuron_act_thresh
    best_neuron_acts = np.stack(np.where(best_neuron_acts), axis=-1)

    return best_prec, best_neuron_acts, best_neuron_idx


def convert_to_adv_name(name, cb_at, gcb=""):
    """Convert layer0_head0 to layer0_attn_preproj_gcb0."""
    if gcb:
        layer, head = name.split("_")
        return layer + f"_{cb_at}_gcb" + head[4:]
    else:
        return layer + "_" + cb_at


def convert_to_base_name(name, gcb=""):
    """Convert layer0_attn_preproj_gcb0 to layer0_head0."""
    split_name = name.split("_")
    layer, head = split_name[0], split_name[-1][3:]
    if "gcb" in name:
        return layer + "_head" + head
    else:
        return layer


def get_layer_head_from_base_name(name):
    """Convert layer0_head0 to 0, 0."""
    split_name = name.split("_")
    layer = int(split_name[0][5:])
    head = None
    if len(split_name) > 1:
        head = int(split_name[-1][4:])
    return layer, head


def get_layer_head_from_adv_name(name):
    """Convert layer0_attn_preproj_gcb0 to 0, 0."""
    base_name = convert_to_base_name(name)
    layer, head = get_layer_head_from_base_name(base_name)
    return layer, head


def get_codes_from_pattern(

    re_pattern,

    tokens_text,

    token_byte_pos,

    cb_acts,

    act_count_ft_tkns,

    gcb="",

    topk=5,

    prec_threshold=0.5,

    at_odd_even=-1,

):
    """Fetch codes that activate on a given regex pattern.



    Retrieves at most `top_k` codes that activate with precision above `prec_threshold`.



    Args:

        re_pattern: regex pattern to search for.

        tokens_text: list of example texts of a dataset.

        token_byte_pos: numpy array of shape (num_examples, seq_len) where

            `token_byte_pos[example_id][token_pos]` is the byte position of

            the token at `token_pos` in the example with `example_id`.

        cb_acts: dict of codebook activations.

        act_count_ft_tkns: dict over all codebooks of number of token activations on the dataset

        gcb: "_gcb" for grouped codebooks and "" for non-grouped codebooks.

        topk: maximum number of codes to return per codebook.

        prec_threshold: minimum precision required for a code to be returned.

        at_odd_even: to limit matches to odd or even positions only.

            -1 (default): to not limit matches.

            0: to limit matches to odd positions only.

            1: to limit matches to even positions only.

            This is useful for the TokFSM dataset when searching for states

            since the first token of states are always at even positions.



    Returns:

        codebook_wise_codes: dict of codebook name to list of

        (code, prec, recall, code_acts) tuples.

        re_token_matches: number of tokens that match the regex pattern.

    """
    byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
    token_pos_ids = [
        byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
    ]
    token_pos_ids = np.unique(token_pos_ids, axis=0)
    re_token_matches = len(token_pos_ids)
    codebook_wise_codes = {}
    for cb_name, cb in tqdm(cb_acts.items()):
        base_cb_name = convert_to_base_name(cb_name, gcb=gcb)
        codes, prec, recall, code_acts = get_code_precision_and_recall(
            token_pos_ids,
            cb,
            cb_act_counts=act_count_ft_tkns[base_cb_name],
        )
        idx = np.arange(min(topk, len(codes)))
        idx = idx[prec[:topk] > prec_threshold]
        codes, prec, recall = codes[idx], prec[idx], recall[idx]
        code_acts = code_acts[idx]
        codes_pr = list(zip(codes, prec, recall, code_acts))
        codebook_wise_codes[base_cb_name] = codes_pr
    return codebook_wise_codes, re_token_matches


def get_neurons_from_pattern(

    re_pattern,

    tokens_text,

    token_byte_pos,

    neuron_acts_by_ex,

    neuron_sorted_acts,

    recall_threshold,

    at_odd_even=-1,

):
    """Fetch the highest precision neurons that activate on a given regex pattern.



    The activation threshold for the neurons is determined by the `recall_threshold`.



    Args:

        re_pattern: regex pattern to search for.

        tokens_text: list of example texts of a dataset.

        token_byte_pos: numpy array of shape (num_examples, seq_len) where

            `token_byte_pos[example_id][token_pos]` is the byte position of

            the token at `token_pos` in the example with `example_id`.

        neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons

            on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).

            The third dimension is 2 because we consider neurons from both: attention and mlp.

        neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons

            on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).

            This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two

            dimensions to the last dimensions and then sorting the last dimension.

        recall_threshold: recall threshold for the neurons (this determines their activation threshold).

        at_odd_even: to limit matches to odd or even positions only.

            -1 (default): to not limit matches.

            0: to limit matches to odd positions only.

            1: to limit matches to even positions only.

            This is useful for the TokFSM dataset when searching for states

            since the first token of states are always at even positions.



    Returns:

        best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.

        best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`

            based on the threshold determined by the `recall` argument.

        best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,

            `is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,

            and `neuron_id` is the neuron's index in the layer.

        re_token_matches: number of tokens that match the regex pattern.

    """
    byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
    token_pos_ids = [
        byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
    ]
    token_pos_ids = np.unique(token_pos_ids, axis=0)
    re_token_matches = len(token_pos_ids)
    best_prec, best_neuron_acts, best_neuron_idx = get_neuron_precision_and_recall(
        token_pos_ids,
        recall_threshold,
        neuron_acts_by_ex,
        neuron_sorted_acts,
    )
    return best_prec, best_neuron_acts, best_neuron_idx, re_token_matches


def compare_codes_with_neurons(

    best_codes_info,

    tokens_text,

    token_byte_pos,

    neuron_acts_by_ex,

    neuron_sorted_acts,

    at_odd_even=-1,

):
    """Compare codes with the highest precision neurons on the regex pattern of the code.



    Args:

        best_codes_info: list of CodeInfo objects.

        tokens_text: list of example texts of a dataset.

        token_byte_pos: numpy array of shape (num_examples, seq_len) where

            `token_byte_pos[example_id][token_pos]` is the byte position of

            the token at `token_pos` in the example with `example_id`.

        neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons

            on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).

            The third dimension is 2 because we consider neurons from both: attention and mlp.

        neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons

            on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).

            This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two

            dimensions to the last dimensions and then sorting the last dimension.

        at_odd_even: to limit matches to odd or even positions only.

            -1 (default): to not limit matches.

            0: to limit matches to odd positions only.

            1: to limit matches to even positions only.

            This is useful for the TokFSM dataset when searching for states

            since the first token of states are always at even positions.



    Returns:

        codes_better_than_neurons: fraction of codes that have higher precision than the highest

            precision neuron on the regex pattern of the code.

        code_best_precs: is an array of the precision of each code in `best_codes_info`.

        all_best_prec: is an array of the highest precision neurons on the regex pattern.

    """
    assert isinstance(neuron_acts_by_ex, np.ndarray)
    (
        neuron_best_prec,
        all_best_neuron_acts,
        all_best_neuron_idxs,
        all_re_token_matches,
    ) = zip(
        *[
            get_neurons_from_pattern(
                code_info.regex,
                tokens_text,
                token_byte_pos,
                neuron_acts_by_ex,
                neuron_sorted_acts,
                code_info.recall,
                at_odd_even=at_odd_even,
            )
            for code_info in tqdm(best_codes_info)
        ],
        strict=True,
    )
    neuron_best_prec = np.array(neuron_best_prec)
    code_best_precs = np.array([code_info.prec for code_info in best_codes_info])
    codes_better_than_neurons = code_best_precs > neuron_best_prec
    return codes_better_than_neurons.mean(), code_best_precs, neuron_best_prec