|
import random |
|
import time |
|
import logging |
|
from json import JSONDecodeError |
|
|
|
import streamlit as st |
|
|
|
from app_utils.backend_utils import load_statements, query |
|
from app_utils.frontend_utils import ( |
|
set_state_if_absent, |
|
reset_results, |
|
entailment_html_messages, |
|
create_df_for_relevant_snippets, |
|
create_ternary_plot, |
|
build_sidebar, |
|
) |
|
from app_utils.config import RETRIEVER_TOP_K |
|
|
|
|
|
def main(): |
|
statements = load_statements() |
|
build_sidebar() |
|
|
|
|
|
set_state_if_absent("statement", "Referral bonus can only be given if your friend joins Newton School on your behalf") |
|
set_state_if_absent("answer", "") |
|
set_state_if_absent("results", None) |
|
set_state_if_absent("raw_json", None) |
|
set_state_if_absent("random_statement_requested", False) |
|
|
|
st.write("Referral Mis-Sell") |
|
st.write() |
|
st.markdown( |
|
""" |
|
##### Enter statement |
|
""" |
|
) |
|
|
|
statement = st.text_input( |
|
"", value=st.session_state.statement, max_chars=100, on_change=reset_results |
|
) |
|
col1, col2 = st.columns(2) |
|
col1.markdown( |
|
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True |
|
) |
|
col2.markdown( |
|
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True |
|
) |
|
|
|
run_pressed = col1.button("Run") |
|
|
|
if col2.button("Random statement"): |
|
reset_results() |
|
statement = random.choice(statements) |
|
|
|
while statement == st.session_state.statement: |
|
statement = random.choice(statements) |
|
st.session_state.statement = statement |
|
st.session_state.random_statement_requested = True |
|
|
|
|
|
|
|
if hasattr(st, "scriptrunner"): |
|
raise st.scriptrunner.script_runner.RerunException( |
|
st.scriptrunner.script_requests.RerunData(widget_states=None) |
|
) |
|
raise st.runtime.scriptrunner.script_runner.RerunException( |
|
st.runtime.scriptrunner.script_requests.RerunData(widget_states=None) |
|
) |
|
else: |
|
st.session_state.random_statement_requested = False |
|
run_query = ( |
|
run_pressed or statement != st.session_state.statement |
|
) and not st.session_state.random_statement_requested |
|
|
|
|
|
if run_query and statement: |
|
time_start = time.time() |
|
reset_results() |
|
st.session_state.statement = statement |
|
with st.spinner("π§ Running Model..."): |
|
try: |
|
st.session_state.results = query(statement, RETRIEVER_TOP_K) |
|
print(f"S: {statement}") |
|
time_end = time.time() |
|
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) |
|
print(f"elapsed time: {time_end - time_start}") |
|
except JSONDecodeError as je: |
|
st.error( |
|
"π An error occurred reading the results. Is the document store working?" |
|
) |
|
return |
|
except Exception as e: |
|
logging.exception(e) |
|
st.error("π An error occurred during the request.") |
|
return |
|
|
|
|
|
if st.session_state.results: |
|
docs = st.session_state.results["documents"] |
|
agg_entailment_info = st.session_state.results["aggregate_entailment_info"] |
|
|
|
|
|
max_key = max(agg_entailment_info, key=agg_entailment_info.get) |
|
message = entailment_html_messages[max_key] |
|
st.markdown(f"<br/><h4>{message}</h4>", unsafe_allow_html=True) |
|
|
|
st.markdown(f"###### Aggregate entailment information:") |
|
col1, col2 = st.columns([2, 1]) |
|
fig = create_ternary_plot(agg_entailment_info) |
|
with col1: |
|
st.plotly_chart(fig, use_container_width=True) |
|
with col2: |
|
st.write(agg_entailment_info) |
|
|
|
st.markdown(f"###### Most Relevant snippets:") |
|
df, urls = create_df_for_relevant_snippets(docs) |
|
st.dataframe(df) |
|
str_wiki_pages = "Data: " |
|
for doc, url in urls.items(): |
|
str_wiki_pages += f"[{doc}]({url}) " |
|
st.markdown(str_wiki_pages) |
|
|
|
|
|
main() |
|
|