import pandas as pd import streamlit as st from streamlit_option_menu import option_menu from load import load_context from subpages import ( DebugPage, FindDuplicatesPage, HomePage, LossesPage, LossySamplesPage, MetricsPage, MisclassifiedPage, Page, ProbingPage, RandomSamplesPage, RawDataPage, ) from subpages.attention import AttentionPage from subpages.hidden_states import HiddenStatesPage from subpages.inspect import InspectPage from utils import classmap sts = st.sidebar st.set_page_config( layout="wide", page_title="Error Analysis", page_icon="🏷️", ) def _show_menu(pages: list[Page]) -> int: with st.sidebar: page_names = [p.name for p in pages] page_icons = [p.icon for p in pages] selected_menu_item = st.session_state.active_page = option_menu( menu_title="ExplaiNER", options=page_names, icons=page_icons, menu_icon="layout-wtf", default_index=0, ) return page_names.index(selected_menu_item) assert False def _initialize_session_state(pages: list[Page]): if "active_page" not in st.session_state: for page in pages: st.session_state.update(**page.get_widget_defaults()) st.session_state.update(st.session_state) def _write_color_legend(context): def style(x): return [f"background-color: {rgb}; opacity: 1;" for rgb in colors] labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels])) colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels] color_legend_df = pd.DataFrame( [classmap[l] for l in labels], columns=["label"], index=labels ).T st.sidebar.write( color_legend_df.T.style.apply(style, axis=0).set_properties( **{"color": "white", "text-align": "center"} ) ) def main(): pages: list[Page] = [ HomePage(), AttentionPage(), HiddenStatesPage(), ProbingPage(), MetricsPage(), LossySamplesPage(), LossesPage(), MisclassifiedPage(), RandomSamplesPage(), FindDuplicatesPage(), InspectPage(), RawDataPage(), DebugPage(), ] _initialize_session_state(pages) selected_page_idx = _show_menu(pages) selected_page = pages[selected_page_idx] if isinstance(selected_page, HomePage): selected_page.render() return if "model_name" not in st.session_state: # this can happen if someone loads another page directly (without going through home) st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'") return context = load_context(**st.session_state) _write_color_legend(context) selected_page.render(context) if __name__ == "__main__": main()