Spaces:
Running
Running
Update code
Browse files- .gitignore +3 -0
- Code_Browser.py +180 -140
- README.md +2 -2
- code_search_utils.py +201 -97
- pages/Concept_Code.py +5 -17
- utils.py +187 -232
- webapp_utils.py +21 -9
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
hgf_webapp/
|
3 |
+
.vscode/
|
Code_Browser.py
CHANGED
@@ -1,15 +1,38 @@
|
|
1 |
"""Web App for the Codebook Features project."""
|
2 |
|
|
|
3 |
import glob
|
4 |
import os
|
5 |
|
6 |
import streamlit as st
|
7 |
|
8 |
import code_search_utils
|
|
|
9 |
import webapp_utils
|
10 |
|
11 |
-
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
webapp_utils.load_widget_state()
|
15 |
|
@@ -20,14 +43,17 @@ st.set_page_config(
|
|
20 |
|
21 |
st.title("Codebook Features")
|
22 |
|
|
|
|
|
23 |
pretty_model_names = {
|
24 |
"TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP",
|
25 |
-
"TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories
|
26 |
-
"TinyStories-33M_ccb_attn_preproj": "TinyStories
|
|
|
27 |
}
|
28 |
orig_model_name = {v: k for k, v in pretty_model_names.items()}
|
29 |
|
30 |
-
base_cache_dir =
|
31 |
dirs = glob.glob(base_cache_dir + "models/*/")
|
32 |
model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
|
33 |
model_name_options = ["_".join(m) for m in model_name_options]
|
@@ -41,25 +67,23 @@ p_model_name = st.selectbox(
|
|
41 |
key=webapp_utils.persist("model_name"),
|
42 |
)
|
43 |
model_name = orig_model_name.get(p_model_name, p_model_name)
|
44 |
-
|
45 |
-
ccb = model_name.split("_")[1]
|
46 |
-
ccb = "_ccb" if ccb == "ccb" else ""
|
47 |
-
cb_at = "_".join(model_name.split("_")[2:])
|
48 |
-
seq_len = 512 if "tinystories" in model_name.lower() else 1024
|
49 |
-
st.session_state["seq_len"] = seq_len
|
50 |
|
51 |
codes_cache_path = base_cache_dir + f"models/{model_name}_*"
|
52 |
dirs = glob.glob(codes_cache_path)
|
53 |
dirs.sort(key=os.path.getmtime)
|
54 |
|
55 |
# session states
|
56 |
-
is_attn = "attn" in cb_at
|
57 |
codes_cache_path = dirs[-1] + "/"
|
58 |
|
59 |
-
model_info =
|
60 |
num_codes = model_info.num_codes
|
61 |
num_layers = model_info.n_layers
|
62 |
num_heads = model_info.n_heads
|
|
|
|
|
|
|
|
|
63 |
dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
|
64 |
|
65 |
(
|
@@ -70,9 +94,12 @@ dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
|
|
70 |
act_count_ft_tkns,
|
71 |
metrics,
|
72 |
) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path)
|
|
|
73 |
metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"]
|
74 |
metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys}
|
75 |
|
|
|
|
|
76 |
st.session_state["model_name_id"] = model_name
|
77 |
st.session_state["cb_acts"] = cb_acts
|
78 |
st.session_state["tokens_text"] = tokens_text
|
@@ -80,11 +107,13 @@ st.session_state["tokens_str"] = tokens_str
|
|
80 |
st.session_state["act_count_ft_tkns"] = act_count_ft_tkns
|
81 |
|
82 |
st.session_state["num_codes"] = num_codes
|
83 |
-
st.session_state["
|
84 |
st.session_state["cb_at"] = cb_at
|
85 |
st.session_state["is_attn"] = is_attn
|
|
|
86 |
|
87 |
-
|
|
|
88 |
st.markdown("## Metrics")
|
89 |
# hide metrics by default
|
90 |
if st.checkbox("Show Model Metrics"):
|
@@ -93,7 +122,7 @@ if not DEPLOY_MODE:
|
|
93 |
st.markdown("## Demo Codes")
|
94 |
demo_codes_desc = (
|
95 |
"This section contains codes that we've found to be interpretable along "
|
96 |
-
"with a description of the feature we think they are capturing."
|
97 |
"Click on the π search button for a code to see the tokens that code activates on."
|
98 |
)
|
99 |
st.write(demo_codes_desc)
|
@@ -144,7 +173,7 @@ if st.checkbox("Show Demo Codes"):
|
|
144 |
continue
|
145 |
if skip:
|
146 |
continue
|
147 |
-
code_info =
|
148 |
comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}"
|
149 |
button_key = (
|
150 |
f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}"
|
@@ -167,150 +196,160 @@ if st.checkbox("Show Demo Codes"):
|
|
167 |
cols[-1].write(code_desc)
|
168 |
skip = True
|
169 |
|
|
|
170 |
|
171 |
st.markdown("## Code Search")
|
172 |
-
|
173 |
-
|
174 |
-
"
|
175 |
-
|
176 |
-
|
|
|
|
|
177 |
)
|
178 |
-
# topk = st.slider("Top K", 1, 20, 10)
|
179 |
-
prec_col, sort_col = st.columns(2)
|
180 |
-
prec_threshold = prec_col.slider(
|
181 |
-
"Precision Threshold",
|
182 |
-
0.0,
|
183 |
-
1.0,
|
184 |
-
0.9,
|
185 |
-
help="Shows codes with precision on the regex pattern above the threshold.",
|
186 |
-
)
|
187 |
-
sort_by_options = ["Precision", "Recall", "Num Acts"]
|
188 |
-
sort_by_name = sort_col.radio(
|
189 |
-
"Sort By",
|
190 |
-
sort_by_options,
|
191 |
-
index=0,
|
192 |
-
horizontal=True,
|
193 |
-
help="Sorts the codes by the selected metric.",
|
194 |
-
)
|
195 |
-
sort_by = sort_by_options.index(sort_by_name)
|
196 |
-
|
197 |
-
|
198 |
-
@st.cache_data(ttl=3600)
|
199 |
-
def get_codebook_wise_codes_for_regex(regex_pattern, prec_threshold, ccb, model_name):
|
200 |
-
"""Get codebook wise codes for a given regex pattern."""
|
201 |
-
assert model_name is not None # required for loading from correct cache data
|
202 |
-
return code_search_utils.get_codes_from_pattern(
|
203 |
-
regex_pattern,
|
204 |
-
tokens_text,
|
205 |
-
token_byte_pos,
|
206 |
-
cb_acts,
|
207 |
-
act_count_ft_tkns,
|
208 |
-
ccb=ccb,
|
209 |
-
topk=8,
|
210 |
-
prec_threshold=prec_threshold,
|
211 |
-
)
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
model_name,
|
220 |
)
|
221 |
-
st.
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
224 |
)
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
# st.markdown(button_height_style, unsafe_allow_html=True)
|
234 |
-
|
235 |
-
cols[0].markdown("Search", help="Button to see token activations for the code.")
|
236 |
-
cols[1].write("Layer")
|
237 |
-
if is_attn:
|
238 |
-
cols[2].write("Head")
|
239 |
-
cols[-4 - non_deploy_offset].write("Code")
|
240 |
-
cols[-3 - non_deploy_offset].write("Precision")
|
241 |
-
cols[-2 - non_deploy_offset].write("Recall")
|
242 |
-
cols[-1 - non_deploy_offset].markdown(
|
243 |
-
"Num Acts",
|
244 |
-
help="Number of tokens that the code activates on in the acts dataset.",
|
245 |
)
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
f"head{head}" if head is not None else ""
|
264 |
)
|
265 |
-
cols = st.columns(num_search_cols)
|
266 |
-
extra_args = {
|
267 |
-
"prec": prec,
|
268 |
-
"recall": rec,
|
269 |
-
"num_acts": code_acts,
|
270 |
-
"regex": regex_pattern,
|
271 |
-
}
|
272 |
-
button_clicked = cols[0].button("π", key=button_key)
|
273 |
-
if button_clicked:
|
274 |
-
webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn)
|
275 |
-
cols[1].write(layer)
|
276 |
-
if is_attn:
|
277 |
-
cols[2].write(head)
|
278 |
-
cols[-4 - non_deploy_offset].write(code)
|
279 |
-
cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%")
|
280 |
-
cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%")
|
281 |
-
cols[-1 - non_deploy_offset].write(str(code_acts))
|
282 |
-
if not DEPLOY_MODE:
|
283 |
-
webapp_utils.add_save_code_button(
|
284 |
-
demo_file_path,
|
285 |
-
num_acts=code_acts,
|
286 |
-
save_regex=True,
|
287 |
-
prec=prec,
|
288 |
-
recall=rec,
|
289 |
-
button_st_container=cols[-1],
|
290 |
-
button_key_suffix=f"_code{code}_layer{layer}_head{head}",
|
291 |
-
)
|
292 |
|
293 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
st.markdown(
|
295 |
-
f""
|
296 |
-
<div style="font-size: 1.0rem; color: red;">
|
297 |
-
No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold}
|
298 |
-
</div>
|
299 |
-
""",
|
300 |
unsafe_allow_html=True,
|
301 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
|
|
303 |
|
304 |
st.markdown("## Code Token Activations")
|
305 |
|
306 |
-
filter_codes = st.checkbox("Show filters", key="filter_codes")
|
307 |
act_range, layer_code_acts = None, None
|
308 |
if filter_codes:
|
309 |
act_range = st.slider(
|
310 |
-
"
|
311 |
0,
|
312 |
10_000,
|
313 |
-
|
314 |
key="ct_act_range",
|
315 |
help="Filter codes by the number of tokens they activate on.",
|
316 |
)
|
@@ -361,6 +400,7 @@ acts, acts_count = webapp_utils.get_code_acts(
|
|
361 |
head,
|
362 |
ctx_size,
|
363 |
num_examples,
|
|
|
364 |
)
|
365 |
|
366 |
st.write(
|
@@ -368,7 +408,7 @@ st.write(
|
|
368 |
f"Activates on {acts_count[0]} tokens on the acts dataset",
|
369 |
)
|
370 |
|
371 |
-
if not
|
372 |
webapp_utils.add_save_code_button(
|
373 |
demo_file_path,
|
374 |
acts_count[0],
|
|
|
1 |
"""Web App for the Codebook Features project."""
|
2 |
|
3 |
+
import argparse
|
4 |
import glob
|
5 |
import os
|
6 |
|
7 |
import streamlit as st
|
8 |
|
9 |
import code_search_utils
|
10 |
+
import utils
|
11 |
import webapp_utils
|
12 |
|
13 |
+
# --- Parse command line arguments ---
|
14 |
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
"--deploy",
|
18 |
+
default=True,
|
19 |
+
help="Deploy mode.",
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--cache_dir",
|
23 |
+
type=str,
|
24 |
+
default="cache/",
|
25 |
+
help="Path to directory containing cache for codebook models.",
|
26 |
+
)
|
27 |
+
try:
|
28 |
+
args = parser.parse_args()
|
29 |
+
except SystemExit as e:
|
30 |
+
# This exception will be raised if --help or invalid command line arguments
|
31 |
+
# are used. Currently streamlit prevents the program from exiting normally
|
32 |
+
# so we have to do a hard exit.
|
33 |
+
os._exit(e.code if isinstance(e.code, int) else 1)
|
34 |
+
|
35 |
+
deploy = args.deploy
|
36 |
|
37 |
webapp_utils.load_widget_state()
|
38 |
|
|
|
43 |
|
44 |
st.title("Codebook Features")
|
45 |
|
46 |
+
# --- Load model info and cache ---
|
47 |
+
|
48 |
pretty_model_names = {
|
49 |
"TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP",
|
50 |
+
"TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories 1 Layer Attention Codebook",
|
51 |
+
"TinyStories-33M_ccb_attn_preproj": "TinyStories 4 Layer Attention Codebook",
|
52 |
+
"TinyStories-1Layer-21M_vcb_mlp": "TinyStories 1 Layer MLP Codebook",
|
53 |
}
|
54 |
orig_model_name = {v: k for k, v in pretty_model_names.items()}
|
55 |
|
56 |
+
base_cache_dir = args.cache_dir
|
57 |
dirs = glob.glob(base_cache_dir + "models/*/")
|
58 |
model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs]
|
59 |
model_name_options = ["_".join(m) for m in model_name_options]
|
|
|
67 |
key=webapp_utils.persist("model_name"),
|
68 |
)
|
69 |
model_name = orig_model_name.get(p_model_name, p_model_name)
|
70 |
+
is_fsm = "FSM" in p_model_name
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
codes_cache_path = base_cache_dir + f"models/{model_name}_*"
|
73 |
dirs = glob.glob(codes_cache_path)
|
74 |
dirs.sort(key=os.path.getmtime)
|
75 |
|
76 |
# session states
|
|
|
77 |
codes_cache_path = dirs[-1] + "/"
|
78 |
|
79 |
+
model_info = utils.ModelInfoForWebapp.load(codes_cache_path)
|
80 |
num_codes = model_info.num_codes
|
81 |
num_layers = model_info.n_layers
|
82 |
num_heads = model_info.n_heads
|
83 |
+
cb_at = model_info.cb_at
|
84 |
+
gcb = model_info.gcb
|
85 |
+
gcb = "_gcb" if gcb else ""
|
86 |
+
is_attn = "attn" in cb_at
|
87 |
dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/"
|
88 |
|
89 |
(
|
|
|
94 |
act_count_ft_tkns,
|
95 |
metrics,
|
96 |
) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path)
|
97 |
+
seq_len = len(tokens_str[0])
|
98 |
metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"]
|
99 |
metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys}
|
100 |
|
101 |
+
# --- Set the session states ---
|
102 |
+
|
103 |
st.session_state["model_name_id"] = model_name
|
104 |
st.session_state["cb_acts"] = cb_acts
|
105 |
st.session_state["tokens_text"] = tokens_text
|
|
|
107 |
st.session_state["act_count_ft_tkns"] = act_count_ft_tkns
|
108 |
|
109 |
st.session_state["num_codes"] = num_codes
|
110 |
+
st.session_state["gcb"] = gcb
|
111 |
st.session_state["cb_at"] = cb_at
|
112 |
st.session_state["is_attn"] = is_attn
|
113 |
+
st.session_state["seq_len"] = seq_len
|
114 |
|
115 |
+
|
116 |
+
if not deploy:
|
117 |
st.markdown("## Metrics")
|
118 |
# hide metrics by default
|
119 |
if st.checkbox("Show Model Metrics"):
|
|
|
122 |
st.markdown("## Demo Codes")
|
123 |
demo_codes_desc = (
|
124 |
"This section contains codes that we've found to be interpretable along "
|
125 |
+
"with a description of the feature we think they are capturing. "
|
126 |
"Click on the π search button for a code to see the tokens that code activates on."
|
127 |
)
|
128 |
st.write(demo_codes_desc)
|
|
|
173 |
continue
|
174 |
if skip:
|
175 |
continue
|
176 |
+
code_info = utils.CodeInfo.from_str(code_txt, regex=code_regex)
|
177 |
comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}"
|
178 |
button_key = (
|
179 |
f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}"
|
|
|
196 |
cols[-1].write(code_desc)
|
197 |
skip = True
|
198 |
|
199 |
+
# --- Code Search ---
|
200 |
|
201 |
st.markdown("## Code Search")
|
202 |
+
code_search_desc = (
|
203 |
+
"If you want to find whether the codebooks model has captured a relevant features from the data,"
|
204 |
+
" you can specify a regex pattern for your feature and find whether any code activating on the regex pattern"
|
205 |
+
" exists. The first group in the regex pattern is the token that the code activates on. If the group contains"
|
206 |
+
" multiple tokens, we search for codes that will activate on the first token in the group followed by the"
|
207 |
+
" subsequent tokens in the group. For example, the search term 'New (York)' will try to find codes that"
|
208 |
+
" activate on the bigram feature 'New York' at the York token."
|
209 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
+
if st.checkbox("Search with Regex"):
|
212 |
+
st.write(code_search_desc)
|
213 |
+
regex_pattern = st.text_input(
|
214 |
+
"Enter a regex pattern",
|
215 |
+
help="Wrap code token in the first group. E.g. New (York)",
|
216 |
+
key="regex_pattern",
|
|
|
217 |
)
|
218 |
+
# topk = st.slider("Top K", 1, 20, 10)
|
219 |
+
prec_col, sort_col = st.columns(2)
|
220 |
+
prec_threshold = prec_col.slider(
|
221 |
+
"Precision Threshold",
|
222 |
+
0.0,
|
223 |
+
1.0,
|
224 |
+
0.9,
|
225 |
+
help="Shows codes with precision on the regex pattern above the threshold.",
|
226 |
)
|
227 |
+
sort_by_options = ["Precision", "Recall", "Num Acts"]
|
228 |
+
sort_by_name = sort_col.radio(
|
229 |
+
"Sort By",
|
230 |
+
sort_by_options,
|
231 |
+
index=0,
|
232 |
+
horizontal=True,
|
233 |
+
help="Sorts the codes by the selected metric.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
)
|
235 |
+
sort_by = sort_by_options.index(sort_by_name)
|
236 |
+
|
237 |
+
@st.cache_data(ttl=3600)
|
238 |
+
def get_codebook_wise_codes_for_regex(
|
239 |
+
regex_pattern, prec_threshold, gcb, model_name
|
240 |
+
):
|
241 |
+
"""Get codebook wise codes for a given regex pattern."""
|
242 |
+
assert model_name is not None # required for loading from correct cache data
|
243 |
+
return code_search_utils.get_codes_from_pattern(
|
244 |
+
regex_pattern,
|
245 |
+
tokens_text,
|
246 |
+
token_byte_pos,
|
247 |
+
cb_acts,
|
248 |
+
act_count_ft_tkns,
|
249 |
+
gcb=gcb,
|
250 |
+
topk=8,
|
251 |
+
prec_threshold=prec_threshold,
|
|
|
252 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
if regex_pattern:
|
255 |
+
codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex(
|
256 |
+
regex_pattern,
|
257 |
+
prec_threshold,
|
258 |
+
gcb,
|
259 |
+
model_name,
|
260 |
+
)
|
261 |
st.markdown(
|
262 |
+
f"Found <span style='color:green;'>{re_token_matches}</span> matches",
|
|
|
|
|
|
|
|
|
263 |
unsafe_allow_html=True,
|
264 |
)
|
265 |
+
num_search_cols = 7 if is_attn else 6
|
266 |
+
non_deploy_offset = 0
|
267 |
+
if not deploy:
|
268 |
+
non_deploy_offset = 1
|
269 |
+
num_search_cols += non_deploy_offset
|
270 |
+
|
271 |
+
cols = st.columns(num_search_cols)
|
272 |
+
|
273 |
+
cols[0].markdown("Search", help="Button to see token activations for the code.")
|
274 |
+
cols[1].write("Layer")
|
275 |
+
if is_attn:
|
276 |
+
cols[2].write("Head")
|
277 |
+
cols[-4 - non_deploy_offset].write("Code")
|
278 |
+
cols[-3 - non_deploy_offset].write("Precision")
|
279 |
+
cols[-2 - non_deploy_offset].write("Recall")
|
280 |
+
cols[-1 - non_deploy_offset].markdown(
|
281 |
+
"Num Acts",
|
282 |
+
help="Number of tokens that the code activates on in the acts dataset.",
|
283 |
+
)
|
284 |
+
if not deploy:
|
285 |
+
cols[-1].markdown(
|
286 |
+
"Save to Demos",
|
287 |
+
help="Button to save the code to demos along with the regex pattern.",
|
288 |
+
)
|
289 |
+
all_codes = codebook_wise_codes.items()
|
290 |
+
all_codes = [
|
291 |
+
(cb_name, code_pr_info)
|
292 |
+
for cb_name, code_pr_infos in all_codes
|
293 |
+
for code_pr_info in code_pr_infos
|
294 |
+
]
|
295 |
+
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
|
296 |
+
for cb_name, (code, prec, rec, code_acts) in all_codes:
|
297 |
+
layer_head = cb_name.split("_")
|
298 |
+
layer = layer_head[0][5:]
|
299 |
+
head = layer_head[1][4:] if len(layer_head) > 1 else None
|
300 |
+
button_key = f"search_code{code}_layer{layer}" + (
|
301 |
+
f"head{head}" if head is not None else ""
|
302 |
+
)
|
303 |
+
cols = st.columns(num_search_cols)
|
304 |
+
extra_args = {
|
305 |
+
"prec": prec,
|
306 |
+
"recall": rec,
|
307 |
+
"num_acts": code_acts,
|
308 |
+
"regex": regex_pattern,
|
309 |
+
}
|
310 |
+
button_clicked = cols[0].button("π", key=button_key)
|
311 |
+
if button_clicked:
|
312 |
+
webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn)
|
313 |
+
cols[1].write(layer)
|
314 |
+
if is_attn:
|
315 |
+
cols[2].write(head)
|
316 |
+
cols[-4 - non_deploy_offset].write(code)
|
317 |
+
cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%")
|
318 |
+
cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%")
|
319 |
+
cols[-1 - non_deploy_offset].write(str(code_acts))
|
320 |
+
if not deploy:
|
321 |
+
webapp_utils.add_save_code_button(
|
322 |
+
demo_file_path,
|
323 |
+
num_acts=code_acts,
|
324 |
+
save_regex=True,
|
325 |
+
prec=prec,
|
326 |
+
recall=rec,
|
327 |
+
button_st_container=cols[-1],
|
328 |
+
button_key_suffix=f"_code{code}_layer{layer}_head{head}",
|
329 |
+
)
|
330 |
+
|
331 |
+
if len(all_codes) == 0:
|
332 |
+
st.markdown(
|
333 |
+
f"""
|
334 |
+
<div style="font-size: 1.0rem; color: red;">
|
335 |
+
No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold}
|
336 |
+
</div>
|
337 |
+
""",
|
338 |
+
unsafe_allow_html=True,
|
339 |
+
)
|
340 |
|
341 |
+
# --- Display Code Token Activations ---
|
342 |
|
343 |
st.markdown("## Code Token Activations")
|
344 |
|
345 |
+
filter_codes = st.checkbox("Show filters", key="filter_codes", value=True)
|
346 |
act_range, layer_code_acts = None, None
|
347 |
if filter_codes:
|
348 |
act_range = st.slider(
|
349 |
+
"Minimum number of activations",
|
350 |
0,
|
351 |
10_000,
|
352 |
+
100,
|
353 |
key="ct_act_range",
|
354 |
help="Filter codes by the number of tokens they activate on.",
|
355 |
)
|
|
|
400 |
head,
|
401 |
ctx_size,
|
402 |
num_examples,
|
403 |
+
is_fsm=is_fsm,
|
404 |
)
|
405 |
|
406 |
st.write(
|
|
|
408 |
f"Activates on {acts_count[0]} tokens on the acts dataset",
|
409 |
)
|
410 |
|
411 |
+
if not deploy:
|
412 |
webapp_utils.add_save_code_button(
|
413 |
demo_file_path,
|
414 |
acts_count[0],
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Codebook Features
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.25.0
|
8 |
app_file: Code_Browser.py
|
|
|
1 |
---
|
2 |
title: Codebook Features
|
3 |
+
emoji: π
|
4 |
colorFrom: gray
|
5 |
+
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.25.0
|
8 |
app_file: Code_Browser.py
|
code_search_utils.py
CHANGED
@@ -2,15 +2,11 @@
|
|
2 |
|
3 |
import pickle
|
4 |
import re
|
5 |
-
from dataclasses import dataclass
|
6 |
-
from typing import Optional
|
7 |
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
from tqdm import tqdm
|
11 |
|
12 |
-
import utils
|
13 |
-
|
14 |
|
15 |
def load_dataset_cache(cache_base_path):
|
16 |
"""Load cache files required for dataset from `cache_base_path`."""
|
@@ -31,28 +27,73 @@ def load_code_search_cache(cache_base_path):
|
|
31 |
return cb_acts, act_count_ft_tkns, metrics
|
32 |
|
33 |
|
34 |
-
def search_re(re_pattern, tokens_text):
|
35 |
-
"""Get list of (example_id, token_pos) where re_pattern matches in tokens_text.
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if re_pattern.find("(") == -1:
|
38 |
re_pattern = f"({re_pattern})"
|
39 |
-
|
40 |
(i, finditer.span(1)[0])
|
41 |
for i, text in enumerate(tokens_text)
|
42 |
for finditer in re.finditer(re_pattern, text)
|
43 |
if finditer.span(1)[0] != finditer.span(1)[1]
|
44 |
]
|
|
|
|
|
|
|
45 |
|
46 |
|
47 |
def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
|
48 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
example_id, byte_id = example_byte_id
|
50 |
index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
|
51 |
return (example_id, index)
|
52 |
|
53 |
|
54 |
-
def
|
55 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
codes = np.array(
|
57 |
[
|
58 |
codebook_acts[example_id][token_pos_id]
|
@@ -76,46 +117,64 @@ def get_code_pr(token_pos_ids, codebook_acts, cb_act_counts=None):
|
|
76 |
return codes, prec, recall, code_acts
|
77 |
|
78 |
|
79 |
-
def
|
80 |
-
token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts
|
81 |
):
|
82 |
-
"""Get
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
if isinstance(neuron_acts_by_ex, torch.Tensor):
|
85 |
-
|
86 |
[
|
87 |
neuron_acts_by_ex[example_id, token_pos_id]
|
88 |
for example_id, token_pos_id in token_pos_ids
|
89 |
],
|
90 |
dim=-1,
|
91 |
) # (layers, 2, dim_size, matches)
|
92 |
-
|
93 |
else:
|
94 |
-
|
95 |
[
|
96 |
neuron_acts_by_ex[example_id, token_pos_id]
|
97 |
for example_id, token_pos_id in token_pos_ids
|
98 |
],
|
99 |
axis=-1,
|
100 |
) # (layers, 2, dim_size, matches)
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
# binary search act_thresh in neuron_sorted_acts
|
107 |
assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
|
108 |
prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
|
109 |
prec_den = prec_den.squeeze(-1)
|
110 |
prec_den = neuron_sorted_acts.shape[-1] - prec_den
|
111 |
-
prec = int(recall *
|
112 |
assert (
|
113 |
-
prec.shape ==
|
114 |
-
), f"{prec.shape} != {
|
115 |
|
116 |
best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
|
117 |
best_prec = prec[best_neuron_idx]
|
118 |
-
print("max prec:", best_prec)
|
119 |
best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
|
120 |
best_neuron_acts = neuron_acts_by_ex[
|
121 |
:, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
|
@@ -126,20 +185,20 @@ def get_neuron_pr(
|
|
126 |
return best_prec, best_neuron_acts, best_neuron_idx
|
127 |
|
128 |
|
129 |
-
def convert_to_adv_name(name, cb_at,
|
130 |
-
"""Convert layer0_head0 to
|
131 |
-
if
|
132 |
layer, head = name.split("_")
|
133 |
-
return layer + f"_{cb_at}
|
134 |
else:
|
135 |
return layer + "_" + cb_at
|
136 |
|
137 |
|
138 |
-
def convert_to_base_name(name,
|
139 |
-
"""Convert
|
140 |
split_name = name.split("_")
|
141 |
layer, head = split_name[0], split_name[-1][3:]
|
142 |
-
if "
|
143 |
return layer + "_head" + head
|
144 |
else:
|
145 |
return layer
|
@@ -156,7 +215,7 @@ def get_layer_head_from_base_name(name):
|
|
156 |
|
157 |
|
158 |
def get_layer_head_from_adv_name(name):
|
159 |
-
"""Convert
|
160 |
base_name = convert_to_base_name(name)
|
161 |
layer, head = get_layer_head_from_base_name(base_name)
|
162 |
return layer, head
|
@@ -168,12 +227,39 @@ def get_codes_from_pattern(
|
|
168 |
token_byte_pos,
|
169 |
cb_acts,
|
170 |
act_count_ft_tkns,
|
171 |
-
|
172 |
topk=5,
|
173 |
prec_threshold=0.5,
|
|
|
174 |
):
|
175 |
-
"""Fetch codes
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
token_pos_ids = [
|
178 |
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
|
179 |
]
|
@@ -181,8 +267,8 @@ def get_codes_from_pattern(
|
|
181 |
re_token_matches = len(token_pos_ids)
|
182 |
codebook_wise_codes = {}
|
183 |
for cb_name, cb in tqdm(cb_acts.items()):
|
184 |
-
base_cb_name = convert_to_base_name(cb_name,
|
185 |
-
codes, prec, recall, code_acts =
|
186 |
token_pos_ids,
|
187 |
cb,
|
188 |
cb_act_counts=act_count_ft_tkns[base_cb_name],
|
@@ -203,15 +289,49 @@ def get_neurons_from_pattern(
|
|
203 |
neuron_acts_by_ex,
|
204 |
neuron_sorted_acts,
|
205 |
recall_threshold,
|
|
|
206 |
):
|
207 |
-
"""Fetch the
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
token_pos_ids = [
|
210 |
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
|
211 |
]
|
212 |
token_pos_ids = np.unique(token_pos_ids, axis=0)
|
213 |
re_token_matches = len(token_pos_ids)
|
214 |
-
best_prec, best_neuron_acts, best_neuron_idx =
|
215 |
token_pos_ids,
|
216 |
recall_threshold,
|
217 |
neuron_acts_by_ex,
|
@@ -226,74 +346,58 @@ def compare_codes_with_neurons(
|
|
226 |
token_byte_pos,
|
227 |
neuron_acts_by_ex,
|
228 |
neuron_sorted_acts,
|
|
|
229 |
):
|
230 |
-
"""Compare codes with neurons.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
assert isinstance(neuron_acts_by_ex, np.ndarray)
|
232 |
(
|
233 |
-
|
234 |
all_best_neuron_acts,
|
235 |
all_best_neuron_idxs,
|
236 |
all_re_token_matches,
|
237 |
) = zip(
|
238 |
*[
|
239 |
get_neurons_from_pattern(
|
240 |
-
code_info.
|
241 |
tokens_text,
|
242 |
token_byte_pos,
|
243 |
neuron_acts_by_ex,
|
244 |
neuron_sorted_acts,
|
245 |
code_info.recall,
|
|
|
246 |
)
|
247 |
-
for code_info in tqdm(
|
248 |
],
|
249 |
strict=True,
|
250 |
)
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
codes_better_than_neurons
|
255 |
-
return codes_better_than_neurons.mean()
|
256 |
-
|
257 |
-
|
258 |
-
def get_code_info_pr_from_str(code_txt, regex):
|
259 |
-
"""Extract code info fields from string."""
|
260 |
-
code_txt = code_txt.strip()
|
261 |
-
code_txt = code_txt.split(", ")
|
262 |
-
code_txt = dict(txt.split(": ") for txt in code_txt)
|
263 |
-
return utils.CodeInfo(**code_txt)
|
264 |
-
|
265 |
-
|
266 |
-
@dataclass
|
267 |
-
class ModelInfoForWebapp:
|
268 |
-
"""Model info for webapp."""
|
269 |
-
|
270 |
-
model_name: str
|
271 |
-
pretrained_path: str
|
272 |
-
dataset_name: str
|
273 |
-
num_codes: int
|
274 |
-
cb_at: str
|
275 |
-
ccb: str
|
276 |
-
n_layers: int
|
277 |
-
n_heads: Optional[int] = None
|
278 |
-
seed: int = 42
|
279 |
-
max_samples: int = 2000
|
280 |
-
|
281 |
-
def __post_init__(self):
|
282 |
-
"""Convert to correct types."""
|
283 |
-
self.num_codes = int(self.num_codes)
|
284 |
-
self.n_layers = int(self.n_layers)
|
285 |
-
if self.n_heads == "None":
|
286 |
-
self.n_heads = None
|
287 |
-
elif self.n_heads is not None:
|
288 |
-
self.n_heads = int(self.n_heads)
|
289 |
-
self.seed = int(self.seed)
|
290 |
-
self.max_samples = int(self.max_samples)
|
291 |
-
|
292 |
-
|
293 |
-
def parse_model_info(path):
|
294 |
-
"""Parse model info from path."""
|
295 |
-
with open(path + "info.txt", "r") as f:
|
296 |
-
lines = f.readlines()
|
297 |
-
lines = dict(line.strip().split(": ") for line in lines)
|
298 |
-
return ModelInfoForWebapp(**lines)
|
299 |
-
return ModelInfoForWebapp(**lines)
|
|
|
2 |
|
3 |
import pickle
|
4 |
import re
|
|
|
|
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
from tqdm import tqdm
|
9 |
|
|
|
|
|
10 |
|
11 |
def load_dataset_cache(cache_base_path):
|
12 |
"""Load cache files required for dataset from `cache_base_path`."""
|
|
|
27 |
return cb_acts, act_count_ft_tkns, metrics
|
28 |
|
29 |
|
30 |
+
def search_re(re_pattern, tokens_text, at_odd_even=-1):
|
31 |
+
"""Get list of (example_id, token_pos) where re_pattern matches in tokens_text.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
re_pattern: regex pattern to search for.
|
35 |
+
tokens_text: list of example texts.
|
36 |
+
at_odd_even: to limit matches to odd or even positions only.
|
37 |
+
-1 (default): to not limit matches.
|
38 |
+
0: to limit matches to odd positions only.
|
39 |
+
1: to limit matches to even positions only.
|
40 |
+
This is useful for the TokFSM dataset when searching for states
|
41 |
+
since the first token of states are always at even positions.
|
42 |
+
"""
|
43 |
+
# TODO: ensure that parentheses are not escaped
|
44 |
+
assert at_odd_even in [-1, 0, 1], f"Invalid at_odd_even: {at_odd_even}"
|
45 |
if re_pattern.find("(") == -1:
|
46 |
re_pattern = f"({re_pattern})"
|
47 |
+
res = [
|
48 |
(i, finditer.span(1)[0])
|
49 |
for i, text in enumerate(tokens_text)
|
50 |
for finditer in re.finditer(re_pattern, text)
|
51 |
if finditer.span(1)[0] != finditer.span(1)[1]
|
52 |
]
|
53 |
+
if at_odd_even != -1:
|
54 |
+
res = [r for r in res if r[1] % 2 == at_odd_even]
|
55 |
+
return res
|
56 |
|
57 |
|
58 |
def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
|
59 |
+
"""Convert byte position (or character position in a text) to its token position.
|
60 |
+
|
61 |
+
Used to convert the searched regex span to its token position.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
example_byte_id: tuple of (example_id, byte_id) where byte_id is a
|
65 |
+
character's position in the text.
|
66 |
+
token_byte_pos: numpy array of shape (num_examples, seq_len) where
|
67 |
+
`token_byte_pos[example_id][token_pos]` is the byte position of
|
68 |
+
the token at `token_pos` in the example with `example_id`.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
(example_id, token_pos_id) tuple.
|
72 |
+
"""
|
73 |
example_id, byte_id = example_byte_id
|
74 |
index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
|
75 |
return (example_id, index)
|
76 |
|
77 |
|
78 |
+
def get_code_precision_and_recall(token_pos_ids, codebook_acts, cb_act_counts=None):
|
79 |
+
"""Search for the codes that activate on the given `token_pos_ids`.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
token_pos_ids: list of (example_id, token_pos_id) tuples.
|
83 |
+
codebook_acts: numpy array of activations of a codebook on a dataset with
|
84 |
+
shape (num_examples, seq_len, k_codebook).
|
85 |
+
cb_act_counts: array of shape (num_codes,) where `cb_act_counts[cb_name][code]`
|
86 |
+
is the number of times the code `code` is activated in the dataset.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
codes: numpy array of code ids sorted by their precision on the given `token_pos_ids`.
|
90 |
+
prec: numpy array where `prec[i]` is the precision of the code
|
91 |
+
`codes[i]` for the given `token_pos_ids`.
|
92 |
+
recall: numpy array where `recall[i]` is the recall of the code
|
93 |
+
`codes[i]` for the given `token_pos_ids`.
|
94 |
+
code_acts: numpy array where `code_acts[i]` is the number of times
|
95 |
+
the code `codes[i]` is activated in the dataset.
|
96 |
+
"""
|
97 |
codes = np.array(
|
98 |
[
|
99 |
codebook_acts[example_id][token_pos_id]
|
|
|
117 |
return codes, prec, recall, code_acts
|
118 |
|
119 |
|
120 |
+
def get_neuron_precision_and_recall(
|
121 |
+
token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts
|
122 |
):
|
123 |
+
"""Get the neurons with the highest precision and recall for the given `token_pos_ids`.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
token_pos_ids: list of token (example_id, token_pos_id) tuples from a dataset over which
|
127 |
+
the neurons with the highest precision and recall are to be found.
|
128 |
+
recall: recall threshold for the neurons (this determines their activation threshold).
|
129 |
+
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
|
130 |
+
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
|
131 |
+
The third dimension is 2 because we consider neurons from both: attention and mlp.
|
132 |
+
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
|
133 |
+
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
|
134 |
+
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
|
135 |
+
dimensions to the last dimensions and then sorting the last dimension.
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.
|
139 |
+
best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`
|
140 |
+
based on the threshold determined by the `recall` argument.
|
141 |
+
best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,
|
142 |
+
`is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,
|
143 |
+
and `neuron_id` is the neuron's index in the layer.
|
144 |
+
"""
|
145 |
if isinstance(neuron_acts_by_ex, torch.Tensor):
|
146 |
+
neuron_acts_on_pattern = torch.stack(
|
147 |
[
|
148 |
neuron_acts_by_ex[example_id, token_pos_id]
|
149 |
for example_id, token_pos_id in token_pos_ids
|
150 |
],
|
151 |
dim=-1,
|
152 |
) # (layers, 2, dim_size, matches)
|
153 |
+
neuron_acts_on_pattern = torch.sort(neuron_acts_on_pattern, dim=-1).values
|
154 |
else:
|
155 |
+
neuron_acts_on_pattern = np.stack(
|
156 |
[
|
157 |
neuron_acts_by_ex[example_id, token_pos_id]
|
158 |
for example_id, token_pos_id in token_pos_ids
|
159 |
],
|
160 |
axis=-1,
|
161 |
) # (layers, 2, dim_size, matches)
|
162 |
+
neuron_acts_on_pattern.sort(axis=-1)
|
163 |
+
neuron_acts_on_pattern = torch.from_numpy(neuron_acts_on_pattern)
|
164 |
+
act_thresh = neuron_acts_on_pattern[
|
165 |
+
:, :, :, -int(recall * neuron_acts_on_pattern.shape[-1])
|
166 |
+
]
|
|
|
167 |
assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
|
168 |
prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
|
169 |
prec_den = prec_den.squeeze(-1)
|
170 |
prec_den = neuron_sorted_acts.shape[-1] - prec_den
|
171 |
+
prec = int(recall * neuron_acts_on_pattern.shape[-1]) / prec_den
|
172 |
assert (
|
173 |
+
prec.shape == neuron_acts_on_pattern.shape[:-1]
|
174 |
+
), f"{prec.shape} != {neuron_acts_on_pattern.shape[:-1]}"
|
175 |
|
176 |
best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
|
177 |
best_prec = prec[best_neuron_idx]
|
|
|
178 |
best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
|
179 |
best_neuron_acts = neuron_acts_by_ex[
|
180 |
:, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
|
|
|
185 |
return best_prec, best_neuron_acts, best_neuron_idx
|
186 |
|
187 |
|
188 |
+
def convert_to_adv_name(name, cb_at, gcb=""):
|
189 |
+
"""Convert layer0_head0 to layer0_attn_preproj_gcb0."""
|
190 |
+
if gcb:
|
191 |
layer, head = name.split("_")
|
192 |
+
return layer + f"_{cb_at}_gcb" + head[4:]
|
193 |
else:
|
194 |
return layer + "_" + cb_at
|
195 |
|
196 |
|
197 |
+
def convert_to_base_name(name, gcb=""):
|
198 |
+
"""Convert layer0_attn_preproj_gcb0 to layer0_head0."""
|
199 |
split_name = name.split("_")
|
200 |
layer, head = split_name[0], split_name[-1][3:]
|
201 |
+
if "gcb" in name:
|
202 |
return layer + "_head" + head
|
203 |
else:
|
204 |
return layer
|
|
|
215 |
|
216 |
|
217 |
def get_layer_head_from_adv_name(name):
|
218 |
+
"""Convert layer0_attn_preproj_gcb0 to 0, 0."""
|
219 |
base_name = convert_to_base_name(name)
|
220 |
layer, head = get_layer_head_from_base_name(base_name)
|
221 |
return layer, head
|
|
|
227 |
token_byte_pos,
|
228 |
cb_acts,
|
229 |
act_count_ft_tkns,
|
230 |
+
gcb="",
|
231 |
topk=5,
|
232 |
prec_threshold=0.5,
|
233 |
+
at_odd_even=-1,
|
234 |
):
|
235 |
+
"""Fetch codes that activate on a given regex pattern.
|
236 |
+
|
237 |
+
Retrieves at most `top_k` codes that activate with precision above `prec_threshold`.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
re_pattern: regex pattern to search for.
|
241 |
+
tokens_text: list of example texts of a dataset.
|
242 |
+
token_byte_pos: numpy array of shape (num_examples, seq_len) where
|
243 |
+
`token_byte_pos[example_id][token_pos]` is the byte position of
|
244 |
+
the token at `token_pos` in the example with `example_id`.
|
245 |
+
cb_acts: dict of codebook activations.
|
246 |
+
act_count_ft_tkns: dict over all codebooks of number of token activations on the dataset
|
247 |
+
gcb: "_gcb" for grouped codebooks and "" for non-grouped codebooks.
|
248 |
+
topk: maximum number of codes to return per codebook.
|
249 |
+
prec_threshold: minimum precision required for a code to be returned.
|
250 |
+
at_odd_even: to limit matches to odd or even positions only.
|
251 |
+
-1 (default): to not limit matches.
|
252 |
+
0: to limit matches to odd positions only.
|
253 |
+
1: to limit matches to even positions only.
|
254 |
+
This is useful for the TokFSM dataset when searching for states
|
255 |
+
since the first token of states are always at even positions.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
codebook_wise_codes: dict of codebook name to list of
|
259 |
+
(code, prec, recall, code_acts) tuples.
|
260 |
+
re_token_matches: number of tokens that match the regex pattern.
|
261 |
+
"""
|
262 |
+
byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
|
263 |
token_pos_ids = [
|
264 |
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
|
265 |
]
|
|
|
267 |
re_token_matches = len(token_pos_ids)
|
268 |
codebook_wise_codes = {}
|
269 |
for cb_name, cb in tqdm(cb_acts.items()):
|
270 |
+
base_cb_name = convert_to_base_name(cb_name, gcb=gcb)
|
271 |
+
codes, prec, recall, code_acts = get_code_precision_and_recall(
|
272 |
token_pos_ids,
|
273 |
cb,
|
274 |
cb_act_counts=act_count_ft_tkns[base_cb_name],
|
|
|
289 |
neuron_acts_by_ex,
|
290 |
neuron_sorted_acts,
|
291 |
recall_threshold,
|
292 |
+
at_odd_even=-1,
|
293 |
):
|
294 |
+
"""Fetch the highest precision neurons that activate on a given regex pattern.
|
295 |
+
|
296 |
+
The activation threshold for the neurons is determined by the `recall_threshold`.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
re_pattern: regex pattern to search for.
|
300 |
+
tokens_text: list of example texts of a dataset.
|
301 |
+
token_byte_pos: numpy array of shape (num_examples, seq_len) where
|
302 |
+
`token_byte_pos[example_id][token_pos]` is the byte position of
|
303 |
+
the token at `token_pos` in the example with `example_id`.
|
304 |
+
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
|
305 |
+
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
|
306 |
+
The third dimension is 2 because we consider neurons from both: attention and mlp.
|
307 |
+
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
|
308 |
+
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
|
309 |
+
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
|
310 |
+
dimensions to the last dimensions and then sorting the last dimension.
|
311 |
+
recall_threshold: recall threshold for the neurons (this determines their activation threshold).
|
312 |
+
at_odd_even: to limit matches to odd or even positions only.
|
313 |
+
-1 (default): to not limit matches.
|
314 |
+
0: to limit matches to odd positions only.
|
315 |
+
1: to limit matches to even positions only.
|
316 |
+
This is useful for the TokFSM dataset when searching for states
|
317 |
+
since the first token of states are always at even positions.
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.
|
321 |
+
best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`
|
322 |
+
based on the threshold determined by the `recall` argument.
|
323 |
+
best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,
|
324 |
+
`is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,
|
325 |
+
and `neuron_id` is the neuron's index in the layer.
|
326 |
+
re_token_matches: number of tokens that match the regex pattern.
|
327 |
+
"""
|
328 |
+
byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
|
329 |
token_pos_ids = [
|
330 |
byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
|
331 |
]
|
332 |
token_pos_ids = np.unique(token_pos_ids, axis=0)
|
333 |
re_token_matches = len(token_pos_ids)
|
334 |
+
best_prec, best_neuron_acts, best_neuron_idx = get_neuron_precision_and_recall(
|
335 |
token_pos_ids,
|
336 |
recall_threshold,
|
337 |
neuron_acts_by_ex,
|
|
|
346 |
token_byte_pos,
|
347 |
neuron_acts_by_ex,
|
348 |
neuron_sorted_acts,
|
349 |
+
at_odd_even=-1,
|
350 |
):
|
351 |
+
"""Compare codes with the highest precision neurons on the regex pattern of the code.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
best_codes_info: list of CodeInfo objects.
|
355 |
+
tokens_text: list of example texts of a dataset.
|
356 |
+
token_byte_pos: numpy array of shape (num_examples, seq_len) where
|
357 |
+
`token_byte_pos[example_id][token_pos]` is the byte position of
|
358 |
+
the token at `token_pos` in the example with `example_id`.
|
359 |
+
neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
|
360 |
+
on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
|
361 |
+
The third dimension is 2 because we consider neurons from both: attention and mlp.
|
362 |
+
neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
|
363 |
+
on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
|
364 |
+
This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
|
365 |
+
dimensions to the last dimensions and then sorting the last dimension.
|
366 |
+
at_odd_even: to limit matches to odd or even positions only.
|
367 |
+
-1 (default): to not limit matches.
|
368 |
+
0: to limit matches to odd positions only.
|
369 |
+
1: to limit matches to even positions only.
|
370 |
+
This is useful for the TokFSM dataset when searching for states
|
371 |
+
since the first token of states are always at even positions.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
codes_better_than_neurons: fraction of codes that have higher precision than the highest
|
375 |
+
precision neuron on the regex pattern of the code.
|
376 |
+
code_best_precs: is an array of the precision of each code in `best_codes_info`.
|
377 |
+
all_best_prec: is an array of the highest precision neurons on the regex pattern.
|
378 |
+
"""
|
379 |
assert isinstance(neuron_acts_by_ex, np.ndarray)
|
380 |
(
|
381 |
+
neuron_best_prec,
|
382 |
all_best_neuron_acts,
|
383 |
all_best_neuron_idxs,
|
384 |
all_re_token_matches,
|
385 |
) = zip(
|
386 |
*[
|
387 |
get_neurons_from_pattern(
|
388 |
+
code_info.regex,
|
389 |
tokens_text,
|
390 |
token_byte_pos,
|
391 |
neuron_acts_by_ex,
|
392 |
neuron_sorted_acts,
|
393 |
code_info.recall,
|
394 |
+
at_odd_even=at_odd_even,
|
395 |
)
|
396 |
+
for code_info in tqdm(best_codes_info)
|
397 |
],
|
398 |
strict=True,
|
399 |
)
|
400 |
+
neuron_best_prec = np.array(neuron_best_prec)
|
401 |
+
code_best_precs = np.array([code_info.prec for code_info in best_codes_info])
|
402 |
+
codes_better_than_neurons = code_best_precs > neuron_best_prec
|
403 |
+
return codes_better_than_neurons.mean(), code_best_precs, neuron_best_prec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/Concept_Code.py
CHANGED
@@ -21,7 +21,7 @@ tokens_text = st.session_state["tokens_text"]
|
|
21 |
tokens_str = st.session_state["tokens_str"]
|
22 |
cb_acts = st.session_state["cb_acts"]
|
23 |
act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
|
24 |
-
|
25 |
|
26 |
|
27 |
def get_example_concept_codes(example_id):
|
@@ -29,8 +29,8 @@ def get_example_concept_codes(example_id):
|
|
29 |
token_pos_ids = [(example_id, i) for i in range(seq_len)]
|
30 |
all_codes = []
|
31 |
for cb_name, cb in cb_acts.items():
|
32 |
-
base_cb_name = code_search_utils.convert_to_base_name(cb_name,
|
33 |
-
codes, prec, rec, code_acts = code_search_utils.
|
34 |
token_pos_ids,
|
35 |
cb,
|
36 |
act_count_ft_tkns[base_cb_name],
|
@@ -112,7 +112,6 @@ concept_code_description = (
|
|
112 |
)
|
113 |
st.write(concept_code_description)
|
114 |
|
115 |
-
# ex_col, p_col, r_col, trunc_col, sort_col = st.columns([1, 2, 2, 1, 1])
|
116 |
ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
|
117 |
example_id = ex_col.number_input(
|
118 |
"Example ID",
|
@@ -121,14 +120,6 @@ example_id = ex_col.number_input(
|
|
121 |
0,
|
122 |
key="example_id",
|
123 |
)
|
124 |
-
# prec_threshold = p_col.slider(
|
125 |
-
# "Precision Threshold",
|
126 |
-
# 0.0,
|
127 |
-
# 1.0,
|
128 |
-
# 0.02,
|
129 |
-
# key="prec",
|
130 |
-
# help="Precision Threshold controls the specificity of the codes for the given example.",
|
131 |
-
# )
|
132 |
recall_threshold = r_col.slider(
|
133 |
"Recall Threshold",
|
134 |
0.0,
|
@@ -138,13 +129,13 @@ recall_threshold = r_col.slider(
|
|
138 |
help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
|
139 |
)
|
140 |
example_truncation = trunc_col.number_input(
|
141 |
-
"Max Output Chars", 0,
|
142 |
)
|
143 |
sort_by_options = ["Precision", "Recall", "Num Acts"]
|
144 |
sort_by_name = sort_col.radio(
|
145 |
"Sort By",
|
146 |
sort_by_options,
|
147 |
-
index=
|
148 |
horizontal=True,
|
149 |
help="Sorts the codes by the selected metric.",
|
150 |
)
|
@@ -158,9 +149,6 @@ button = st.button(
|
|
158 |
args=(example_id,),
|
159 |
help="Find an example which has codes above the recall threshold.",
|
160 |
)
|
161 |
-
# if button:
|
162 |
-
# find_next_example(st.session_state["example_id"])
|
163 |
-
|
164 |
|
165 |
st.markdown("### Example Text")
|
166 |
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
|
|
|
21 |
tokens_str = st.session_state["tokens_str"]
|
22 |
cb_acts = st.session_state["cb_acts"]
|
23 |
act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
|
24 |
+
gcb = st.session_state["gcb"]
|
25 |
|
26 |
|
27 |
def get_example_concept_codes(example_id):
|
|
|
29 |
token_pos_ids = [(example_id, i) for i in range(seq_len)]
|
30 |
all_codes = []
|
31 |
for cb_name, cb in cb_acts.items():
|
32 |
+
base_cb_name = code_search_utils.convert_to_base_name(cb_name, gcb=gcb)
|
33 |
+
codes, prec, rec, code_acts = code_search_utils.get_code_precision_and_recall(
|
34 |
token_pos_ids,
|
35 |
cb,
|
36 |
act_count_ft_tkns[base_cb_name],
|
|
|
112 |
)
|
113 |
st.write(concept_code_description)
|
114 |
|
|
|
115 |
ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
|
116 |
example_id = ex_col.number_input(
|
117 |
"Example ID",
|
|
|
120 |
0,
|
121 |
key="example_id",
|
122 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
recall_threshold = r_col.slider(
|
124 |
"Recall Threshold",
|
125 |
0.0,
|
|
|
129 |
help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
|
130 |
)
|
131 |
example_truncation = trunc_col.number_input(
|
132 |
+
"Max Output Chars", 0, 102400, 1024, key="max_chars"
|
133 |
)
|
134 |
sort_by_options = ["Precision", "Recall", "Num Acts"]
|
135 |
sort_by_name = sort_col.radio(
|
136 |
"Sort By",
|
137 |
sort_by_options,
|
138 |
+
index=1,
|
139 |
horizontal=True,
|
140 |
help="Sorts the codes by the selected metric.",
|
141 |
)
|
|
|
149 |
args=(example_id,),
|
150 |
help="Find an example which has codes above the recall threshold.",
|
151 |
)
|
|
|
|
|
|
|
152 |
|
153 |
st.markdown("### Example Text")
|
154 |
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
|
utils.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
"""Util functions for codebook features."""
|
|
|
|
|
2 |
import re
|
3 |
import typing
|
4 |
from dataclasses import dataclass
|
@@ -57,11 +59,6 @@ class CodeInfo:
|
|
57 |
if self.regex is not None:
|
58 |
assert self.prec is not None and self.recall is not None
|
59 |
|
60 |
-
def check_patch_info(self):
|
61 |
-
"""Check if the patch info is present."""
|
62 |
-
# TODO: pos can be none for patching
|
63 |
-
assert self.pos is not None and self.code_pos is not None
|
64 |
-
|
65 |
def __repr__(self):
|
66 |
"""Return the string representation."""
|
67 |
repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}"
|
@@ -76,6 +73,57 @@ class CodeInfo:
|
|
76 |
repr += ")"
|
77 |
return repr
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
def logits_to_pred(logits, tokenizer, k=5):
|
81 |
"""Convert logits to top-k predictions."""
|
@@ -88,53 +136,6 @@ def logits_to_pred(logits, tokenizer, k=5):
|
|
88 |
return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))]
|
89 |
|
90 |
|
91 |
-
def patch_codebook_ids(
|
92 |
-
corrupted_codebook_ids, hook, pos, cache, cache_pos=None, code_idx=None
|
93 |
-
):
|
94 |
-
"""Patch codebook ids with cached ids."""
|
95 |
-
if cache_pos is None:
|
96 |
-
cache_pos = pos
|
97 |
-
if code_idx is None:
|
98 |
-
corrupted_codebook_ids[:, pos] = cache[hook.name][:, cache_pos]
|
99 |
-
else:
|
100 |
-
for code_id in range(32):
|
101 |
-
if code_id in code_idx:
|
102 |
-
corrupted_codebook_ids[:, pos, code_id] = cache[hook.name][
|
103 |
-
:, cache_pos, code_id
|
104 |
-
]
|
105 |
-
else:
|
106 |
-
corrupted_codebook_ids[:, pos, code_id] = -1
|
107 |
-
|
108 |
-
return corrupted_codebook_ids
|
109 |
-
|
110 |
-
|
111 |
-
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
|
112 |
-
"""Calculate the average logit difference between the answer and the other token."""
|
113 |
-
# Only the final logits are relevant for the answer
|
114 |
-
final_logits = logits[:, -1, :]
|
115 |
-
answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
|
116 |
-
answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
|
117 |
-
if per_prompt:
|
118 |
-
return answer_logit_diff
|
119 |
-
else:
|
120 |
-
return answer_logit_diff.mean()
|
121 |
-
|
122 |
-
|
123 |
-
def normalize_patched_logit_diff(
|
124 |
-
patched_logit_diff,
|
125 |
-
base_average_logit_diff,
|
126 |
-
corrupted_average_logit_diff,
|
127 |
-
):
|
128 |
-
"""Normalize the patched logit difference."""
|
129 |
-
# Subtract corrupted logit diff to measure the improvement,
|
130 |
-
# divide by the total improvement from clean to corrupted to normalise
|
131 |
-
# 0 means zero change, negative means actively made worse,
|
132 |
-
# 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
|
133 |
-
return (patched_logit_diff - corrupted_average_logit_diff) / (
|
134 |
-
base_average_logit_diff - corrupted_average_logit_diff
|
135 |
-
)
|
136 |
-
|
137 |
-
|
138 |
def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
|
139 |
"""Return the set of token ids each codebook feature activates on."""
|
140 |
codebook_ids = cb_acts[cb_key]
|
@@ -154,7 +155,6 @@ def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
|
|
154 |
|
155 |
def color_str(s: str, html: bool, color: Optional[str] = None):
|
156 |
"""Color the string for html or terminal."""
|
157 |
-
|
158 |
if html:
|
159 |
color = "DeepSkyBlue" if color is None else color
|
160 |
return f"<span style='color:{color}'>{s}</span>"
|
@@ -163,7 +163,7 @@ def color_str(s: str, html: bool, color: Optional[str] = None):
|
|
163 |
return colored(s, color)
|
164 |
|
165 |
|
166 |
-
def
|
167 |
"""Separate states with a dash and color red the tokens in color_idx."""
|
168 |
ret_string = ""
|
169 |
itr_over_color_idx = 0
|
@@ -224,31 +224,48 @@ def prepare_example_print(
|
|
224 |
return example_output
|
225 |
|
226 |
|
227 |
-
def
|
228 |
-
|
229 |
tokens,
|
230 |
-
|
231 |
n=3,
|
232 |
max_examples=100,
|
233 |
randomize=False,
|
234 |
html=False,
|
235 |
return_example_list=False,
|
236 |
):
|
237 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
if randomize:
|
239 |
raise NotImplementedError("Randomize not yet implemented.")
|
240 |
-
indices = range(len(
|
241 |
print_output = [] if return_example_list else ""
|
242 |
-
curr_ex =
|
243 |
total_examples = 0
|
244 |
tokens_to_color = []
|
245 |
-
color_fn =
|
246 |
for idx in indices:
|
247 |
if total_examples > max_examples:
|
248 |
break
|
249 |
-
i, j =
|
250 |
|
251 |
if i != curr_ex and curr_ex >= 0:
|
|
|
252 |
curr_ex_output = prepare_example_print(
|
253 |
curr_ex,
|
254 |
tokens[curr_ex],
|
@@ -275,17 +292,16 @@ def tkn_print(
|
|
275 |
print_output.append((curr_ex_output, len(tokens_to_color)))
|
276 |
else:
|
277 |
print_output += curr_ex_output
|
278 |
-
|
279 |
-
print_output += color_str(asterisk_str, html, "green")
|
280 |
total_examples += 1
|
281 |
|
282 |
return print_output
|
283 |
|
284 |
|
285 |
-
def
|
286 |
ft_tkns,
|
287 |
tokens,
|
288 |
-
|
289 |
n=3,
|
290 |
start=0,
|
291 |
stop=1000,
|
@@ -301,17 +317,17 @@ def print_ft_tkns(
|
|
301 |
num_tokens = len(tokens) * len(tokens[0])
|
302 |
codes, token_act_freqs, token_acts = [], [], []
|
303 |
for i in indices:
|
304 |
-
|
305 |
-
freq = (len(
|
306 |
if freq_filter is not None and freq[1] > freq_filter:
|
307 |
continue
|
308 |
codes.append(i)
|
309 |
token_act_freqs.append(freq)
|
310 |
-
if len(
|
311 |
-
tkn_acts =
|
312 |
-
|
313 |
tokens,
|
314 |
-
|
315 |
n=n,
|
316 |
max_examples=max_examples,
|
317 |
randomize=randomize,
|
@@ -340,149 +356,59 @@ def patch_in_codes(run_cb_ids, hook, pos, code, code_pos=None):
|
|
340 |
return run_cb_ids
|
341 |
|
342 |
|
343 |
-
def
|
344 |
"""Get the layer name used to store hooks/cache."""
|
345 |
-
if
|
346 |
-
|
347 |
-
|
348 |
-
return f"blocks.{layer_idx}.{cb_at}.codebook_layer.codebook.{head_idx}.hook_codebook_ids"
|
349 |
-
|
350 |
-
|
351 |
-
def get_cb_layer_names(layer, patch_types, n_heads):
|
352 |
-
"""Get the layer names used to store hooks/cache."""
|
353 |
-
layer_names = []
|
354 |
-
attn_added, mlp_added = False, False
|
355 |
-
if "attn_out" in patch_types:
|
356 |
-
attn_added = True
|
357 |
-
for head in range(n_heads):
|
358 |
-
layer_names.append(
|
359 |
-
f"blocks.{layer}.attn.codebook_layer.codebook.{head}.hook_codebook_ids"
|
360 |
-
)
|
361 |
-
if "mlp_out" in patch_types:
|
362 |
-
mlp_added = True
|
363 |
-
layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids")
|
364 |
-
|
365 |
-
for patch_type in patch_types:
|
366 |
-
# match patch_type of the pattern attn_\d_head_\d
|
367 |
-
attn_head = re.match(r"attn_(\d)_head_(\d)", patch_type)
|
368 |
-
if (not attn_added) and attn_head and attn_head[1] == str(layer):
|
369 |
-
layer_names.append(
|
370 |
-
f"blocks.{layer}.attn.codebook_layer.codebook.{attn_head[2]}.hook_codebook_ids"
|
371 |
-
)
|
372 |
-
mlp = re.match(r"mlp_(\d)", patch_type)
|
373 |
-
if (not mlp_added) and mlp and mlp[1] == str(layer):
|
374 |
-
layer_names.append(f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids")
|
375 |
-
|
376 |
-
return layer_names
|
377 |
-
|
378 |
-
|
379 |
-
def cb_layer_name_to_info(layer_name):
|
380 |
-
"""Get the layer info from the layer name."""
|
381 |
-
layer_name_split = layer_name.split(".")
|
382 |
-
layer_idx = int(layer_name_split[1])
|
383 |
-
cb_at = layer_name_split[2]
|
384 |
-
if cb_at == "mlp":
|
385 |
-
head_idx = None
|
386 |
else:
|
387 |
-
|
388 |
-
return cb_at, layer_idx, head_idx
|
389 |
-
|
390 |
-
|
391 |
-
def get_hooks(code, cb_at, layer_idx, head_idx=None, pos=None):
|
392 |
-
"""Get the hooks for the codebook features."""
|
393 |
-
hook_fns = [
|
394 |
-
partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code))
|
395 |
-
]
|
396 |
-
return [
|
397 |
-
(get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i])
|
398 |
-
for i in range(len(code))
|
399 |
-
]
|
400 |
-
|
401 |
-
|
402 |
-
def run_with_codes(
|
403 |
-
input, cb_model, code, cb_at, layer_idx, head_idx=None, pos=None, prepend_bos=True
|
404 |
-
):
|
405 |
-
"""Run the model with the codebook features patched in."""
|
406 |
-
hook_fns = [
|
407 |
-
partial(patch_in_codes, pos=pos, code=code[i]) for i in range(len(code))
|
408 |
-
]
|
409 |
-
cb_model.reset_codebook_metrics()
|
410 |
-
cb_model.reset_hook_kwargs()
|
411 |
-
fwd_hooks = [
|
412 |
-
(get_cb_layer_name(cb_at[i], layer_idx[i], head_idx[i]), hook_fns[i])
|
413 |
-
for i in range(len(cb_at))
|
414 |
-
]
|
415 |
-
with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
|
416 |
-
patched_logits, patched_cache = hooked_model.run_with_cache(
|
417 |
-
input, prepend_bos=prepend_bos
|
418 |
-
)
|
419 |
-
return patched_logits, patched_cache
|
420 |
-
|
421 |
-
|
422 |
-
def in_hook_list(list_of_arg_tuples, layer, head=None):
|
423 |
-
"""Check if the component specified by `layer` and `head` is in the `list_of_arg_tuples`."""
|
424 |
-
# if head is not provided, then checks in MLP
|
425 |
-
for arg_tuple in list_of_arg_tuples:
|
426 |
-
if head is None:
|
427 |
-
if arg_tuple.cb_at == "mlp" and arg_tuple.layer == layer:
|
428 |
-
return True
|
429 |
-
else:
|
430 |
-
if (
|
431 |
-
arg_tuple.cb_at == "attn"
|
432 |
-
and arg_tuple.layer == layer
|
433 |
-
and arg_tuple.head == head
|
434 |
-
):
|
435 |
-
return True
|
436 |
-
return False
|
437 |
|
438 |
|
439 |
-
|
440 |
-
def generate_with_codes(
|
441 |
input,
|
442 |
cb_model,
|
|
|
|
|
443 |
list_of_code_infos=(),
|
444 |
-
disable_other_comps=False,
|
445 |
-
automata=None,
|
446 |
-
generate_kwargs=None,
|
447 |
):
|
448 |
-
"""
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
|
|
451 |
hook_fns = [
|
452 |
-
partial(patch_in_codes, pos=tupl.pos, code=tupl.code)
|
453 |
for tupl in list_of_code_infos
|
454 |
]
|
455 |
fwd_hooks = [
|
456 |
-
(
|
457 |
for i, tupl in enumerate(list_of_code_infos)
|
458 |
]
|
459 |
cb_model.reset_hook_kwargs()
|
460 |
-
if disable_other_comps:
|
461 |
-
for layer, cb in cb_model.all_codebooks.items():
|
462 |
-
for head_idx, head in enumerate(cb[0].codebook):
|
463 |
-
if not in_hook_list(list_of_code_infos, layer, head_idx):
|
464 |
-
head.set_hook_kwargs(
|
465 |
-
disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False
|
466 |
-
)
|
467 |
-
if not in_hook_list(list_of_code_infos, layer):
|
468 |
-
cb[1].set_hook_kwargs(
|
469 |
-
disable_topk=1, disable_for_tkns=[-1], keep_k_codes=False
|
470 |
-
)
|
471 |
with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
|
472 |
-
|
473 |
-
return
|
474 |
|
475 |
|
476 |
-
def
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
|
|
|
|
|
|
|
|
|
|
485 |
)
|
|
|
486 |
|
487 |
|
488 |
def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
|
@@ -511,11 +437,27 @@ def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
|
|
511 |
return 0.5 * loss
|
512 |
|
513 |
|
514 |
-
def
|
515 |
-
"""
|
516 |
-
|
517 |
-
|
518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
|
520 |
|
521 |
def find_code_changes(cache1, cache2, pos=None):
|
@@ -525,8 +467,8 @@ def find_code_changes(cache1, cache2, pos=None):
|
|
525 |
c1 = cache1[k][0, pos]
|
526 |
c2 = cache2[k][0, pos]
|
527 |
if not torch.all(c1 == c2):
|
528 |
-
print(
|
529 |
-
print(
|
530 |
|
531 |
|
532 |
def common_codes_in_cache(cache_codes, threshold=0.0):
|
@@ -541,39 +483,52 @@ def common_codes_in_cache(cache_codes, threshold=0.0):
|
|
541 |
return codes, counts
|
542 |
|
543 |
|
544 |
-
def
|
545 |
-
info_str: str,
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
"""
|
552 |
-
code, layer, head, occ_freq, train_act_freq = info_str.split(", ")
|
553 |
-
code = int(code.split(": ")[1])
|
554 |
-
layer = int(layer.split(": ")[1])
|
555 |
-
head = int(head.split(": ")[1]) if head else None
|
556 |
-
occ_freq = float(occ_freq.split(": ")[1])
|
557 |
-
train_act_freq = float(train_act_freq.split(": ")[1])
|
558 |
-
return CodeInfo(code, layer, head, pos=pos, code_pos=code_pos, cb_at=cb_at)
|
559 |
-
|
560 |
-
|
561 |
-
def parse_concept_codes_string(info_str: str, pos=None, code_append=False):
|
562 |
-
"""Parse the concept codes string."""
|
563 |
code_info_strs = info_str.strip().split("\n")
|
564 |
-
|
|
|
565 |
layer, head = None, None
|
566 |
-
|
|
|
|
|
|
|
567 |
for code_info_str in code_info_strs:
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
|
|
570 |
)
|
571 |
-
if code_append:
|
572 |
continue
|
573 |
-
if layer ==
|
574 |
-
code_pos -= 1
|
575 |
else:
|
576 |
code_pos = -1
|
577 |
-
|
578 |
-
layer, head =
|
579 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Util functions for codebook features."""
|
2 |
+
|
3 |
+
import pathlib
|
4 |
import re
|
5 |
import typing
|
6 |
from dataclasses import dataclass
|
|
|
59 |
if self.regex is not None:
|
60 |
assert self.prec is not None and self.recall is not None
|
61 |
|
|
|
|
|
|
|
|
|
|
|
62 |
def __repr__(self):
|
63 |
"""Return the string representation."""
|
64 |
repr = f"CodeInfo(code={self.code}, layer={self.layer}, head={self.head}, cb_at={self.cb_at}"
|
|
|
73 |
repr += ")"
|
74 |
return repr
|
75 |
|
76 |
+
@classmethod
|
77 |
+
def from_str(cls, code_txt, *args, **kwargs):
|
78 |
+
"""Extract code info fields from string."""
|
79 |
+
code_txt = code_txt.strip().lower()
|
80 |
+
code_txt = code_txt.split(", ")
|
81 |
+
code_txt = dict(txt.split(": ") for txt in code_txt)
|
82 |
+
return cls(*args, **code_txt, **kwargs)
|
83 |
+
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class ModelInfoForWebapp:
|
87 |
+
"""Model info for webapp."""
|
88 |
+
|
89 |
+
model_name: str
|
90 |
+
pretrained_path: str
|
91 |
+
dataset_name: str
|
92 |
+
num_codes: int
|
93 |
+
cb_at: str
|
94 |
+
gcb: str
|
95 |
+
n_layers: int
|
96 |
+
n_heads: Optional[int] = None
|
97 |
+
seed: int = 42
|
98 |
+
max_samples: int = 2000
|
99 |
+
|
100 |
+
def __post_init__(self):
|
101 |
+
"""Convert to correct types."""
|
102 |
+
self.num_codes = int(self.num_codes)
|
103 |
+
self.n_layers = int(self.n_layers)
|
104 |
+
if self.n_heads == "None":
|
105 |
+
self.n_heads = None
|
106 |
+
elif self.n_heads is not None:
|
107 |
+
self.n_heads = int(self.n_heads)
|
108 |
+
self.seed = int(self.seed)
|
109 |
+
self.max_samples = int(self.max_samples)
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def load(cls, path):
|
113 |
+
"""Parse model info from path."""
|
114 |
+
path = pathlib.Path(path)
|
115 |
+
with open(path / "info.txt", "r") as f:
|
116 |
+
lines = f.readlines()
|
117 |
+
lines = dict(line.strip().split(": ") for line in lines)
|
118 |
+
return cls(**lines)
|
119 |
+
|
120 |
+
def save(self, path):
|
121 |
+
"""Save model info to path."""
|
122 |
+
path = pathlib.Path(path)
|
123 |
+
with open(path / "info.txt", "w") as f:
|
124 |
+
for k, v in self.__dict__.items():
|
125 |
+
f.write(f"{k}: {v}\n")
|
126 |
+
|
127 |
|
128 |
def logits_to_pred(logits, tokenizer, k=5):
|
129 |
"""Convert logits to top-k predictions."""
|
|
|
136 |
return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))]
|
137 |
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
def features_to_tokens(cb_key, cb_acts, num_codes, code=None):
|
140 |
"""Return the set of token ids each codebook feature activates on."""
|
141 |
codebook_ids = cb_acts[cb_key]
|
|
|
155 |
|
156 |
def color_str(s: str, html: bool, color: Optional[str] = None):
|
157 |
"""Color the string for html or terminal."""
|
|
|
158 |
if html:
|
159 |
color = "DeepSkyBlue" if color is None else color
|
160 |
return f"<span style='color:{color}'>{s}</span>"
|
|
|
163 |
return colored(s, color)
|
164 |
|
165 |
|
166 |
+
def color_tokens_tokfsm(tokens, color_idx, html=False):
|
167 |
"""Separate states with a dash and color red the tokens in color_idx."""
|
168 |
ret_string = ""
|
169 |
itr_over_color_idx = 0
|
|
|
224 |
return example_output
|
225 |
|
226 |
|
227 |
+
def print_token_activations_of_code(
|
228 |
+
code_act_by_pos,
|
229 |
tokens,
|
230 |
+
is_fsm=False,
|
231 |
n=3,
|
232 |
max_examples=100,
|
233 |
randomize=False,
|
234 |
html=False,
|
235 |
return_example_list=False,
|
236 |
):
|
237 |
+
"""Print the context with the tokens that a code activates on.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
code_act_by_pos: list of (example_id, token_pos_id) tuples specifying
|
241 |
+
the token positions that a code activates on in a dataset.
|
242 |
+
tokens: list of tokens of a dataset.
|
243 |
+
is_fsm: whether the dataset is the TokFSM dataset.
|
244 |
+
n: context to print around each side of a token that the code activates on.
|
245 |
+
max_examples: maximum number of examples to print.
|
246 |
+
randomize: whether to randomize the order of examples.
|
247 |
+
html: Format the printing style for html or terminal.
|
248 |
+
return_example_list: whether to return the printed string by examples or as a single string.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
string of all examples formatted if `return_example_list` is False otherwise
|
252 |
+
list of (example_string, num_tokens_colored) tuples for each example.
|
253 |
+
"""
|
254 |
if randomize:
|
255 |
raise NotImplementedError("Randomize not yet implemented.")
|
256 |
+
indices = range(len(code_act_by_pos))
|
257 |
print_output = [] if return_example_list else ""
|
258 |
+
curr_ex = code_act_by_pos[0][0]
|
259 |
total_examples = 0
|
260 |
tokens_to_color = []
|
261 |
+
color_fn = color_tokens_tokfsm if is_fsm else partial(color_tokens, n=n)
|
262 |
for idx in indices:
|
263 |
if total_examples > max_examples:
|
264 |
break
|
265 |
+
i, j = code_act_by_pos[idx]
|
266 |
|
267 |
if i != curr_ex and curr_ex >= 0:
|
268 |
+
# got new example so print the previous one
|
269 |
curr_ex_output = prepare_example_print(
|
270 |
curr_ex,
|
271 |
tokens[curr_ex],
|
|
|
292 |
print_output.append((curr_ex_output, len(tokens_to_color)))
|
293 |
else:
|
294 |
print_output += curr_ex_output
|
295 |
+
print_output += color_str("*" * 50, html, "green")
|
|
|
296 |
total_examples += 1
|
297 |
|
298 |
return print_output
|
299 |
|
300 |
|
301 |
+
def print_token_activations_of_codes(
|
302 |
ft_tkns,
|
303 |
tokens,
|
304 |
+
is_fsm=False,
|
305 |
n=3,
|
306 |
start=0,
|
307 |
stop=1000,
|
|
|
317 |
num_tokens = len(tokens) * len(tokens[0])
|
318 |
codes, token_act_freqs, token_acts = [], [], []
|
319 |
for i in indices:
|
320 |
+
tkns_of_code = ft_tkns[i]
|
321 |
+
freq = (len(tkns_of_code), 100 * len(tkns_of_code) / num_tokens)
|
322 |
if freq_filter is not None and freq[1] > freq_filter:
|
323 |
continue
|
324 |
codes.append(i)
|
325 |
token_act_freqs.append(freq)
|
326 |
+
if len(tkns_of_code) > 0:
|
327 |
+
tkn_acts = print_token_activations_of_code(
|
328 |
+
tkns_of_code,
|
329 |
tokens,
|
330 |
+
is_fsm,
|
331 |
n=n,
|
332 |
max_examples=max_examples,
|
333 |
randomize=randomize,
|
|
|
356 |
return run_cb_ids
|
357 |
|
358 |
|
359 |
+
def get_cb_hook_key(cb_at: str, layer_idx: int, gcb_idx: Optional[int] = None):
|
360 |
"""Get the layer name used to store hooks/cache."""
|
361 |
+
comp_name = "attn" if "attn" in cb_at else "mlp"
|
362 |
+
if gcb_idx is None:
|
363 |
+
return f"blocks.{layer_idx}.{comp_name}.codebook_layer.hook_codebook_ids"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
else:
|
365 |
+
return f"blocks.{layer_idx}.{comp_name}.codebook_layer.codebook.{gcb_idx}.hook_codebook_ids"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
|
368 |
+
def run_model_fn_with_codes(
|
|
|
369 |
input,
|
370 |
cb_model,
|
371 |
+
fn_name,
|
372 |
+
fn_kwargs=None,
|
373 |
list_of_code_infos=(),
|
|
|
|
|
|
|
374 |
):
|
375 |
+
"""Run the `cb_model`'s `fn_name` method while activating the codes in `list_of_code_infos`.
|
376 |
+
|
377 |
+
Common use case includes running the `run_with_cache` method while activating the codes.
|
378 |
+
For running the `generate` method, use `generate_with_codes` instead.
|
379 |
+
"""
|
380 |
+
if fn_kwargs is None:
|
381 |
+
fn_kwargs = {}
|
382 |
hook_fns = [
|
383 |
+
partial(patch_in_codes, pos=tupl.pos, code=tupl.code, code_pos=tupl.code_pos)
|
384 |
for tupl in list_of_code_infos
|
385 |
]
|
386 |
fwd_hooks = [
|
387 |
+
(get_cb_hook_key(tupl.cb_at, tupl.layer, tupl.head), hook_fns[i])
|
388 |
for i, tupl in enumerate(list_of_code_infos)
|
389 |
]
|
390 |
cb_model.reset_hook_kwargs()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
with cb_model.hooks(fwd_hooks, [], True, False) as hooked_model:
|
392 |
+
ret = hooked_model.__getattribute__(fn_name)(input, **fn_kwargs)
|
393 |
+
return ret
|
394 |
|
395 |
|
396 |
+
def generate_with_codes(
|
397 |
+
input,
|
398 |
+
cb_model,
|
399 |
+
list_of_code_infos=(),
|
400 |
+
tokfsm=None,
|
401 |
+
generate_kwargs=None,
|
402 |
+
):
|
403 |
+
"""Sample from the language model while activating the codes in `list_of_code_infos`."""
|
404 |
+
gen = run_model_fn_with_codes(
|
405 |
+
input,
|
406 |
+
cb_model,
|
407 |
+
"generate",
|
408 |
+
generate_kwargs,
|
409 |
+
list_of_code_infos,
|
410 |
)
|
411 |
+
return tokfsm.seq_to_traj(gen) if tokfsm is not None else gen
|
412 |
|
413 |
|
414 |
def JSD(logits1, logits2, pos=-1, reduction="batchmean"):
|
|
|
437 |
return 0.5 * loss
|
438 |
|
439 |
|
440 |
+
def cb_hook_key_to_info(layer_hook_key: str):
|
441 |
+
"""Get the layer info from the codebook layer hook key.
|
442 |
+
|
443 |
+
Args:
|
444 |
+
layer_hook_key: the hook key of the codebook layer.
|
445 |
+
E.g. `blocks.3.attn.codebook_layer.hook_codebook_ids`
|
446 |
+
|
447 |
+
Returns:
|
448 |
+
comp_name: the name of the component codebook is appied at.
|
449 |
+
layer_idx: the layer index.
|
450 |
+
gcb_idx: the codebook index if the codebook layer is grouped, otherwise None.
|
451 |
+
"""
|
452 |
+
layer_search = re.search(r"blocks\.(\d+)\.(\w+)\.", layer_hook_key)
|
453 |
+
assert layer_search is not None
|
454 |
+
layer_idx, comp_name = int(layer_search.group(1)), layer_search.group(2)
|
455 |
+
gcb_idx_search = re.search(r"codebook\.(\d+)", layer_hook_key)
|
456 |
+
if gcb_idx_search is not None:
|
457 |
+
gcb_idx = int(gcb_idx_search.group(1))
|
458 |
+
else:
|
459 |
+
gcb_idx = None
|
460 |
+
return comp_name, layer_idx, gcb_idx
|
461 |
|
462 |
|
463 |
def find_code_changes(cache1, cache2, pos=None):
|
|
|
467 |
c1 = cache1[k][0, pos]
|
468 |
c2 = cache2[k][0, pos]
|
469 |
if not torch.all(c1 == c2):
|
470 |
+
print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist())
|
471 |
+
print(cb_hook_key_to_info(k), c1.tolist(), c2.tolist())
|
472 |
|
473 |
|
474 |
def common_codes_in_cache(cache_codes, threshold=0.0):
|
|
|
483 |
return codes, counts
|
484 |
|
485 |
|
486 |
+
def parse_topic_codes_string(
|
487 |
+
info_str: str,
|
488 |
+
pos: Optional[int] = None,
|
489 |
+
code_append: Optional[bool] = False,
|
490 |
+
**code_info_kwargs,
|
491 |
+
):
|
492 |
+
"""Parse the topic codes string."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
code_info_strs = info_str.strip().split("\n")
|
494 |
+
code_info_strs = [e.strip() for e in code_info_strs if e]
|
495 |
+
topic_codes = []
|
496 |
layer, head = None, None
|
497 |
+
if code_append is None:
|
498 |
+
code_pos = None
|
499 |
+
else:
|
500 |
+
code_pos = "append" if code_append else -1
|
501 |
for code_info_str in code_info_strs:
|
502 |
+
topic_codes.append(
|
503 |
+
CodeInfo.from_str(
|
504 |
+
code_info_str,
|
505 |
+
pos=pos,
|
506 |
+
code_pos=code_pos,
|
507 |
+
**code_info_kwargs,
|
508 |
+
)
|
509 |
)
|
510 |
+
if code_append is None or code_append:
|
511 |
continue
|
512 |
+
if layer == topic_codes[-1].layer and head == topic_codes[-1].head:
|
513 |
+
code_pos -= 1 # type: ignore
|
514 |
else:
|
515 |
code_pos = -1
|
516 |
+
topic_codes[-1].code_pos = code_pos
|
517 |
+
layer, head = topic_codes[-1].layer, topic_codes[-1].head
|
518 |
+
return topic_codes
|
519 |
+
|
520 |
+
|
521 |
+
def find_similar_codes(cb_model, code_info, n=8):
|
522 |
+
"""Find the `n` most similar codes to the given code using cosine similarity.
|
523 |
+
|
524 |
+
Useful for finding related codes for interpretability.
|
525 |
+
"""
|
526 |
+
codebook = cb_model.get_codebook(code_info)
|
527 |
+
device = codebook.weight.device
|
528 |
+
code = codebook(torch.tensor(code_info.code).to(device))
|
529 |
+
code = code.to(device)
|
530 |
+
logits = torch.matmul(code, codebook.weight.T)
|
531 |
+
_, indices = torch.topk(logits, n)
|
532 |
+
assert indices[0] == code_info.code
|
533 |
+
assert torch.allclose(logits[indices[0]], torch.tensor(1.0))
|
534 |
+
return indices[1:], logits[indices[1:]].tolist()
|
webapp_utils.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
"""Utility functions for running webapp using streamlit."""
|
2 |
|
3 |
|
|
|
|
|
|
|
4 |
import streamlit as st
|
5 |
from streamlit.components.v1 import html
|
6 |
|
@@ -61,10 +64,10 @@ def load_ft_tkns(model_id, layer, head=None, code=None):
|
|
61 |
# model_id required to not mix cache_data for different models
|
62 |
assert model_id is not None
|
63 |
cb_at = st.session_state["cb_at"]
|
64 |
-
|
65 |
cb_acts = st.session_state["cb_acts"]
|
66 |
if head is not None:
|
67 |
-
cb_name = f"layer{layer}_{cb_at}{
|
68 |
else:
|
69 |
cb_name = f"layer{layer}_{cb_at}"
|
70 |
return utils.features_to_tokens(
|
@@ -84,11 +87,12 @@ def get_code_acts(
|
|
84 |
ctx_size=5,
|
85 |
num_examples=100,
|
86 |
return_example_list=False,
|
|
|
87 |
):
|
88 |
"""Get the token activations for a given code."""
|
89 |
ft_tkns = load_ft_tkns(model_id, layer, head, code)
|
90 |
ft_tkns = [ft_tkns]
|
91 |
-
_, freqs, acts = utils.
|
92 |
ft_tkns,
|
93 |
tokens=tokens_str,
|
94 |
indices=[0],
|
@@ -96,6 +100,7 @@ def get_code_acts(
|
|
96 |
n=ctx_size,
|
97 |
max_examples=num_examples,
|
98 |
return_example_list=return_example_list,
|
|
|
99 |
)
|
100 |
return acts[0], freqs[0]
|
101 |
|
@@ -122,8 +127,16 @@ def find_next_code(code, layer_code_acts, act_range=None):
|
|
122 |
"""Find the next code that has activations in the given range."""
|
123 |
if act_range is None:
|
124 |
return code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
for code_iter, code_act_count in enumerate(layer_code_acts[code:]):
|
126 |
-
if code_act_count >=
|
127 |
code += code_iter
|
128 |
break
|
129 |
return code
|
@@ -161,8 +174,8 @@ def add_save_code_button(
|
|
161 |
demo_file_path: str,
|
162 |
num_acts: int,
|
163 |
save_regex: bool = False,
|
164 |
-
prec: float = None,
|
165 |
-
recall: float = None,
|
166 |
button_st_container=st,
|
167 |
button_text: bool = False,
|
168 |
button_key_suffix: str = "",
|
@@ -176,12 +189,12 @@ def add_save_code_button(
|
|
176 |
if save_button:
|
177 |
description = st.text_input(
|
178 |
"Write a description for the code",
|
179 |
-
key="save_code_desc",
|
180 |
)
|
181 |
if not description:
|
182 |
return
|
183 |
|
184 |
-
description = st.session_state.get("save_code_desc", None)
|
185 |
if description:
|
186 |
layer = st.session_state["ct_act_layer"]
|
187 |
is_attn = st.session_state["is_attn"]
|
@@ -207,4 +220,3 @@ def add_save_code_button(
|
|
207 |
saved = add_code_to_demo_file(code_info, demo_file_path)
|
208 |
if saved:
|
209 |
st.success("Code saved!", icon="π")
|
210 |
-
st.success("Code saved!", icon="π")
|
|
|
1 |
"""Utility functions for running webapp using streamlit."""
|
2 |
|
3 |
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
import streamlit as st
|
8 |
from streamlit.components.v1 import html
|
9 |
|
|
|
64 |
# model_id required to not mix cache_data for different models
|
65 |
assert model_id is not None
|
66 |
cb_at = st.session_state["cb_at"]
|
67 |
+
gcb = st.session_state["gcb"]
|
68 |
cb_acts = st.session_state["cb_acts"]
|
69 |
if head is not None:
|
70 |
+
cb_name = f"layer{layer}_{cb_at}{gcb}{head}"
|
71 |
else:
|
72 |
cb_name = f"layer{layer}_{cb_at}"
|
73 |
return utils.features_to_tokens(
|
|
|
87 |
ctx_size=5,
|
88 |
num_examples=100,
|
89 |
return_example_list=False,
|
90 |
+
is_fsm=False,
|
91 |
):
|
92 |
"""Get the token activations for a given code."""
|
93 |
ft_tkns = load_ft_tkns(model_id, layer, head, code)
|
94 |
ft_tkns = [ft_tkns]
|
95 |
+
_, freqs, acts = utils.print_token_activations_of_codes(
|
96 |
ft_tkns,
|
97 |
tokens=tokens_str,
|
98 |
indices=[0],
|
|
|
100 |
n=ctx_size,
|
101 |
max_examples=num_examples,
|
102 |
return_example_list=return_example_list,
|
103 |
+
is_fsm=is_fsm,
|
104 |
)
|
105 |
return acts[0], freqs[0]
|
106 |
|
|
|
127 |
"""Find the next code that has activations in the given range."""
|
128 |
if act_range is None:
|
129 |
return code
|
130 |
+
min_act, max_act = 0, np.inf
|
131 |
+
if isinstance(act_range, tuple):
|
132 |
+
if len(act_range) == 2:
|
133 |
+
min_act, max_act = act_range
|
134 |
+
else:
|
135 |
+
min_act = act_range[0]
|
136 |
+
elif isinstance(act_range, int):
|
137 |
+
min_act = act_range
|
138 |
for code_iter, code_act_count in enumerate(layer_code_acts[code:]):
|
139 |
+
if code_act_count >= min_act and code_act_count <= max_act:
|
140 |
code += code_iter
|
141 |
break
|
142 |
return code
|
|
|
174 |
demo_file_path: str,
|
175 |
num_acts: int,
|
176 |
save_regex: bool = False,
|
177 |
+
prec: Optional[float] = None,
|
178 |
+
recall: Optional[float] = None,
|
179 |
button_st_container=st,
|
180 |
button_text: bool = False,
|
181 |
button_key_suffix: str = "",
|
|
|
189 |
if save_button:
|
190 |
description = st.text_input(
|
191 |
"Write a description for the code",
|
192 |
+
key=f"save_code_desc{button_key_suffix}",
|
193 |
)
|
194 |
if not description:
|
195 |
return
|
196 |
|
197 |
+
description = st.session_state.get(f"save_code_desc{button_key_suffix}", None)
|
198 |
if description:
|
199 |
layer = st.session_state["ct_act_layer"]
|
200 |
is_attn = st.session_state["is_attn"]
|
|
|
220 |
saved = add_code_to_demo_file(code_info, demo_file_path)
|
221 |
if saved:
|
222 |
st.success("Code saved!", icon="π")
|
|