File size: 5,271 Bytes
13e8963
 
 
f2d4743
13e8963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e698b42
 
dfae691
 
 
13e8963
 
 
f2d4743
13e8963
 
c28665f
13e8963
f2d4743
c28665f
 
f2d4743
c28665f
13e8963
f2d4743
 
 
 
 
 
 
 
 
13e8963
 
 
 
 
 
 
 
 
f2d4743
13e8963
 
 
 
 
c28665f
 
 
 
 
 
 
 
 
 
13e8963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca2e2c2
13e8963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e531ec
f2d4743
5e531ec
13e8963
 
 
 
 
f2d4743
 
c28665f
f2d4743
 
13e8963
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr  # type: ignore
import plotly.express as px  # type: ignore

from backend.data import load_cot_data
from backend.envs import API, REPO_ID, TOKEN

logo1_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/AI2_Logo_Square.png"
logo2_url = "https://raw.githubusercontent.com/logikon-ai/cot-eval/main/assets/logo_logikon_notext_withborder.png"
LOGOS = f'<div style="display: flex; justify-content: center;"><a href="https://allenai.org/"><img src="{logo1_url}" alt="AI2" style="width: 30vw; min-width: 20px; max-width: 60px;"></a> <a href="https://logikon.ai"><img src="{logo2_url}" alt="Logikon AI" style="width: 30vw; min-width: 20px; max-width: 60px; margin-left: 10px;"></a></div>'

TITLE = f'<h1 align="center" id="space-title"> Open CoT Dashboard</h1> {LOGOS}'

INTRODUCTION_TEXT = """
Baseline accuracies and marginal accuracy gains for specific models and CoT regimes from the [Open CoT Leaderboard](https://huggingface.co/spaces/logikon/open_cot_leaderboard).
"""

def restart_space():
    API.restart_space(repo_id=REPO_ID, token=TOKEN)

try:
    df_cot_err, df_cot_regimes = load_cot_data()
except Exception as err:
    print(err)
    # sleep for 10 seconds before restarting the space
    import time
    time.sleep(10)
    restart_space()


def plot_evals_init(model_id, regex_model_filter, plotly_mode, request: gr.Request):
    if request and "model" in request.query_params:
        model_param = request.query_params["model"]
        if model_param in df_cot_err.model.to_list():
            model_id = model_param
    return plot_evals(model_id, regex_model_filter, plotly_mode)


def plot_evals(model_id, regex_model_filter, plotly_mode):
    df = df_cot_err.copy()
    df["selected"] = df_cot_err.model.apply(lambda x: "selected" if x==model_id else "-")
    
    try:
        df_filter = df.model.str.contains(regex_model_filter)
    except Exception as err:
        gr.Warning("Failed to apply regex filter", duration=4)
        print("Failed to apply regex filter" + err)
        df_filter = df.model.str.contains(".*")    
    df = df[df_filter | df.selected.eq("selected")]

    #df.sort_values(["selected", "model"], inplace=True, ascending=True)  # has currently no effect with px.scatter
    template = "plotly_dark" if plotly_mode=="dark" else "plotly" 
    fig = px.scatter(df, x="base accuracy", y="marginal acc. gain", color="selected", symbol="model",
                    facet_col="task", facet_col_wrap=3,
                    category_orders={"selected": ["selected", "-"]},
                    color_discrete_sequence=["Orange", "Gray"],
                    template=template,
                    error_y="acc_gain-err", hover_data=['model', "cot accuracy"],
                    width=1200, height=700)
    
    fig.update_layout(
        title={"automargin": True},
    )
    return fig, model_id


def styled_model_table_init(model_id, request: gr.Request):
    if request and "model" in request.query_params:
        model_param = request.query_params["model"]
        if model_param in df_cot_regimes.model.to_list():
            model_id = model_param
    return styled_model_table(model_id)


def styled_model_table(model_id):

    def make_pretty(styler):
        styler.hide(axis="index")
        styler.format(precision=1),
        styler.background_gradient(
            axis=None,
            subset=["acc_base", "acc_cot"],
            vmin=20, vmax=100, cmap="YlGnBu"
        )
        styler.background_gradient(
            axis=None,
            subset=["acc_gain"],
            vmin=-20, vmax=20, cmap="coolwarm"
        )
        styler.set_table_styles({
            'task': [{'selector': '',
                  'props': [('font-weight', 'bold')]}],
            'B': [{'selector': 'td',
                  'props': 'color: blue;'}]
        }, overwrite=False)
        return styler

    df_cot_model = df_cot_regimes[df_cot_regimes.model.eq(model_id)][['task', 'cot_chain', 'best_of',
          'temperature', 'top_k', 'top_p', 'acc_base', 'acc_cot', 'acc_gain']]
    
    df_cot_model = df_cot_model \
      .rename(columns={"temperature": "temp"}) \
      .replace({'cot_chain': 'ReflectBeforeRun'}, "Reflect") \
      .sort_values(["task", "cot_chain"]) \
      .reset_index(drop=True)
    
    return df_cot_model.style.pipe(make_pretty)


demo = gr.Blocks()

with demo:

    gr.HTML(TITLE)
    gr.Markdown(INTRODUCTION_TEXT)
    with gr.Row():
        selected_model = gr.Dropdown(list(df_cot_err.model.unique()), value="allenai/tulu-2-70b", label="Model", info="with performance details below", scale=2)
        regex_model_filter = gr.Textbox(".*", label="Regex", info="to filter models shown in plots", scale=2)
        plotly_mode = gr.Radio(["dark","light"], value="light", label="Theme", info="of plots", scale=1)
        submit = gr.Button("Update", scale=1)
    table = gr.DataFrame()
    plot = gr.Plot(label="evals")


    submit.click(plot_evals, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model])
    submit.click(styled_model_table, selected_model, table)

    demo.load(plot_evals_init, [selected_model, regex_model_filter, plotly_mode], [plot, selected_model])
    demo.load(styled_model_table_init, selected_model, table)

demo.launch()