File size: 4,557 Bytes
6923ebd 6289538 6923ebd 6289538 6923ebd 6289538 6923ebd 6289538 6923ebd 6289538 6923ebd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import random
import time
import logging
from json import JSONDecodeError
import streamlit as st
from app_utils.backend_utils import load_statements, query
from app_utils.frontend_utils import (
set_state_if_absent,
reset_results,
entailment_html_messages,
create_df_for_relevant_snippets,
create_ternary_plot,
build_sidebar,
)
from app_utils.config import RETRIEVER_TOP_K
def main():
statements = load_statements()
build_sidebar()
# Persistent state
set_state_if_absent("statement", "Referral bonus can only be given if your friend joins Newton School on your behalf")
set_state_if_absent("answer", "")
set_state_if_absent("results", None)
set_state_if_absent("raw_json", None)
set_state_if_absent("random_statement_requested", False)
st.write("Referral Mis-Sell")
st.write()
st.markdown(
"""
##### Enter statement
"""
)
# Search bar
statement = st.text_input(
"", value=st.session_state.statement, max_chars=100, on_change=reset_results
)
col1, col2 = st.columns(2)
col1.markdown(
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
)
col2.markdown(
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
)
# Run button
run_pressed = col1.button("Run")
# Random statement button
if col2.button("Random statement"):
reset_results()
statement = random.choice(statements)
# Avoid picking the same statement twice (the change is not visible on the UI)
while statement == st.session_state.statement:
statement = random.choice(statements)
st.session_state.statement = statement
st.session_state.random_statement_requested = True
# Re-runs the script setting the random statement as the textbox value
# Unfortunately necessary as the Random statement button is _below_ the textbox
# Adapted for Streamlit>=1.12.0
if hasattr(st, "scriptrunner"):
raise st.scriptrunner.script_runner.RerunException(
st.scriptrunner.script_requests.RerunData(widget_states=None)
)
raise st.runtime.scriptrunner.script_runner.RerunException(
st.runtime.scriptrunner.script_requests.RerunData(widget_states=None)
)
else:
st.session_state.random_statement_requested = False
run_query = (
run_pressed or statement != st.session_state.statement
) and not st.session_state.random_statement_requested
# Get results for query
if run_query and statement:
time_start = time.time()
reset_results()
st.session_state.statement = statement
with st.spinner("🧠 Running Model..."):
try:
st.session_state.results = query(statement, RETRIEVER_TOP_K)
print(f"S: {statement}")
time_end = time.time()
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
print(f"elapsed time: {time_end - time_start}")
except JSONDecodeError as je:
st.error(
"👓 An error occurred reading the results. Is the document store working?"
)
return
except Exception as e:
logging.exception(e)
st.error("🐞 An error occurred during the request.")
return
# Display results
if st.session_state.results:
docs = st.session_state.results["documents"]
agg_entailment_info = st.session_state.results["aggregate_entailment_info"]
# show different messages depending on entailment results
max_key = max(agg_entailment_info, key=agg_entailment_info.get)
message = entailment_html_messages[max_key]
st.markdown(f"<br/><h4>{message}</h4>", unsafe_allow_html=True)
st.markdown(f"###### Aggregate entailment information:")
col1, col2 = st.columns([2, 1])
fig = create_ternary_plot(agg_entailment_info)
with col1:
st.plotly_chart(fig, use_container_width=True)
with col2:
st.write(agg_entailment_info)
st.markdown(f"###### Most Relevant snippets:")
df, urls = create_df_for_relevant_snippets(docs)
st.dataframe(df)
str_wiki_pages = "Data: "
for doc, url in urls.items():
str_wiki_pages += f"[{doc}]({url}) "
st.markdown(str_wiki_pages)
main()
|