File size: 4,557 Bytes
6923ebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6289538
6923ebd
 
 
 
 
6289538
6923ebd
 
 
6289538
6923ebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6289538
6923ebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6289538
6923ebd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import random
import time
import logging
from json import JSONDecodeError

import streamlit as st

from app_utils.backend_utils import load_statements, query
from app_utils.frontend_utils import (
    set_state_if_absent,
    reset_results,
    entailment_html_messages,
    create_df_for_relevant_snippets,
    create_ternary_plot,
    build_sidebar,
)
from app_utils.config import RETRIEVER_TOP_K


def main():
    statements = load_statements()
    build_sidebar()

    # Persistent state
    set_state_if_absent("statement", "Referral bonus can only be given if your friend joins Newton School on your behalf")
    set_state_if_absent("answer", "")
    set_state_if_absent("results", None)
    set_state_if_absent("raw_json", None)
    set_state_if_absent("random_statement_requested", False)

    st.write("Referral Mis-Sell")
    st.write()
    st.markdown(
        """
    ##### Enter statement
    """
    )
    # Search bar
    statement = st.text_input(
        "", value=st.session_state.statement, max_chars=100, on_change=reset_results
    )
    col1, col2 = st.columns(2)
    col1.markdown(
        "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
    )
    col2.markdown(
        "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
    )
    # Run button
    run_pressed = col1.button("Run")
    # Random statement button
    if col2.button("Random statement"):
        reset_results()
        statement = random.choice(statements)
        # Avoid picking the same statement twice (the change is not visible on the UI)
        while statement == st.session_state.statement:
            statement = random.choice(statements)
        st.session_state.statement = statement
        st.session_state.random_statement_requested = True
        # Re-runs the script setting the random statement as the textbox value
        # Unfortunately necessary as the Random statement button is _below_ the textbox
        # Adapted for Streamlit>=1.12.0
        if hasattr(st, "scriptrunner"):
            raise st.scriptrunner.script_runner.RerunException(
                st.scriptrunner.script_requests.RerunData(widget_states=None)
            )
        raise st.runtime.scriptrunner.script_runner.RerunException(
            st.runtime.scriptrunner.script_requests.RerunData(widget_states=None)
        )
    else:
        st.session_state.random_statement_requested = False
    run_query = (
        run_pressed or statement != st.session_state.statement
    ) and not st.session_state.random_statement_requested

    # Get results for query
    if run_query and statement:
        time_start = time.time()
        reset_results()
        st.session_state.statement = statement
        with st.spinner("🧠 &nbsp;&nbsp; Running Model..."):
            try:
                st.session_state.results = query(statement, RETRIEVER_TOP_K)
                print(f"S: {statement}")
                time_end = time.time()
                print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
                print(f"elapsed time: {time_end - time_start}")
            except JSONDecodeError as je:
                st.error(
                    "👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?"
                )
                return
            except Exception as e:
                logging.exception(e)
                st.error("🐞 &nbsp;&nbsp; An error occurred during the request.")
                return

    # Display results
    if st.session_state.results:
        docs = st.session_state.results["documents"]
        agg_entailment_info = st.session_state.results["aggregate_entailment_info"]

        # show different messages depending on entailment results
        max_key = max(agg_entailment_info, key=agg_entailment_info.get)
        message = entailment_html_messages[max_key]
        st.markdown(f"<br/><h4>{message}</h4>", unsafe_allow_html=True)

        st.markdown(f"###### Aggregate entailment information:")
        col1, col2 = st.columns([2, 1])
        fig = create_ternary_plot(agg_entailment_info)
        with col1:
            st.plotly_chart(fig, use_container_width=True)
        with col2:
            st.write(agg_entailment_info)

        st.markdown(f"###### Most Relevant snippets:")
        df, urls = create_df_for_relevant_snippets(docs)
        st.dataframe(df)
        str_wiki_pages = "Data: "
        for doc, url in urls.items():
            str_wiki_pages += f"[{doc}]({url}) "
        st.markdown(str_wiki_pages)


main()