Spaces:
Running
Running
Merge pull request #26 from lfoppiano/add-pdf-viewer
Browse files- document_qa/document_qa_engine.py +105 -46
- document_qa/grobid_processors.py +148 -64
- requirements.txt +3 -1
- streamlit_app.py +90 -30
- tests/__init__.py +0 -0
- tests/conftest.py +37 -0
- tests/resources/2312.07559.paragraphs.tei.xml +0 -0
- tests/resources/2312.07559.sentences.tei.xml +0 -0
- tests/test_document_qa_engine.py +71 -0
- tests/test_grobid_processors.py +46 -0
document_qa/document_qa_engine.py
CHANGED
@@ -3,18 +3,87 @@ import os
|
|
3 |
from pathlib import Path
|
4 |
from typing import Union, Any
|
5 |
|
6 |
-
|
7 |
from grobid_client.grobid_client import GrobidClient
|
8 |
-
from langchain.chains import create_extraction_chain
|
9 |
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
|
10 |
map_rerank_prompt
|
11 |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
12 |
from langchain.retrievers import MultiQueryRetriever
|
13 |
from langchain.schema import Document
|
14 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
15 |
from langchain.vectorstores import Chroma
|
16 |
from tqdm import tqdm
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class DocumentQAEngine:
|
@@ -44,6 +113,7 @@ class DocumentQAEngine:
|
|
44 |
self.llm = llm
|
45 |
self.memory = memory
|
46 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
|
|
47 |
|
48 |
if embeddings_root_path is not None:
|
49 |
self.embeddings_root_path = embeddings_root_path
|
@@ -57,7 +127,7 @@ class DocumentQAEngine:
|
|
57 |
grobid_client = GrobidClient(
|
58 |
grobid_server=self.grobid_url,
|
59 |
batch_size=1000,
|
60 |
-
coordinates=["p"],
|
61 |
sleep_time=5,
|
62 |
timeout=60,
|
63 |
check_server=True
|
@@ -105,7 +175,7 @@ class DocumentQAEngine:
|
|
105 |
if verbose:
|
106 |
print(query)
|
107 |
|
108 |
-
response = self._run_query(doc_id, query, context_size=context_size)
|
109 |
response = response['output_text'] if 'output_text' in response else response
|
110 |
|
111 |
if verbose:
|
@@ -116,17 +186,17 @@ class DocumentQAEngine:
|
|
116 |
return self._parse_json(response, output_parser), response
|
117 |
except Exception as oe:
|
118 |
print("Failing to parse the response", oe)
|
119 |
-
return None, response
|
120 |
elif extraction_schema:
|
121 |
try:
|
122 |
chain = create_extraction_chain(extraction_schema, self.llm)
|
123 |
parsed = chain.run(response)
|
124 |
-
return parsed, response
|
125 |
except Exception as oe:
|
126 |
print("Failing to parse the response", oe)
|
127 |
-
return None, response
|
128 |
else:
|
129 |
-
return None, response
|
130 |
|
131 |
def query_storage(self, query: str, doc_id, context_size=4):
|
132 |
documents = self._get_context(doc_id, query, context_size)
|
@@ -157,12 +227,15 @@ class DocumentQAEngine:
|
|
157 |
|
158 |
def _run_query(self, doc_id, query, context_size=4):
|
159 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
|
|
|
|
|
|
160 |
response = self.chain.run(input_documents=relevant_documents,
|
161 |
question=query)
|
162 |
|
163 |
if self.memory:
|
164 |
self.memory.save_context({"input": query}, {"output": response})
|
165 |
-
return response
|
166 |
|
167 |
def _get_context(self, doc_id, query, context_size=4):
|
168 |
db = self.embeddings_dict[doc_id]
|
@@ -188,14 +261,15 @@ class DocumentQAEngine:
|
|
188 |
relevant_documents = multi_query_retriever.get_relevant_documents(query)
|
189 |
return relevant_documents
|
190 |
|
191 |
-
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1,
|
192 |
"""
|
193 |
Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
|
194 |
"""
|
195 |
if verbose:
|
196 |
print("File", pdf_file_path)
|
197 |
filename = Path(pdf_file_path).stem
|
198 |
-
|
|
|
199 |
|
200 |
biblio = structure['biblio']
|
201 |
biblio['filename'] = filename.replace(" ", "_")
|
@@ -207,48 +281,33 @@ class DocumentQAEngine:
|
|
207 |
metadatas = []
|
208 |
ids = []
|
209 |
|
210 |
-
if chunk_size
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
texts.append(passage['text'])
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
metadatas = [biblio for _ in range(len(texts))]
|
231 |
-
ids = [id for id, t in enumerate(texts)]
|
232 |
-
|
233 |
-
if "biblio" in include:
|
234 |
-
biblio_metadata = copy.copy(biblio)
|
235 |
-
biblio_metadata['type'] = "biblio"
|
236 |
-
biblio_metadata['section'] = "header"
|
237 |
-
for key in ['title', 'authors', 'publication_year']:
|
238 |
-
if key in biblio_metadata:
|
239 |
-
texts.append("{}: {}".format(key, biblio_metadata[key]))
|
240 |
-
metadatas.append(biblio_metadata)
|
241 |
-
ids.append(key)
|
242 |
|
243 |
return texts, metadatas, ids
|
244 |
|
245 |
-
def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1
|
246 |
-
include = ["biblio"] if include_biblio else []
|
247 |
texts, metadata, ids = self.get_text_from_document(
|
248 |
pdf_path,
|
249 |
chunk_size=chunk_size,
|
250 |
-
perc_overlap=perc_overlap
|
251 |
-
include=include)
|
252 |
if doc_id:
|
253 |
hash = doc_id
|
254 |
else:
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Union, Any
|
5 |
|
6 |
+
import tiktoken
|
7 |
from grobid_client.grobid_client import GrobidClient
|
8 |
+
from langchain.chains import create_extraction_chain
|
9 |
from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
|
10 |
map_rerank_prompt
|
11 |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
12 |
from langchain.retrievers import MultiQueryRetriever
|
13 |
from langchain.schema import Document
|
|
|
14 |
from langchain.vectorstores import Chroma
|
15 |
from tqdm import tqdm
|
16 |
|
17 |
+
from document_qa.grobid_processors import GrobidProcessor
|
18 |
+
|
19 |
+
|
20 |
+
class TextMerger:
|
21 |
+
def __init__(self, model_name=None, encoding_name="gpt2"):
|
22 |
+
if model_name is not None:
|
23 |
+
self.enc = tiktoken.encoding_for_model(model_name)
|
24 |
+
else:
|
25 |
+
self.enc = tiktoken.get_encoding(encoding_name)
|
26 |
+
|
27 |
+
def encode(self, text, allowed_special=set(), disallowed_special="all"):
|
28 |
+
return self.enc.encode(
|
29 |
+
text,
|
30 |
+
allowed_special=allowed_special,
|
31 |
+
disallowed_special=disallowed_special,
|
32 |
+
)
|
33 |
+
|
34 |
+
def merge_passages(self, passages, chunk_size, tolerance=0.2):
|
35 |
+
new_passages = []
|
36 |
+
new_coordinates = []
|
37 |
+
current_texts = []
|
38 |
+
current_coordinates = []
|
39 |
+
for idx, passage in enumerate(passages):
|
40 |
+
text = passage['text']
|
41 |
+
coordinates = passage['coordinates']
|
42 |
+
current_texts.append(text)
|
43 |
+
current_coordinates.append(coordinates)
|
44 |
+
|
45 |
+
accumulated_text = " ".join(current_texts)
|
46 |
+
|
47 |
+
encoded_accumulated_text = self.encode(accumulated_text)
|
48 |
+
|
49 |
+
if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance:
|
50 |
+
if len(current_texts) > 1:
|
51 |
+
new_passages.append(current_texts[:-1])
|
52 |
+
new_coordinates.append(current_coordinates[:-1])
|
53 |
+
current_texts = [current_texts[-1]]
|
54 |
+
current_coordinates = [current_coordinates[-1]]
|
55 |
+
else:
|
56 |
+
new_passages.append(current_texts)
|
57 |
+
new_coordinates.append(current_coordinates)
|
58 |
+
current_texts = []
|
59 |
+
current_coordinates = []
|
60 |
+
|
61 |
+
elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance:
|
62 |
+
new_passages.append(current_texts)
|
63 |
+
new_coordinates.append(current_coordinates)
|
64 |
+
current_texts = []
|
65 |
+
current_coordinates = []
|
66 |
+
|
67 |
+
if len(current_texts) > 0:
|
68 |
+
new_passages.append(current_texts)
|
69 |
+
new_coordinates.append(current_coordinates)
|
70 |
+
|
71 |
+
new_passages_struct = []
|
72 |
+
for i, passages in enumerate(new_passages):
|
73 |
+
text = " ".join(passages)
|
74 |
+
coordinates = ";".join(new_coordinates[i])
|
75 |
+
|
76 |
+
new_passages_struct.append(
|
77 |
+
{
|
78 |
+
"text": text,
|
79 |
+
"coordinates": coordinates,
|
80 |
+
"type": "aggregated chunks",
|
81 |
+
"section": "mixed",
|
82 |
+
"subSection": "mixed"
|
83 |
+
}
|
84 |
+
)
|
85 |
+
|
86 |
+
return new_passages_struct
|
87 |
|
88 |
|
89 |
class DocumentQAEngine:
|
|
|
113 |
self.llm = llm
|
114 |
self.memory = memory
|
115 |
self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
|
116 |
+
self.text_merger = TextMerger()
|
117 |
|
118 |
if embeddings_root_path is not None:
|
119 |
self.embeddings_root_path = embeddings_root_path
|
|
|
127 |
grobid_client = GrobidClient(
|
128 |
grobid_server=self.grobid_url,
|
129 |
batch_size=1000,
|
130 |
+
coordinates=["p", "title", "persName"],
|
131 |
sleep_time=5,
|
132 |
timeout=60,
|
133 |
check_server=True
|
|
|
175 |
if verbose:
|
176 |
print(query)
|
177 |
|
178 |
+
response, coordinates = self._run_query(doc_id, query, context_size=context_size)
|
179 |
response = response['output_text'] if 'output_text' in response else response
|
180 |
|
181 |
if verbose:
|
|
|
186 |
return self._parse_json(response, output_parser), response
|
187 |
except Exception as oe:
|
188 |
print("Failing to parse the response", oe)
|
189 |
+
return None, response, coordinates
|
190 |
elif extraction_schema:
|
191 |
try:
|
192 |
chain = create_extraction_chain(extraction_schema, self.llm)
|
193 |
parsed = chain.run(response)
|
194 |
+
return parsed, response, coordinates
|
195 |
except Exception as oe:
|
196 |
print("Failing to parse the response", oe)
|
197 |
+
return None, response, coordinates
|
198 |
else:
|
199 |
+
return None, response, coordinates
|
200 |
|
201 |
def query_storage(self, query: str, doc_id, context_size=4):
|
202 |
documents = self._get_context(doc_id, query, context_size)
|
|
|
227 |
|
228 |
def _run_query(self, doc_id, query, context_size=4):
|
229 |
relevant_documents = self._get_context(doc_id, query, context_size)
|
230 |
+
relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
|
231 |
+
for doc in
|
232 |
+
relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)]
|
233 |
response = self.chain.run(input_documents=relevant_documents,
|
234 |
question=query)
|
235 |
|
236 |
if self.memory:
|
237 |
self.memory.save_context({"input": query}, {"output": response})
|
238 |
+
return response, relevant_document_coordinates
|
239 |
|
240 |
def _get_context(self, doc_id, query, context_size=4):
|
241 |
db = self.embeddings_dict[doc_id]
|
|
|
261 |
relevant_documents = multi_query_retriever.get_relevant_documents(query)
|
262 |
return relevant_documents
|
263 |
|
264 |
+
def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
|
265 |
"""
|
266 |
Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
|
267 |
"""
|
268 |
if verbose:
|
269 |
print("File", pdf_file_path)
|
270 |
filename = Path(pdf_file_path).stem
|
271 |
+
coordinates = True # if chunk_size == -1 else False
|
272 |
+
structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
|
273 |
|
274 |
biblio = structure['biblio']
|
275 |
biblio['filename'] = filename.replace(" ", "_")
|
|
|
281 |
metadatas = []
|
282 |
ids = []
|
283 |
|
284 |
+
if chunk_size > 0:
|
285 |
+
new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size)
|
286 |
+
else:
|
287 |
+
new_passages = structure['passages']
|
|
|
288 |
|
289 |
+
for passage in new_passages:
|
290 |
+
biblio_copy = copy.copy(biblio)
|
291 |
+
if len(str.strip(passage['text'])) > 0:
|
292 |
+
texts.append(passage['text'])
|
293 |
|
294 |
+
biblio_copy['type'] = passage['type']
|
295 |
+
biblio_copy['section'] = passage['section']
|
296 |
+
biblio_copy['subSection'] = passage['subSection']
|
297 |
+
biblio_copy['coordinates'] = passage['coordinates']
|
298 |
+
metadatas.append(biblio_copy)
|
299 |
+
|
300 |
+
# ids.append(passage['passage_id'])
|
301 |
+
|
302 |
+
ids = [id for id, t in enumerate(new_passages)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
return texts, metadatas, ids
|
305 |
|
306 |
+
def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
|
|
|
307 |
texts, metadata, ids = self.get_text_from_document(
|
308 |
pdf_path,
|
309 |
chunk_size=chunk_size,
|
310 |
+
perc_overlap=perc_overlap)
|
|
|
311 |
if doc_id:
|
312 |
hash = doc_id
|
313 |
else:
|
document_qa/grobid_processors.py
CHANGED
@@ -131,13 +131,13 @@ class GrobidProcessor(BaseProcessor):
|
|
131 |
# super().__init__()
|
132 |
self.grobid_client = grobid_client
|
133 |
|
134 |
-
def process_structure(self, input_path):
|
135 |
pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
|
136 |
input_path,
|
137 |
consolidate_header=True,
|
138 |
consolidate_citations=False,
|
139 |
segment_sentences=False,
|
140 |
-
tei_coordinates=
|
141 |
include_raw_citations=False,
|
142 |
include_raw_affiliations=False,
|
143 |
generateIDs=True)
|
@@ -145,7 +145,7 @@ class GrobidProcessor(BaseProcessor):
|
|
145 |
if status != 200:
|
146 |
return
|
147 |
|
148 |
-
output_data = self.parse_grobid_xml(text)
|
149 |
output_data['filename'] = Path(pdf_file).stem.replace(".tei", "")
|
150 |
|
151 |
return output_data
|
@@ -159,7 +159,7 @@ class GrobidProcessor(BaseProcessor):
|
|
159 |
|
160 |
return doc
|
161 |
|
162 |
-
def parse_grobid_xml(self, text):
|
163 |
output_data = OrderedDict()
|
164 |
|
165 |
doc_biblio = grobid_tei_xml.parse_document_xml(text)
|
@@ -176,61 +176,115 @@ class GrobidProcessor(BaseProcessor):
|
|
176 |
pass
|
177 |
|
178 |
output_data['biblio'] = biblio
|
179 |
-
|
180 |
passages = []
|
181 |
output_data['passages'] = passages
|
182 |
-
|
183 |
-
# passages.append({
|
184 |
-
# "text": self.post_process(biblio['title']),
|
185 |
-
# "type": "paragraph",
|
186 |
-
# "section": "<header>",
|
187 |
-
# "subSection": "<title>",
|
188 |
-
# "passage_id": "title0"
|
189 |
-
# })
|
190 |
-
|
191 |
-
if doc_biblio.abstract is not None and len(doc_biblio.abstract) > 0:
|
192 |
-
passages.append({
|
193 |
-
"text": self.post_process(doc_biblio.abstract),
|
194 |
-
"type": "paragraph",
|
195 |
-
"section": "<header>",
|
196 |
-
"subSection": "<abstract>",
|
197 |
-
"passage_id": "abstract0"
|
198 |
-
})
|
199 |
|
200 |
soup = BeautifulSoup(text, 'xml')
|
201 |
-
|
202 |
-
|
203 |
-
passages.
|
204 |
-
{
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
for
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
return output_data
|
236 |
|
@@ -526,6 +580,21 @@ class GrobidAggregationProcessor(GrobidProcessor, GrobidQuantitiesProcessor, Gro
|
|
526 |
def extract_materials(self, text):
|
527 |
return self.gmp.extract_materials(text)
|
528 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
@staticmethod
|
530 |
def prune_overlapping_annotations(entities: list) -> list:
|
531 |
# Sorting by offsets
|
@@ -731,25 +800,40 @@ def get_children_list_grobid(soup: object, use_paragraphs: object = True, verbos
|
|
731 |
return children
|
732 |
|
733 |
|
734 |
-
def
|
735 |
-
|
736 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
737 |
for child in soup.TEI.children:
|
738 |
if child.name == 'text':
|
739 |
-
|
|
|
|
|
740 |
|
741 |
if verbose:
|
742 |
-
print(str(
|
743 |
|
744 |
-
return
|
745 |
|
746 |
|
747 |
-
def
|
748 |
children = []
|
749 |
-
child_name = "p" if use_paragraphs else "s"
|
750 |
for child in soup.TEI.children:
|
751 |
if child.name == 'text':
|
752 |
-
children.extend(
|
|
|
753 |
|
754 |
if verbose:
|
755 |
print(str(children))
|
|
|
131 |
# super().__init__()
|
132 |
self.grobid_client = grobid_client
|
133 |
|
134 |
+
def process_structure(self, input_path, coordinates=False):
|
135 |
pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
|
136 |
input_path,
|
137 |
consolidate_header=True,
|
138 |
consolidate_citations=False,
|
139 |
segment_sentences=False,
|
140 |
+
tei_coordinates=coordinates,
|
141 |
include_raw_citations=False,
|
142 |
include_raw_affiliations=False,
|
143 |
generateIDs=True)
|
|
|
145 |
if status != 200:
|
146 |
return
|
147 |
|
148 |
+
output_data = self.parse_grobid_xml(text, coordinates=coordinates)
|
149 |
output_data['filename'] = Path(pdf_file).stem.replace(".tei", "")
|
150 |
|
151 |
return output_data
|
|
|
159 |
|
160 |
return doc
|
161 |
|
162 |
+
def parse_grobid_xml(self, text, coordinates=False):
|
163 |
output_data = OrderedDict()
|
164 |
|
165 |
doc_biblio = grobid_tei_xml.parse_document_xml(text)
|
|
|
176 |
pass
|
177 |
|
178 |
output_data['biblio'] = biblio
|
|
|
179 |
passages = []
|
180 |
output_data['passages'] = passages
|
181 |
+
passage_type = "paragraph"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
soup = BeautifulSoup(text, 'xml')
|
184 |
+
blocks_header = get_xml_nodes_header(soup, use_paragraphs=True)
|
185 |
+
|
186 |
+
passages.append({
|
187 |
+
"text": f"authors: {biblio['authors']}",
|
188 |
+
"type": passage_type,
|
189 |
+
"section": "<header>",
|
190 |
+
"subSection": "<title>",
|
191 |
+
"passage_id": "htitle",
|
192 |
+
"coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
|
193 |
+
blocks_header['authors']])
|
194 |
+
})
|
195 |
+
|
196 |
+
passages.append({
|
197 |
+
"text": self.post_process(" ".join([node.text for node in blocks_header['title']])),
|
198 |
+
"type": passage_type,
|
199 |
+
"section": "<header>",
|
200 |
+
"subSection": "<title>",
|
201 |
+
"passage_id": "htitle",
|
202 |
+
"coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
|
203 |
+
blocks_header['title']])
|
204 |
+
})
|
205 |
+
|
206 |
+
passages.append({
|
207 |
+
"text": self.post_process(
|
208 |
+
''.join(node.text for node in blocks_header['abstract'] for text in node.find_all(text=True) if
|
209 |
+
text.parent.name != "ref" or (
|
210 |
+
text.parent.name == "ref" and text.parent.attrs[
|
211 |
+
'type'] != 'bibr'))),
|
212 |
+
"type": passage_type,
|
213 |
+
"section": "<header>",
|
214 |
+
"subSection": "<abstract>",
|
215 |
+
"passage_id": "habstract",
|
216 |
+
"coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
|
217 |
+
blocks_header['abstract']])
|
218 |
+
})
|
219 |
+
|
220 |
+
text_blocks_body = get_xml_nodes_body(soup, verbose=False, use_paragraphs=True)
|
221 |
+
|
222 |
+
use_paragraphs = True
|
223 |
+
if not use_paragraphs:
|
224 |
+
passages.extend([
|
225 |
+
{
|
226 |
+
"text": self.post_process(''.join(text for text in sentence.find_all(text=True) if
|
227 |
+
text.parent.name != "ref" or (
|
228 |
+
text.parent.name == "ref" and text.parent.attrs[
|
229 |
+
'type'] != 'bibr'))),
|
230 |
+
"type": passage_type,
|
231 |
+
"section": "<body>",
|
232 |
+
"subSection": "<paragraph>",
|
233 |
+
"passage_id": str(paragraph_id),
|
234 |
+
"coordinates": paragraph['coords'] if coordinates and sentence.has_attr('coords') else ""
|
235 |
+
}
|
236 |
+
for paragraph_id, paragraph in enumerate(text_blocks_body) for
|
237 |
+
sentence_id, sentence in enumerate(paragraph)
|
238 |
+
])
|
239 |
+
else:
|
240 |
+
passages.extend([
|
241 |
+
{
|
242 |
+
"text": self.post_process(''.join(text for text in paragraph.find_all(text=True) if
|
243 |
+
text.parent.name != "ref" or (
|
244 |
+
text.parent.name == "ref" and text.parent.attrs[
|
245 |
+
'type'] != 'bibr'))),
|
246 |
+
"type": passage_type,
|
247 |
+
"section": "<body>",
|
248 |
+
"subSection": "<paragraph>",
|
249 |
+
"passage_id": str(paragraph_id),
|
250 |
+
"coordinates": paragraph['coords'] if coordinates and paragraph.has_attr('coords') else ""
|
251 |
+
}
|
252 |
+
for paragraph_id, paragraph in enumerate(text_blocks_body)
|
253 |
+
])
|
254 |
+
|
255 |
+
text_blocks_figures = get_xml_nodes_figures(soup, verbose=False)
|
256 |
+
|
257 |
+
if not use_paragraphs:
|
258 |
+
passages.extend([
|
259 |
+
{
|
260 |
+
"text": self.post_process(''.join(text for text in sentence.find_all(text=True) if
|
261 |
+
text.parent.name != "ref" or (
|
262 |
+
text.parent.name == "ref" and text.parent.attrs[
|
263 |
+
'type'] != 'bibr'))),
|
264 |
+
"type": passage_type,
|
265 |
+
"section": "<body>",
|
266 |
+
"subSection": "<figure>",
|
267 |
+
"passage_id": str(paragraph_id) + str(sentence_id),
|
268 |
+
"coordinates": sentence['coords'] if coordinates and 'coords' in sentence else ""
|
269 |
+
}
|
270 |
+
for paragraph_id, paragraph in enumerate(text_blocks_figures) for
|
271 |
+
sentence_id, sentence in enumerate(paragraph)
|
272 |
+
])
|
273 |
+
else:
|
274 |
+
passages.extend([
|
275 |
+
{
|
276 |
+
"text": self.post_process(''.join(text for text in paragraph.find_all(text=True) if
|
277 |
+
text.parent.name != "ref" or (
|
278 |
+
text.parent.name == "ref" and text.parent.attrs[
|
279 |
+
'type'] != 'bibr'))),
|
280 |
+
"type": passage_type,
|
281 |
+
"section": "<body>",
|
282 |
+
"subSection": "<figure>",
|
283 |
+
"passage_id": str(paragraph_id),
|
284 |
+
"coordinates": paragraph['coords'] if coordinates and paragraph.has_attr('coords') else ""
|
285 |
+
}
|
286 |
+
for paragraph_id, paragraph in enumerate(text_blocks_figures)
|
287 |
+
])
|
288 |
|
289 |
return output_data
|
290 |
|
|
|
580 |
def extract_materials(self, text):
|
581 |
return self.gmp.extract_materials(text)
|
582 |
|
583 |
+
@staticmethod
|
584 |
+
def box_to_dict(box, color=None, type=None):
|
585 |
+
|
586 |
+
if box is None or box == "" or len(box) < 5:
|
587 |
+
return {}
|
588 |
+
|
589 |
+
item = {"page": box[0], "x": box[1], "y": box[2], "width": box[3], "height": box[4]}
|
590 |
+
if color is not None:
|
591 |
+
item['color'] = color
|
592 |
+
|
593 |
+
if type:
|
594 |
+
item['type'] = type
|
595 |
+
|
596 |
+
return item
|
597 |
+
|
598 |
@staticmethod
|
599 |
def prune_overlapping_annotations(entities: list) -> list:
|
600 |
# Sorting by offsets
|
|
|
800 |
return children
|
801 |
|
802 |
|
803 |
+
def get_xml_nodes_header(soup: object, use_paragraphs: bool = True) -> list:
|
804 |
+
sub_tag = "p" if use_paragraphs else "s"
|
805 |
+
|
806 |
+
header_elements = {
|
807 |
+
"authors": [persNameNode for persNameNode in soup.teiHeader.find_all("persName")],
|
808 |
+
"abstract": [p_in_abstract for abstractNodes in soup.teiHeader.find_all("abstract") for p_in_abstract in
|
809 |
+
abstractNodes.find_all(sub_tag)],
|
810 |
+
"title": [soup.teiHeader.fileDesc.title]
|
811 |
+
}
|
812 |
+
|
813 |
+
return header_elements
|
814 |
+
|
815 |
+
|
816 |
+
def get_xml_nodes_body(soup: object, use_paragraphs: bool = True, verbose: bool = False) -> list:
|
817 |
+
nodes = []
|
818 |
+
tag_name = "p" if use_paragraphs else "s"
|
819 |
for child in soup.TEI.children:
|
820 |
if child.name == 'text':
|
821 |
+
# nodes.extend([subchild.find_all(tag_name) for subchild in child.find_all("body")])
|
822 |
+
nodes.extend(
|
823 |
+
[subsubchild for subchild in child.find_all("body") for subsubchild in subchild.find_all(tag_name)])
|
824 |
|
825 |
if verbose:
|
826 |
+
print(str(nodes))
|
827 |
|
828 |
+
return nodes
|
829 |
|
830 |
|
831 |
+
def get_xml_nodes_figures(soup: object, verbose: bool = False) -> list:
|
832 |
children = []
|
|
|
833 |
for child in soup.TEI.children:
|
834 |
if child.name == 'text':
|
835 |
+
children.extend(
|
836 |
+
[subchild for subchilds in child.find_all("body") for subchild in subchilds.find_all("figDesc")])
|
837 |
|
838 |
if verbose:
|
839 |
print(str(children))
|
requirements.txt
CHANGED
@@ -19,7 +19,9 @@ chromadb==0.4.19
|
|
19 |
tiktoken==0.4.0
|
20 |
openai==0.27.7
|
21 |
langchain==0.0.350
|
|
|
22 |
typing-inspect==0.9.0
|
23 |
typing_extensions==4.8.0
|
24 |
pydantic==2.4.2
|
25 |
-
sentence_transformers==2.2.2
|
|
|
|
19 |
tiktoken==0.4.0
|
20 |
openai==0.27.7
|
21 |
langchain==0.0.350
|
22 |
+
langchain-core==0.1.0
|
23 |
typing-inspect==0.9.0
|
24 |
typing_extensions==4.8.0
|
25 |
pydantic==2.4.2
|
26 |
+
sentence_transformers==2.2.2
|
27 |
+
streamlit-pdf-viewer
|
streamlit_app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import base64
|
2 |
import os
|
3 |
import re
|
4 |
from hashlib import blake2b
|
@@ -8,6 +7,7 @@ import dotenv
|
|
8 |
from grobid_quantities.quantities import QuantitiesAPI
|
9 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
10 |
from langchain.memory import ConversationBufferWindowMemory
|
|
|
11 |
|
12 |
dotenv.load_dotenv(override=True)
|
13 |
|
@@ -70,6 +70,18 @@ if 'memory' not in st.session_state:
|
|
70 |
if 'binary' not in st.session_state:
|
71 |
st.session_state['binary'] = None
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
st.set_page_config(
|
74 |
page_title="Scientific Document Insights Q/A",
|
75 |
page_icon="📝",
|
@@ -216,7 +228,9 @@ with st.sidebar:
|
|
216 |
st.session_state['model'] = model = st.selectbox(
|
217 |
"Model:",
|
218 |
options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
|
219 |
-
index=
|
|
|
|
|
220 |
placeholder="Select model",
|
221 |
help="Select the LLM model:",
|
222 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
|
@@ -291,21 +305,44 @@ question = st.chat_input(
|
|
291 |
|
292 |
with st.sidebar:
|
293 |
st.header("Settings")
|
294 |
-
mode = st.radio(
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
disabled=uploaded_file is not None)
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
st.session_state['ner_processing'] = st.checkbox("Identify materials and properties.")
|
305 |
st.markdown(
|
306 |
'The LLM responses undergo post-processing to extract <span style="color:orange">physical quantities, measurements</span>, and <span style="color:green">materials</span> mentions.',
|
307 |
unsafe_allow_html=True)
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
st.divider()
|
310 |
|
311 |
st.header("Documentation")
|
@@ -324,13 +361,6 @@ with st.sidebar:
|
|
324 |
st.markdown(
|
325 |
"""If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
|
326 |
|
327 |
-
|
328 |
-
@st.cache_resource
|
329 |
-
def get_pdf_display(binary):
|
330 |
-
base64_pdf = base64.b64encode(binary).decode('utf-8')
|
331 |
-
return F'<embed src="data:application/pdf;base64,{base64_pdf}" width="100%" height="700" type="application/pdf"></embed>'
|
332 |
-
|
333 |
-
|
334 |
if uploaded_file and not st.session_state.loaded_embeddings:
|
335 |
if model not in st.session_state['api_keys']:
|
336 |
st.error("Before uploading a document, you must enter the API key. ")
|
@@ -345,16 +375,31 @@ if uploaded_file and not st.session_state.loaded_embeddings:
|
|
345 |
|
346 |
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
|
347 |
chunk_size=chunk_size,
|
348 |
-
perc_overlap=0.1
|
349 |
-
include_biblio=True)
|
350 |
st.session_state['loaded_embeddings'] = True
|
351 |
st.session_state.messages = []
|
352 |
|
353 |
# timestamp = datetime.utcnow()
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
359 |
with right_column:
|
360 |
# css = '''
|
@@ -398,8 +443,18 @@ with right_column:
|
|
398 |
context_size=context_size)
|
399 |
elif mode == "LLM":
|
400 |
with st.spinner("Generating response..."):
|
401 |
-
_, text_response = st.session_state['rqa'][model].query_document(question,
|
402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
|
404 |
if not text_response:
|
405 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
@@ -418,11 +473,16 @@ with right_column:
|
|
418 |
st.write(text_response)
|
419 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
420 |
|
421 |
-
# if len(st.session_state.messages) > 1:
|
422 |
-
# last_answer = st.session_state.messages[len(st.session_state.messages)-1]
|
423 |
-
# if last_answer['role'] == "assistant":
|
424 |
-
# last_question = st.session_state.messages[len(st.session_state.messages)-2]
|
425 |
-
# st.session_state.memory.save_context({"input": last_question['content']}, {"output": last_answer['content']})
|
426 |
-
|
427 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
428 |
play_old_messages()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
from hashlib import blake2b
|
|
|
7 |
from grobid_quantities.quantities import QuantitiesAPI
|
8 |
from langchain.llms.huggingface_hub import HuggingFaceHub
|
9 |
from langchain.memory import ConversationBufferWindowMemory
|
10 |
+
from streamlit_pdf_viewer import pdf_viewer
|
11 |
|
12 |
dotenv.load_dotenv(override=True)
|
13 |
|
|
|
70 |
if 'binary' not in st.session_state:
|
71 |
st.session_state['binary'] = None
|
72 |
|
73 |
+
if 'annotations' not in st.session_state:
|
74 |
+
st.session_state['annotations'] = None
|
75 |
+
|
76 |
+
if 'should_show_annotations' not in st.session_state:
|
77 |
+
st.session_state['should_show_annotations'] = True
|
78 |
+
|
79 |
+
if 'pdf' not in st.session_state:
|
80 |
+
st.session_state['pdf'] = None
|
81 |
+
|
82 |
+
if 'pdf_rendering' not in st.session_state:
|
83 |
+
st.session_state['pdf_rendering'] = None
|
84 |
+
|
85 |
st.set_page_config(
|
86 |
page_title="Scientific Document Insights Q/A",
|
87 |
page_icon="📝",
|
|
|
228 |
st.session_state['model'] = model = st.selectbox(
|
229 |
"Model:",
|
230 |
options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
|
231 |
+
index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index(
|
232 |
+
"zephyr-7b-beta") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else (
|
233 |
+
OPENAI_MODELS + list(OPEN_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]),
|
234 |
placeholder="Select model",
|
235 |
help="Select the LLM model:",
|
236 |
disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
|
|
|
305 |
|
306 |
with st.sidebar:
|
307 |
st.header("Settings")
|
308 |
+
mode = st.radio(
|
309 |
+
"Query mode",
|
310 |
+
("LLM", "Embeddings"),
|
311 |
+
disabled=not uploaded_file,
|
312 |
+
index=0,
|
313 |
+
horizontal=True,
|
314 |
+
help="LLM will respond the question, Embedding will show the "
|
315 |
+
"paragraphs relevant to the question in the paper."
|
316 |
+
)
|
317 |
+
|
318 |
+
# Add a checkbox for showing annotations
|
319 |
+
# st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True)
|
320 |
+
# st.session_state['should_show_annotations'] = st.checkbox("Show annotations", value=True)
|
321 |
+
|
322 |
+
chunk_size = st.slider("Text chunks size", -1, 2000, value=-1,
|
323 |
+
help="Size of chunks in which split the document. -1: use paragraphs, > 0 paragraphs are aggregated.",
|
324 |
disabled=uploaded_file is not None)
|
325 |
+
if chunk_size == -1:
|
326 |
+
context_size = st.slider("Context size (paragraphs)", 3, 20, value=10,
|
327 |
+
help="Number of paragraphs to consider when answering a question",
|
328 |
+
disabled=not uploaded_file)
|
329 |
+
else:
|
330 |
+
context_size = st.slider("Context size (chunks)", 3, 10, value=4,
|
331 |
+
help="Number of chunks to consider when answering a question",
|
332 |
+
disabled=not uploaded_file)
|
333 |
|
334 |
st.session_state['ner_processing'] = st.checkbox("Identify materials and properties.")
|
335 |
st.markdown(
|
336 |
'The LLM responses undergo post-processing to extract <span style="color:orange">physical quantities, measurements</span>, and <span style="color:green">materials</span> mentions.',
|
337 |
unsafe_allow_html=True)
|
338 |
|
339 |
+
st.session_state['pdf_rendering'] = st.radio(
|
340 |
+
"PDF rendering mode",
|
341 |
+
{"PDF.JS", "Native browser engine"},
|
342 |
+
index=1,
|
343 |
+
disabled=not uploaded_file,
|
344 |
+
)
|
345 |
+
|
346 |
st.divider()
|
347 |
|
348 |
st.header("Documentation")
|
|
|
361 |
st.markdown(
|
362 |
"""If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
|
363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
if uploaded_file and not st.session_state.loaded_embeddings:
|
365 |
if model not in st.session_state['api_keys']:
|
366 |
st.error("Before uploading a document, you must enter the API key. ")
|
|
|
375 |
|
376 |
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
|
377 |
chunk_size=chunk_size,
|
378 |
+
perc_overlap=0.1)
|
|
|
379 |
st.session_state['loaded_embeddings'] = True
|
380 |
st.session_state.messages = []
|
381 |
|
382 |
# timestamp = datetime.utcnow()
|
383 |
|
384 |
+
|
385 |
+
def rgb_to_hex(rgb):
|
386 |
+
return "#{:02x}{:02x}{:02x}".format(*rgb)
|
387 |
+
|
388 |
+
|
389 |
+
def generate_color_gradient(num_elements):
|
390 |
+
# Define warm and cold colors in RGB format
|
391 |
+
warm_color = (255, 165, 0) # Orange
|
392 |
+
cold_color = (0, 0, 255) # Blue
|
393 |
+
|
394 |
+
# Generate a linear gradient of colors
|
395 |
+
color_gradient = [
|
396 |
+
rgb_to_hex(tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in
|
397 |
+
zip(warm_color, cold_color)))
|
398 |
+
for i in range(num_elements)
|
399 |
+
]
|
400 |
+
|
401 |
+
return color_gradient
|
402 |
+
|
403 |
|
404 |
with right_column:
|
405 |
# css = '''
|
|
|
443 |
context_size=context_size)
|
444 |
elif mode == "LLM":
|
445 |
with st.spinner("Generating response..."):
|
446 |
+
_, text_response, coordinates = st.session_state['rqa'][model].query_document(question,
|
447 |
+
st.session_state.doc_id,
|
448 |
+
context_size=context_size)
|
449 |
+
|
450 |
+
annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
|
451 |
+
for coord_doc in coordinates]
|
452 |
+
gradients = generate_color_gradient(len(annotations))
|
453 |
+
for i, color in enumerate(gradients):
|
454 |
+
for annotation in annotations[i]:
|
455 |
+
annotation['color'] = color
|
456 |
+
st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in
|
457 |
+
annotation_doc]
|
458 |
|
459 |
if not text_response:
|
460 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
|
|
473 |
st.write(text_response)
|
474 |
st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
|
475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
elif st.session_state.loaded_embeddings and st.session_state.doc_id:
|
477 |
play_old_messages()
|
478 |
+
|
479 |
+
with left_column:
|
480 |
+
if st.session_state['binary']:
|
481 |
+
pdf_viewer(
|
482 |
+
input=st.session_state['binary'],
|
483 |
+
width=600,
|
484 |
+
height=800,
|
485 |
+
annotation_outline_size=2,
|
486 |
+
annotations=st.session_state['annotations'],
|
487 |
+
rendering='unwrap' if st.session_state['pdf_rendering'] == 'PDF.JS' else 'legacy_embed'
|
488 |
+
)
|
tests/__init__.py
ADDED
File without changes
|
tests/conftest.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from unittest.mock import MagicMock
|
4 |
+
|
5 |
+
import pytest
|
6 |
+
from _pytest._py.path import LocalPath
|
7 |
+
|
8 |
+
# derived from https://github.com/elifesciences/sciencebeam-trainer-delft/tree/develop/tests
|
9 |
+
|
10 |
+
LOGGER = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
@pytest.fixture(scope='session', autouse=True)
|
14 |
+
def setup_logging():
|
15 |
+
logging.root.handlers = []
|
16 |
+
logging.basicConfig(level='INFO')
|
17 |
+
logging.getLogger('tests').setLevel('DEBUG')
|
18 |
+
# logging.getLogger('sciencebeam_trainer_delft').setLevel('DEBUG')
|
19 |
+
|
20 |
+
|
21 |
+
def _backport_assert_called(mock: MagicMock):
|
22 |
+
assert mock.called
|
23 |
+
|
24 |
+
|
25 |
+
@pytest.fixture(scope='session', autouse=True)
|
26 |
+
def patch_magicmock():
|
27 |
+
try:
|
28 |
+
MagicMock.assert_called
|
29 |
+
except AttributeError:
|
30 |
+
MagicMock.assert_called = _backport_assert_called
|
31 |
+
|
32 |
+
|
33 |
+
@pytest.fixture
|
34 |
+
def temp_dir(tmpdir: LocalPath):
|
35 |
+
# convert to standard Path
|
36 |
+
return Path(str(tmpdir))
|
37 |
+
|
tests/resources/2312.07559.paragraphs.tei.xml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tests/resources/2312.07559.sentences.tei.xml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tests/test_document_qa_engine.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from document_qa.document_qa_engine import TextMerger
|
2 |
+
|
3 |
+
|
4 |
+
def test_merge_passages_small_chunk():
|
5 |
+
merger = TextMerger()
|
6 |
+
|
7 |
+
passages = [
|
8 |
+
{
|
9 |
+
'text': "The quick brown fox jumps over the tree",
|
10 |
+
'coordinates': '1'
|
11 |
+
},
|
12 |
+
{
|
13 |
+
'text': "and went straight into the mouth of a bear.",
|
14 |
+
'coordinates': '2'
|
15 |
+
},
|
16 |
+
{
|
17 |
+
'text': "The color of the colors is a color with colors",
|
18 |
+
'coordinates': '3'
|
19 |
+
},
|
20 |
+
{
|
21 |
+
'text': "the main colors are not the colorw we show",
|
22 |
+
'coordinates': '4'
|
23 |
+
}
|
24 |
+
]
|
25 |
+
new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0)
|
26 |
+
|
27 |
+
assert len(new_passages) == 4
|
28 |
+
assert new_passages[0]['coordinates'] == "1"
|
29 |
+
assert new_passages[0]['text'] == "The quick brown fox jumps over the tree"
|
30 |
+
|
31 |
+
assert new_passages[1]['coordinates'] == "2"
|
32 |
+
assert new_passages[1]['text'] == "and went straight into the mouth of a bear."
|
33 |
+
|
34 |
+
assert new_passages[2]['coordinates'] == "3"
|
35 |
+
assert new_passages[2]['text'] == "The color of the colors is a color with colors"
|
36 |
+
|
37 |
+
assert new_passages[3]['coordinates'] == "4"
|
38 |
+
assert new_passages[3]['text'] == "the main colors are not the colorw we show"
|
39 |
+
|
40 |
+
|
41 |
+
def test_merge_passages_big_chunk():
|
42 |
+
merger = TextMerger()
|
43 |
+
|
44 |
+
passages = [
|
45 |
+
{
|
46 |
+
'text': "The quick brown fox jumps over the tree",
|
47 |
+
'coordinates': '1'
|
48 |
+
},
|
49 |
+
{
|
50 |
+
'text': "and went straight into the mouth of a bear.",
|
51 |
+
'coordinates': '2'
|
52 |
+
},
|
53 |
+
{
|
54 |
+
'text': "The color of the colors is a color with colors",
|
55 |
+
'coordinates': '3'
|
56 |
+
},
|
57 |
+
{
|
58 |
+
'text': "the main colors are not the colorw we show",
|
59 |
+
'coordinates': '4'
|
60 |
+
}
|
61 |
+
]
|
62 |
+
new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0)
|
63 |
+
|
64 |
+
assert len(new_passages) == 2
|
65 |
+
assert new_passages[0]['coordinates'] == "1;2"
|
66 |
+
assert new_passages[0][
|
67 |
+
'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear."
|
68 |
+
|
69 |
+
assert new_passages[1]['coordinates'] == "3;4"
|
70 |
+
assert new_passages[1][
|
71 |
+
'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show"
|
tests/test_grobid_processors.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bs4 import BeautifulSoup
|
2 |
+
from document_qa.grobid_processors import get_xml_nodes_body, get_xml_nodes_figures, get_xml_nodes_header
|
3 |
+
|
4 |
+
|
5 |
+
def test_get_xml_nodes_body_paragraphs():
|
6 |
+
with open("resources/2312.07559.paragraphs.tei.xml", 'r') as fo:
|
7 |
+
soup = BeautifulSoup(fo, 'xml')
|
8 |
+
|
9 |
+
nodes = get_xml_nodes_body(soup, use_paragraphs=True)
|
10 |
+
|
11 |
+
assert len(nodes) == 70
|
12 |
+
|
13 |
+
|
14 |
+
def test_get_xml_nodes_body_sentences():
|
15 |
+
with open("resources/2312.07559.sentences.tei.xml", 'r') as fo:
|
16 |
+
soup = BeautifulSoup(fo, 'xml')
|
17 |
+
|
18 |
+
children = get_xml_nodes_body(soup, use_paragraphs=False)
|
19 |
+
|
20 |
+
assert len(children) == 327
|
21 |
+
|
22 |
+
|
23 |
+
def test_get_xml_nodes_figures():
|
24 |
+
with open("resources/2312.07559.paragraphs.tei.xml", 'r') as fo:
|
25 |
+
soup = BeautifulSoup(fo, 'xml')
|
26 |
+
|
27 |
+
children = get_xml_nodes_figures(soup)
|
28 |
+
|
29 |
+
assert len(children) == 13
|
30 |
+
|
31 |
+
|
32 |
+
def test_get_xml_nodes_header_paragraphs():
|
33 |
+
with open("resources/2312.07559.paragraphs.tei.xml", 'r') as fo:
|
34 |
+
soup = BeautifulSoup(fo, 'xml')
|
35 |
+
|
36 |
+
children = get_xml_nodes_header(soup)
|
37 |
+
|
38 |
+
assert len(children) == 8
|
39 |
+
|
40 |
+
def test_get_xml_nodes_header_sentences():
|
41 |
+
with open("resources/2312.07559.sentences.tei.xml", 'r') as fo:
|
42 |
+
soup = BeautifulSoup(fo, 'xml')
|
43 |
+
|
44 |
+
children = get_xml_nodes_header(soup, use_paragraphs=False)
|
45 |
+
|
46 |
+
assert len(children) == 15
|