JMuscatello commited on
Commit
d1b7a20
β€’
2 Parent(s): 60d2d8a 5436ef7

Merge branch 'add_find_demo' into main

Browse files
Files changed (2) hide show
  1. pages/6_πŸ”Ž_Find_Demo.py +178 -1
  2. requirements.txt +4 -1
pages/6_πŸ”Ž_Find_Demo.py CHANGED
@@ -1,11 +1,116 @@
1
  import os
 
 
 
 
2
 
3
  import streamlit as st
4
  import streamlit_analytics
 
 
 
5
  from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  streamlit_analytics.start_tracking()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  st.set_page_config(
10
  page_title="Find Demo",
11
  page_icon="πŸ”Ž",
@@ -22,7 +127,79 @@ add_logo_to_sidebar()
22
  st.sidebar.success("πŸ‘† Select a demo above.")
23
 
24
  st.title('πŸ”Ž Find Demo')
25
- st.markdown("πŸ— This demo is currently under construction. Please visit back soon.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  add_email_signup_form()
28
 
 
1
  import os
2
+ from io import StringIO
3
+ import re
4
+
5
+ import pandas as pd
6
 
7
  import streamlit as st
8
  import streamlit_analytics
9
+
10
+ import streamlit_toggle as tog
11
+
12
  from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
13
 
14
+ from huggingface_hub import snapshot_download
15
+
16
+ from haystack.document_stores import InMemoryDocumentStore
17
+ from haystack.nodes import BM25Retriever, EmbeddingRetriever
18
+
19
+ HF_TOKEN = os.environ.get("HF_TOKEN")
20
+ DATA_REPO_ID = "simplexico/cuad-qa-answers"
21
+ DATA_FILENAME = "cuad_questions_answers.json"
22
+ EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2"
23
+ if EMBEDDING_MODEL == "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" or EMBEDDING_MODEL == "sentence-transformers/paraphrase-MiniLM-L3-v2":
24
+ EMBEDDING_DIM = 384
25
+ else:
26
+ EMBEDDING_DIM = 768
27
+
28
+ EXAMPLE_TEXT = "the governing law is the State of Texas"
29
+
30
  streamlit_analytics.start_tracking()
31
 
32
+ @st.cache(allow_output_mutation=True)
33
+ def load_dataset():
34
+ snapshot_download(repo_id=DATA_REPO_ID, token=HF_TOKEN, local_dir='./', repo_type='dataset')
35
+ df = pd.read_json(DATA_FILENAME)
36
+ return df
37
+
38
+ @st.cache(allow_output_mutation=True)
39
+ def generate_document_store(df):
40
+ """Create haystack document store using contract clause data
41
+ """
42
+ document_dicts = []
43
+
44
+ for idx, row in df.iterrows():
45
+ document_dicts.append(
46
+ {
47
+ 'content': row['paragraph'],
48
+ 'meta': {'contract_title': row['contract_title']}
49
+ }
50
+ )
51
+
52
+ document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=EMBEDDING_DIM, similarity='cosine')
53
+
54
+ document_store.write_documents(document_dicts)
55
+
56
+ return document_store
57
+
58
+ def files_to_dataframe(uploaded_files, limit=10):
59
+ texts = []
60
+ titles = []
61
+ for uploaded_file in uploaded_files[:limit]:
62
+
63
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
64
+
65
+ text = stringio.read().strip()
66
+ paragraphs = text.split("\n\n")
67
+ paragraphs = [p.strip() for p in paragraphs if len(p.split()) > 10]
68
+ texts.extend(paragraphs)
69
+ titles.extend([uploaded_file.name]*len(paragraphs))
70
+
71
+ return pd.DataFrame({'paragraph': texts, 'contract_title': titles})
72
+
73
+ @st.cache(allow_output_mutation=True)
74
+ def generate_bm25_retriever(document_store):
75
+ return BM25Retriever(document_store)
76
+
77
+ @st.cache(allow_output_mutation=True)
78
+ def generate_embeddings(embedding_model, document_store):
79
+ embedding_retriever = EmbeddingRetriever(
80
+ embedding_model=embedding_model,
81
+ document_store=document_store,
82
+ model_format="sentence_transformers",
83
+ scale_score=True
84
+ )
85
+ document_store.update_embeddings(embedding_retriever)
86
+ return embedding_retriever
87
+
88
+ def process_query(query, retriever):
89
+ """Generates dataframe with top ten results"""
90
+ texts = []
91
+ contract_titles = []
92
+ scores = []
93
+ ranking = []
94
+ candidate_documents = retriever.retrieve(
95
+ query=query,
96
+ top_k=10,
97
+ )
98
+
99
+ for idx, document in enumerate(candidate_documents):
100
+ texts.append(document.content)
101
+ contract_titles.append(document.meta["contract_title"])
102
+ scores.append(str(round(document.score, 2)))
103
+ ranking.append(idx + 1)
104
+
105
+ return pd.DataFrame(
106
+ {
107
+ "Ranking": ranking,
108
+ "Text": texts,
109
+ "Source Contract": contract_titles,
110
+ "Similarity": scores
111
+ }
112
+ )
113
+
114
  st.set_page_config(
115
  page_title="Find Demo",
116
  page_icon="πŸ”Ž",
 
127
  st.sidebar.success("πŸ‘† Select a demo above.")
128
 
129
  st.title('πŸ”Ž Find Demo')
130
+
131
+ st.write("""
132
+ This demo shows how a set of clauses can be searched.
133
+ Upload a set of contracts on the left and the paragraphs can be searched using **keywords** or using **semantic search**.
134
+ Semantic search leverages an AI model which matches on clauses with a similar meaning to the input text.
135
+ """)
136
+ st.write("**πŸ‘ˆ Upload a set of contracts on the left** to start the demo")
137
+
138
+
139
+ #df = load_dataset()
140
+
141
+ #document_store = generate_document_store(df)
142
+
143
+ #bm25_retriever = generate_bm25_retriever(document_store)
144
+
145
+ #embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
146
+ col1, col2, col3, col4, col5 = st.columns(5)
147
+
148
+ uploaded_files = st.sidebar.file_uploader("Select contracts to search **(upload up to 10 files)**", accept_multiple_files=True)
149
+
150
+ if uploaded_files:
151
+ with col1:
152
+ st.write("Toggle between **keyword** or **semantic** search:")
153
+ value = tog.st_toggle_switch(
154
+ label="Keyword/Semantic",
155
+ label_after=True,
156
+ inactive_color='#D3D3D3',
157
+ active_color="#11567f",
158
+ track_color="#29B5E8"
159
+ )
160
+ if value:
161
+ search_type = "semantic"
162
+ else:
163
+ search_type = "keyword"
164
+
165
+ print(value)
166
+
167
+ df = files_to_dataframe(uploaded_files)
168
+ document_store = generate_document_store(df)
169
+ bm25_retriever = generate_bm25_retriever(document_store)
170
+ st.write("**πŸ‘‡ Enter search query below** and hit the button **Find Clauses** to see the demo in action")
171
+ query = st.text_area(label='Enter Search Query', value=EXAMPLE_TEXT, height=250)
172
+ button = st.button('**Find Clauses**', type='primary', use_container_width=True)
173
+
174
+ if button:
175
+
176
+ hide_dataframe_row_index = """
177
+ <style>
178
+ .row_heading.level0 {display:none}
179
+ .blank {display:none}
180
+ </style>
181
+ """
182
+
183
+ st.subheader(f'Search Results ({search_type}):')
184
+ # Inject CSS with Markdown
185
+ st.markdown(hide_dataframe_row_index, unsafe_allow_html=True)
186
+
187
+ if search_type == "keyword":
188
+ df_bm25 = process_query(query, bm25_retriever)
189
+ st.table(df_bm25)
190
+
191
+ if search_type == "semantic":
192
+ embedding_retriever = generate_embeddings(EMBEDDING_MODEL, document_store)
193
+ df_embed = process_query(query, embedding_retriever)
194
+ st.table(df_embed)
195
+
196
+ # with col2:
197
+
198
+ # st.subheader('Semantic Search Results:')
199
+ # # Inject CSS with Markdown
200
+ # st.markdown(hide_dataframe_row_index, unsafe_allow_html=True)
201
+ # df_embed = process_query(query, embedding_retriever)
202
+ # st.table(df_embed)
203
 
204
  add_email_signup_form()
205
 
requirements.txt CHANGED
@@ -8,6 +8,7 @@ click==8.1.3
8
  cloudpickle==2.2.1
9
  decorator==5.1.1
10
  entrypoints==0.4
 
11
  filelock==3.10.0
12
  gitdb==4.0.10
13
  GitPython==3.1.31
@@ -25,7 +26,8 @@ matplotlib==3.7.1
25
  mdurl==0.1.2
26
  nltk==3.8.1
27
  numba==0.56.4
28
- numpy==1.24.2
 
29
  packaging==23.0
30
  pandas==1.5.3
31
  Pillow==9.4.0
@@ -53,6 +55,7 @@ https://huggingface.co/spacy/en_core_web_md/resolve/main/en_core_web_md-any-py3-
53
  smmap==5.0.0
54
  streamlit==1.20.0
55
  streamlit-analytics==0.4.1
 
56
  threadpoolctl==3.1.0
57
  toml==0.10.2
58
  toolz==0.12.0
 
8
  cloudpickle==2.2.1
9
  decorator==5.1.1
10
  entrypoints==0.4
11
+ farm-haystack==1.15.1
12
  filelock==3.10.0
13
  gitdb==4.0.10
14
  GitPython==3.1.31
 
26
  mdurl==0.1.2
27
  nltk==3.8.1
28
  numba==0.56.4
29
+ numpy==1.23.5
30
+ >>>>>>> add_find_demo
31
  packaging==23.0
32
  pandas==1.5.3
33
  Pillow==9.4.0
 
55
  smmap==5.0.0
56
  streamlit==1.20.0
57
  streamlit-analytics==0.4.1
58
+ streamlit-toggle-switch==1.0.2
59
  threadpoolctl==3.1.0
60
  toml==0.10.2
61
  toolz==0.12.0