Spaces:
Running
Running
"""Web app page for showing codes for different examples in the dataset.""" | |
import streamlit as st | |
from streamlit_extras.switch_page_button import switch_page | |
import code_search_utils | |
import webapp_utils | |
webapp_utils.load_widget_state() | |
if "cb_acts" not in st.session_state: | |
switch_page("Code_Browser") | |
total_examples = 2000 | |
prec_threshold = 0.01 | |
model_name = st.session_state["model_name_id"] | |
seq_len = st.session_state["seq_len"] | |
tokens_text = st.session_state["tokens_text"] | |
tokens_str = st.session_state["tokens_str"] | |
cb_acts = st.session_state["cb_acts"] | |
act_count_ft_tkns = st.session_state["act_count_ft_tkns"] | |
gcb = st.session_state["gcb"] | |
def get_example_topic_codes(example_id): | |
"""Get topic codes for the given example id.""" | |
token_pos_ids = [(example_id, i) for i in range(seq_len)] | |
all_codes = [] | |
for cb_name, cb in cb_acts.items(): | |
base_cb_name = code_search_utils.convert_to_base_name(cb_name, gcb=gcb) | |
codes, prec, rec, code_acts = code_search_utils.get_code_precision_and_recall( | |
token_pos_ids, | |
cb, | |
act_count_ft_tkns[base_cb_name], | |
) | |
prec_sat_idx = prec >= prec_threshold | |
codes, prec, rec, code_acts = ( | |
codes[prec_sat_idx], | |
prec[prec_sat_idx], | |
rec[prec_sat_idx], | |
code_acts[prec_sat_idx], | |
) | |
rec_sat_idx = rec >= recall_threshold | |
codes, prec, rec, code_acts = ( | |
codes[rec_sat_idx], | |
prec[rec_sat_idx], | |
rec[rec_sat_idx], | |
code_acts[rec_sat_idx], | |
) | |
codes_pr = list(zip(codes, prec, rec, code_acts)) | |
all_codes.append((cb_name, codes_pr)) | |
return all_codes | |
def find_next_example(example_id): | |
"""Find the example after `example_id` that has topic codes.""" | |
initial_example_id = example_id | |
example_id += 1 | |
while example_id != initial_example_id: | |
all_codes = get_example_topic_codes(example_id) | |
codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes]) | |
if codes_found > 0: | |
st.session_state["example_id"] = example_id | |
return | |
example_id = (example_id + 1) % total_examples | |
st.error( | |
f"No examples found at the specified recall threshold: {recall_threshold}.", | |
icon="🚨", | |
) | |
def redirect_to_main_with_code(code, layer, head): | |
"""Redirect to main page with the given code.""" | |
st.session_state["ct_act_code"] = code | |
st.session_state["ct_act_layer"] = layer | |
if st.session_state["is_attn"]: | |
st.session_state["ct_act_head"] = head | |
switch_page("Code Browser") | |
def show_examples_for_topic_code(code, layer, head, code_act_ratio=0.3): | |
"""Show examples that the code activates on.""" | |
ex_acts, _ = webapp_utils.get_code_acts( | |
model_name, | |
tokens_str, | |
code, | |
layer, | |
head, | |
ctx_size=5, | |
return_example_list=True, | |
) | |
filt_ex_acts = [] | |
for act_str, num_acts in ex_acts: | |
if num_acts > seq_len * code_act_ratio: | |
filt_ex_acts.append(act_str) | |
st.markdown("#### Examples for Code") | |
st.markdown( | |
webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True | |
) | |
is_attn = st.session_state["is_attn"] | |
st.markdown("## Topic Code") | |
topic_code_description = ( | |
"Topic codes are codes that activate many different times on passages that describe a particular" | |
" topic or concept (e.g. “fire”). This interface provides a way to search for such codes by looking" | |
" at different examples in the dataset (ExampleID) and finding codes that activate on some fraction" | |
" of the tokens in that example (Recall Threshold). Decrease the Recall Threshold to view more possible" | |
" topic codes and increase it to see fewer. Click “Find Next Example” to find the next example with at" | |
" least one code firing on that example above the Recall Threshold.\n\n" | |
"Topic codes are displayed for the codebook model selected on the Code Browser page. To view topic codes" | |
" for a different model, go to the Code Browser page and select a different model." | |
) | |
st.write(topic_code_description) | |
ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1]) | |
example_id = ex_col.number_input( | |
"Example ID", | |
0, | |
total_examples - 1, | |
0, | |
key="example_id", | |
) | |
recall_threshold = r_col.slider( | |
"Recall Threshold", | |
0.0, | |
1.0, | |
0.2, | |
key="recall", | |
help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.", | |
) | |
example_truncation = trunc_col.number_input( | |
"Max Output Chars", 0, 102400, 1024, key="max_chars" | |
) | |
sort_by_options = ["Precision", "Recall", "Num Acts"] | |
sort_by_name = sort_col.radio( | |
"Sort By", | |
sort_by_options, | |
index=1, | |
horizontal=True, | |
help="Sorts the codes by the selected metric.", | |
) | |
sort_by = sort_by_options.index(sort_by_name) | |
button = st.button( | |
"Find Next Example", | |
key="find_next_example", | |
on_click=find_next_example, | |
args=(example_id,), | |
help="Find an example which has codes above the recall threshold.", | |
) | |
st.markdown("### Example Text") | |
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else "" | |
st.write(tokens_text[example_id][:example_truncation] + trunc_suffix) | |
cols = st.columns(7 if is_attn else 6) | |
cols[0].markdown("Search", help="Button to see token activations for the code.") | |
cols[1].write("Layer") | |
if is_attn: | |
cols[2].write("Head") | |
cols[-4].write("Code") | |
cols[-3].write("Precision") | |
cols[-2].write("Recall") | |
cols[-1].markdown( | |
"Num Acts", | |
help="Number of tokens that the code activates on in the acts dataset.", | |
) | |
all_codes = get_example_topic_codes(example_id) | |
all_codes = [ | |
(cb_name, code_pr_info) | |
for cb_name, code_pr_infos in all_codes | |
for code_pr_info in code_pr_infos | |
] | |
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True) | |
for cb_name, (code, p, r, acts) in all_codes: | |
cols = st.columns(7 if is_attn else 6) | |
code_button = cols[0].button( | |
"🔍", | |
key=f"ex-code-{code}-{cb_name}", | |
) | |
layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name) | |
cols[1].write(str(layer)) | |
if is_attn: | |
cols[2].write(str(head)) | |
cols[-4].write(code) | |
cols[-3].write(f"{p*100:.2f}%") | |
cols[-2].write(f"{r*100:.2f}%") | |
cols[-1].write(str(acts)) | |
if code_button: | |
show_examples_for_topic_code( | |
code, | |
layer, | |
head, | |
code_act_ratio=recall_threshold, | |
) | |
if len(all_codes) == 0: | |
st.markdown( | |
f"<div style='text-align:center'>No codes found at recall threshold = {recall_threshold}." | |
" Consider decreasing the recall threshold.</div>", | |
unsafe_allow_html=True, | |
) | |