Spaces:
Running
Running
"""Web App for the Codebook Features project.""" | |
import argparse | |
import glob | |
import os | |
import streamlit as st | |
import code_search_utils | |
import utils | |
import webapp_utils | |
# --- Parse command line arguments --- | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--deploy", | |
default=True, | |
help="Deploy mode.", | |
) | |
parser.add_argument( | |
"--cache_dir", | |
type=str, | |
default="cache/", | |
help="Path to directory containing cache for codebook models.", | |
) | |
try: | |
args = parser.parse_args() | |
except SystemExit as e: | |
# This exception will be raised if --help or invalid command line arguments | |
# are used. Currently streamlit prevents the program from exiting normally | |
# so we have to do a hard exit. | |
os._exit(e.code if isinstance(e.code, int) else 1) | |
deploy = args.deploy | |
webapp_utils.load_widget_state() | |
st.set_page_config( | |
page_title="Codebook Features", | |
page_icon="๐", | |
) | |
st.title("Codebook Features") | |
# --- Load model info and cache --- | |
pretty_model_names = { | |
"TinyStories-1Layer-21M#100ksteps_vcb_mlp": "TinyStories-1L-21M-MLP", | |
"TinyStories-1Layer-21M_ccb_attn_preproj": "TinyStories 1 Layer Attention Codebook", | |
"TinyStories-33M_ccb_attn_preproj": "TinyStories 4 Layer Attention Codebook", | |
"TinyStories-1Layer-21M_vcb_mlp": "TinyStories 1 Layer MLP Codebook", | |
} | |
orig_model_name = {v: k for k, v in pretty_model_names.items()} | |
base_cache_dir = args.cache_dir | |
dirs = glob.glob(base_cache_dir + "models/*/") | |
model_name_options = [d.split("/")[-2].split("_")[:-2] for d in dirs] | |
model_name_options = ["_".join(m) for m in model_name_options] | |
model_name_options = sorted(set(model_name_options)) | |
def_model_idx = ["attn" in m.lower() for m in model_name_options].index(True) | |
p_model_name = st.selectbox( | |
"Model", | |
[pretty_model_names.get(m, m) for m in model_name_options], | |
index=def_model_idx, | |
key=webapp_utils.persist("model_name"), | |
) | |
model_name = orig_model_name.get(p_model_name, p_model_name) | |
is_fsm = "FSM" in p_model_name | |
codes_cache_path = base_cache_dir + f"models/{model_name}_*" | |
dirs = glob.glob(codes_cache_path) | |
dirs.sort(key=os.path.getmtime) | |
# session states | |
codes_cache_path = dirs[-1] + "/" | |
model_info = utils.ModelInfoForWebapp.load(codes_cache_path) | |
num_codes = model_info.num_codes | |
num_layers = model_info.n_layers | |
num_heads = model_info.n_heads | |
cb_at = model_info.cb_at | |
gcb = model_info.gcb | |
gcb = "_gcb" if gcb else "" | |
is_attn = "attn" in cb_at | |
dataset_cache_path = base_cache_dir + f"datasets/{model_info.dataset_name}/" | |
( | |
tokens_str, | |
tokens_text, | |
token_byte_pos, | |
cb_acts, | |
act_count_ft_tkns, | |
metrics, | |
) = webapp_utils.load_code_search_cache(codes_cache_path, dataset_cache_path) | |
seq_len = len(tokens_str[0]) | |
metric_keys = ["eval_loss", "eval_accuracy", "eval_dead_code_fraction"] | |
metrics = {k: v for k, v in metrics.items() if k.split("/")[0] in metric_keys} | |
# --- Set the session states --- | |
st.session_state["model_name_id"] = model_name | |
st.session_state["cb_acts"] = cb_acts | |
st.session_state["tokens_text"] = tokens_text | |
st.session_state["tokens_str"] = tokens_str | |
st.session_state["act_count_ft_tkns"] = act_count_ft_tkns | |
st.session_state["num_codes"] = num_codes | |
st.session_state["gcb"] = gcb | |
st.session_state["cb_at"] = cb_at | |
st.session_state["is_attn"] = is_attn | |
st.session_state["seq_len"] = seq_len | |
if not deploy: | |
st.markdown("## Metrics") | |
# hide metrics by default | |
if st.checkbox("Show Model Metrics"): | |
st.write(metrics) | |
st.markdown("## Demo Codes") | |
demo_codes_desc = ( | |
"This section contains codes that we've found to be interpretable along " | |
"with a description of the feature we think they are capturing. " | |
"Click on the ๐ search button for a code to see the tokens that code activates on." | |
) | |
st.write(demo_codes_desc) | |
demo_file_path = codes_cache_path + "demo_codes.txt" | |
if st.checkbox("Show Demo Codes"): | |
try: | |
with open(demo_file_path, "r") as f: | |
demo_codes = f.readlines() | |
except FileNotFoundError: | |
demo_codes = [] | |
code_desc, code_regex = "", "" | |
demo_codes = [code.strip() for code in demo_codes if code.strip()] | |
num_cols = 6 if is_attn else 5 | |
cols = st.columns([1] * (num_cols - 1) + [2]) | |
# st.markdown(button_height_style, unsafe_allow_html=True) | |
cols[0].markdown("Search", help="Button to see token activations for the code.") | |
cols[1].write("Code") | |
cols[2].write("Layer") | |
if is_attn: | |
cols[3].write("Head") | |
cols[-2].markdown( | |
"Num Acts", | |
help="Number of tokens that the code activates on in the acts dataset.", | |
) | |
cols[-1].markdown("Description", help="Interpreted description of the code.") | |
if len(demo_codes) == 0: | |
st.markdown( | |
f""" | |
<div style="font-size: 1.0rem; color: red;"> | |
No demo codes found in file {demo_file_path} | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
skip = True | |
for code_txt in demo_codes: | |
if code_txt.startswith("##"): | |
skip = True | |
continue | |
if code_txt.startswith("#"): | |
code_desc, code_regex = code_txt[1:].split(":") | |
code_desc, code_regex = code_desc.strip(), code_regex.strip() | |
skip = False | |
continue | |
if skip: | |
continue | |
code_info = utils.CodeInfo.from_str(code_txt, regex=code_regex) | |
comp_info = f"layer{code_info.layer}_{f'head{code_info.head}' if code_info.head is not None else ''}" | |
button_key = ( | |
f"demo_search_code{code_info.code}_layer{code_info.layer}_desc-{code_info.description}" | |
+ (f"head{code_info.head}" if code_info.head is not None else "") | |
) | |
cols = st.columns([1] * (num_cols - 1) + [2]) | |
button_clicked = cols[0].button( | |
"๐", | |
key=button_key, | |
) | |
if button_clicked: | |
webapp_utils.set_ct_acts( | |
code_info.code, code_info.layer, code_info.head, None, is_attn | |
) | |
cols[1].write(code_info.code) | |
cols[2].write(str(code_info.layer)) | |
if is_attn: | |
cols[3].write(str(code_info.head)) | |
cols[-2].write(str(act_count_ft_tkns[comp_info][code_info.code])) | |
cols[-1].write(code_desc) | |
skip = True | |
# --- Code Search --- | |
st.markdown("## Code Search") | |
code_search_desc = ( | |
"To find whether the codebooks model has captured a relevant feature from the data (e.g. pronouns)," | |
" you can specify a regex pattern for your feature (e.g. โhe|she|theyโ) and find whether any code" | |
" activating on the regex pattern exists.\n\n" | |
"Since strings can contain several tokens, you can specify the token you want a code to fire on by" | |
" using a capture group. For example, the search term โNew (York)โ will try to find codes that" | |
" activate on the bigram feature โNew Yorkโ at the York token" | |
) | |
if st.checkbox("Search with Regex"): | |
st.write(code_search_desc) | |
regex_pattern = st.text_input( | |
"Enter a regex pattern", | |
help="Wrap code token in the first group. E.g. New (York)", | |
key="regex_pattern", | |
) | |
# topk = st.slider("Top K", 1, 20, 10) | |
prec_col, sort_col = st.columns(2) | |
prec_threshold = prec_col.slider( | |
"Precision Threshold", | |
0.0, | |
1.0, | |
0.9, | |
help="Shows codes with precision on the regex pattern above the threshold.", | |
) | |
sort_by_options = ["Precision", "Recall", "Num Acts"] | |
sort_by_name = sort_col.radio( | |
"Sort By", | |
sort_by_options, | |
index=0, | |
horizontal=True, | |
help="Sorts the codes by the selected metric.", | |
) | |
sort_by = sort_by_options.index(sort_by_name) | |
def get_codebook_wise_codes_for_regex( | |
regex_pattern, prec_threshold, gcb, model_name | |
): | |
"""Get codebook wise codes for a given regex pattern.""" | |
assert model_name is not None # required for loading from correct cache data | |
return code_search_utils.get_codes_from_pattern( | |
regex_pattern, | |
tokens_text, | |
token_byte_pos, | |
cb_acts, | |
act_count_ft_tkns, | |
gcb=gcb, | |
topk=8, | |
prec_threshold=prec_threshold, | |
) | |
if regex_pattern: | |
codebook_wise_codes, re_token_matches = get_codebook_wise_codes_for_regex( | |
regex_pattern, | |
prec_threshold, | |
gcb, | |
model_name, | |
) | |
st.markdown( | |
f"Found <span style='color:green;'>{re_token_matches}</span> matches", | |
unsafe_allow_html=True, | |
) | |
num_search_cols = 7 if is_attn else 6 | |
non_deploy_offset = 0 | |
if not deploy: | |
non_deploy_offset = 1 | |
num_search_cols += non_deploy_offset | |
cols = st.columns(num_search_cols) | |
cols[0].markdown("Search", help="Button to see token activations for the code.") | |
cols[1].write("Layer") | |
if is_attn: | |
cols[2].write("Head") | |
cols[-4 - non_deploy_offset].write("Code") | |
cols[-3 - non_deploy_offset].write("Precision") | |
cols[-2 - non_deploy_offset].write("Recall") | |
cols[-1 - non_deploy_offset].markdown( | |
"Num Acts", | |
help="Number of tokens that the code activates on in the acts dataset.", | |
) | |
if not deploy: | |
cols[-1].markdown( | |
"Save to Demos", | |
help="Button to save the code to demos along with the regex pattern.", | |
) | |
all_codes = codebook_wise_codes.items() | |
all_codes = [ | |
(cb_name, code_pr_info) | |
for cb_name, code_pr_infos in all_codes | |
for code_pr_info in code_pr_infos | |
] | |
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True) | |
for cb_name, (code, prec, rec, code_acts) in all_codes: | |
layer_head = cb_name.split("_") | |
layer = layer_head[0][5:] | |
head = layer_head[1][4:] if len(layer_head) > 1 else None | |
button_key = f"search_code{code}_layer{layer}" + ( | |
f"head{head}" if head is not None else "" | |
) | |
cols = st.columns(num_search_cols) | |
extra_args = { | |
"prec": prec, | |
"recall": rec, | |
"num_acts": code_acts, | |
"regex": regex_pattern, | |
} | |
button_clicked = cols[0].button("๐", key=button_key) | |
if button_clicked: | |
webapp_utils.set_ct_acts(code, layer, head, extra_args, is_attn) | |
cols[1].write(layer) | |
if is_attn: | |
cols[2].write(head) | |
cols[-4 - non_deploy_offset].write(code) | |
cols[-3 - non_deploy_offset].write(f"{prec*100:.2f}%") | |
cols[-2 - non_deploy_offset].write(f"{rec*100:.2f}%") | |
cols[-1 - non_deploy_offset].write(str(code_acts)) | |
if not deploy: | |
webapp_utils.add_save_code_button( | |
demo_file_path, | |
num_acts=code_acts, | |
save_regex=True, | |
prec=prec, | |
recall=rec, | |
button_st_container=cols[-1], | |
button_key_suffix=f"_code{code}_layer{layer}_head{head}", | |
) | |
if len(all_codes) == 0: | |
st.markdown( | |
f""" | |
<div style="font-size: 1.0rem; color: red;"> | |
No codes found for pattern {regex_pattern} at precision threshold: {prec_threshold} | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# --- Display Code Token Activations --- | |
st.markdown("## Code Token Activations") | |
filter_codes = st.checkbox("Show filters", key="filter_codes", value=True) | |
act_range, layer_code_acts = None, None | |
if filter_codes: | |
act_range = st.slider( | |
"Minimum number of activations", | |
0, | |
10_000, | |
100, | |
key="ct_act_range", | |
help="Filter codes by the number of tokens they activate on.", | |
) | |
cols = st.columns(5 if is_attn else 4) | |
layer = cols[0].number_input("Layer", 0, num_layers - 1, 0, key="ct_act_layer") | |
if is_attn: | |
head = cols[1].number_input("Head", 0, num_heads - 1, 0, key="ct_act_head") | |
else: | |
head = None | |
def_code = st.session_state.get("ct_act_code", 0) | |
if filter_codes: | |
layer_code_acts = act_count_ft_tkns[ | |
f"layer{layer}{'_head'+str(head) if head is not None else ''}" | |
] | |
def_code = webapp_utils.find_next_code(def_code, layer_code_acts, act_range) | |
if "ct_act_code" in st.session_state: | |
st.session_state["ct_act_code"] = def_code | |
code = cols[-3].number_input( | |
"Code", | |
0, | |
num_codes - 1, | |
def_code, | |
key="ct_act_code", | |
) | |
num_examples = cols[-2].number_input( | |
"Max Results", | |
-1, | |
1000, # setting to 1000 for efficiency purposes even though it can be more than 1000. | |
100, | |
help="Number of examples to show in the results. Set to -1 to show all examples.", | |
) | |
ctx_size = cols[-1].number_input( | |
"Context Size", | |
1, | |
10, | |
5, | |
help="Number of tokens to show before and after the code token.", | |
) | |
acts, acts_count = webapp_utils.get_code_acts( | |
model_name, | |
tokens_str, | |
code, | |
layer, | |
head, | |
ctx_size, | |
num_examples, | |
is_fsm=is_fsm, | |
) | |
st.write( | |
f"Token Activations for Layer {layer}{f' Head {head}' if head is not None else ''} Code {code} | " | |
f"Activates on {acts_count[0]} tokens on the acts dataset", | |
) | |
if not deploy: | |
webapp_utils.add_save_code_button( | |
demo_file_path, | |
acts_count[0], | |
save_regex=False, | |
button_text=True, | |
button_key_suffix="_token_acts", | |
) | |
st.markdown(webapp_utils.escape_markdown(acts), unsafe_allow_html=True) | |