McCoasta commited on
Commit
c9bae40
1 Parent(s): 20e36ae

Upload 4 files

Browse files
Files changed (4) hide show
  1. RTDB_security_rules.txt +3 -0
  2. app.py +423 -0
  3. packages.txt +4 -0
  4. requirements.txt +3 -0
RTDB_security_rules.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers.git@21f6f58721dd9154357576be6de54eefef1f1818
3
+ git+https://github.com/impira/docquery.git@a494fe5af452d20011da75637aa82d246a869fa0#egg=docquery[web,donut]
app.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ from PIL import Image, ImageDraw
6
+ import traceback
7
+
8
+ import gradio as gr
9
+
10
+ import torch
11
+ from docquery import pipeline
12
+ from docquery.document import load_document, ImageDocument
13
+ from docquery.ocr_reader import get_ocr_reader
14
+
15
+
16
+ def ensure_list(x):
17
+ if isinstance(x, list):
18
+ return x
19
+ else:
20
+ return [x]
21
+
22
+
23
+ CHECKPOINTS = {
24
+ "LayoutLMv1": "impira/layoutlm-document-qa",
25
+ "LayoutLMv1 for Invoices": "impira/layoutlm-invoices",
26
+ "Donut": "naver-clova-ix/donut-base-finetuned-docvqa",
27
+ }
28
+
29
+ PIPELINES = {}
30
+
31
+
32
+ def construct_pipeline(task, model):
33
+ global PIPELINES
34
+ if model in PIPELINES:
35
+ return PIPELINES[model]
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
39
+ PIPELINES[model] = ret
40
+ return ret
41
+
42
+
43
+ def run_pipeline(model, question, document, top_k):
44
+ pipeline = construct_pipeline("document-question-answering", model)
45
+ return pipeline(question=question, **document.context, top_k=top_k)
46
+
47
+
48
+ # TODO: Move into docquery
49
+ # TODO: Support words past the first page (or window?)
50
+ def lift_word_boxes(document, page):
51
+ return document.context["image"][page][1]
52
+
53
+
54
+ def expand_bbox(word_boxes):
55
+ if len(word_boxes) == 0:
56
+ return None
57
+
58
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
59
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
60
+ return [min_x, min_y, max_x, max_y]
61
+
62
+
63
+ # LayoutLM boxes are normalized to 0, 1000
64
+ def normalize_bbox(box, width, height, padding=0.005):
65
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
66
+ if padding != 0:
67
+ min_x = max(0, min_x - padding)
68
+ min_y = max(0, min_y - padding)
69
+ max_x = min(max_x + padding, 1)
70
+ max_y = min(max_y + padding, 1)
71
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
72
+
73
+
74
+ examples = [
75
+ [
76
+ "invoice.png",
77
+ "What is the invoice number?",
78
+ ],
79
+ [
80
+ "contract.jpeg",
81
+ "What is the purchase amount?",
82
+ ],
83
+ [
84
+ "statement.png",
85
+ "What are net sales for 2020?",
86
+ ],
87
+ # [
88
+ # "docquery.png",
89
+ # "How many likes does the space have?",
90
+ # ],
91
+ # [
92
+ # "hacker_news.png",
93
+ # "What is the title of post number 5?",
94
+ # ],
95
+ ]
96
+
97
+ question_files = {
98
+ "What are net sales for 2020?": "statement.pdf",
99
+ "How many likes does the space have?": "https://huggingface.co/spaces/impira/docquery",
100
+ "What is the title of post number 5?": "https://news.ycombinator.com",
101
+ }
102
+
103
+
104
+ def process_path(path):
105
+ error = None
106
+ if path:
107
+ try:
108
+ document = load_document(path)
109
+ return (
110
+ document,
111
+ gr.update(visible=True, value=document.preview),
112
+ gr.update(visible=True),
113
+ gr.update(visible=False, value=None),
114
+ gr.update(visible=False, value=None),
115
+ None,
116
+ )
117
+ except Exception as e:
118
+ traceback.print_exc()
119
+ error = str(e)
120
+ return (
121
+ None,
122
+ gr.update(visible=False, value=None),
123
+ gr.update(visible=False),
124
+ gr.update(visible=False, value=None),
125
+ gr.update(visible=False, value=None),
126
+ gr.update(visible=True, value=error) if error is not None else None,
127
+ None,
128
+ )
129
+
130
+
131
+ def process_upload(file):
132
+ if file:
133
+ return process_path(file.name)
134
+ else:
135
+ return (
136
+ None,
137
+ gr.update(visible=False, value=None),
138
+ gr.update(visible=False),
139
+ gr.update(visible=False, value=None),
140
+ gr.update(visible=False, value=None),
141
+ None,
142
+ )
143
+
144
+
145
+ colors = ["#64A087", "black", "black"]
146
+
147
+
148
+ def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
149
+ if not question or document is None:
150
+ return None, None, None
151
+
152
+ text_value = None
153
+ predictions = run_pipeline(model, question, document, 3)
154
+ pages = [x.copy().convert("RGB") for x in document.preview]
155
+ for i, p in enumerate(ensure_list(predictions)):
156
+ if i == 0:
157
+ text_value = p["answer"]
158
+ else:
159
+ # Keep the code around to produce multiple boxes, but only show the top
160
+ # prediction for now
161
+ break
162
+
163
+ if "word_ids" in p:
164
+ image = pages[p["page"]]
165
+ draw = ImageDraw.Draw(image, "RGBA")
166
+ word_boxes = lift_word_boxes(document, p["page"])
167
+ x1, y1, x2, y2 = normalize_bbox(
168
+ expand_bbox([word_boxes[i] for i in p["word_ids"]]),
169
+ image.width,
170
+ image.height,
171
+ )
172
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
173
+
174
+ return (
175
+ gr.update(visible=True, value=pages),
176
+ gr.update(visible=True, value=predictions),
177
+ gr.update(
178
+ visible=True,
179
+ value=text_value,
180
+ ),
181
+ )
182
+
183
+
184
+ def load_example_document(img, question, model):
185
+ if img is not None:
186
+ if question in question_files:
187
+ document = load_document(question_files[question])
188
+ else:
189
+ document = ImageDocument(Image.fromarray(img), get_ocr_reader())
190
+ preview, answer, answer_text = process_question(question, document, model)
191
+ return document, question, preview, gr.update(visible=True), answer, answer_text
192
+ else:
193
+ return None, None, None, gr.update(visible=False), None, None
194
+
195
+
196
+ CSS = """
197
+ #question input {
198
+ font-size: 16px;
199
+ }
200
+ #url-textbox {
201
+ padding: 0 !important;
202
+ }
203
+ #short-upload-box .w-full {
204
+ min-height: 10rem !important;
205
+ }
206
+ /* I think something like this can be used to re-shape
207
+ * the table
208
+ */
209
+ /*
210
+ .gr-samples-table tr {
211
+ display: inline;
212
+ }
213
+ .gr-samples-table .p-2 {
214
+ width: 100px;
215
+ }
216
+ */
217
+ #select-a-file {
218
+ width: 100%;
219
+ }
220
+ #file-clear {
221
+ padding-top: 2px !important;
222
+ padding-bottom: 2px !important;
223
+ padding-left: 8px !important;
224
+ padding-right: 8px !important;
225
+ margin-top: 10px;
226
+ }
227
+ .gradio-container .gr-button-primary {
228
+ background: linear-gradient(180deg, #FAED27 0%, #FAED27 100%);
229
+ border: 1px solid #000000;
230
+ border-radius: 8px;
231
+ color: #000000;
232
+ }
233
+ .gradio-container.dark button#submit-button {
234
+ background: linear-gradient(180deg, #FAED27 0%, #FAED27 100%);
235
+ border: 1px solid #000000;
236
+ border-radius: 8px;
237
+ color: #000000
238
+ }
239
+
240
+ table.gr-samples-table tr td {
241
+ border: none;
242
+ outline: none;
243
+ }
244
+
245
+ table.gr-samples-table tr td:first-of-type {
246
+ width: 0%;
247
+ }
248
+
249
+ div#short-upload-box div.absolute {
250
+ display: none !important;
251
+ }
252
+
253
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
254
+ gap: 0px 2%;
255
+ }
256
+
257
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
258
+ gap: 0px;
259
+ }
260
+
261
+ gradio-app h2, .gradio-app h2 {
262
+ padding-top: 10px;
263
+ }
264
+
265
+ #answer {
266
+ overflow-y: scroll;
267
+ color: white;
268
+ background: #666;
269
+ border-color: #666;
270
+ font-size: 20px;
271
+ font-weight: bold;
272
+ }
273
+
274
+ #answer span {
275
+ color: white;
276
+ }
277
+
278
+ #answer textarea {
279
+ color:white;
280
+ background: #777;
281
+ border-color: #777;
282
+ font-size: 18px;
283
+ }
284
+
285
+ #url-error input {
286
+ color: red;
287
+ }
288
+ """
289
+
290
+ with gr.Blocks(css=CSS) as demo:
291
+ gr.Markdown()
292
+ gr.Markdown(
293
+
294
+ )
295
+
296
+ document = gr.Variable()
297
+ example_question = gr.Textbox(visible=False)
298
+ example_image = gr.Image(visible=False)
299
+
300
+ with gr.Row(equal_height=True):
301
+ with gr.Column():
302
+ with gr.Row():
303
+ gr.Markdown("## 1. Select a file", elem_id="select-a-file")
304
+ img_clear_button = gr.Button(
305
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
306
+ )
307
+ image = gr.Gallery(visible=False)
308
+ with gr.Row(equal_height=True):
309
+ with gr.Column():
310
+ with gr.Row():
311
+ url = gr.Textbox(
312
+ show_label=False,
313
+ placeholder="URL",
314
+ lines=1,
315
+ max_lines=1,
316
+ elem_id="url-textbox",
317
+ )
318
+ submit = gr.Button("Get")
319
+ url_error = gr.Textbox(
320
+ visible=False,
321
+ elem_id="url-error",
322
+ max_lines=1,
323
+ interactive=False,
324
+ label="Error",
325
+ )
326
+ gr.Markdown("— or —")
327
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
328
+ gr.Examples(
329
+ examples=examples,
330
+ inputs=[example_image, example_question],
331
+ )
332
+
333
+ with gr.Column() as col:
334
+ gr.Markdown("## 2. Ask a question")
335
+ question = gr.Textbox(
336
+ label="Question",
337
+ placeholder="e.g. What is the invoice number?",
338
+ lines=1,
339
+ max_lines=1,
340
+ )
341
+ model = gr.Radio(
342
+ choices=list(CHECKPOINTS.keys()),
343
+ value=list(CHECKPOINTS.keys())[0],
344
+ label="Model",
345
+ )
346
+
347
+ with gr.Row():
348
+ clear_button = gr.Button("Clear", variant="secondary")
349
+ submit_button = gr.Button(
350
+ "Submit", variant="primary", elem_id="submit-button"
351
+ )
352
+ with gr.Column():
353
+ output_text = gr.Textbox(
354
+ label="Top Answer", visible=False, elem_id="answer"
355
+ )
356
+ output = gr.JSON(label="Output", visible=False)
357
+
358
+ for cb in [img_clear_button, clear_button]:
359
+ cb.click(
360
+ lambda _: (
361
+ gr.update(visible=False, value=None),
362
+ None,
363
+ gr.update(visible=False, value=None),
364
+ gr.update(visible=False, value=None),
365
+ gr.update(visible=False),
366
+ None,
367
+ None,
368
+ None,
369
+ gr.update(visible=False, value=None),
370
+ None,
371
+ ),
372
+ inputs=clear_button,
373
+ outputs=[
374
+ image,
375
+ document,
376
+ output,
377
+ output_text,
378
+ img_clear_button,
379
+ example_image,
380
+ upload,
381
+ url,
382
+ url_error,
383
+ question,
384
+ ],
385
+ )
386
+
387
+ upload.change(
388
+ fn=process_upload,
389
+ inputs=[upload],
390
+ outputs=[document, image, img_clear_button, output, output_text, url_error],
391
+ )
392
+ submit.click(
393
+ fn=process_path,
394
+ inputs=[url],
395
+ outputs=[document, image, img_clear_button, output, output_text, url_error],
396
+ )
397
+
398
+ question.submit(
399
+ fn=process_question,
400
+ inputs=[question, document, model],
401
+ outputs=[image, output, output_text],
402
+ )
403
+
404
+ submit_button.click(
405
+ process_question,
406
+ inputs=[question, document, model],
407
+ outputs=[image, output, output_text],
408
+ )
409
+
410
+ model.change(
411
+ process_question,
412
+ inputs=[question, document, model],
413
+ outputs=[image, output, output_text],
414
+ )
415
+
416
+ example_image.change(
417
+ fn=load_example_document,
418
+ inputs=[example_image, example_question, model],
419
+ outputs=[document, question, image, img_clear_button, output, output_text],
420
+ )
421
+
422
+ if __name__ == "__main__":
423
+ demo.launch(enable_queue=False)
packages.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ poppler-utils
2
+ tesseract-ocr
3
+ chromium
4
+ chromium-driver
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers.git@21f6f58721dd9154357576be6de54eefef1f1818
3
+ git+https://github.com/impira/docquery.git@a494fe5af452d20011da75637aa82d246a869fa0#egg=docquery[web,donut]