Luca Foppiano commited on
Commit
b042214
2 Parent(s): 8d140dd a7d9efc

Merge pull request #26 from lfoppiano/add-pdf-viewer

Browse files
document_qa/document_qa_engine.py CHANGED
@@ -3,18 +3,87 @@ import os
3
  from pathlib import Path
4
  from typing import Union, Any
5
 
6
- from document_qa.grobid_processors import GrobidProcessor
7
  from grobid_client.grobid_client import GrobidClient
8
- from langchain.chains import create_extraction_chain, ConversationChain, ConversationalRetrievalChain
9
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
10
  map_rerank_prompt
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
13
  from langchain.schema import Document
14
- from langchain.text_splitter import RecursiveCharacterTextSplitter
15
  from langchain.vectorstores import Chroma
16
  from tqdm import tqdm
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  class DocumentQAEngine:
@@ -44,6 +113,7 @@ class DocumentQAEngine:
44
  self.llm = llm
45
  self.memory = memory
46
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
 
47
 
48
  if embeddings_root_path is not None:
49
  self.embeddings_root_path = embeddings_root_path
@@ -57,7 +127,7 @@ class DocumentQAEngine:
57
  grobid_client = GrobidClient(
58
  grobid_server=self.grobid_url,
59
  batch_size=1000,
60
- coordinates=["p"],
61
  sleep_time=5,
62
  timeout=60,
63
  check_server=True
@@ -105,7 +175,7 @@ class DocumentQAEngine:
105
  if verbose:
106
  print(query)
107
 
108
- response = self._run_query(doc_id, query, context_size=context_size)
109
  response = response['output_text'] if 'output_text' in response else response
110
 
111
  if verbose:
@@ -116,17 +186,17 @@ class DocumentQAEngine:
116
  return self._parse_json(response, output_parser), response
117
  except Exception as oe:
118
  print("Failing to parse the response", oe)
119
- return None, response
120
  elif extraction_schema:
121
  try:
122
  chain = create_extraction_chain(extraction_schema, self.llm)
123
  parsed = chain.run(response)
124
- return parsed, response
125
  except Exception as oe:
126
  print("Failing to parse the response", oe)
127
- return None, response
128
  else:
129
- return None, response
130
 
131
  def query_storage(self, query: str, doc_id, context_size=4):
132
  documents = self._get_context(doc_id, query, context_size)
@@ -157,12 +227,15 @@ class DocumentQAEngine:
157
 
158
  def _run_query(self, doc_id, query, context_size=4):
159
  relevant_documents = self._get_context(doc_id, query, context_size)
 
 
 
160
  response = self.chain.run(input_documents=relevant_documents,
161
  question=query)
162
 
163
  if self.memory:
164
  self.memory.save_context({"input": query}, {"output": response})
165
- return response
166
 
167
  def _get_context(self, doc_id, query, context_size=4):
168
  db = self.embeddings_dict[doc_id]
@@ -188,14 +261,15 @@ class DocumentQAEngine:
188
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
189
  return relevant_documents
190
 
191
- def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, include=(), verbose=False):
192
  """
193
  Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
194
  """
195
  if verbose:
196
  print("File", pdf_file_path)
197
  filename = Path(pdf_file_path).stem
198
- structure = self.grobid_processor.process_structure(pdf_file_path)
 
199
 
200
  biblio = structure['biblio']
201
  biblio['filename'] = filename.replace(" ", "_")
@@ -207,48 +281,33 @@ class DocumentQAEngine:
207
  metadatas = []
208
  ids = []
209
 
210
- if chunk_size < 0:
211
- for passage in structure['passages']:
212
- biblio_copy = copy.copy(biblio)
213
- if len(str.strip(passage['text'])) > 0:
214
- texts.append(passage['text'])
215
 
216
- biblio_copy['type'] = passage['type']
217
- biblio_copy['section'] = passage['section']
218
- biblio_copy['subSection'] = passage['subSection']
219
- metadatas.append(biblio_copy)
220
 
221
- ids.append(passage['passage_id'])
222
- else:
223
- document_text = " ".join([passage['text'] for passage in structure['passages']])
224
- # text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
225
- text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
226
- chunk_size=chunk_size,
227
- chunk_overlap=chunk_size * perc_overlap
228
- )
229
- texts = text_splitter.split_text(document_text)
230
- metadatas = [biblio for _ in range(len(texts))]
231
- ids = [id for id, t in enumerate(texts)]
232
-
233
- if "biblio" in include:
234
- biblio_metadata = copy.copy(biblio)
235
- biblio_metadata['type'] = "biblio"
236
- biblio_metadata['section'] = "header"
237
- for key in ['title', 'authors', 'publication_year']:
238
- if key in biblio_metadata:
239
- texts.append("{}: {}".format(key, biblio_metadata[key]))
240
- metadatas.append(biblio_metadata)
241
- ids.append(key)
242
 
243
  return texts, metadatas, ids
244
 
245
- def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1, include_biblio=False):
246
- include = ["biblio"] if include_biblio else []
247
  texts, metadata, ids = self.get_text_from_document(
248
  pdf_path,
249
  chunk_size=chunk_size,
250
- perc_overlap=perc_overlap,
251
- include=include)
252
  if doc_id:
253
  hash = doc_id
254
  else:
 
3
  from pathlib import Path
4
  from typing import Union, Any
5
 
6
+ import tiktoken
7
  from grobid_client.grobid_client import GrobidClient
8
+ from langchain.chains import create_extraction_chain
9
  from langchain.chains.question_answering import load_qa_chain, stuff_prompt, refine_prompts, map_reduce_prompt, \
10
  map_rerank_prompt
11
  from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
12
  from langchain.retrievers import MultiQueryRetriever
13
  from langchain.schema import Document
 
14
  from langchain.vectorstores import Chroma
15
  from tqdm import tqdm
16
 
17
+ from document_qa.grobid_processors import GrobidProcessor
18
+
19
+
20
+ class TextMerger:
21
+ def __init__(self, model_name=None, encoding_name="gpt2"):
22
+ if model_name is not None:
23
+ self.enc = tiktoken.encoding_for_model(model_name)
24
+ else:
25
+ self.enc = tiktoken.get_encoding(encoding_name)
26
+
27
+ def encode(self, text, allowed_special=set(), disallowed_special="all"):
28
+ return self.enc.encode(
29
+ text,
30
+ allowed_special=allowed_special,
31
+ disallowed_special=disallowed_special,
32
+ )
33
+
34
+ def merge_passages(self, passages, chunk_size, tolerance=0.2):
35
+ new_passages = []
36
+ new_coordinates = []
37
+ current_texts = []
38
+ current_coordinates = []
39
+ for idx, passage in enumerate(passages):
40
+ text = passage['text']
41
+ coordinates = passage['coordinates']
42
+ current_texts.append(text)
43
+ current_coordinates.append(coordinates)
44
+
45
+ accumulated_text = " ".join(current_texts)
46
+
47
+ encoded_accumulated_text = self.encode(accumulated_text)
48
+
49
+ if len(encoded_accumulated_text) > chunk_size + chunk_size * tolerance:
50
+ if len(current_texts) > 1:
51
+ new_passages.append(current_texts[:-1])
52
+ new_coordinates.append(current_coordinates[:-1])
53
+ current_texts = [current_texts[-1]]
54
+ current_coordinates = [current_coordinates[-1]]
55
+ else:
56
+ new_passages.append(current_texts)
57
+ new_coordinates.append(current_coordinates)
58
+ current_texts = []
59
+ current_coordinates = []
60
+
61
+ elif chunk_size <= len(encoded_accumulated_text) < chunk_size + chunk_size * tolerance:
62
+ new_passages.append(current_texts)
63
+ new_coordinates.append(current_coordinates)
64
+ current_texts = []
65
+ current_coordinates = []
66
+
67
+ if len(current_texts) > 0:
68
+ new_passages.append(current_texts)
69
+ new_coordinates.append(current_coordinates)
70
+
71
+ new_passages_struct = []
72
+ for i, passages in enumerate(new_passages):
73
+ text = " ".join(passages)
74
+ coordinates = ";".join(new_coordinates[i])
75
+
76
+ new_passages_struct.append(
77
+ {
78
+ "text": text,
79
+ "coordinates": coordinates,
80
+ "type": "aggregated chunks",
81
+ "section": "mixed",
82
+ "subSection": "mixed"
83
+ }
84
+ )
85
+
86
+ return new_passages_struct
87
 
88
 
89
  class DocumentQAEngine:
 
113
  self.llm = llm
114
  self.memory = memory
115
  self.chain = load_qa_chain(llm, chain_type=qa_chain_type)
116
+ self.text_merger = TextMerger()
117
 
118
  if embeddings_root_path is not None:
119
  self.embeddings_root_path = embeddings_root_path
 
127
  grobid_client = GrobidClient(
128
  grobid_server=self.grobid_url,
129
  batch_size=1000,
130
+ coordinates=["p", "title", "persName"],
131
  sleep_time=5,
132
  timeout=60,
133
  check_server=True
 
175
  if verbose:
176
  print(query)
177
 
178
+ response, coordinates = self._run_query(doc_id, query, context_size=context_size)
179
  response = response['output_text'] if 'output_text' in response else response
180
 
181
  if verbose:
 
186
  return self._parse_json(response, output_parser), response
187
  except Exception as oe:
188
  print("Failing to parse the response", oe)
189
+ return None, response, coordinates
190
  elif extraction_schema:
191
  try:
192
  chain = create_extraction_chain(extraction_schema, self.llm)
193
  parsed = chain.run(response)
194
+ return parsed, response, coordinates
195
  except Exception as oe:
196
  print("Failing to parse the response", oe)
197
+ return None, response, coordinates
198
  else:
199
+ return None, response, coordinates
200
 
201
  def query_storage(self, query: str, doc_id, context_size=4):
202
  documents = self._get_context(doc_id, query, context_size)
 
227
 
228
  def _run_query(self, doc_id, query, context_size=4):
229
  relevant_documents = self._get_context(doc_id, query, context_size)
230
+ relevant_document_coordinates = [doc.metadata['coordinates'].split(";") if 'coordinates' in doc.metadata else []
231
+ for doc in
232
+ relevant_documents] # filter(lambda d: d['type'] == "sentence", relevant_documents)]
233
  response = self.chain.run(input_documents=relevant_documents,
234
  question=query)
235
 
236
  if self.memory:
237
  self.memory.save_context({"input": query}, {"output": response})
238
+ return response, relevant_document_coordinates
239
 
240
  def _get_context(self, doc_id, query, context_size=4):
241
  db = self.embeddings_dict[doc_id]
 
261
  relevant_documents = multi_query_retriever.get_relevant_documents(query)
262
  return relevant_documents
263
 
264
+ def get_text_from_document(self, pdf_file_path, chunk_size=-1, perc_overlap=0.1, verbose=False):
265
  """
266
  Extract text from documents using Grobid, if chunk_size is < 0 it keeps each paragraph separately
267
  """
268
  if verbose:
269
  print("File", pdf_file_path)
270
  filename = Path(pdf_file_path).stem
271
+ coordinates = True # if chunk_size == -1 else False
272
+ structure = self.grobid_processor.process_structure(pdf_file_path, coordinates=coordinates)
273
 
274
  biblio = structure['biblio']
275
  biblio['filename'] = filename.replace(" ", "_")
 
281
  metadatas = []
282
  ids = []
283
 
284
+ if chunk_size > 0:
285
+ new_passages = self.text_merger.merge_passages(structure['passages'], chunk_size=chunk_size)
286
+ else:
287
+ new_passages = structure['passages']
 
288
 
289
+ for passage in new_passages:
290
+ biblio_copy = copy.copy(biblio)
291
+ if len(str.strip(passage['text'])) > 0:
292
+ texts.append(passage['text'])
293
 
294
+ biblio_copy['type'] = passage['type']
295
+ biblio_copy['section'] = passage['section']
296
+ biblio_copy['subSection'] = passage['subSection']
297
+ biblio_copy['coordinates'] = passage['coordinates']
298
+ metadatas.append(biblio_copy)
299
+
300
+ # ids.append(passage['passage_id'])
301
+
302
+ ids = [id for id, t in enumerate(new_passages)]
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  return texts, metadatas, ids
305
 
306
+ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
 
307
  texts, metadata, ids = self.get_text_from_document(
308
  pdf_path,
309
  chunk_size=chunk_size,
310
+ perc_overlap=perc_overlap)
 
311
  if doc_id:
312
  hash = doc_id
313
  else:
document_qa/grobid_processors.py CHANGED
@@ -131,13 +131,13 @@ class GrobidProcessor(BaseProcessor):
131
  # super().__init__()
132
  self.grobid_client = grobid_client
133
 
134
- def process_structure(self, input_path):
135
  pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
136
  input_path,
137
  consolidate_header=True,
138
  consolidate_citations=False,
139
  segment_sentences=False,
140
- tei_coordinates=False,
141
  include_raw_citations=False,
142
  include_raw_affiliations=False,
143
  generateIDs=True)
@@ -145,7 +145,7 @@ class GrobidProcessor(BaseProcessor):
145
  if status != 200:
146
  return
147
 
148
- output_data = self.parse_grobid_xml(text)
149
  output_data['filename'] = Path(pdf_file).stem.replace(".tei", "")
150
 
151
  return output_data
@@ -159,7 +159,7 @@ class GrobidProcessor(BaseProcessor):
159
 
160
  return doc
161
 
162
- def parse_grobid_xml(self, text):
163
  output_data = OrderedDict()
164
 
165
  doc_biblio = grobid_tei_xml.parse_document_xml(text)
@@ -176,61 +176,115 @@ class GrobidProcessor(BaseProcessor):
176
  pass
177
 
178
  output_data['biblio'] = biblio
179
-
180
  passages = []
181
  output_data['passages'] = passages
182
- # if biblio['title'] is not None and len(biblio['title']) > 0:
183
- # passages.append({
184
- # "text": self.post_process(biblio['title']),
185
- # "type": "paragraph",
186
- # "section": "<header>",
187
- # "subSection": "<title>",
188
- # "passage_id": "title0"
189
- # })
190
-
191
- if doc_biblio.abstract is not None and len(doc_biblio.abstract) > 0:
192
- passages.append({
193
- "text": self.post_process(doc_biblio.abstract),
194
- "type": "paragraph",
195
- "section": "<header>",
196
- "subSection": "<abstract>",
197
- "passage_id": "abstract0"
198
- })
199
 
200
  soup = BeautifulSoup(text, 'xml')
201
- text_blocks_body = get_children_body(soup, verbose=False)
202
-
203
- passages.extend([
204
- {
205
- "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if
206
- text.parent.name != "ref" or (
207
- text.parent.name == "ref" and text.parent.attrs[
208
- 'type'] != 'bibr'))),
209
- "type": "paragraph",
210
- "section": "<body>",
211
- "subSection": "<paragraph>",
212
- "passage_id": str(paragraph_id) + str(sentence_id)
213
- }
214
- for paragraph_id, paragraph in enumerate(text_blocks_body) for
215
- sentence_id, sentence in enumerate(paragraph)
216
- ])
217
-
218
- text_blocks_figures = get_children_figures(soup, verbose=False)
219
-
220
- passages.extend([
221
- {
222
- "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if
223
- text.parent.name != "ref" or (
224
- text.parent.name == "ref" and text.parent.attrs[
225
- 'type'] != 'bibr'))),
226
- "type": "paragraph",
227
- "section": "<body>",
228
- "subSection": "<figure>",
229
- "passage_id": str(paragraph_id) + str(sentence_id)
230
- }
231
- for paragraph_id, paragraph in enumerate(text_blocks_figures) for
232
- sentence_id, sentence in enumerate(paragraph)
233
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  return output_data
236
 
@@ -526,6 +580,21 @@ class GrobidAggregationProcessor(GrobidProcessor, GrobidQuantitiesProcessor, Gro
526
  def extract_materials(self, text):
527
  return self.gmp.extract_materials(text)
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  @staticmethod
530
  def prune_overlapping_annotations(entities: list) -> list:
531
  # Sorting by offsets
@@ -731,25 +800,40 @@ def get_children_list_grobid(soup: object, use_paragraphs: object = True, verbos
731
  return children
732
 
733
 
734
- def get_children_body(soup: object, use_paragraphs: object = True, verbose: object = False) -> object:
735
- children = []
736
- child_name = "p" if use_paragraphs else "s"
 
 
 
 
 
 
 
 
 
 
 
 
 
737
  for child in soup.TEI.children:
738
  if child.name == 'text':
739
- children.extend([subchild.find_all(child_name) for subchild in child.find_all("body")])
 
 
740
 
741
  if verbose:
742
- print(str(children))
743
 
744
- return children
745
 
746
 
747
- def get_children_figures(soup: object, use_paragraphs: object = True, verbose: object = False) -> object:
748
  children = []
749
- child_name = "p" if use_paragraphs else "s"
750
  for child in soup.TEI.children:
751
  if child.name == 'text':
752
- children.extend([subchild.find_all("figDesc") for subchild in child.find_all("body")])
 
753
 
754
  if verbose:
755
  print(str(children))
 
131
  # super().__init__()
132
  self.grobid_client = grobid_client
133
 
134
+ def process_structure(self, input_path, coordinates=False):
135
  pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
136
  input_path,
137
  consolidate_header=True,
138
  consolidate_citations=False,
139
  segment_sentences=False,
140
+ tei_coordinates=coordinates,
141
  include_raw_citations=False,
142
  include_raw_affiliations=False,
143
  generateIDs=True)
 
145
  if status != 200:
146
  return
147
 
148
+ output_data = self.parse_grobid_xml(text, coordinates=coordinates)
149
  output_data['filename'] = Path(pdf_file).stem.replace(".tei", "")
150
 
151
  return output_data
 
159
 
160
  return doc
161
 
162
+ def parse_grobid_xml(self, text, coordinates=False):
163
  output_data = OrderedDict()
164
 
165
  doc_biblio = grobid_tei_xml.parse_document_xml(text)
 
176
  pass
177
 
178
  output_data['biblio'] = biblio
 
179
  passages = []
180
  output_data['passages'] = passages
181
+ passage_type = "paragraph"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  soup = BeautifulSoup(text, 'xml')
184
+ blocks_header = get_xml_nodes_header(soup, use_paragraphs=True)
185
+
186
+ passages.append({
187
+ "text": f"authors: {biblio['authors']}",
188
+ "type": passage_type,
189
+ "section": "<header>",
190
+ "subSection": "<title>",
191
+ "passage_id": "htitle",
192
+ "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
193
+ blocks_header['authors']])
194
+ })
195
+
196
+ passages.append({
197
+ "text": self.post_process(" ".join([node.text for node in blocks_header['title']])),
198
+ "type": passage_type,
199
+ "section": "<header>",
200
+ "subSection": "<title>",
201
+ "passage_id": "htitle",
202
+ "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
203
+ blocks_header['title']])
204
+ })
205
+
206
+ passages.append({
207
+ "text": self.post_process(
208
+ ''.join(node.text for node in blocks_header['abstract'] for text in node.find_all(text=True) if
209
+ text.parent.name != "ref" or (
210
+ text.parent.name == "ref" and text.parent.attrs[
211
+ 'type'] != 'bibr'))),
212
+ "type": passage_type,
213
+ "section": "<header>",
214
+ "subSection": "<abstract>",
215
+ "passage_id": "habstract",
216
+ "coordinates": ";".join([node['coords'] if coordinates and node.has_attr('coords') else "" for node in
217
+ blocks_header['abstract']])
218
+ })
219
+
220
+ text_blocks_body = get_xml_nodes_body(soup, verbose=False, use_paragraphs=True)
221
+
222
+ use_paragraphs = True
223
+ if not use_paragraphs:
224
+ passages.extend([
225
+ {
226
+ "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if
227
+ text.parent.name != "ref" or (
228
+ text.parent.name == "ref" and text.parent.attrs[
229
+ 'type'] != 'bibr'))),
230
+ "type": passage_type,
231
+ "section": "<body>",
232
+ "subSection": "<paragraph>",
233
+ "passage_id": str(paragraph_id),
234
+ "coordinates": paragraph['coords'] if coordinates and sentence.has_attr('coords') else ""
235
+ }
236
+ for paragraph_id, paragraph in enumerate(text_blocks_body) for
237
+ sentence_id, sentence in enumerate(paragraph)
238
+ ])
239
+ else:
240
+ passages.extend([
241
+ {
242
+ "text": self.post_process(''.join(text for text in paragraph.find_all(text=True) if
243
+ text.parent.name != "ref" or (
244
+ text.parent.name == "ref" and text.parent.attrs[
245
+ 'type'] != 'bibr'))),
246
+ "type": passage_type,
247
+ "section": "<body>",
248
+ "subSection": "<paragraph>",
249
+ "passage_id": str(paragraph_id),
250
+ "coordinates": paragraph['coords'] if coordinates and paragraph.has_attr('coords') else ""
251
+ }
252
+ for paragraph_id, paragraph in enumerate(text_blocks_body)
253
+ ])
254
+
255
+ text_blocks_figures = get_xml_nodes_figures(soup, verbose=False)
256
+
257
+ if not use_paragraphs:
258
+ passages.extend([
259
+ {
260
+ "text": self.post_process(''.join(text for text in sentence.find_all(text=True) if
261
+ text.parent.name != "ref" or (
262
+ text.parent.name == "ref" and text.parent.attrs[
263
+ 'type'] != 'bibr'))),
264
+ "type": passage_type,
265
+ "section": "<body>",
266
+ "subSection": "<figure>",
267
+ "passage_id": str(paragraph_id) + str(sentence_id),
268
+ "coordinates": sentence['coords'] if coordinates and 'coords' in sentence else ""
269
+ }
270
+ for paragraph_id, paragraph in enumerate(text_blocks_figures) for
271
+ sentence_id, sentence in enumerate(paragraph)
272
+ ])
273
+ else:
274
+ passages.extend([
275
+ {
276
+ "text": self.post_process(''.join(text for text in paragraph.find_all(text=True) if
277
+ text.parent.name != "ref" or (
278
+ text.parent.name == "ref" and text.parent.attrs[
279
+ 'type'] != 'bibr'))),
280
+ "type": passage_type,
281
+ "section": "<body>",
282
+ "subSection": "<figure>",
283
+ "passage_id": str(paragraph_id),
284
+ "coordinates": paragraph['coords'] if coordinates and paragraph.has_attr('coords') else ""
285
+ }
286
+ for paragraph_id, paragraph in enumerate(text_blocks_figures)
287
+ ])
288
 
289
  return output_data
290
 
 
580
  def extract_materials(self, text):
581
  return self.gmp.extract_materials(text)
582
 
583
+ @staticmethod
584
+ def box_to_dict(box, color=None, type=None):
585
+
586
+ if box is None or box == "" or len(box) < 5:
587
+ return {}
588
+
589
+ item = {"page": box[0], "x": box[1], "y": box[2], "width": box[3], "height": box[4]}
590
+ if color is not None:
591
+ item['color'] = color
592
+
593
+ if type:
594
+ item['type'] = type
595
+
596
+ return item
597
+
598
  @staticmethod
599
  def prune_overlapping_annotations(entities: list) -> list:
600
  # Sorting by offsets
 
800
  return children
801
 
802
 
803
+ def get_xml_nodes_header(soup: object, use_paragraphs: bool = True) -> list:
804
+ sub_tag = "p" if use_paragraphs else "s"
805
+
806
+ header_elements = {
807
+ "authors": [persNameNode for persNameNode in soup.teiHeader.find_all("persName")],
808
+ "abstract": [p_in_abstract for abstractNodes in soup.teiHeader.find_all("abstract") for p_in_abstract in
809
+ abstractNodes.find_all(sub_tag)],
810
+ "title": [soup.teiHeader.fileDesc.title]
811
+ }
812
+
813
+ return header_elements
814
+
815
+
816
+ def get_xml_nodes_body(soup: object, use_paragraphs: bool = True, verbose: bool = False) -> list:
817
+ nodes = []
818
+ tag_name = "p" if use_paragraphs else "s"
819
  for child in soup.TEI.children:
820
  if child.name == 'text':
821
+ # nodes.extend([subchild.find_all(tag_name) for subchild in child.find_all("body")])
822
+ nodes.extend(
823
+ [subsubchild for subchild in child.find_all("body") for subsubchild in subchild.find_all(tag_name)])
824
 
825
  if verbose:
826
+ print(str(nodes))
827
 
828
+ return nodes
829
 
830
 
831
+ def get_xml_nodes_figures(soup: object, verbose: bool = False) -> list:
832
  children = []
 
833
  for child in soup.TEI.children:
834
  if child.name == 'text':
835
+ children.extend(
836
+ [subchild for subchilds in child.find_all("body") for subchild in subchilds.find_all("figDesc")])
837
 
838
  if verbose:
839
  print(str(children))
requirements.txt CHANGED
@@ -19,7 +19,9 @@ chromadb==0.4.19
19
  tiktoken==0.4.0
20
  openai==0.27.7
21
  langchain==0.0.350
 
22
  typing-inspect==0.9.0
23
  typing_extensions==4.8.0
24
  pydantic==2.4.2
25
- sentence_transformers==2.2.2
 
 
19
  tiktoken==0.4.0
20
  openai==0.27.7
21
  langchain==0.0.350
22
+ langchain-core==0.1.0
23
  typing-inspect==0.9.0
24
  typing_extensions==4.8.0
25
  pydantic==2.4.2
26
+ sentence_transformers==2.2.2
27
+ streamlit-pdf-viewer
streamlit_app.py CHANGED
@@ -1,4 +1,3 @@
1
- import base64
2
  import os
3
  import re
4
  from hashlib import blake2b
@@ -8,6 +7,7 @@ import dotenv
8
  from grobid_quantities.quantities import QuantitiesAPI
9
  from langchain.llms.huggingface_hub import HuggingFaceHub
10
  from langchain.memory import ConversationBufferWindowMemory
 
11
 
12
  dotenv.load_dotenv(override=True)
13
 
@@ -70,6 +70,18 @@ if 'memory' not in st.session_state:
70
  if 'binary' not in st.session_state:
71
  st.session_state['binary'] = None
72
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  st.set_page_config(
74
  page_title="Scientific Document Insights Q/A",
75
  page_icon="📝",
@@ -216,7 +228,9 @@ with st.sidebar:
216
  st.session_state['model'] = model = st.selectbox(
217
  "Model:",
218
  options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
219
- index=4,
 
 
220
  placeholder="Select model",
221
  help="Select the LLM model:",
222
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
@@ -291,21 +305,44 @@ question = st.chat_input(
291
 
292
  with st.sidebar:
293
  st.header("Settings")
294
- mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0, horizontal=True,
295
- help="LLM will respond the question, Embedding will show the "
296
- "paragraphs relevant to the question in the paper.")
297
- chunk_size = st.slider("Chunks size", 100, 2000, value=250,
298
- help="Size of chunks in which the document is partitioned",
 
 
 
 
 
 
 
 
 
 
 
299
  disabled=uploaded_file is not None)
300
- context_size = st.slider("Context size", 3, 10, value=4,
301
- help="Number of chunks to consider when answering a question",
302
- disabled=not uploaded_file)
 
 
 
 
 
303
 
304
  st.session_state['ner_processing'] = st.checkbox("Identify materials and properties.")
305
  st.markdown(
306
  'The LLM responses undergo post-processing to extract <span style="color:orange">physical quantities, measurements</span>, and <span style="color:green">materials</span> mentions.',
307
  unsafe_allow_html=True)
308
 
 
 
 
 
 
 
 
309
  st.divider()
310
 
311
  st.header("Documentation")
@@ -324,13 +361,6 @@ with st.sidebar:
324
  st.markdown(
325
  """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
326
 
327
-
328
- @st.cache_resource
329
- def get_pdf_display(binary):
330
- base64_pdf = base64.b64encode(binary).decode('utf-8')
331
- return F'<embed src="data:application/pdf;base64,{base64_pdf}" width="100%" height="700" type="application/pdf"></embed>'
332
-
333
-
334
  if uploaded_file and not st.session_state.loaded_embeddings:
335
  if model not in st.session_state['api_keys']:
336
  st.error("Before uploading a document, you must enter the API key. ")
@@ -345,16 +375,31 @@ if uploaded_file and not st.session_state.loaded_embeddings:
345
 
346
  st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
347
  chunk_size=chunk_size,
348
- perc_overlap=0.1,
349
- include_biblio=True)
350
  st.session_state['loaded_embeddings'] = True
351
  st.session_state.messages = []
352
 
353
  # timestamp = datetime.utcnow()
354
 
355
- with left_column:
356
- if st.session_state['binary']:
357
- left_column.markdown(get_pdf_display(st.session_state['binary']), unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  with right_column:
360
  # css = '''
@@ -398,8 +443,18 @@ with right_column:
398
  context_size=context_size)
399
  elif mode == "LLM":
400
  with st.spinner("Generating response..."):
401
- _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
402
- context_size=context_size)
 
 
 
 
 
 
 
 
 
 
403
 
404
  if not text_response:
405
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
@@ -418,11 +473,16 @@ with right_column:
418
  st.write(text_response)
419
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
420
 
421
- # if len(st.session_state.messages) > 1:
422
- # last_answer = st.session_state.messages[len(st.session_state.messages)-1]
423
- # if last_answer['role'] == "assistant":
424
- # last_question = st.session_state.messages[len(st.session_state.messages)-2]
425
- # st.session_state.memory.save_context({"input": last_question['content']}, {"output": last_answer['content']})
426
-
427
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
428
  play_old_messages()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  from hashlib import blake2b
 
7
  from grobid_quantities.quantities import QuantitiesAPI
8
  from langchain.llms.huggingface_hub import HuggingFaceHub
9
  from langchain.memory import ConversationBufferWindowMemory
10
+ from streamlit_pdf_viewer import pdf_viewer
11
 
12
  dotenv.load_dotenv(override=True)
13
 
 
70
  if 'binary' not in st.session_state:
71
  st.session_state['binary'] = None
72
 
73
+ if 'annotations' not in st.session_state:
74
+ st.session_state['annotations'] = None
75
+
76
+ if 'should_show_annotations' not in st.session_state:
77
+ st.session_state['should_show_annotations'] = True
78
+
79
+ if 'pdf' not in st.session_state:
80
+ st.session_state['pdf'] = None
81
+
82
+ if 'pdf_rendering' not in st.session_state:
83
+ st.session_state['pdf_rendering'] = None
84
+
85
  st.set_page_config(
86
  page_title="Scientific Document Insights Q/A",
87
  page_icon="📝",
 
228
  st.session_state['model'] = model = st.selectbox(
229
  "Model:",
230
  options=OPENAI_MODELS + list(OPEN_MODELS.keys()),
231
+ index=(OPENAI_MODELS + list(OPEN_MODELS.keys())).index(
232
+ "zephyr-7b-beta") if "DEFAULT_MODEL" not in os.environ or not os.environ["DEFAULT_MODEL"] else (
233
+ OPENAI_MODELS + list(OPEN_MODELS.keys())).index(os.environ["DEFAULT_MODEL"]),
234
  placeholder="Select model",
235
  help="Select the LLM model:",
236
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded']
 
305
 
306
  with st.sidebar:
307
  st.header("Settings")
308
+ mode = st.radio(
309
+ "Query mode",
310
+ ("LLM", "Embeddings"),
311
+ disabled=not uploaded_file,
312
+ index=0,
313
+ horizontal=True,
314
+ help="LLM will respond the question, Embedding will show the "
315
+ "paragraphs relevant to the question in the paper."
316
+ )
317
+
318
+ # Add a checkbox for showing annotations
319
+ # st.session_state['show_annotations'] = st.checkbox("Show annotations", value=True)
320
+ # st.session_state['should_show_annotations'] = st.checkbox("Show annotations", value=True)
321
+
322
+ chunk_size = st.slider("Text chunks size", -1, 2000, value=-1,
323
+ help="Size of chunks in which split the document. -1: use paragraphs, > 0 paragraphs are aggregated.",
324
  disabled=uploaded_file is not None)
325
+ if chunk_size == -1:
326
+ context_size = st.slider("Context size (paragraphs)", 3, 20, value=10,
327
+ help="Number of paragraphs to consider when answering a question",
328
+ disabled=not uploaded_file)
329
+ else:
330
+ context_size = st.slider("Context size (chunks)", 3, 10, value=4,
331
+ help="Number of chunks to consider when answering a question",
332
+ disabled=not uploaded_file)
333
 
334
  st.session_state['ner_processing'] = st.checkbox("Identify materials and properties.")
335
  st.markdown(
336
  'The LLM responses undergo post-processing to extract <span style="color:orange">physical quantities, measurements</span>, and <span style="color:green">materials</span> mentions.',
337
  unsafe_allow_html=True)
338
 
339
+ st.session_state['pdf_rendering'] = st.radio(
340
+ "PDF rendering mode",
341
+ {"PDF.JS", "Native browser engine"},
342
+ index=1,
343
+ disabled=not uploaded_file,
344
+ )
345
+
346
  st.divider()
347
 
348
  st.header("Documentation")
 
361
  st.markdown(
362
  """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
363
 
 
 
 
 
 
 
 
364
  if uploaded_file and not st.session_state.loaded_embeddings:
365
  if model not in st.session_state['api_keys']:
366
  st.error("Before uploading a document, you must enter the API key. ")
 
375
 
376
  st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
377
  chunk_size=chunk_size,
378
+ perc_overlap=0.1)
 
379
  st.session_state['loaded_embeddings'] = True
380
  st.session_state.messages = []
381
 
382
  # timestamp = datetime.utcnow()
383
 
384
+
385
+ def rgb_to_hex(rgb):
386
+ return "#{:02x}{:02x}{:02x}".format(*rgb)
387
+
388
+
389
+ def generate_color_gradient(num_elements):
390
+ # Define warm and cold colors in RGB format
391
+ warm_color = (255, 165, 0) # Orange
392
+ cold_color = (0, 0, 255) # Blue
393
+
394
+ # Generate a linear gradient of colors
395
+ color_gradient = [
396
+ rgb_to_hex(tuple(int(warm * (1 - i / num_elements) + cold * (i / num_elements)) for warm, cold in
397
+ zip(warm_color, cold_color)))
398
+ for i in range(num_elements)
399
+ ]
400
+
401
+ return color_gradient
402
+
403
 
404
  with right_column:
405
  # css = '''
 
443
  context_size=context_size)
444
  elif mode == "LLM":
445
  with st.spinner("Generating response..."):
446
+ _, text_response, coordinates = st.session_state['rqa'][model].query_document(question,
447
+ st.session_state.doc_id,
448
+ context_size=context_size)
449
+
450
+ annotations = [[GrobidAggregationProcessor.box_to_dict([cs for cs in c.split(",")]) for c in coord_doc]
451
+ for coord_doc in coordinates]
452
+ gradients = generate_color_gradient(len(annotations))
453
+ for i, color in enumerate(gradients):
454
+ for annotation in annotations[i]:
455
+ annotation['color'] = color
456
+ st.session_state['annotations'] = [annotation for annotation_doc in annotations for annotation in
457
+ annotation_doc]
458
 
459
  if not text_response:
460
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
 
473
  st.write(text_response)
474
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
475
 
 
 
 
 
 
 
476
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
477
  play_old_messages()
478
+
479
+ with left_column:
480
+ if st.session_state['binary']:
481
+ pdf_viewer(
482
+ input=st.session_state['binary'],
483
+ width=600,
484
+ height=800,
485
+ annotation_outline_size=2,
486
+ annotations=st.session_state['annotations'],
487
+ rendering='unwrap' if st.session_state['pdf_rendering'] == 'PDF.JS' else 'legacy_embed'
488
+ )
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from unittest.mock import MagicMock
4
+
5
+ import pytest
6
+ from _pytest._py.path import LocalPath
7
+
8
+ # derived from https://github.com/elifesciences/sciencebeam-trainer-delft/tree/develop/tests
9
+
10
+ LOGGER = logging.getLogger(__name__)
11
+
12
+
13
+ @pytest.fixture(scope='session', autouse=True)
14
+ def setup_logging():
15
+ logging.root.handlers = []
16
+ logging.basicConfig(level='INFO')
17
+ logging.getLogger('tests').setLevel('DEBUG')
18
+ # logging.getLogger('sciencebeam_trainer_delft').setLevel('DEBUG')
19
+
20
+
21
+ def _backport_assert_called(mock: MagicMock):
22
+ assert mock.called
23
+
24
+
25
+ @pytest.fixture(scope='session', autouse=True)
26
+ def patch_magicmock():
27
+ try:
28
+ MagicMock.assert_called
29
+ except AttributeError:
30
+ MagicMock.assert_called = _backport_assert_called
31
+
32
+
33
+ @pytest.fixture
34
+ def temp_dir(tmpdir: LocalPath):
35
+ # convert to standard Path
36
+ return Path(str(tmpdir))
37
+
tests/resources/2312.07559.paragraphs.tei.xml ADDED
The diff for this file is too large to render. See raw diff
 
tests/resources/2312.07559.sentences.tei.xml ADDED
The diff for this file is too large to render. See raw diff
 
tests/test_document_qa_engine.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from document_qa.document_qa_engine import TextMerger
2
+
3
+
4
+ def test_merge_passages_small_chunk():
5
+ merger = TextMerger()
6
+
7
+ passages = [
8
+ {
9
+ 'text': "The quick brown fox jumps over the tree",
10
+ 'coordinates': '1'
11
+ },
12
+ {
13
+ 'text': "and went straight into the mouth of a bear.",
14
+ 'coordinates': '2'
15
+ },
16
+ {
17
+ 'text': "The color of the colors is a color with colors",
18
+ 'coordinates': '3'
19
+ },
20
+ {
21
+ 'text': "the main colors are not the colorw we show",
22
+ 'coordinates': '4'
23
+ }
24
+ ]
25
+ new_passages = merger.merge_passages(passages, chunk_size=10, tolerance=0)
26
+
27
+ assert len(new_passages) == 4
28
+ assert new_passages[0]['coordinates'] == "1"
29
+ assert new_passages[0]['text'] == "The quick brown fox jumps over the tree"
30
+
31
+ assert new_passages[1]['coordinates'] == "2"
32
+ assert new_passages[1]['text'] == "and went straight into the mouth of a bear."
33
+
34
+ assert new_passages[2]['coordinates'] == "3"
35
+ assert new_passages[2]['text'] == "The color of the colors is a color with colors"
36
+
37
+ assert new_passages[3]['coordinates'] == "4"
38
+ assert new_passages[3]['text'] == "the main colors are not the colorw we show"
39
+
40
+
41
+ def test_merge_passages_big_chunk():
42
+ merger = TextMerger()
43
+
44
+ passages = [
45
+ {
46
+ 'text': "The quick brown fox jumps over the tree",
47
+ 'coordinates': '1'
48
+ },
49
+ {
50
+ 'text': "and went straight into the mouth of a bear.",
51
+ 'coordinates': '2'
52
+ },
53
+ {
54
+ 'text': "The color of the colors is a color with colors",
55
+ 'coordinates': '3'
56
+ },
57
+ {
58
+ 'text': "the main colors are not the colorw we show",
59
+ 'coordinates': '4'
60
+ }
61
+ ]
62
+ new_passages = merger.merge_passages(passages, chunk_size=20, tolerance=0)
63
+
64
+ assert len(new_passages) == 2
65
+ assert new_passages[0]['coordinates'] == "1;2"
66
+ assert new_passages[0][
67
+ 'text'] == "The quick brown fox jumps over the tree and went straight into the mouth of a bear."
68
+
69
+ assert new_passages[1]['coordinates'] == "3;4"
70
+ assert new_passages[1][
71
+ 'text'] == "The color of the colors is a color with colors the main colors are not the colorw we show"
tests/test_grobid_processors.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bs4 import BeautifulSoup
2
+ from document_qa.grobid_processors import get_xml_nodes_body, get_xml_nodes_figures, get_xml_nodes_header
3
+
4
+
5
+ def test_get_xml_nodes_body_paragraphs():
6
+ with open("resources/2312.07559.paragraphs.tei.xml", 'r') as fo:
7
+ soup = BeautifulSoup(fo, 'xml')
8
+
9
+ nodes = get_xml_nodes_body(soup, use_paragraphs=True)
10
+
11
+ assert len(nodes) == 70
12
+
13
+
14
+ def test_get_xml_nodes_body_sentences():
15
+ with open("resources/2312.07559.sentences.tei.xml", 'r') as fo:
16
+ soup = BeautifulSoup(fo, 'xml')
17
+
18
+ children = get_xml_nodes_body(soup, use_paragraphs=False)
19
+
20
+ assert len(children) == 327
21
+
22
+
23
+ def test_get_xml_nodes_figures():
24
+ with open("resources/2312.07559.paragraphs.tei.xml", 'r') as fo:
25
+ soup = BeautifulSoup(fo, 'xml')
26
+
27
+ children = get_xml_nodes_figures(soup)
28
+
29
+ assert len(children) == 13
30
+
31
+
32
+ def test_get_xml_nodes_header_paragraphs():
33
+ with open("resources/2312.07559.paragraphs.tei.xml", 'r') as fo:
34
+ soup = BeautifulSoup(fo, 'xml')
35
+
36
+ children = get_xml_nodes_header(soup)
37
+
38
+ assert len(children) == 8
39
+
40
+ def test_get_xml_nodes_header_sentences():
41
+ with open("resources/2312.07559.sentences.tei.xml", 'r') as fo:
42
+ soup = BeautifulSoup(fo, 'xml')
43
+
44
+ children = get_xml_nodes_header(soup, use_paragraphs=False)
45
+
46
+ assert len(children) == 15