Spaces:
Build error
Build error
# Copyright (c) 2022, Lawrence Livermore National Security, LLC. | |
# All rights reserved. | |
# See the top-level LICENSE and NOTICE files for details. | |
# LLNL-CODE-838964 | |
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception | |
import torch | |
import gradio as gr | |
from pathlib import Path | |
from torchvision.transforms import ToPILImage, ToTensor | |
tensor_to_image = ToPILImage() | |
image_to_tensor = ToTensor() | |
import sys | |
sys.path.append('DiT_Extractor/') | |
sys.path.append('CrossEncoder/') | |
sys.path.append('UnifiedQA/') | |
import dit_runner | |
import sentence_extractor | |
import cross_encoder | |
import demo_QA | |
from torchvision.transforms import ToPILImage | |
tensor_to_image = ToPILImage() | |
def run_fn(pdf_file_obj, question_text, input_topk): | |
pdf = pdf_file_obj.name | |
print('Running PDF: {0}'.format(pdf)) | |
viz_images = dit_runner.get_dit_preds(pdf, score_threshold=0.5) | |
entity_json = '{0}.json'.format(Path(pdf).name[:-4]) | |
sentence_extractor.get_contexts(entity_json) | |
contexts_json = 'contexts_{0}'.format(entity_json) | |
# contexts_json = 'contexts_2105u2iwiwxh.03011.json' | |
cross_encoder.get_ranked_contexts(contexts_json, question_text) | |
ranked_contexts_json = 'ranked_{0}'.format(contexts_json) | |
# ranked_contexts_json = 'ranked_contexts_2105u2iwiwxh.03011.json' | |
input_topk = int(input_topk) | |
# viz_images = [tensor_to_image(x) for x in torch.randn(4, 3, 256, 256)] | |
qa_results = demo_QA.get_qa_results(contexts_json, ranked_contexts_json, input_topk) | |
history = [('<<< [Retrieval Score: {0:.02f}] >>> {1}'.format(s, c), a) for c, s, a in zip(qa_results['contexts'], qa_results['context_scores'], qa_results['answers'])] | |
# Show in ascending order of score, since results box is already scrolled down. | |
history = history[::-1] | |
return viz_images, contexts_json, ranked_contexts_json, history | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("<h1><center>Detect-Retrieve-Comprehend for Document-Level QA</center></h1>") | |
gr.Markdown("<center>This is a supplemental demo for our recent paper, expected to be publically available around October: <b>Detect, Retrieve, Comprehend: A Flexible Framework for Zero-Shot Document-Level Question Answering</b>. In this system, our input is a PDF file with a specific question of interest. The output is a set of most probable answers. There are 4 main components in our deployed pipeline: (1) DiT Layout Analysis (2) Context Extraction (3) Cross-Encoder Retrieval (4) UnifiedQA. See below for example uses with further explanation. Note that demo runtimes may be between 2-8 minutes, since this is currently cpu-based Space.</center>") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_pdf_file = gr.File(file_count='single', label='PDF File') | |
with gr.Row(): | |
input_question_text = gr.Textbox(label='Question') | |
with gr.Row(): | |
input_k_percent = gr.Slider(minimum=1, maximum=24, step=1, value=8, label='Top K') | |
with gr.Row(): | |
button_run = gr.Button('Run QA on Document') | |
gr.Markdown("<h3><center>Summary</center></h3>") | |
with gr.Row(): | |
gr.Markdown(''' | |
- <u>**DiT - Document Image Transformer**</u>: PDF -> converted into a list of images -> each image receives Entity Predictions | |
- Note that using this computer vision approach allows us to ignore things like *page numbers, footnotes, references*, etc | |
- <u>**Paragraph-based Text Extraction**</u>: DiT Bounding Boxes -> Convert into PDF-Space Coordinates -> Text Extraction using PDFMiner6 -> Tokenize & Sentence Split if tokenizer max length is exceeded | |
- <u>**CrossEncoder Context Retrieval**</u>: All Contexts + Question -> Top K Relevant Contexts best suited for answering question | |
- <u>**UnifiedQA**</u>: Most Relevant Contexts + Supplied Question -> Predict Set of Probable Answers | |
''') | |
with gr.Row(): | |
examples = [ | |
['examples/1909.00694.pdf', 'What is the seed lexicon?', 5], | |
['examples/1909.00694.pdf', 'How big is seed lexicon used for training?', 5], | |
['examples/1810.04805.pdf', 'What is this paper about?', 5], | |
['examples/1810.04805.pdf', 'What is the model size?', 5], | |
['examples/2105.03011.pdf', 'How many questions are in this dataset?', 5], | |
['examples/1909.00694.pdf', 'How are relations used to propagate polarity?', 5], | |
] | |
gr.Examples(examples=examples, | |
inputs=[input_pdf_file, input_question_text, input_k_percent]) | |
with gr.Column(): | |
with gr.Row(): | |
output_gallery = gr.Gallery(label='DiT Predicted Entities') | |
with gr.Row(): | |
gr.Markdown(''' | |
- The `DiT predicted Entities` output box is scrollable! Scroll to see different page predictions. Note that predictions with confidence scores < 0.5 are not passed forward for text extraction. | |
- If an image is clicked, the output box will switch to a gallery view. To view these outputs in much higher resolution, right-click and choose "open image in new tab" | |
''') | |
with gr.Row(): | |
output_contexts = gr.File(label='Detected Contexts', interactive=False) | |
output_ranked_contexts = gr.File(label='CrossEncoder Ranked Contexts', interactive=False) | |
with gr.Row(): | |
output_qa_results = gr.Chatbot(color_map=['blue', 'green'], label='UnifiedQA Results').style() | |
gr.Markdown("<h3><center>Related Work & Code</center></h3>") | |
gr.Markdown("<center>DiT (Document Image Transformer) - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>") | |
gr.Markdown("<center>CrossEncoder - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>") | |
gr.Markdown("<center>UnifiedQA - <a href=https://arxiv.org/abs/2005.00700>Arxiv Page</a> | <a href=https://github.com/allenai/unifiedqa>Github Repo</a></center>") | |
button_run.click(fn=run_fn, inputs=[input_pdf_file, input_question_text, input_k_percent], outputs=[output_gallery, output_contexts, output_ranked_contexts, output_qa_results]) | |
demo.launch(enable_queue=True) |