"""Web app page for showing codes for different examples in the dataset.""" import streamlit as st from streamlit_extras.switch_page_button import switch_page import code_search_utils import webapp_utils webapp_utils.load_widget_state() if "cb_acts" not in st.session_state: switch_page("Code_Browser") total_examples = 2000 prec_threshold = 0.01 model_name = st.session_state["model_name_id"] seq_len = st.session_state["seq_len"] tokens_text = st.session_state["tokens_text"] tokens_str = st.session_state["tokens_str"] cb_acts = st.session_state["cb_acts"] act_count_ft_tkns = st.session_state["act_count_ft_tkns"] gcb = st.session_state["gcb"] def get_example_topic_codes(example_id): """Get topic codes for the given example id.""" token_pos_ids = [(example_id, i) for i in range(seq_len)] all_codes = [] for cb_name, cb in cb_acts.items(): base_cb_name = code_search_utils.convert_to_base_name(cb_name, gcb=gcb) codes, prec, rec, code_acts = code_search_utils.get_code_precision_and_recall( token_pos_ids, cb, act_count_ft_tkns[base_cb_name], ) prec_sat_idx = prec >= prec_threshold codes, prec, rec, code_acts = ( codes[prec_sat_idx], prec[prec_sat_idx], rec[prec_sat_idx], code_acts[prec_sat_idx], ) rec_sat_idx = rec >= recall_threshold codes, prec, rec, code_acts = ( codes[rec_sat_idx], prec[rec_sat_idx], rec[rec_sat_idx], code_acts[rec_sat_idx], ) codes_pr = list(zip(codes, prec, rec, code_acts)) all_codes.append((cb_name, codes_pr)) return all_codes def find_next_example(example_id): """Find the example after `example_id` that has topic codes.""" initial_example_id = example_id example_id += 1 while example_id != initial_example_id: all_codes = get_example_topic_codes(example_id) codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes]) if codes_found > 0: st.session_state["example_id"] = example_id return example_id = (example_id + 1) % total_examples st.error( f"No examples found at the specified recall threshold: {recall_threshold}.", icon="🚨", ) def redirect_to_main_with_code(code, layer, head): """Redirect to main page with the given code.""" st.session_state["ct_act_code"] = code st.session_state["ct_act_layer"] = layer if st.session_state["is_attn"]: st.session_state["ct_act_head"] = head switch_page("Code Browser") def show_examples_for_topic_code(code, layer, head, code_act_ratio=0.3): """Show examples that the code activates on.""" ex_acts, _ = webapp_utils.get_code_acts( model_name, tokens_str, code, layer, head, ctx_size=5, return_example_list=True, ) filt_ex_acts = [] for act_str, num_acts in ex_acts: if num_acts > seq_len * code_act_ratio: filt_ex_acts.append(act_str) st.markdown("#### Examples for Code") st.markdown( webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True ) is_attn = st.session_state["is_attn"] st.markdown("## Topic Code") topic_code_description = ( "Topic codes are codes that activate many different times on passages that describe a particular" " topic or concept (e.g. “fire”). This interface provides a way to search for such codes by looking" " at different examples in the dataset (ExampleID) and finding codes that activate on some fraction" " of the tokens in that example (Recall Threshold). Decrease the Recall Threshold to view more possible" " topic codes and increase it to see fewer. Click “Find Next Example” to find the next example with at" " least one code firing on that example above the Recall Threshold.\n\n" "Topic codes are displayed for the codebook model selected on the Code Browser page. To view topic codes" " for a different model, go to the Code Browser page and select a different model." ) st.write(topic_code_description) ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1]) example_id = ex_col.number_input( "Example ID", 0, total_examples - 1, 0, key="example_id", ) recall_threshold = r_col.slider( "Recall Threshold", 0.0, 1.0, 0.2, key="recall", help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.", ) example_truncation = trunc_col.number_input( "Max Output Chars", 0, 102400, 1024, key="max_chars" ) sort_by_options = ["Precision", "Recall", "Num Acts"] sort_by_name = sort_col.radio( "Sort By", sort_by_options, index=1, horizontal=True, help="Sorts the codes by the selected metric.", ) sort_by = sort_by_options.index(sort_by_name) button = st.button( "Find Next Example", key="find_next_example", on_click=find_next_example, args=(example_id,), help="Find an example which has codes above the recall threshold.", ) st.markdown("### Example Text") trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else "" st.write(tokens_text[example_id][:example_truncation] + trunc_suffix) cols = st.columns(7 if is_attn else 6) cols[0].markdown("Search", help="Button to see token activations for the code.") cols[1].write("Layer") if is_attn: cols[2].write("Head") cols[-4].write("Code") cols[-3].write("Precision") cols[-2].write("Recall") cols[-1].markdown( "Num Acts", help="Number of tokens that the code activates on in the acts dataset.", ) all_codes = get_example_topic_codes(example_id) all_codes = [ (cb_name, code_pr_info) for cb_name, code_pr_infos in all_codes for code_pr_info in code_pr_infos ] all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True) for cb_name, (code, p, r, acts) in all_codes: cols = st.columns(7 if is_attn else 6) code_button = cols[0].button( "🔍", key=f"ex-code-{code}-{cb_name}", ) layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name) cols[1].write(str(layer)) if is_attn: cols[2].write(str(head)) cols[-4].write(code) cols[-3].write(f"{p*100:.2f}%") cols[-2].write(f"{r*100:.2f}%") cols[-1].write(str(acts)) if code_button: show_examples_for_topic_code( code, layer, head, code_act_ratio=recall_threshold, ) if len(all_codes) == 0: st.markdown( f"