notSoNLPnerd commited on
Commit
c24940a
β€’
1 Parent(s): d26c2ca

made changes

Browse files
Files changed (3) hide show
  1. .streamlit/config.toml +10 -0
  2. app.py +50 -28
  3. backend_utils.py +13 -17
.streamlit/config.toml CHANGED
@@ -1,3 +1,13 @@
1
  [theme]
2
  base = "light"
3
  font="monospace"
 
 
 
 
 
 
 
 
 
 
 
1
  [theme]
2
  base = "light"
3
  font="monospace"
4
+ [global]
5
+
6
+ # By default, Streamlit checks if the Python watchdog module is available and, if not, prints a warning asking for you to install it. The watchdog module is not required, but highly recommended. It improves Streamlit's ability to detect changes to files in your filesystem.
7
+ # If you'd like to turn off this warning, set this to True.
8
+ # Default: false
9
+ disableWatchdogWarning = true
10
+
11
+ # If True, will show a warning when you run a Streamlit-enabled script via "python my_script.py".
12
+ # Default: true
13
+ showWarningOnDirectExecution = false
app.py CHANGED
@@ -1,59 +1,81 @@
1
  import streamlit as st
2
- from backend_utils import app_init, set_q1, set_q2, set_q3, set_q4, set_q5
 
3
 
4
- st.markdown("<center> <h1> Haystack Demo </h1> </center>", unsafe_allow_html=True)
 
 
5
 
6
- if st.session_state.get('pipelines_loaded', False):
7
- with st.spinner('Loading pipelines...'):
8
- p1, p2, p3 = app_init()
9
- st.success('Pipelines are loaded', icon="βœ…")
10
- st.session_state['pipelines_loaded'] = True
 
 
 
 
11
 
12
  placeholder = st.empty()
13
  with placeholder:
14
  search_bar, button = st.columns([3, 1])
15
  with search_bar:
16
- username = st.text_area(f"", max_chars=200, key='query')
17
 
18
  with button:
19
- st.write("")
20
- st.write("")
21
  run_pressed = st.button("Run")
22
 
23
- st.radio("Type", ("Retrieval Augmented", "Retrieval Augmented with Web Search"), key="query_type")
24
-
25
- # st.sidebar.selectbox(
26
- # "Example Questions:",
27
- # QUERIES,
28
- # key='q_drop_down', on_change=set_question)
29
 
 
 
30
  c1, c2, c3, c4, c5 = st.columns(5)
31
  with c1:
32
- st.button('Example Q1', on_click=set_q1)
33
  with c2:
34
- st.button('Example Q2', on_click=set_q2)
35
  with c3:
36
- st.button('Example Q3', on_click=set_q3)
37
  with c4:
38
- st.button('Example Q4', on_click=set_q4)
39
  with c5:
40
- st.button('Example Q5', on_click=set_q5)
 
 
 
 
 
 
 
 
41
 
42
- st.markdown("<h4> Answer with PLAIN GPT </h4>", unsafe_allow_html=True)
43
  placeholder_plain_gpt = st.empty()
44
- st.text("")
45
- st.text("")
46
- st.markdown(f"<h4> Answer with {st.session_state['query_type'].upper()} </h4>", unsafe_allow_html=True)
47
  placeholder_retrieval_augmented = st.empty()
48
 
49
  if st.session_state.get('query') and run_pressed:
50
  input = st.session_state['query']
51
- p1, p2, p3 = app_init()
52
- answers = p1.run(input)
 
 
 
53
  placeholder_plain_gpt.markdown(answers['results'][0])
54
 
55
  if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
56
- answers_2 = p2.run(input)
 
 
 
 
 
 
57
  else:
 
58
  answers_2 = p3.run(input)
59
  placeholder_retrieval_augmented.markdown(answers_2['results'][0])
 
1
  import streamlit as st
2
+ from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
3
+ get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES)
4
 
5
+ st.set_page_config(
6
+ page_title="Retrieval Augmentation with Haystack",
7
+ )
8
 
9
+ st.markdown("<center> <h2> Reduce Hallucinations with Retrieval Augmentation </h2> </center>", unsafe_allow_html=True)
10
+
11
+ st.markdown("Ask a question about the collapse of the Silicon Valley Bank (SVB).", unsafe_allow_html=True)
12
+
13
+ # if not st.session_state.get('pipelines_loaded', False):
14
+ # with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
15
+ # p1, p2, p3 = app_init()
16
+ # st.success('Pipelines are loaded', icon="βœ…")
17
+ # st.session_state['pipelines_loaded'] = True
18
 
19
  placeholder = st.empty()
20
  with placeholder:
21
  search_bar, button = st.columns([3, 1])
22
  with search_bar:
23
+ username = st.text_area(f" ", max_chars=200, key='query')
24
 
25
  with button:
26
+ st.write(" ")
27
+ st.write(" ")
28
  run_pressed = st.button("Run")
29
 
30
+ st.markdown("<center> <h5> Example questions </h5> </center>", unsafe_allow_html=True)
 
 
 
 
 
31
 
32
+ st.write(" ")
33
+ st.write(" ")
34
  c1, c2, c3, c4, c5 = st.columns(5)
35
  with c1:
36
+ st.button(QUERIES[0], on_click=set_q1)
37
  with c2:
38
+ st.button(QUERIES[1], on_click=set_q2)
39
  with c3:
40
+ st.button(QUERIES[2], on_click=set_q3)
41
  with c4:
42
+ st.button(QUERIES[3], on_click=set_q4)
43
  with c5:
44
+ st.button(QUERIES[4], on_click=set_q5)
45
+
46
+ st.write(" ")
47
+ st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieval Augmented with Web Search"), key="query_type")
48
+
49
+ # st.sidebar.selectbox(
50
+ # "Example Questions:",
51
+ # QUERIES,
52
+ # key='q_drop_down', on_change=set_question)
53
 
54
+ st.markdown("<h5> Answer with GPT's Internal Knowledge </h5>", unsafe_allow_html=True)
55
  placeholder_plain_gpt = st.empty()
56
+ st.text(" ")
57
+ st.text(" ")
58
+ st.markdown(f"<h5> Answer with {st.session_state['query_type']} </h5>", unsafe_allow_html=True)
59
  placeholder_retrieval_augmented = st.empty()
60
 
61
  if st.session_state.get('query') and run_pressed:
62
  input = st.session_state['query']
63
+ with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
64
+ p1 = get_plain_pipeline()
65
+ with st.spinner('Fetching answers from GPT\'s internal knowledge... '
66
+ '\n This may take a few mins and might also fail if OpenAI API server is down.'):
67
+ answers = p1.run(input)
68
  placeholder_plain_gpt.markdown(answers['results'][0])
69
 
70
  if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
71
+ with st.spinner(
72
+ 'Loading Retrieval Augmented pipeline... \
73
+ n This may take a few mins and might also fail if OpenAI API server is down.'):
74
+ p2 = get_retrieval_augmented_pipeline()
75
+ with st.spinner('Fetching relevant documents from documented stores and calculating answers... '
76
+ '\n This may take a few mins and might also fail if OpenAI API server is down.'):
77
+ answers_2 = p2.run(input)
78
  else:
79
+ p3 = get_web_retrieval_augmented_pipeline()
80
  answers_2 = p3.run(input)
81
  placeholder_retrieval_augmented.markdown(answers_2['results'][0])
backend_utils.py CHANGED
@@ -1,5 +1,3 @@
1
- import os
2
-
3
  import streamlit as st
4
  from haystack import Pipeline
5
  from haystack.document_stores import FAISSDocumentStore
@@ -15,14 +13,8 @@ QUERIES = [
15
  "When did SVB collapse?"
16
  ]
17
 
18
- def ChangeWidgetFontSize(wgt_txt, wch_font_size = '12px'):
19
- htmlstr = """<script>var elements = window.parent.document.querySelectorAll('*'), i;
20
- for (i = 0; i < elements.length; ++i) { if (elements[i].innerText == |wgt_txt|)
21
- { elements[i].style.fontSize='""" + wch_font_size + """';} } </script> """
22
-
23
- htmlstr = htmlstr.replace('|wgt_txt|', "'" + wgt_txt + "'")
24
-
25
 
 
26
  def get_plain_pipeline():
27
  prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
28
  # Now let make one PromptNode use the default model and the other one the OpenAI model:
@@ -33,6 +25,7 @@ def get_plain_pipeline():
33
  return pipeline
34
 
35
 
 
36
  def get_retrieval_augmented_pipeline():
37
  ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
38
  faiss_config_path="data/my_faiss_index.json")
@@ -62,6 +55,7 @@ def get_retrieval_augmented_pipeline():
62
  return pipeline
63
 
64
 
 
65
  def get_web_retrieval_augmented_pipeline():
66
  search_key = st.secrets["WEBRET_API_KEY"]
67
  web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
@@ -82,13 +76,16 @@ def get_web_retrieval_augmented_pipeline():
82
  return pipeline
83
 
84
 
85
- @st.cache_resource(show_spinner=False)
86
- def app_init():
87
- os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
88
- p1 = get_plain_pipeline()
89
- p2 = get_retrieval_augmented_pipeline()
90
- p3 = get_web_retrieval_augmented_pipeline()
91
- return p1, p2, p3
 
 
 
92
 
93
 
94
  if 'query' not in st.session_state:
@@ -117,4 +114,3 @@ def set_q4():
117
 
118
  def set_q5():
119
  st.session_state['query'] = QUERIES[4]
120
-
 
 
 
1
  import streamlit as st
2
  from haystack import Pipeline
3
  from haystack.document_stores import FAISSDocumentStore
 
13
  "When did SVB collapse?"
14
  ]
15
 
 
 
 
 
 
 
 
16
 
17
+ @st.cache_resource(show_spinner=False)
18
  def get_plain_pipeline():
19
  prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
20
  # Now let make one PromptNode use the default model and the other one the OpenAI model:
 
25
  return pipeline
26
 
27
 
28
+ @st.cache_resource(show_spinner=False)
29
  def get_retrieval_augmented_pipeline():
30
  ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
31
  faiss_config_path="data/my_faiss_index.json")
 
55
  return pipeline
56
 
57
 
58
+ @st.cache_resource(show_spinner=False)
59
  def get_web_retrieval_augmented_pipeline():
60
  search_key = st.secrets["WEBRET_API_KEY"]
61
  web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
 
76
  return pipeline
77
 
78
 
79
+ # @st.cache_resource(show_spinner=False)
80
+ # def app_init():
81
+ # print("Loading Pipelines...")
82
+ # p1 = get_plain_pipeline()
83
+ # print("Loaded Plain Pipeline")
84
+ # p2 = get_retrieval_augmented_pipeline()
85
+ # print("Loaded Retrieval Augmented Pipeline")
86
+ # p3 = get_web_retrieval_augmented_pipeline()
87
+ # print("Loaded Web Retrieval Augmented Pipeline")
88
+ # return p1, p2, p3
89
 
90
 
91
  if 'query' not in st.session_state:
 
114
 
115
  def set_q5():
116
  st.session_state['query'] = QUERIES[4]