codebook-features / pages /Topic_Code_Browser.py
taufeeque's picture
Add message on topic code model
50c3f87
raw
history blame
7.06 kB
"""Web app page for showing codes for different examples in the dataset."""
import streamlit as st
from streamlit_extras.switch_page_button import switch_page
import code_search_utils
import webapp_utils
webapp_utils.load_widget_state()
if "cb_acts" not in st.session_state:
switch_page("Code_Browser")
total_examples = 2000
prec_threshold = 0.01
model_name = st.session_state["model_name_id"]
seq_len = st.session_state["seq_len"]
tokens_text = st.session_state["tokens_text"]
tokens_str = st.session_state["tokens_str"]
cb_acts = st.session_state["cb_acts"]
act_count_ft_tkns = st.session_state["act_count_ft_tkns"]
gcb = st.session_state["gcb"]
def get_example_topic_codes(example_id):
"""Get topic codes for the given example id."""
token_pos_ids = [(example_id, i) for i in range(seq_len)]
all_codes = []
for cb_name, cb in cb_acts.items():
base_cb_name = code_search_utils.convert_to_base_name(cb_name, gcb=gcb)
codes, prec, rec, code_acts = code_search_utils.get_code_precision_and_recall(
token_pos_ids,
cb,
act_count_ft_tkns[base_cb_name],
)
prec_sat_idx = prec >= prec_threshold
codes, prec, rec, code_acts = (
codes[prec_sat_idx],
prec[prec_sat_idx],
rec[prec_sat_idx],
code_acts[prec_sat_idx],
)
rec_sat_idx = rec >= recall_threshold
codes, prec, rec, code_acts = (
codes[rec_sat_idx],
prec[rec_sat_idx],
rec[rec_sat_idx],
code_acts[rec_sat_idx],
)
codes_pr = list(zip(codes, prec, rec, code_acts))
all_codes.append((cb_name, codes_pr))
return all_codes
def find_next_example(example_id):
"""Find the example after `example_id` that has topic codes."""
initial_example_id = example_id
example_id += 1
while example_id != initial_example_id:
all_codes = get_example_topic_codes(example_id)
codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes])
if codes_found > 0:
st.session_state["example_id"] = example_id
return
example_id = (example_id + 1) % total_examples
st.error(
f"No examples found at the specified recall threshold: {recall_threshold}.",
icon="🚨",
)
def redirect_to_main_with_code(code, layer, head):
"""Redirect to main page with the given code."""
st.session_state["ct_act_code"] = code
st.session_state["ct_act_layer"] = layer
if st.session_state["is_attn"]:
st.session_state["ct_act_head"] = head
switch_page("Code Browser")
def show_examples_for_topic_code(code, layer, head, code_act_ratio=0.3):
"""Show examples that the code activates on."""
ex_acts, _ = webapp_utils.get_code_acts(
model_name,
tokens_str,
code,
layer,
head,
ctx_size=5,
return_example_list=True,
)
filt_ex_acts = []
for act_str, num_acts in ex_acts:
if num_acts > seq_len * code_act_ratio:
filt_ex_acts.append(act_str)
st.markdown("#### Examples for Code")
st.markdown(
webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True
)
is_attn = st.session_state["is_attn"]
st.markdown("## Topic Code")
topic_code_description = (
"Topic codes are codes that activate many different times on passages that describe a particular"
" topic or concept (e.g. “fire”). This interface provides a way to search for such codes by looking"
" at different examples in the dataset (ExampleID) and finding codes that activate on some fraction"
" of the tokens in that example (Recall Threshold). Decrease the Recall Threshold to view more possible"
" topic codes and increase it to see fewer. Click “Find Next Example” to find the next example with at"
" least one code firing on that example above the Recall Threshold.\n\n"
"Topic codes are displayed for the codebook model selected on the Code Browser page. To view topic codes"
" for a different model, go to the Code Browser page and select a different model."
)
st.write(topic_code_description)
ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1])
example_id = ex_col.number_input(
"Example ID",
0,
total_examples - 1,
0,
key="example_id",
)
recall_threshold = r_col.slider(
"Recall Threshold",
0.0,
1.0,
0.2,
key="recall",
help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.",
)
example_truncation = trunc_col.number_input(
"Max Output Chars", 0, 102400, 1024, key="max_chars"
)
sort_by_options = ["Precision", "Recall", "Num Acts"]
sort_by_name = sort_col.radio(
"Sort By",
sort_by_options,
index=1,
horizontal=True,
help="Sorts the codes by the selected metric.",
)
sort_by = sort_by_options.index(sort_by_name)
button = st.button(
"Find Next Example",
key="find_next_example",
on_click=find_next_example,
args=(example_id,),
help="Find an example which has codes above the recall threshold.",
)
st.markdown("### Example Text")
trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else ""
st.write(tokens_text[example_id][:example_truncation] + trunc_suffix)
cols = st.columns(7 if is_attn else 6)
cols[0].markdown("Search", help="Button to see token activations for the code.")
cols[1].write("Layer")
if is_attn:
cols[2].write("Head")
cols[-4].write("Code")
cols[-3].write("Precision")
cols[-2].write("Recall")
cols[-1].markdown(
"Num Acts",
help="Number of tokens that the code activates on in the acts dataset.",
)
all_codes = get_example_topic_codes(example_id)
all_codes = [
(cb_name, code_pr_info)
for cb_name, code_pr_infos in all_codes
for code_pr_info in code_pr_infos
]
all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True)
for cb_name, (code, p, r, acts) in all_codes:
cols = st.columns(7 if is_attn else 6)
code_button = cols[0].button(
"🔍",
key=f"ex-code-{code}-{cb_name}",
)
layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name)
cols[1].write(str(layer))
if is_attn:
cols[2].write(str(head))
cols[-4].write(code)
cols[-3].write(f"{p*100:.2f}%")
cols[-2].write(f"{r*100:.2f}%")
cols[-1].write(str(acts))
if code_button:
show_examples_for_topic_code(
code,
layer,
head,
code_act_ratio=recall_threshold,
)
if len(all_codes) == 0:
st.markdown(
f"<div style='text-align:center'>No codes found at recall threshold = {recall_threshold}."
" Consider decreasing the recall threshold.</div>",
unsafe_allow_html=True,
)