Raghuan commited on
Commit
8b0a3c9
1 Parent(s): da58bde

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +480 -0
  2. packages.txt +4 -0
  3. requirements.txt +15 -0
  4. utils.py +53 -0
app.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import chromadb
3
+ import gc
4
+ import gradio as gr
5
+ import io
6
+ import numpy as np
7
+ import os
8
+ import pandas as pd
9
+ import pymupdf
10
+ from pypdf import PdfReader
11
+ import spaces
12
+ import torch
13
+ from PIL import Image
14
+ from chromadb.utils import embedding_functions
15
+ from chromadb.utils.data_loaders import ImageLoader
16
+ from doctr.io import DocumentFile
17
+ from doctr.models import ocr_predictor
18
+ from gradio.themes.utils import sizes
19
+ from langchain import PromptTemplate
20
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
21
+ from langchain_community.llms import HuggingFaceEndpoint
22
+ from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
23
+ from utils import *
24
+
25
+
26
+ def result_to_text(result, as_text=False) -> str or list:
27
+ full_doc = []
28
+ for _, page in enumerate(result.pages, start=1):
29
+ text = ""
30
+ for block in page.blocks:
31
+ text += "\n\t"
32
+ for line in block.lines:
33
+ for word in line.words:
34
+ text += word.value + " "
35
+
36
+ full_doc.append(clean_text(text) + "\n\n")
37
+
38
+ return "\n".join(full_doc) if as_text else full_doc
39
+
40
+
41
+ ocr_model = ocr_predictor(
42
+ "db_resnet50",
43
+ "crnn_mobilenet_v3_large",
44
+ pretrained=True,
45
+ assume_straight_pages=True,
46
+ )
47
+
48
+
49
+ if torch.cuda.is_available():
50
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
51
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
52
+ "llava-hf/llava-v1.6-mistral-7b-hf",
53
+ torch_dtype=torch.float16,
54
+ low_cpu_mem_usage=True,
55
+ load_in_4bit=True,
56
+ )
57
+
58
+
59
+ @spaces.GPU()
60
+ def get_image_description(image):
61
+ torch.cuda.empty_cache()
62
+ gc.collect()
63
+
64
+ # n = len(prompt)
65
+ prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
66
+
67
+ inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
68
+ output = vision_model.generate(**inputs, max_new_tokens=100)
69
+ return processor.decode(output[0], skip_special_tokens=True)
70
+
71
+
72
+ CSS = """
73
+ #table_col {background-color: rgb(33, 41, 54);}
74
+ """
75
+
76
+
77
+ # def get_vectordb(text, images, tables):
78
+ def get_vectordb(text, images, img_doc_files):
79
+ client = chromadb.EphemeralClient()
80
+ loader = ImageLoader()
81
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
82
+ model_name="multi-qa-mpnet-base-dot-v1"
83
+ )
84
+ if "text_db" in [i.name for i in client.list_collections()]:
85
+ client.delete_collection("text_db")
86
+ if "image_db" in [i.name for i in client.list_collections()]:
87
+ client.delete_collection("image_db")
88
+
89
+ text_collection = client.get_or_create_collection(
90
+ name="text_db",
91
+ embedding_function=sentence_transformer_ef,
92
+ data_loader=loader,
93
+ )
94
+ image_collection = client.get_or_create_collection(
95
+ name="image_db",
96
+ embedding_function=sentence_transformer_ef,
97
+ data_loader=loader,
98
+ metadata={"hnsw:space": "cosine"},
99
+ )
100
+ descs = []
101
+ for i in range(len(images)):
102
+ try:
103
+ descs.append(img_doc_files[i] + "\n" + get_image_description(images[i]))
104
+ except:
105
+ descs.append("Could not generate image description due to some error")
106
+ print(descs[-1])
107
+ print()
108
+
109
+ # image_descriptions = get_image_descriptions(images)
110
+ image_dict = [{"image": image_to_bytes(img)} for img in images]
111
+
112
+ if len(images) > 0:
113
+ image_collection.add(
114
+ ids=[str(i) for i in range(len(images))],
115
+ documents=descs,
116
+ metadatas=image_dict,
117
+ )
118
+
119
+ splitter = RecursiveCharacterTextSplitter(
120
+ chunk_size=500,
121
+ chunk_overlap=10,
122
+ )
123
+
124
+ if len(text.replace(" ", "").replace("\n", "")) == 0:
125
+ gr.Error("No text found in documents")
126
+ else:
127
+ docs = splitter.create_documents([text])
128
+ doc_texts = [i.page_content for i in docs]
129
+ text_collection.add(
130
+ ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts
131
+ )
132
+ return client
133
+
134
+
135
+ def extract_only_text(reader):
136
+ text = ""
137
+ for _, page in enumerate(reader.pages):
138
+ text = page.extract_text()
139
+ return text.strip()
140
+
141
+
142
+ def extract_data_from_pdfs(
143
+ docs, session, include_images, do_ocr, progress=gr.Progress()
144
+ ):
145
+ if len(docs) == 0:
146
+ raise gr.Error("No documents to process")
147
+ progress(0, "Extracting Images")
148
+
149
+ # images = extract_images(docs)
150
+
151
+ progress(0.25, "Extracting Text")
152
+
153
+ all_text = ""
154
+
155
+ images = []
156
+ img_docs = []
157
+ for doc in docs:
158
+ if do_ocr == "Get Text With OCR":
159
+ pdf_doc = DocumentFile.from_pdf(doc)
160
+ result = ocr_model(pdf_doc)
161
+ all_text += result_to_text(result, as_text=True) + "\n\n"
162
+ else:
163
+ reader = PdfReader(doc)
164
+ all_text += extract_only_text(reader) + "\n\n"
165
+
166
+ if include_images == "Include Images":
167
+ imgs = extract_images([doc])
168
+ images.extend(imgs)
169
+ img_docs.extend([doc.split("/")[-1] for _ in range(len(imgs))])
170
+
171
+ progress(
172
+ 0.6, "Generating image descriptions and inserting everything into vectorDB"
173
+ )
174
+ vectordb = get_vectordb(all_text, images, img_docs)
175
+
176
+ progress(1, "Completed")
177
+ session["processed"] = True
178
+ return (
179
+ vectordb,
180
+ session,
181
+ gr.Row(visible=True),
182
+ all_text[:2000] + "...",
183
+ # display,
184
+ images[:2],
185
+ "<h1 style='text-align: center'>Completed<h1>",
186
+ # image_descriptions
187
+ )
188
+
189
+
190
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
191
+ model_name="multi-qa-mpnet-base-dot-v1"
192
+ )
193
+
194
+
195
+ def conversation(
196
+ vectordb_client,
197
+ msg,
198
+ num_context,
199
+ img_context,
200
+ history,
201
+ temperature,
202
+ max_new_tokens,
203
+ hf_token,
204
+ model_path,
205
+ ):
206
+ if hf_token.strip() != "" and model_path.strip() != "":
207
+ llm = HuggingFaceEndpoint(
208
+ repo_id=model_path,
209
+ temperature=temperature,
210
+ max_new_tokens=max_new_tokens,
211
+ huggingfacehub_api_token=hf_token,
212
+ )
213
+ else:
214
+ llm = HuggingFaceEndpoint(
215
+ repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
216
+ temperature=temperature,
217
+ max_new_tokens=max_new_tokens,
218
+ huggingfacehub_api_token=os.getenv("P_HF_TOKEN", "None"),
219
+ )
220
+
221
+ text_collection = vectordb_client.get_collection(
222
+ "text_db", embedding_function=sentence_transformer_ef
223
+ )
224
+ image_collection = vectordb_client.get_collection(
225
+ "image_db", embedding_function=sentence_transformer_ef
226
+ )
227
+
228
+ results = text_collection.query(
229
+ query_texts=[msg], include=["documents"], n_results=num_context
230
+ )["documents"][0]
231
+ similar_images = image_collection.query(
232
+ query_texts=[msg],
233
+ include=["metadatas", "distances", "documents"],
234
+ n_results=img_context,
235
+ )
236
+ img_links = [i["image"] for i in similar_images["metadatas"][0]]
237
+
238
+ images_and_locs = [
239
+ Image.open(io.BytesIO(base64.b64decode(i[1])))
240
+ for i in zip(similar_images["distances"][0], img_links)
241
+ ]
242
+ img_desc = "\n".join(similar_images["documents"][0])
243
+ if len(img_links) == 0:
244
+ img_desc = "No Images Are Provided"
245
+ template = """
246
+ Context:
247
+ {context}
248
+
249
+ Included Images:
250
+ {images}
251
+
252
+ Question:
253
+ {question}
254
+
255
+ Answer:
256
+
257
+ """
258
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
259
+ context = "\n\n".join(results)
260
+ # references = [gr.Textbox(i, visible=True, interactive=False) for i in results]
261
+ response = llm(prompt.format(context=context, question=msg, images=img_desc))
262
+ return history + [(msg, response)], results, images_and_locs
263
+
264
+
265
+ def check_validity_and_llm(session_states):
266
+ if session_states.get("processed", False) == True:
267
+ return gr.Tabs(selected=2)
268
+ raise gr.Error("Please extract data first")
269
+
270
+
271
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft(text_size=sizes.text_md)) as demo:
272
+ vectordb = gr.State()
273
+ doc_collection = gr.State(value=[])
274
+ session_states = gr.State(value={})
275
+ references = gr.State(value=[])
276
+
277
+ gr.Markdown(
278
+ """<h2><center>Multimodal PDF Chatbot</center></h2>
279
+ <h3><center><b>Interact With Your PDF Documents</b></center></h3>"""
280
+ )
281
+ gr.Markdown(
282
+ """<center><h3><b>Note: </b> This application leverages advanced Retrieval-Augmented Generation (RAG) techniques to provide context-aware responses from your PDF documents</center><h3><br>
283
+ <center>Utilizing multimodal capabilities, this chatbot can interpret and answer queries based on both textual and visual information within your PDFs.</center>"""
284
+ )
285
+ gr.Markdown(
286
+ """
287
+ <center><b>Warning: </b> Extracting text and images from your document and generating embeddings may take some time due to the use of OCR and multimodal LLMs for image description<center>
288
+ """
289
+ )
290
+ with gr.Tabs() as tabs:
291
+ with gr.TabItem("Upload PDFs", id=0) as pdf_tab:
292
+ with gr.Row():
293
+ with gr.Column():
294
+ documents = gr.File(
295
+ file_count="multiple",
296
+ file_types=["pdf"],
297
+ interactive=True,
298
+ label="Upload your PDF file/s",
299
+ )
300
+ pdf_btn = gr.Button(value="Next", elem_id="button1")
301
+
302
+ with gr.TabItem("Extract Data", id=1) as preprocess:
303
+ with gr.Row():
304
+ with gr.Column():
305
+ back_p1 = gr.Button(value="Back")
306
+ with gr.Column():
307
+ embed = gr.Button(value="Extract Data")
308
+ with gr.Column():
309
+ next_p1 = gr.Button(value="Next")
310
+ with gr.Row():
311
+ include_images = gr.Radio(
312
+ ["Include Images", "Exclude Images"],
313
+ value="Include Images",
314
+ label="Include/ Exclude Images",
315
+ interactive=True,
316
+ )
317
+ do_ocr = gr.Radio(
318
+ ["Get Text With OCR", "Get Available Text Only"],
319
+ value="Get Text With OCR",
320
+ label="OCR/ No OCR",
321
+ interactive=True,
322
+ )
323
+
324
+ with gr.Row(equal_height=True, variant="panel") as row:
325
+ selected = gr.Dataframe(
326
+ interactive=False,
327
+ col_count=(1, "fixed"),
328
+ headers=["Selected Files"],
329
+ )
330
+ prog = gr.HTML(
331
+ value="<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>"
332
+ )
333
+
334
+ with gr.Accordion("See Parts of Extracted Data", open=False):
335
+ with gr.Column(visible=True) as sample_data:
336
+ with gr.Row():
337
+ with gr.Column():
338
+ ext_text = gr.Textbox(
339
+ label="Sample Extracted Text", lines=15
340
+ )
341
+ with gr.Column():
342
+ images = gr.Gallery(
343
+ label="Sample Extracted Images", columns=1, rows=2
344
+ )
345
+
346
+ with gr.TabItem("Chat", id=2) as chat_tab:
347
+ with gr.Accordion("Config (Advanced) (Optional)", open=False):
348
+ with gr.Row(variant="panel", equal_height=True):
349
+ choice = gr.Radio(
350
+ ["chromaDB"],
351
+ value="chromaDB",
352
+ label="Vector Database",
353
+ interactive=True,
354
+ )
355
+ with gr.Accordion("Use your own model (optional)", open=False):
356
+ hf_token = gr.Textbox(
357
+ label="HuggingFace Token", interactive=True
358
+ )
359
+ model_path = gr.Textbox(label="Model Path", interactive=True)
360
+ with gr.Row(variant="panel", equal_height=True):
361
+ num_context = gr.Slider(
362
+ label="Number of text context elements",
363
+ minimum=1,
364
+ maximum=20,
365
+ step=1,
366
+ interactive=True,
367
+ value=3,
368
+ )
369
+ img_context = gr.Slider(
370
+ label="Number of image context elements",
371
+ minimum=1,
372
+ maximum=10,
373
+ step=1,
374
+ interactive=True,
375
+ value=2,
376
+ )
377
+ with gr.Row(variant="panel", equal_height=True):
378
+ temp = gr.Slider(
379
+ label="Temperature",
380
+ minimum=0.1,
381
+ maximum=1,
382
+ step=0.1,
383
+ interactive=True,
384
+ value=0.4,
385
+ )
386
+ max_tokens = gr.Slider(
387
+ label="Max Tokens",
388
+ minimum=10,
389
+ maximum=2000,
390
+ step=10,
391
+ interactive=True,
392
+ value=500,
393
+ )
394
+ with gr.Row():
395
+ with gr.Column():
396
+ ret_images = gr.Gallery("Similar Images", columns=1, rows=2)
397
+ with gr.Column():
398
+ chatbot = gr.Chatbot(height=400)
399
+ with gr.Accordion("Text References", open=False):
400
+ # text_context = gr.Row()
401
+
402
+ @gr.render(inputs=references)
403
+ def gen_refs(references):
404
+ # print(references)
405
+ n = len(references)
406
+ for i in range(n):
407
+ gr.Textbox(
408
+ label=f"Reference-{i+1}", value=references[i], lines=3
409
+ )
410
+
411
+ with gr.Row():
412
+ msg = gr.Textbox(
413
+ placeholder="Type your question here (e.g. 'What is this document about?')",
414
+ interactive=True,
415
+ container=True,
416
+ )
417
+ with gr.Row():
418
+ submit_btn = gr.Button("Submit message")
419
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
420
+
421
+ pdf_btn.click(
422
+ fn=extract_pdfs,
423
+ inputs=[documents, doc_collection],
424
+ outputs=[doc_collection, tabs, selected],
425
+ )
426
+ embed.click(
427
+ extract_data_from_pdfs,
428
+ inputs=[doc_collection, session_states, include_images, do_ocr],
429
+ outputs=[
430
+ vectordb,
431
+ session_states,
432
+ sample_data,
433
+ ext_text,
434
+ images,
435
+ prog,
436
+ ],
437
+ )
438
+
439
+ submit_btn.click(
440
+ conversation,
441
+ [
442
+ vectordb,
443
+ msg,
444
+ num_context,
445
+ img_context,
446
+ chatbot,
447
+ temp,
448
+ max_tokens,
449
+ hf_token,
450
+ model_path,
451
+ ],
452
+ [chatbot, references, ret_images],
453
+ )
454
+ msg.submit(
455
+ conversation,
456
+ [
457
+ vectordb,
458
+ msg,
459
+ num_context,
460
+ img_context,
461
+ chatbot,
462
+ temp,
463
+ max_tokens,
464
+ hf_token,
465
+ model_path,
466
+ ],
467
+ [chatbot, references, ret_images],
468
+ )
469
+
470
+ documents.change(
471
+ lambda: "<h1 style='text-align: center'>Click the 'Extract' button to extract data from PDFs<h1>",
472
+ None,
473
+ prog,
474
+ )
475
+
476
+ back_p1.click(lambda: gr.Tabs(selected=0), None, tabs)
477
+
478
+ next_p1.click(check_validity_and_llm, session_states, tabs)
479
+ if __name__ == "__main__":
480
+ demo.launch()
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
3
+ libtesseract-dev
4
+ ghostscript
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chromadb==0.5.3
2
+ langchain==0.2.5
3
+ langchain_community==0.2.5
4
+ langchain-huggingface
5
+ numpy<2.0.0
6
+ pandas==2.2.2
7
+ Pillow==10.3.0
8
+ pymupdf==1.24.5
9
+ sentence_transformers==3.0.1
10
+ accelerate
11
+ bitsandbytes
12
+ tf2onnx
13
+ clean-text[gpl]
14
+ python-doctr[torch]
15
+ pypdf
utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pymupdf
2
+ from PIL import Image
3
+ import io
4
+ import gradio as gr
5
+ import base64
6
+ import pandas as pd
7
+ import pymupdf
8
+
9
+
10
+ def image_to_bytes(image):
11
+ img_byte_arr = io.BytesIO()
12
+ image.save(img_byte_arr, format="PNG")
13
+ return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
14
+
15
+
16
+ def extract_pdfs(docs, doc_collection):
17
+ if docs:
18
+ doc_collection = []
19
+ doc_collection.extend(docs)
20
+ return (
21
+ doc_collection,
22
+ gr.Tabs(selected=1),
23
+ pd.DataFrame([i.split("/")[-1] for i in list(docs)], columns=["Filename"]),
24
+ )
25
+
26
+
27
+ def extract_images(docs):
28
+ images = []
29
+ for doc_path in docs:
30
+ doc = pymupdf.open(doc_path)
31
+
32
+ for page_index in range(len(doc)):
33
+ page = doc[page_index]
34
+ image_list = page.get_images()
35
+
36
+ for _, img in enumerate(image_list, start=1):
37
+ xref = img[0]
38
+ pix = pymupdf.Pixmap(doc, xref)
39
+
40
+ if pix.n - pix.alpha > 3:
41
+ pix = pymupdf.Pixmap(pymupdf.csRGB, pix)
42
+
43
+ images.append(Image.open(io.BytesIO(pix.pil_tobytes("JPEG"))))
44
+ return images
45
+
46
+
47
+ def clean_text(text):
48
+ text = text.strip()
49
+ cleaned_text = text.replace("\n", " ")
50
+ cleaned_text = cleaned_text.replace("\t", " ")
51
+ cleaned_text = cleaned_text.replace(" ", " ")
52
+ cleaned_text = cleaned_text.strip()
53
+ return cleaned_text