DocumentQA / app.py
Epoching's picture
Update app.py
95b0c4c
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
import torch
import gradio as gr
from pathlib import Path
from torchvision.transforms import ToPILImage, ToTensor
tensor_to_image = ToPILImage()
image_to_tensor = ToTensor()
import sys
sys.path.append('DiT_Extractor/')
sys.path.append('CrossEncoder/')
sys.path.append('UnifiedQA/')
import dit_runner
import sentence_extractor
import cross_encoder
import demo_QA
from torchvision.transforms import ToPILImage
tensor_to_image = ToPILImage()
def run_fn(pdf_file_obj, question_text, input_topk):
pdf = pdf_file_obj.name
print('Running PDF: {0}'.format(pdf))
viz_images = dit_runner.get_dit_preds(pdf, score_threshold=0.5)
entity_json = '{0}.json'.format(Path(pdf).name[:-4])
sentence_extractor.get_contexts(entity_json)
contexts_json = 'contexts_{0}'.format(entity_json)
# contexts_json = 'contexts_2105u2iwiwxh.03011.json'
cross_encoder.get_ranked_contexts(contexts_json, question_text)
ranked_contexts_json = 'ranked_{0}'.format(contexts_json)
# ranked_contexts_json = 'ranked_contexts_2105u2iwiwxh.03011.json'
input_topk = int(input_topk)
# viz_images = [tensor_to_image(x) for x in torch.randn(4, 3, 256, 256)]
qa_results = demo_QA.get_qa_results(contexts_json, ranked_contexts_json, input_topk)
history = [('<<< [Retrieval Score: {0:.02f}] >>> {1}'.format(s, c), a) for c, s, a in zip(qa_results['contexts'], qa_results['context_scores'], qa_results['answers'])]
# Show in ascending order of score, since results box is already scrolled down.
history = history[::-1]
return viz_images, contexts_json, ranked_contexts_json, history
demo = gr.Blocks()
with demo:
gr.Markdown("<h1><center>Detect-Retrieve-Comprehend for Document-Level QA</center></h1>")
gr.Markdown("<center>This is a supplemental demo for our recent paper, expected to be publically available around October: <b>Detect, Retrieve, Comprehend: A Flexible Framework for Zero-Shot Document-Level Question Answering</b>. In this system, our input is a PDF file with a specific question of interest. The output is a set of most probable answers. There are 4 main components in our deployed pipeline: (1) DiT Layout Analysis (2) Context Extraction (3) Cross-Encoder Retrieval (4) UnifiedQA. See below for example uses with further explanation. Note that demo runtimes may be between 2-8 minutes, since this is currently cpu-based Space.</center>")
with gr.Row():
with gr.Column():
with gr.Row():
input_pdf_file = gr.File(file_count='single', label='PDF File')
with gr.Row():
input_question_text = gr.Textbox(label='Question')
with gr.Row():
input_k_percent = gr.Slider(minimum=1, maximum=24, step=1, value=8, label='Top K')
with gr.Row():
button_run = gr.Button('Run QA on Document')
gr.Markdown("<h3><center>Summary</center></h3>")
with gr.Row():
gr.Markdown('''
- <u>**DiT - Document Image Transformer**</u>: PDF -> converted into a list of images -> each image receives Entity Predictions
- Note that using this computer vision approach allows us to ignore things like *page numbers, footnotes, references*, etc
- <u>**Paragraph-based Text Extraction**</u>: DiT Bounding Boxes -> Convert into PDF-Space Coordinates -> Text Extraction using PDFMiner6 -> Tokenize & Sentence Split if tokenizer max length is exceeded
- <u>**CrossEncoder Context Retrieval**</u>: All Contexts + Question -> Top K Relevant Contexts best suited for answering question
- <u>**UnifiedQA**</u>: Most Relevant Contexts + Supplied Question -> Predict Set of Probable Answers
''')
with gr.Row():
examples = [
['examples/1909.00694.pdf', 'What is the seed lexicon?', 5],
['examples/1909.00694.pdf', 'How big is seed lexicon used for training?', 5],
['examples/1810.04805.pdf', 'What is this paper about?', 5],
['examples/1810.04805.pdf', 'What is the model size?', 5],
['examples/2105.03011.pdf', 'How many questions are in this dataset?', 5],
['examples/1909.00694.pdf', 'How are relations used to propagate polarity?', 5],
]
gr.Examples(examples=examples,
inputs=[input_pdf_file, input_question_text, input_k_percent])
with gr.Column():
with gr.Row():
output_gallery = gr.Gallery(label='DiT Predicted Entities')
with gr.Row():
gr.Markdown('''
- The `DiT predicted Entities` output box is scrollable! Scroll to see different page predictions. Note that predictions with confidence scores < 0.5 are not passed forward for text extraction.
- If an image is clicked, the output box will switch to a gallery view. To view these outputs in much higher resolution, right-click and choose "open image in new tab"
''')
with gr.Row():
output_contexts = gr.File(label='Detected Contexts', interactive=False)
output_ranked_contexts = gr.File(label='CrossEncoder Ranked Contexts', interactive=False)
with gr.Row():
output_qa_results = gr.Chatbot(color_map=['blue', 'green'], label='UnifiedQA Results').style()
gr.Markdown("<h3><center>Related Work & Code</center></h3>")
gr.Markdown("<center>DiT (Document Image Transformer) - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>")
gr.Markdown("<center>CrossEncoder - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>")
gr.Markdown("<center>UnifiedQA - <a href=https://arxiv.org/abs/2005.00700>Arxiv Page</a> | <a href=https://github.com/allenai/unifiedqa>Github Repo</a></center>")
button_run.click(fn=run_fn, inputs=[input_pdf_file, input_question_text, input_k_percent], outputs=[output_gallery, output_contexts, output_ranked_contexts, output_qa_results])
demo.launch(enable_queue=True)