prashant
commited on
Commit
•
2663a97
1
Parent(s):
fa191c0
retriever update and coherence
Browse files- appStore/coherence.py +63 -5
- appStore/keyword_search.py +3 -3
- paramconfig.cfg +18 -5
- utils/ndc_explorer.py +55 -0
- utils/semantic_search.py +49 -28
appStore/coherence.py
CHANGED
@@ -4,21 +4,42 @@ sys.path.append('../utils')
|
|
4 |
|
5 |
import streamlit as st
|
6 |
import ast
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Reading data and Declaring necessary variables
|
9 |
with open('docStore/ndcs/countryList.txt') as dfile:
|
10 |
-
|
11 |
countryList = ast.literal_eval(countryList)
|
12 |
countrynames = list(countryList.keys())
|
13 |
|
14 |
with open('docStore/ndcs/cca.txt', encoding='utf-8', errors='ignore') as dfile:
|
15 |
-
|
16 |
cca_sent = ast.literal_eval(cca_sent)
|
17 |
|
18 |
with open('docStore/ndcs/ccm.txt', encoding='utf-8', errors='ignore') as dfile:
|
19 |
ccm_sent = dfile.read()
|
20 |
ccm_sent = ast.literal_eval(ccm_sent)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def app():
|
23 |
|
24 |
#### APP INFO #####
|
@@ -55,6 +76,43 @@ def app():
|
|
55 |
indicator is based on vector similarities in which only paragraphs \
|
56 |
with similarity above 0.55 to the indicators are considered. """)
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
import streamlit as st
|
6 |
import ast
|
7 |
+
import logging
|
8 |
+
from utils.ndc_explorer import countrySpecificCCA, countrySpecificCCM
|
9 |
+
from utils.checkconfig import getconfig
|
10 |
+
from utils.semantic_search import runSemanticPreprocessingPipeline
|
11 |
+
|
12 |
|
13 |
# Reading data and Declaring necessary variables
|
14 |
with open('docStore/ndcs/countryList.txt') as dfile:
|
15 |
+
countryList = dfile.read()
|
16 |
countryList = ast.literal_eval(countryList)
|
17 |
countrynames = list(countryList.keys())
|
18 |
|
19 |
with open('docStore/ndcs/cca.txt', encoding='utf-8', errors='ignore') as dfile:
|
20 |
+
cca_sent = dfile.read()
|
21 |
cca_sent = ast.literal_eval(cca_sent)
|
22 |
|
23 |
with open('docStore/ndcs/ccm.txt', encoding='utf-8', errors='ignore') as dfile:
|
24 |
ccm_sent = dfile.read()
|
25 |
ccm_sent = ast.literal_eval(ccm_sent)
|
26 |
|
27 |
+
config = getconfig('paramconfig.cfg')
|
28 |
+
split_by = config.get('coherence','SPLIT_BY')
|
29 |
+
split_length = int(config.get('coherence','SPLIT_LENGTH'))
|
30 |
+
split_overlap = int(config.get('coherence','SPLIT_OVERLAP'))
|
31 |
+
split_respect_sentence_boundary = bool(int(config.get('coherence',
|
32 |
+
'RESPECT_SENTENCE_BOUNDARY')))
|
33 |
+
remove_punc = bool(int(config.get('coherence','REMOVE_PUNC')))
|
34 |
+
embedding_model = config.get('coherence','RETRIEVER')
|
35 |
+
embedding_model_format = config.get('coherence','RETRIEVER_FORMAT')
|
36 |
+
embedding_layer = int(config.get('coherence','RETRIEVER_EMB_LAYER'))
|
37 |
+
embedding_dim = int(config.get('coherence','EMBEDDING_DIM'))
|
38 |
+
retriever_top_k = int(config.get('coherence','RETRIEVER_TOP_K'))
|
39 |
+
reader_model = config.get('coherence','READER')
|
40 |
+
reader_top_k = int(config.get('coherence','RETRIEVER_TOP_K'))
|
41 |
+
|
42 |
+
|
43 |
def app():
|
44 |
|
45 |
#### APP INFO #####
|
|
|
76 |
indicator is based on vector similarities in which only paragraphs \
|
77 |
with similarity above 0.55 to the indicators are considered. """)
|
78 |
|
79 |
+
with st.sidebar:
|
80 |
+
|
81 |
+
option = st.selectbox('Select Country', (countrynames))
|
82 |
+
countryCode = countryList[option]
|
83 |
+
st.markdown("---")
|
84 |
+
|
85 |
+
with st.container():
|
86 |
+
if st.button("Check Coherence"):
|
87 |
+
sent_cca = countrySpecificCCA(cca_sent,1,countryCode)
|
88 |
+
sent_ccm = countrySpecificCCM(ccm_sent,1,countryCode)
|
89 |
+
|
90 |
+
if 'filepath' in st.session_state:
|
91 |
+
allDocuments = runSemanticPreprocessingPipeline(
|
92 |
+
file_path= st.session_state['filepath'],
|
93 |
+
file_name = st.session_state['filename'],
|
94 |
+
split_by=split_by,
|
95 |
+
split_length= split_length,
|
96 |
+
split_overlap=split_overlap,
|
97 |
+
removePunc= remove_punc,
|
98 |
+
split_respect_sentence_boundary=split_respect_sentence_boundary)
|
99 |
+
genre = st.radio( "Select Category",('Climate Change Adaptation', 'Climate Change Mitigation'))
|
100 |
+
if genre == 'Climate Change Adaptation':
|
101 |
+
sent_dict = sent_cca
|
102 |
+
else:
|
103 |
+
sent_dict = sent_ccm
|
104 |
+
sent_labels = []
|
105 |
+
for key,sent in sent_dict.items():
|
106 |
+
sent_labels.append(sent)
|
107 |
+
if len(allDocuments['documents']) > 100:
|
108 |
+
warning_msg = ": This might take sometime, please sit back and relax."
|
109 |
+
else:
|
110 |
+
warning_msg = ""
|
111 |
+
logging.info("starting Coherence analysis, country selected {}".format(option))
|
112 |
+
with st.spinner("Performing Similar/Contextual search{}".format(warning_msg)):
|
113 |
+
pass
|
114 |
+
|
115 |
+
|
116 |
+
else:
|
117 |
+
st.info("🤔 No document found, please try to upload it at the sidebar!")
|
118 |
+
logging.warning("Terminated as no document provided")
|
appStore/keyword_search.py
CHANGED
@@ -20,6 +20,7 @@ remove_punc = bool(int(config.get('semantic_search','REMOVE_PUNC')))
|
|
20 |
embedding_model = config.get('semantic_search','RETRIEVER')
|
21 |
embedding_model_format = config.get('semantic_search','RETRIEVER_FORMAT')
|
22 |
embedding_layer = int(config.get('semantic_search','RETRIEVER_EMB_LAYER'))
|
|
|
23 |
retriever_top_k = int(config.get('semantic_search','RETRIEVER_TOP_K'))
|
24 |
reader_model = config.get('semantic_search','READER')
|
25 |
reader_top_k = int(config.get('semantic_search','RETRIEVER_TOP_K'))
|
@@ -97,8 +98,7 @@ def app():
|
|
97 |
logging.warning("Terminated as no keyword provided")
|
98 |
else:
|
99 |
if 'filepath' in st.session_state:
|
100 |
-
|
101 |
-
|
102 |
if searchtype:
|
103 |
allDocuments = runLexicalPreprocessingPipeline(
|
104 |
file_name=st.session_state['filename'],
|
@@ -137,7 +137,7 @@ def app():
|
|
137 |
embedding_layer=embedding_layer,
|
138 |
embedding_model_format=embedding_model_format,
|
139 |
reader_model=reader_model,reader_top_k=reader_top_k,
|
140 |
-
retriever_top_k=retriever_top_k)
|
141 |
|
142 |
else:
|
143 |
st.info("🤔 No document found, please try to upload it at the sidebar!")
|
|
|
20 |
embedding_model = config.get('semantic_search','RETRIEVER')
|
21 |
embedding_model_format = config.get('semantic_search','RETRIEVER_FORMAT')
|
22 |
embedding_layer = int(config.get('semantic_search','RETRIEVER_EMB_LAYER'))
|
23 |
+
embedding_dim = int(config.get('semantic_search','EMBEDDING_DIM'))
|
24 |
retriever_top_k = int(config.get('semantic_search','RETRIEVER_TOP_K'))
|
25 |
reader_model = config.get('semantic_search','READER')
|
26 |
reader_top_k = int(config.get('semantic_search','RETRIEVER_TOP_K'))
|
|
|
98 |
logging.warning("Terminated as no keyword provided")
|
99 |
else:
|
100 |
if 'filepath' in st.session_state:
|
101 |
+
|
|
|
102 |
if searchtype:
|
103 |
allDocuments = runLexicalPreprocessingPipeline(
|
104 |
file_name=st.session_state['filename'],
|
|
|
137 |
embedding_layer=embedding_layer,
|
138 |
embedding_model_format=embedding_model_format,
|
139 |
reader_model=reader_model,reader_top_k=reader_top_k,
|
140 |
+
retriever_top_k=retriever_top_k, embedding_dim=embedding_dim)
|
141 |
|
142 |
else:
|
143 |
st.info("🤔 No document found, please try to upload it at the sidebar!")
|
paramconfig.cfg
CHANGED
@@ -8,8 +8,9 @@ REMOVE_PUNC = 0
|
|
8 |
[semantic_search]
|
9 |
RETRIEVER_TOP_K = 10
|
10 |
MAX_SEQ_LENGTH = 64
|
11 |
-
RETRIEVER =
|
12 |
RETRIEVER_FORMAT = sentence_transformers
|
|
|
13 |
RETRIEVER_EMB_LAYER = -1
|
14 |
READER = deepset/tinyroberta-squad2
|
15 |
READER_TOP_K = 10
|
@@ -30,9 +31,21 @@ SPLIT_OVERLAP = 10
|
|
30 |
RESPECT_SENTENCE_BOUNDARY = 1
|
31 |
TOP_KEY = 15
|
32 |
|
33 |
-
[preprocessor]
|
34 |
-
SPLIT_OVERLAP_WORD = 10
|
35 |
-
SPLIT_OVERLAP_SENTENCE = 1
|
36 |
-
|
37 |
[tfidf]
|
38 |
TOP_N = 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
[semantic_search]
|
9 |
RETRIEVER_TOP_K = 10
|
10 |
MAX_SEQ_LENGTH = 64
|
11 |
+
RETRIEVER = multi-qa-distilbert-dot-v1
|
12 |
RETRIEVER_FORMAT = sentence_transformers
|
13 |
+
EMBEDDING_DIM = 768
|
14 |
RETRIEVER_EMB_LAYER = -1
|
15 |
READER = deepset/tinyroberta-squad2
|
16 |
READER_TOP_K = 10
|
|
|
31 |
RESPECT_SENTENCE_BOUNDARY = 1
|
32 |
TOP_KEY = 15
|
33 |
|
|
|
|
|
|
|
|
|
34 |
[tfidf]
|
35 |
TOP_N = 20
|
36 |
+
|
37 |
+
[coherence]
|
38 |
+
RETRIEVER_TOP_K = 10
|
39 |
+
MAX_SEQ_LENGTH = 64
|
40 |
+
RETRIEVER = all-MiniLM-L6-v2
|
41 |
+
RETRIEVER_FORMAT = sentence_transformers
|
42 |
+
RETRIEVER_EMB_LAYER = -1
|
43 |
+
EMBEDDING_DIM = 384
|
44 |
+
READER = deepset/tinyroberta-squad2
|
45 |
+
READER_TOP_K = 10
|
46 |
+
THRESHOLD = 0.55
|
47 |
+
SPLIT_BY = sentence
|
48 |
+
SPLIT_LENGTH = 3
|
49 |
+
SPLIT_OVERLAP = 0
|
50 |
+
RESPECT_SENTENCE_BOUNDARY = 1
|
51 |
+
REMOVE_PUNC = 0
|
utils/ndc_explorer.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import urllib.request
|
3 |
+
import json
|
4 |
+
|
5 |
+
link = "https://klimalog.die-gdi.de/ndc/open-data/dataset.json"
|
6 |
+
def get_document(countryCode: str):
|
7 |
+
with urllib.request.urlopen(link) as urlfile:
|
8 |
+
data = json.loads(urlfile.read())
|
9 |
+
categoriesData = {}
|
10 |
+
categoriesData['categories']= data['categories']
|
11 |
+
categoriesData['subcategories']= data['subcategories']
|
12 |
+
keys_sub = categoriesData['subcategories'].keys()
|
13 |
+
documentType= 'NDCs'
|
14 |
+
if documentType in data.keys():
|
15 |
+
if countryCode in data[documentType].keys():
|
16 |
+
get_dict = {}
|
17 |
+
for key, value in data[documentType][countryCode].items():
|
18 |
+
if key not in ['country_name','region_id', 'region_name']:
|
19 |
+
get_dict[key] = value['classification']
|
20 |
+
else:
|
21 |
+
get_dict[key] = value
|
22 |
+
else:
|
23 |
+
return None
|
24 |
+
else:
|
25 |
+
return None
|
26 |
+
|
27 |
+
country = {}
|
28 |
+
for key in categoriesData['categories']:
|
29 |
+
country[key]= {}
|
30 |
+
for key,value in categoriesData['subcategories'].items():
|
31 |
+
country[value['category']][key] = get_dict[key]
|
32 |
+
|
33 |
+
return country
|
34 |
+
|
35 |
+
# country_ndc = get_document('NDCs', countryList[option])
|
36 |
+
|
37 |
+
def countrySpecificCCA(cca_sent, threshold, countryCode):
|
38 |
+
temp = {}
|
39 |
+
doc = get_document(countryCode)
|
40 |
+
for key,value in cca_sent.items():
|
41 |
+
id_ = doc['climate change adaptation'][key]['id']
|
42 |
+
if id_ >threshold:
|
43 |
+
temp[key] = value['id'][id_]
|
44 |
+
return temp
|
45 |
+
|
46 |
+
|
47 |
+
def countrySpecificCCM(ccm_sent, threshold, countryCode):
|
48 |
+
temp = {}
|
49 |
+
doc = get_document(countryCode)
|
50 |
+
for key,value in ccm_sent.items():
|
51 |
+
id_ = doc['climate change mitigation'][key]['id']
|
52 |
+
if id_ >threshold:
|
53 |
+
temp[key] = value['id'][id_]
|
54 |
+
|
55 |
+
return temp
|
utils/semantic_search.py
CHANGED
@@ -63,6 +63,7 @@ class QueryCheck(BaseComponent):
|
|
63 |
else:
|
64 |
output = {"query": "find all issues related to {}".format(query),
|
65 |
"query_type": 'statements/keyword'}
|
|
|
66 |
return output, "output_1"
|
67 |
|
68 |
def run_batch(self, query):
|
@@ -154,7 +155,8 @@ def loadRetriever(embedding_model:Text = None, embedding_model_format:Text = No
|
|
154 |
return retriever
|
155 |
|
156 |
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True)
|
157 |
-
def createDocumentStore(documents:List[Document], similarity:str = '
|
|
|
158 |
"""
|
159 |
Creates the InMemory Document Store from haystack list of Documents.
|
160 |
It is mandatory component for Retriever to work in Haystack frame work.
|
@@ -164,13 +166,17 @@ def createDocumentStore(documents:List[Document], similarity:str = 'cosine'):
|
|
164 |
documents: List of haystack document. If using the preprocessing pipeline,
|
165 |
can be fetched key = 'documents; on output of preprocessing pipeline.
|
166 |
similarity: scoring function, can be either 'cosine' or 'dot_product'
|
|
|
|
|
|
|
167 |
|
168 |
Return
|
169 |
-------
|
170 |
document_store: InMemory Document Store object type.
|
171 |
|
172 |
"""
|
173 |
-
document_store = InMemoryDocumentStore(similarity = similarity
|
|
|
174 |
document_store.write_documents(documents)
|
175 |
|
176 |
return document_store
|
@@ -178,9 +184,10 @@ def createDocumentStore(documents:List[Document], similarity:str = 'cosine'):
|
|
178 |
|
179 |
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True)
|
180 |
def semanticSearchPipeline(documents:List[Document], embedding_model:Text = None,
|
181 |
-
embedding_model_format:Text = None,
|
182 |
embedding_layer:int = None, retriever_top_k:int = 10,
|
183 |
-
reader_model:str = None, reader_top_k:int = 10
|
|
|
184 |
"""
|
185 |
creates the semantic search pipeline and document Store object from the
|
186 |
list of haystack documents. The top_k for the Reader and Retirever are kept
|
@@ -207,6 +214,11 @@ def semanticSearchPipeline(documents:List[Document], embedding_model:Text = Non
|
|
207 |
reader_top_k: Reader will use retrieved results to further find better matches.
|
208 |
As purpose here is to use reader to extract context, the value is
|
209 |
same as retriever_top_k.
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
Return
|
212 |
---------
|
@@ -219,7 +231,8 @@ def semanticSearchPipeline(documents:List[Document], embedding_model:Text = Non
|
|
219 |
embeddings of each paragraph in document store.
|
220 |
|
221 |
"""
|
222 |
-
document_store = createDocumentStore(documents
|
|
|
223 |
retriever = loadRetriever(embedding_model = embedding_model,
|
224 |
embedding_model_format=embedding_model_format,
|
225 |
embedding_layer=embedding_layer,
|
@@ -227,17 +240,22 @@ def semanticSearchPipeline(documents:List[Document], embedding_model:Text = Non
|
|
227 |
document_store = document_store)
|
228 |
|
229 |
document_store.update_embeddings(retriever)
|
230 |
-
querycheck = QueryCheck()
|
231 |
reader = FARMReader(model_name_or_path=reader_model,
|
232 |
top_k = reader_top_k, use_gpu=True)
|
233 |
-
|
234 |
semanticsearch_pipeline = Pipeline()
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
return semanticsearch_pipeline, document_store
|
243 |
|
@@ -281,7 +299,8 @@ def semanticsearchAnnotator(matches: List[List[int]], document):
|
|
281 |
def semantic_search(query:Text,documents:List[Document],embedding_model:Text,
|
282 |
embedding_model_format:Text,
|
283 |
embedding_layer:int, reader_model:str,
|
284 |
-
retriever_top_k:int = 10, reader_top_k:int = 10
|
|
|
285 |
"""
|
286 |
Performs the Semantic search on the List of haystack documents which is
|
287 |
returned by preprocessing Pipeline.
|
@@ -297,22 +316,24 @@ def semantic_search(query:Text,documents:List[Document],embedding_model:Text,
|
|
297 |
embedding_layer= embedding_layer,
|
298 |
embedding_model_format= embedding_model_format,
|
299 |
reader_model= reader_model, retriever_top_k= retriever_top_k,
|
300 |
-
reader_top_k= reader_top_k)
|
301 |
|
302 |
results = semanticsearch_pipeline.run(query = query)
|
303 |
-
|
304 |
-
|
305 |
-
st.markdown("##### Top few semantic search results #####")
|
306 |
else:
|
307 |
-
print("Top few semantic search results")
|
308 |
-
for i,answer in enumerate(results['answers']):
|
309 |
-
temp = answer.to_dict()
|
310 |
-
start_idx = temp['offsets_in_document'][0]['start']
|
311 |
-
end_idx = temp['offsets_in_document'][0]['end']
|
312 |
-
match = [[start_idx,end_idx]]
|
313 |
-
doc = doc_store.get_document_by_id(temp['document_id']).content
|
314 |
if check_streamlit:
|
315 |
-
st.
|
316 |
else:
|
317 |
-
print("
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
else:
|
64 |
output = {"query": "find all issues related to {}".format(query),
|
65 |
"query_type": 'statements/keyword'}
|
66 |
+
logging.info(output)
|
67 |
return output, "output_1"
|
68 |
|
69 |
def run_batch(self, query):
|
|
|
155 |
return retriever
|
156 |
|
157 |
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True)
|
158 |
+
def createDocumentStore(documents:List[Document], similarity:str = 'dot_product',
|
159 |
+
embedding_dim:int = 768):
|
160 |
"""
|
161 |
Creates the InMemory Document Store from haystack list of Documents.
|
162 |
It is mandatory component for Retriever to work in Haystack frame work.
|
|
|
166 |
documents: List of haystack document. If using the preprocessing pipeline,
|
167 |
can be fetched key = 'documents; on output of preprocessing pipeline.
|
168 |
similarity: scoring function, can be either 'cosine' or 'dot_product'
|
169 |
+
embedding_dim: Document store has default value of embedding size = 768, and
|
170 |
+
update_embeddings method of Docstore cannot infer the embedding size of
|
171 |
+
retiever automaticallu, therefore set this value as per the model card.
|
172 |
|
173 |
Return
|
174 |
-------
|
175 |
document_store: InMemory Document Store object type.
|
176 |
|
177 |
"""
|
178 |
+
document_store = InMemoryDocumentStore(similarity = similarity,
|
179 |
+
embedding_dim = embedding_dim )
|
180 |
document_store.write_documents(documents)
|
181 |
|
182 |
return document_store
|
|
|
184 |
|
185 |
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},allow_output_mutation=True)
|
186 |
def semanticSearchPipeline(documents:List[Document], embedding_model:Text = None,
|
187 |
+
useQueryCheck = True, embedding_model_format:Text = None,
|
188 |
embedding_layer:int = None, retriever_top_k:int = 10,
|
189 |
+
reader_model:str = None, reader_top_k:int = 10,
|
190 |
+
embedding_dim:int = 768):
|
191 |
"""
|
192 |
creates the semantic search pipeline and document Store object from the
|
193 |
list of haystack documents. The top_k for the Reader and Retirever are kept
|
|
|
214 |
reader_top_k: Reader will use retrieved results to further find better matches.
|
215 |
As purpose here is to use reader to extract context, the value is
|
216 |
same as retriever_top_k.
|
217 |
+
useQueryCheck: Whether to use the querycheck which modifies the query or not.
|
218 |
+
embedding_dim: Document store has default value of embedding size = 768, and
|
219 |
+
update_embeddings method of Docstore cannot infer the embedding size of
|
220 |
+
retiever automaticallu, therefore set this value as per the model card.
|
221 |
+
|
222 |
|
223 |
Return
|
224 |
---------
|
|
|
231 |
embeddings of each paragraph in document store.
|
232 |
|
233 |
"""
|
234 |
+
document_store = createDocumentStore(documents=documents,
|
235 |
+
embedding_dim=embedding_dim)
|
236 |
retriever = loadRetriever(embedding_model = embedding_model,
|
237 |
embedding_model_format=embedding_model_format,
|
238 |
embedding_layer=embedding_layer,
|
|
|
240 |
document_store = document_store)
|
241 |
|
242 |
document_store.update_embeddings(retriever)
|
|
|
243 |
reader = FARMReader(model_name_or_path=reader_model,
|
244 |
top_k = reader_top_k, use_gpu=True)
|
|
|
245 |
semanticsearch_pipeline = Pipeline()
|
246 |
+
if useQueryCheck:
|
247 |
+
querycheck = QueryCheck()
|
248 |
+
semanticsearch_pipeline.add_node(component = querycheck, name = "QueryCheck",
|
249 |
+
inputs = ["Query"])
|
250 |
+
semanticsearch_pipeline.add_node(component = retriever, name = "EmbeddingRetriever",
|
251 |
+
inputs = ["QueryCheck.output_1"])
|
252 |
+
semanticsearch_pipeline.add_node(component = reader, name = "FARMReader",
|
253 |
+
inputs= ["EmbeddingRetriever"])
|
254 |
+
else:
|
255 |
+
semanticsearch_pipeline.add_node(component = retriever, name = "EmbeddingRetriever",
|
256 |
+
inputs = ["Query"])
|
257 |
+
semanticsearch_pipeline.add_node(component = reader, name = "FARMReader",
|
258 |
+
inputs= ["EmbeddingRetriever"])
|
259 |
|
260 |
return semanticsearch_pipeline, document_store
|
261 |
|
|
|
299 |
def semantic_search(query:Text,documents:List[Document],embedding_model:Text,
|
300 |
embedding_model_format:Text,
|
301 |
embedding_layer:int, reader_model:str,
|
302 |
+
retriever_top_k:int = 10, reader_top_k:int = 10,
|
303 |
+
return_results:bool = False, embedding_dim:int = 768):
|
304 |
"""
|
305 |
Performs the Semantic search on the List of haystack documents which is
|
306 |
returned by preprocessing Pipeline.
|
|
|
316 |
embedding_layer= embedding_layer,
|
317 |
embedding_model_format= embedding_model_format,
|
318 |
reader_model= reader_model, retriever_top_k= retriever_top_k,
|
319 |
+
reader_top_k= reader_top_k, embedding_dim=embedding_dim)
|
320 |
|
321 |
results = semanticsearch_pipeline.run(query = query)
|
322 |
+
if return_results:
|
323 |
+
return results
|
|
|
324 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
if check_streamlit:
|
326 |
+
st.markdown("##### Top few semantic search results #####")
|
327 |
else:
|
328 |
+
print("Top few semantic search results")
|
329 |
+
for i,answer in enumerate(results['answers']):
|
330 |
+
temp = answer.to_dict()
|
331 |
+
start_idx = temp['offsets_in_document'][0]['start']
|
332 |
+
end_idx = temp['offsets_in_document'][0]['end']
|
333 |
+
match = [[start_idx,end_idx]]
|
334 |
+
doc = doc_store.get_document_by_id(temp['document_id']).content
|
335 |
+
if check_streamlit:
|
336 |
+
st.write("Result {}".format(i+1))
|
337 |
+
else:
|
338 |
+
print("Result {}".format(i+1))
|
339 |
+
semanticsearchAnnotator(match, doc)
|