Epoching commited on
Commit
c14d9ad
1 Parent(s): 0b25b77
Files changed (35) hide show
  1. .gitignore +2 -0
  2. CrossEncoder/cross_encoder.py +122 -0
  3. CrossEncoder/cross_encoder_env.yml +53 -0
  4. DiT_Extractor/base_utils.py +378 -0
  5. DiT_Extractor/dit_object_detection/README.md +120 -0
  6. DiT_Extractor/dit_object_detection/ditod/__init__.py +11 -0
  7. DiT_Extractor/dit_object_detection/ditod/backbone.py +156 -0
  8. DiT_Extractor/dit_object_detection/ditod/beit.py +671 -0
  9. DiT_Extractor/dit_object_detection/ditod/config.py +32 -0
  10. DiT_Extractor/dit_object_detection/ditod/deit.py +476 -0
  11. DiT_Extractor/dit_object_detection/publaynet_configs/Base-RCNN-FPN.yaml +69 -0
  12. DiT_Extractor/dit_object_detection/publaynet_configs/cascade/cascade_dit_base.yaml +20 -0
  13. DiT_Extractor/dit_object_detection/publaynet_configs/cascade/cascade_dit_large.yaml +28 -0
  14. DiT_Extractor/dit_object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml +15 -0
  15. DiT_Extractor/dit_object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_large.yaml +22 -0
  16. DiT_Extractor/dit_runner.py +158 -0
  17. DiT_Extractor/sentence_extractor.py +136 -0
  18. LICENSE +207 -0
  19. NOTICE +21 -0
  20. README.md +14 -4
  21. UnifiedQA/demo_QA.py +180 -0
  22. app.py +120 -0
  23. env_setup.sh +32 -0
  24. examples/1810.04805.pdf +0 -0
  25. examples/1909.00694.pdf +0 -0
  26. examples/2105.03011.pdf +0 -0
  27. ms-marco-electra-base/CEBinaryClassificationEvaluator_MS-Marco_results.csv +43 -0
  28. ms-marco-electra-base/README.md +64 -0
  29. ms-marco-electra-base/config.json +31 -0
  30. ms-marco-electra-base/pytorch_model.bin +3 -0
  31. ms-marco-electra-base/special_tokens_map.json +1 -0
  32. ms-marco-electra-base/tokenizer_config.json +1 -0
  33. ms-marco-electra-base/vocab.txt +0 -0
  34. packages.txt +1 -0
  35. requirements.txt +13 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .ipynb_checkpoints
2
+ __pycache__
CrossEncoder/cross_encoder.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Lawrence Livermore National Security, LLC.
2
+ # All rights reserved.
3
+ # See the top-level LICENSE and NOTICE files for details.
4
+ # LLNL-CODE-838964
5
+
6
+ # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
7
+
8
+ from sentence_transformers.cross_encoder import CrossEncoder as CE
9
+ import numpy as np
10
+ from typing import List, Dict, Tuple
11
+ import json
12
+ from collections import defaultdict
13
+ import os
14
+
15
+
16
+ class CrossEncoder:
17
+ def __init__(self,
18
+ model_path: str = None,
19
+ max_length: int = None,
20
+ **kwargs):
21
+
22
+ if max_length != None:
23
+ self.model = CE(model_path, max_length = max_length, **kwargs)
24
+
25
+ self.model = CE(model_path, **kwargs)
26
+
27
+
28
+ def predict(self,
29
+ sentences: List[Tuple[str, str]],
30
+ batch_size: int = 32,
31
+ show_progress_bar: bool = False) -> List[float]:
32
+
33
+ return self.model.predict(sentences = sentences,
34
+ batch_size = batch_size,
35
+ show_progress_bar = show_progress_bar)
36
+
37
+
38
+ class CERank:
39
+
40
+ def __init__(self, model, batch_size: int =128, **kwargs):
41
+ self.cross_encoder = model
42
+ self.batch_size = batch_size
43
+
44
+
45
+ def flatten_examples(self, contexts: Dict[str, Dict], question: str):
46
+
47
+ text_pairs, pair_ids = [], []
48
+ for context_id, context in contexts.items():
49
+ pair_ids.append(['question_0', context_id])
50
+ text_pairs.append([question, context['text']])
51
+
52
+ return text_pairs, pair_ids
53
+
54
+ def group_questionrank(self, pair_ids, rank_scores):
55
+
56
+ unsorted = defaultdict(list)
57
+ for pair, score in zip(pair_ids, rank_scores):
58
+ query_id, paragraph_id = pair[0], pair[1]
59
+ unsorted[query_id].append((paragraph_id, score))
60
+
61
+
62
+ return unsorted
63
+
64
+ def get_rankings(self, pair_ids, rank_scores, text_pairs):
65
+
66
+ unsorted_ranks = self.group_questionrank(pair_ids, rank_scores)
67
+ rankings = defaultdict(dict)
68
+
69
+ for idx, (query_id, ranks) in enumerate(unsorted_ranks.items()):
70
+ sort_ranks = sorted(ranks, key = lambda item: item[1], reverse = True)
71
+ sorted_ranks, scores = list(zip(*sort_ranks))
72
+ rankings[query_id]['text'] = text_pairs[idx][0]
73
+ rankings[query_id]['scores'] = list(scores)
74
+ rankings[query_id]['ranks'] = list(sorted_ranks)
75
+
76
+ return rankings
77
+
78
+
79
+ def rank(self,
80
+ contexts: Dict[str, Dict],
81
+ question: str):
82
+
83
+
84
+ text_pairs, pair_ids = self.flatten_examples(contexts, question)
85
+ rank_scores = [float(score) for score in self.cross_encoder.predict(text_pairs, batch_size = self.batch_size)]
86
+ full_results = self.get_rankings(pair_ids, rank_scores, text_pairs)
87
+
88
+ return full_results
89
+
90
+
91
+
92
+ def get_ranked_contexts(context_json, question):
93
+
94
+ dirname = 'examples'
95
+ model_path = '/data/actici/pretrained_weights/ms-marco-electra-base'
96
+ max_length = 512
97
+
98
+ # Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism.
99
+ cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False})
100
+ ranker = CERank(cross_encoder)
101
+
102
+ with open(context_json, 'r') as fin:
103
+ contexts = json.load(fin)
104
+
105
+ rankings = ranker.rank(contexts, question)
106
+
107
+ with open('ranked_{0}.json'.format(context_json[:-5]), 'w') as fout:
108
+ json.dump(rankings, fout)
109
+
110
+ def get_ranked_contexts_in_memory(contexts, question):
111
+
112
+ dirname = 'examples'
113
+ model_path = '/data/actici/pretrained_weights/ms-marco-electra-base'
114
+ max_length = 512
115
+
116
+ # Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism.
117
+ cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False})
118
+ ranker = CERank(cross_encoder)
119
+
120
+ rankings = ranker.rank(contexts, question)
121
+
122
+ return rankings
CrossEncoder/cross_encoder_env.yml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cross_encoder_env
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - _openmp_mutex=5.1=1_gnu
7
+ - ca-certificates=2022.4.26=h06a4308_0
8
+ - certifi=2022.6.15=py39h06a4308_0
9
+ - ld_impl_linux-64=2.38=h1181459_1
10
+ - libffi=3.3=he6710b0_2
11
+ - libgcc-ng=11.2.0=h1234567_1
12
+ - libgomp=11.2.0=h1234567_1
13
+ - libstdcxx-ng=11.2.0=h1234567_1
14
+ - ncurses=6.3=h7f8727e_2
15
+ - openssl=1.1.1o=h7f8727e_0
16
+ - pip=21.2.4=py39h06a4308_0
17
+ - python=3.9.12=h12debd9_1
18
+ - readline=8.1.2=h7f8727e_1
19
+ - setuptools=61.2.0=py39h06a4308_0
20
+ - sqlite=3.38.5=hc218d9a_0
21
+ - tk=8.6.12=h1ccaba5_0
22
+ - tzdata=2022a=hda174b7_0
23
+ - wheel=0.37.1=pyhd3eb1b0_0
24
+ - xz=5.2.5=h7f8727e_1
25
+ - zlib=1.2.12=h7f8727e_2
26
+ - pip:
27
+ - charset-normalizer==2.0.12
28
+ - click==8.1.3
29
+ - filelock==3.7.1
30
+ - huggingface-hub==0.8.1
31
+ - idna==3.3
32
+ - joblib==1.1.0
33
+ - nltk==3.7
34
+ - numpy==1.23.0
35
+ - packaging==21.3
36
+ - pillow==9.1.1
37
+ - pyparsing==3.0.9
38
+ - pyyaml==6.0
39
+ - regex==2022.6.2
40
+ - requests==2.28.0
41
+ - scikit-learn==1.1.1
42
+ - scipy==1.8.1
43
+ - sentence-transformers==2.2.2
44
+ - sentencepiece==0.1.96
45
+ - threadpoolctl==3.1.0
46
+ - tokenizers==0.12.1
47
+ - torch==1.11.0
48
+ - torchvision==0.12.0
49
+ - tqdm==4.64.0
50
+ - transformers==4.20.1
51
+ - typing-extensions==4.2.0
52
+ - urllib3==1.26.9
53
+ prefix: /home/ordonez2/miniconda3/envs/cross_encoder
DiT_Extractor/base_utils.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Lawrence Livermore National Security, LLC.
2
+ # All rights reserved.
3
+ # See the top-level LICENSE and NOTICE files for details.
4
+ # LLNL-CODE-838964
5
+
6
+ # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
7
+
8
+ from pdfminer.pdfpage import PDFParser
9
+ from pdfminer.pdfpage import PDFDocument
10
+ from pdfminer.pdfpage import PDFPage
11
+ from pdfminer.layout import LTTextBoxHorizontal
12
+ from pdfminer.layout import LTTextLineHorizontal
13
+ from pdfminer.layout import LTChar
14
+ from pdfminer.layout import LAParams
15
+ from pdfminer.layout import LTRect
16
+ from pdfminer.layout import LTFigure
17
+
18
+ from pdfminer.converter import PDFPageAggregator
19
+ from pdfminer.pdfinterp import PDFResourceManager
20
+ from pdfminer.pdfinterp import PDFPageInterpreter
21
+ from pdfminer import pdfinterp
22
+
23
+ from collections.abc import Iterable
24
+ from collections import Counter
25
+ from collections import OrderedDict
26
+
27
+ import os
28
+
29
+ # This is use for highlighting in PDFs
30
+ from PyPDF2.generic import (
31
+ DictionaryObject,
32
+ NumberObject,
33
+ FloatObject,
34
+ NameObject,
35
+ TextStringObject,
36
+ ArrayObject
37
+ )
38
+
39
+ # Used to extract pages
40
+ from PyPDF2 import PdfFileReader, PdfFileWriter
41
+
42
+ def get_page_sizes(document):
43
+ parser = PDFParser(open(document, 'rb'))
44
+ doc = PDFDocument(parser)
45
+ pageSizesList = []
46
+ for page in PDFPage.create_pages(doc):
47
+ # the media box that is the page size as list of 4 integers x0 y0 x1 y1
48
+ pageSizesList.append(page.mediabox) # <- appending
49
+ return pageSizesList
50
+
51
+ def get_page_count(document):
52
+ # Is there a better way of getting the page count than doing this?
53
+ parser = PDFParser(document)
54
+ tmpdoc = PDFDocument(parser)
55
+ page_count = pdfinterp.resolve1(tmpdoc.catalog['Pages'])['Count']
56
+ return page_count
57
+
58
+ def get_pdf_page_count(filename):
59
+ with open(filename, 'rb') as document:
60
+ return get_page_count(document)
61
+
62
+ def get_pages(document, page_numbers = None):
63
+ #Create resource manager
64
+ rsrcmgr = PDFResourceManager()
65
+ # Set parameters for analysis.
66
+ laparams = LAParams()
67
+ # Create a PDF page aggregator object.
68
+ device = PDFPageAggregator(rsrcmgr, laparams=laparams)
69
+ interpreter = PDFPageInterpreter(rsrcmgr, device)
70
+
71
+ page_count = get_page_count(document)
72
+
73
+ if page_numbers is None:
74
+ page_numbers = range(page_count)
75
+
76
+ for page, page_number in zip(PDFPage.get_pages(document, page_numbers), page_numbers):
77
+ interpreter.process_page(page)
78
+ # receive the LTPage object for the page.
79
+ layout = device.get_result()
80
+ #print("Yield page:", page_number)
81
+ yield layout, page_number
82
+
83
+ def partial_overlaps(box, other):
84
+ """
85
+ Determine if the two bounding boxes overlap eachother.
86
+ TODO: Really should just use a standard Python library for this.
87
+
88
+ box -- 2 coordinate bounding box (x1,y1,x2,y2)
89
+ other -- 2 coordinate bounding box (x1,y1,x2,y2)
90
+ """
91
+ # a1 x1 a2 x2
92
+ # <------------------>
93
+ x_intersects = (other[0] < box[0] and other[2] > box[0]) or (
94
+ other[0] < box[2] and other[2] > box[2])
95
+ y_intersects = (other[1] < box[1] and other[3] > box[1]) or (
96
+ other[1] < box[3] and other[3] > box[3])
97
+
98
+ intersects = x_intersects or y_intersects
99
+ # TODO: Simplify?
100
+ return intersects and overlaps(box, other)
101
+ #return intersects
102
+
103
+ def overlaps(box, other):
104
+ """
105
+ Determine if the two bounding boxes overlap eachother.
106
+ TODO: Really should just use a standard Python library for this.
107
+
108
+ box -- 2 coordinate bounding box (x1,y1,x2,y2)
109
+ other -- 2 coordinate bounding box (x1,y1,x2,y2)
110
+ """
111
+ x_intersects = box[0] > other[2] or box[2] < other[0]
112
+ y_intersects = box[1] > other[3] or box[3] < other[1]
113
+
114
+ intersects = not (x_intersects or y_intersects)
115
+ return intersects
116
+
117
+ def union(src, other):
118
+ """
119
+ Expand src by union of other bbox
120
+
121
+ src -- 2 coordinate bounding box (x1,y1,x2,y2)
122
+ other -- 2 coordinate bounding box (x1,y1,x2,y2)
123
+
124
+ returns union of src and other
125
+ """
126
+ xmin = min(src[0], other[0])
127
+ ymin = min(src[1], other[1])
128
+ xmax = max(src[2], other[2])
129
+ ymax = max(src[3], other[3])
130
+
131
+ return [xmin, ymin, xmax, ymax]
132
+
133
+
134
+
135
+ # See: https://gist.github.com/agentcooper/4c55133f5d95866acdee5017cd318558#file-pypdf2highlight-py
136
+ # x1, y1 starts in bottom left corner
137
+ def createHighlight(x1, y1, x2, y2, meta, color = [1, 0, 0]):
138
+ newHighlight = DictionaryObject()
139
+
140
+ newHighlight.update({
141
+ NameObject("/F"): NumberObject(4),
142
+ NameObject("/Type"): NameObject("/Annot"),
143
+ NameObject("/Subtype"): NameObject("/Highlight"),
144
+
145
+ NameObject("/T"): TextStringObject(meta["author"]),
146
+ NameObject("/Contents"): TextStringObject(meta["contents"]),
147
+
148
+ NameObject("/C"): ArrayObject([FloatObject(c) for c in color]),
149
+ NameObject("/Rect"): ArrayObject([
150
+ FloatObject(x1),
151
+ FloatObject(y1),
152
+ FloatObject(x2),
153
+ FloatObject(y2)
154
+ ]),
155
+ NameObject("/QuadPoints"): ArrayObject([
156
+ FloatObject(x1),
157
+ FloatObject(y2),
158
+ FloatObject(x2),
159
+ FloatObject(y2),
160
+ FloatObject(x1),
161
+ FloatObject(y1),
162
+ FloatObject(x2),
163
+ FloatObject(y1)
164
+ ]),
165
+ })
166
+
167
+ return newHighlight
168
+
169
+ def addHighlightToPage(highlight, page, output):
170
+ highlight_ref = output._addObject(highlight);
171
+
172
+ if "/Annots" in page:
173
+ page[NameObject("/Annots")].append(highlight_ref)
174
+ else:
175
+ page[NameObject("/Annots")] = ArrayObject([highlight_ref])
176
+
177
+ def get_pdf_words(document, page_numbers=None):
178
+ """
179
+ Get all words from LTChar or LTTextLineHorizontal objects from the document.
180
+
181
+ :param document: string path of the PDF file to process
182
+ :returns: A map of page #'s containing lists of coordinates and PDFMiner
183
+ objects. Ex.: {page_number: [[x1, y1, x2, y2, <LTTextLineHorizontal>],]}
184
+ """
185
+ pdf_doc = open(document, 'rb')
186
+
187
+ bboxes = {}
188
+ for layout, page in get_pages(pdf_doc, page_numbers):
189
+ #print(element.get_text())
190
+ bboxes[page] = []
191
+ for element in layout:
192
+ if not isinstance(element, Iterable):
193
+ continue # not iterable
194
+ for subElement in element:
195
+ #print('Subelement type:', type(subElement))
196
+ if isinstance(subElement, LTChar):
197
+ if (subElement.get_text() == ' '):
198
+ pass # TODO: Handle word deliminator
199
+ # Print the character in this class
200
+ # print(subElement.get_text(), end='')
201
+ item = list(subElement.bbox)
202
+ item.append(subElement)
203
+ bboxes[page].append(item)
204
+ elif isinstance(subElement, LTTextLineHorizontal):
205
+ #print(subElement.bbox)
206
+ item = list(subElement.bbox)
207
+ item.append(subElement)
208
+ bboxes[page].append(item)
209
+ else:
210
+ pass
211
+ return bboxes
212
+
213
+ def get_paragraphs(words):
214
+ paragraph_tolerance = 0.1
215
+ max_height_diff = 1
216
+ paragraphs = []
217
+
218
+ for page, elements in words.items():
219
+ # Find nominal font size
220
+ # Round to int
221
+ freq = Counter()
222
+ for element in elements:
223
+ height = int(element[3] - element[1])
224
+ #print(height,end=' ')
225
+ freq[height] += 1
226
+
227
+ nominal_font = freq.most_common(1)[0][0]
228
+ print("Nominal font is:", nominal_font)
229
+
230
+ print("Page:", page)
231
+ x_offset_prev_line = None
232
+ prev_x_offset = None
233
+ prev_y_offset = None
234
+ paragraph_content = ""
235
+ #print("Element count:", len(elements))
236
+ first_line = False
237
+ processed_first_line = False
238
+
239
+ for element in elements:
240
+ x_offset = element[0]
241
+ y_offset = element[1]
242
+ height = int(element[3] - element[1])
243
+ text = element[4].get_text()
244
+
245
+ if x_offset_prev_line != None:
246
+ large_x_offset = (abs(x_offset_prev_line - x_offset) > paragraph_tolerance)
247
+
248
+ # Font size mismatch?
249
+ if abs(height - nominal_font) > max_height_diff:
250
+ if len(paragraph_content) > 0:
251
+ print("Content append:", len(paragraph_content))
252
+ paragraphs.append(paragraph_content)
253
+ paragraph_content = ""
254
+ print("Continue due to height != nominal_font")
255
+ continue
256
+
257
+ print("ELEMENT:", element[0:4], text[0:15])
258
+ if prev_y_offset is not None and len(paragraph_content) > 0:
259
+ if y_offset < prev_y_offset - height * 1.5:
260
+ print("Content append:", len(paragraph_content))
261
+ if len(paragraph_content) > 0:
262
+ paragraphs.append(paragraph_content)
263
+ paragraph_content = text
264
+ prev_y_offset = None
265
+ continue
266
+
267
+ prev_y_offset = y_offset
268
+
269
+ prev_y_offset = y_offset
270
+ #print("element:", element)
271
+ if not isinstance(element[4], LTTextLineHorizontal):
272
+ continue
273
+
274
+ #print("Running text:", text)
275
+ #print(f"x_offset_prev_line , x_offset]: {x_offset_prev_line, x_offset}")
276
+
277
+
278
+ # Find first paragraph
279
+ if x_offset_prev_line is None:
280
+ #print("x_offset_prev is none")
281
+ x_offset_prev_line = x_offset
282
+ if not processed_first_line:
283
+ first_line = True
284
+ processed_first_line = True
285
+ if height == nominal_font:
286
+ paragraph_content += text
287
+ #print("Continue due to x_offset_prev_line is none")
288
+ continue
289
+
290
+
291
+
292
+ # Check case if first line was indented
293
+ if x_offset_prev_line > x_offset and first_line:
294
+ #print("x_offset < element[0]")
295
+ first_line = False
296
+ paragraph_content += text
297
+ x_offset_prev_line = x_offset
298
+ #print("Continue due to x_offset_prev_line > x_offset and first_line")
299
+ continue
300
+
301
+ # is this indented?
302
+ # and ignore small changes
303
+ if x_offset_prev_line < x_offset and large_x_offset:
304
+ #print(f"x_offset_prev_line > x_offset: {x_offset_prev_line, x_offset}")
305
+ if height == nominal_font and len(paragraph_content) > 0:
306
+ paragraphs.append(paragraph_content)
307
+
308
+ paragraph_content = text
309
+ # Reset at next line read
310
+ # What if next paragraph is also indented???
311
+ x_offset_prev_line = None
312
+ #print("Continue due to x_offset_prev_line < x_offset and large_x_offset")
313
+ continue
314
+
315
+ #print(element[0:4])
316
+ if height == nominal_font:
317
+ paragraph_content += text
318
+ #print("End of loop")
319
+
320
+ # TODO: Remove redundant space
321
+ if paragraph_content != "":
322
+ paragraphs.append(paragraph_content)
323
+
324
+ # Find paragraph indexes
325
+ c = 0
326
+ indexes = []
327
+ for p in paragraphs:
328
+ c += len(p)
329
+ indexes.append(c)
330
+
331
+ return paragraphs, indexes
332
+
333
+ def get_pdf_elements(document, element_type, page_numbers=None):
334
+ pdf_doc = open(document, 'rb')
335
+
336
+ items = {}
337
+ for layout, page in get_pages(pdf_doc, page_numbers):
338
+ #print(element.get_text())
339
+ items[page] = []
340
+ for element in layout:
341
+ if isinstance(element, element_type):
342
+ item = list(element.bbox)
343
+ if hasattr(element, 'non_stroking_color'):
344
+ item.append(element.non_stroking_color)
345
+ items[page].append(item)
346
+ print(items)
347
+ return items
348
+
349
+ def get_large_colored_background_rectangles(document, page_numbers=None):
350
+ # Only include rectangles that are at least 4" x 1" in size
351
+ min_size = (288.0, 72.0)
352
+
353
+ elements = get_pdf_elements(document, LTRect, page_numbers)
354
+ rects_out = {}
355
+ for page, rects in elements.items():
356
+ print("Rects:", rects)
357
+ for rect in rects:
358
+ width = rect[2] - rect[0]
359
+ height = rect[3] - rect[1]
360
+ print("Dimensions:", width, height)
361
+ if (width > min_size[0] and
362
+ height > min_size[1]):
363
+ if not page in rects_out:
364
+ rects_out[page] = []
365
+ rects_out[page].append(rect)
366
+ return rects_out
367
+
368
+ def extract_pages(document, output, page_numbers=None):
369
+ pdf = PdfFileReader(document)
370
+
371
+ pdf_writer = PdfFileWriter()
372
+ for page in page_numbers:
373
+ current_page = pdf.getPage(page)
374
+ pdf_writer.addPage(current_page)
375
+
376
+ with open(output, "wb") as out:
377
+ pdf_writer.write(out)
378
+
DiT_Extractor/dit_object_detection/README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiT for Object Detection
2
+
3
+ This folder contains Mask R-CNN Cascade Mask R-CNN running instructions on top of [Detectron2](https://github.com/facebookresearch/detectron2) for PubLayNet and ICDAR 2019 cTDaR.
4
+
5
+ ## Usage
6
+
7
+ ### Inference
8
+
9
+ The quickest way to try out DiT for document layout analysis is the web demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/nielsr/dit-document-layout-analysis).
10
+
11
+ One can run inference using the `inference.py` script. It can be run as follows (from the root of the unilm repository):
12
+
13
+ ```
14
+ python ./dit/object_detection/inference.py \
15
+ --image_path ./dit/object_detection/publaynet_example.jpeg \
16
+ --output_file_name output.jpg \
17
+ --config ./dit/object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml \
18
+ --opts MODEL.WEIGHTS https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_mrcnn.pth \
19
+ ```
20
+
21
+ Make sure that the configuration file (YAML) and PyTorch checkpoint match. The example above uses DiT-base with the Mask R-CNN framework fine-tuned on PubLayNet.
22
+
23
+ ### Data Preparation
24
+
25
+ **PubLayNet**
26
+
27
+ Download the data from this [link](https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.218138265.1825957955.1646384196-1495010506.1633610665) (~96GB). Then extract it to `PATH-to-PubLayNet`.
28
+
29
+ A soft link needs to be created to make the data accessible for the program:`ln -s PATH-to-PubLayNet publaynet_data`.
30
+
31
+ **ICDAR 2019 cTDaR**
32
+
33
+ Download the data from this [link](https://github.com/cndplab-founder/ICDAR2019_cTDaR) (~4GB). Assume path to this repository is named as `PATH-to-ICDARrepo`.
34
+
35
+ Then run `python convert_to_coco_format.py --root_dir=PATH-to-ICDARrepo --target_dir=PATH-toICDAR`. Now the path to processed data is `PATH-to-ICDAR`.
36
+
37
+ Run the following command to get the adaptively binarized images for archival subset.
38
+
39
+ ```
40
+ cp -r PATH-to-ICDAR/trackA_archival PATH-to-ICDAR/at_trackA_archival
41
+ python adaptive_binarize.py --root_dir PATH-to-ICDAR/at_trackA_archival
42
+ ```
43
+
44
+ The binarized archival subset will be in `PATH-to-ICDAR/at_trackA_archival`.
45
+
46
+ According to the subset you want to evaluate/fine-tune, a soft link should be created:`ln -s PATH-to-ICDAR/trackA_modern data` or `ln -s PATH-to-ICDAR/at_trackA_archival data`.
47
+
48
+ ### Evaluation
49
+
50
+ Following commands provide two examples to evaluate the fine-tuned checkpoints.
51
+
52
+ The config files can be found in `icdar19_configs` and `publaynet_configs`.
53
+
54
+ 1) Evaluate the fine-tuned checkpoint of DiT-Base with Mask R-CNN on PublayNet:
55
+ ```bash
56
+ python train_net.py --config-file publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS <finetuned_checkpoint_file_path or link> OUTPUT_DIR <your_output_dir>
57
+ ```
58
+
59
+ 2) Evaluate the fine-tuned checkpoint of DiT-Large with Cascade Mask R-CNN on ICDAR 2019 cTDaR archival subset (make sure you have created a soft link from `PATH-to-ICDAR/at_trackA_archival` to `data`):
60
+ ```bash
61
+ python train_net.py --config-file icdar19_configs/cascade/cascade_dit_large.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS <finetuned_checkpoint_file_path or link> OUTPUT_DIR <your_output_dir>
62
+ ```
63
+
64
+ **Note**: We have fixed the **bug** in the [ICDAR2019 measurement tool](https://github.com/cndplab-founder/ctdar_measurement_tool) during integrating the tool into our code. If you use the tool to get the evaluation score, please modify the [code](https://github.com/cndplab-founder/ctdar_measurement_tool/blob/738456d3164a838ffaeefe7d1b5e64f3a4368a0e/evaluate.py#L146
65
+ ) as follows:
66
+ ```bash
67
+ ...
68
+ # print(each_file)
69
+
70
+ # for file in gt_file_lst:
71
+ # if file.split(".") != "xml":
72
+ # gt_file_lst.remove(file)
73
+ # # print(gt_file_lst)
74
+
75
+ # Comment the code above and add the code below
76
+ for i in range(len(gt_file_lst) - 1, -1, -1):
77
+ if gt_file_lst[i].split(".")[-1] != "xml":
78
+ del gt_file_lst[i]
79
+
80
+ if len(gt_file_lst) > 0:
81
+ ...
82
+ ```
83
+
84
+ ### Training
85
+ The following commands provide two examples to train the Mask R-CNN/Cascade Mask R-CNN with DiT backbone on 8 32GB Nvidia V100 GPUs.
86
+
87
+ 1) Fine-tune DiT-Base with Cascade Mask R-CNN on PublayNet:
88
+ ```bash
89
+ python train_net.py --config-file publaynet_configs/cascade/cascade_dit_base.yaml --num-gpus 8 MODEL.WEIGHTS <DiT-Base_file_path or link> OUTPUT_DIR <your_output_dir>
90
+ ```
91
+
92
+
93
+ 2) Fine-tune DiT-Large with Mask R-CNN on ICDAR 2019 cTDaR modern:
94
+ ```bash
95
+ python train_net.py --config-file icdar19_configs/markrcnn/maskrcnn_dit_large.yaml --num-gpus 8 MODEL.WEIGHTS <DiT-Large_file_path or link> OUTPUT_DIR <your_output_dir>
96
+ ```
97
+
98
+
99
+
100
+ [Detectron2's document](https://detectron2.readthedocs.io/en/latest/tutorials/getting_started.html) may help you for more details.
101
+
102
+
103
+ ## Citation
104
+
105
+ If you find this repository useful, please consider citing our work:
106
+ ```
107
+ @misc{li2022dit,
108
+ title={DiT: Self-supervised Pre-training for Document Image Transformer},
109
+ author={Junlong Li and Yiheng Xu and Tengchao Lv and Lei Cui and Cha Zhang and Furu Wei},
110
+ year={2022},
111
+ eprint={2203.02378},
112
+ archivePrefix={arXiv},
113
+ primaryClass={cs.CV}
114
+ }
115
+ ```
116
+
117
+
118
+
119
+ ## Acknowledgment
120
+ Thanks to [Detectron2](https://github.com/facebookresearch/detectron2) for Mask R-CNN and Cascade Mask R-CNN implementation.
DiT_Extractor/dit_object_detection/ditod/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------------
2
+ # MPViT: Multi-Path Vision Transformer for Dense Prediction
3
+ # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
4
+ # All Rights Reserved.
5
+ # Written by Youngwan Lee
6
+ # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ # --------------------------------------------------------------------------------
9
+
10
+ from .config import add_vit_config
11
+ from .backbone import build_vit_fpn_backbone
DiT_Extractor/dit_object_detection/ditod/backbone.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------------
2
+ # VIT: Multi-Path Vision Transformer for Dense Prediction
3
+ # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
4
+ # All Rights Reserved.
5
+ # Written by Youngwan Lee
6
+ # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ # --------------------------------------------------------------------------------
9
+ # References:
10
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
11
+ # CoaT: https://github.com/mlpc-ucsd/CoaT
12
+ # --------------------------------------------------------------------------------
13
+
14
+
15
+ import torch
16
+
17
+ from detectron2.layers import (
18
+ ShapeSpec,
19
+ )
20
+ from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
21
+ from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
22
+
23
+ from .beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
24
+ from .deit import deit_base_patch16, mae_base_patch16
25
+
26
+ __all__ = [
27
+ "build_vit_fpn_backbone",
28
+ ]
29
+
30
+
31
+ class VIT_Backbone(Backbone):
32
+ """
33
+ Implement VIT backbone.
34
+ """
35
+
36
+ def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs):
37
+ super().__init__()
38
+ self._out_features = out_features
39
+ if 'base' in name:
40
+ self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
41
+ else:
42
+ self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
43
+
44
+ if name == 'beit_base_patch16':
45
+ model_func = beit_base_patch16
46
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
47
+ elif name == 'dit_base_patch16':
48
+ model_func = dit_base_patch16
49
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
50
+ elif name == "deit_base_patch16":
51
+ model_func = deit_base_patch16
52
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
53
+ elif name == "mae_base_patch16":
54
+ model_func = mae_base_patch16
55
+ self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
56
+ elif name == "dit_large_patch16":
57
+ model_func = dit_large_patch16
58
+ self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
59
+ elif name == "beit_large_patch16":
60
+ model_func = beit_large_patch16
61
+ self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
62
+ else:
63
+ raise ValueError("Unsupported VIT name yet.")
64
+
65
+ if 'beit' in name or 'dit' in name:
66
+ if pos_type == "abs":
67
+ self.backbone = model_func(img_size=img_size,
68
+ out_features=out_features,
69
+ drop_path_rate=drop_path,
70
+ use_abs_pos_emb=True,
71
+ **model_kwargs)
72
+ elif pos_type == "shared_rel":
73
+ self.backbone = model_func(img_size=img_size,
74
+ out_features=out_features,
75
+ drop_path_rate=drop_path,
76
+ use_shared_rel_pos_bias=True,
77
+ **model_kwargs)
78
+ elif pos_type == "rel":
79
+ self.backbone = model_func(img_size=img_size,
80
+ out_features=out_features,
81
+ drop_path_rate=drop_path,
82
+ use_rel_pos_bias=True,
83
+ **model_kwargs)
84
+ else:
85
+ raise ValueError()
86
+ else:
87
+ self.backbone = model_func(img_size=img_size,
88
+ out_features=out_features,
89
+ drop_path_rate=drop_path,
90
+ **model_kwargs)
91
+
92
+ def forward(self, x):
93
+ """
94
+ Args:
95
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
96
+
97
+ Returns:
98
+ dict[str->Tensor]: names and the corresponding features
99
+ """
100
+ assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
101
+ return self.backbone.forward_features(x)
102
+
103
+ def output_shape(self):
104
+ return {
105
+ name: ShapeSpec(
106
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
107
+ )
108
+ for name in self._out_features
109
+ }
110
+
111
+
112
+ def build_VIT_backbone(cfg):
113
+ """
114
+ Create a VIT instance from config.
115
+
116
+ Args:
117
+ cfg: a detectron2 CfgNode
118
+
119
+ Returns:
120
+ A VIT backbone instance.
121
+ """
122
+ # fmt: off
123
+ name = cfg.MODEL.VIT.NAME
124
+ out_features = cfg.MODEL.VIT.OUT_FEATURES
125
+ drop_path = cfg.MODEL.VIT.DROP_PATH
126
+ img_size = cfg.MODEL.VIT.IMG_SIZE
127
+ pos_type = cfg.MODEL.VIT.POS_TYPE
128
+
129
+ model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
130
+
131
+ return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs)
132
+
133
+
134
+ @BACKBONE_REGISTRY.register()
135
+ def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
136
+ """
137
+ Create a VIT w/ FPN backbone.
138
+
139
+ Args:
140
+ cfg: a detectron2 CfgNode
141
+
142
+ Returns:
143
+ backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
144
+ """
145
+ bottom_up = build_VIT_backbone(cfg)
146
+ in_features = cfg.MODEL.FPN.IN_FEATURES
147
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
148
+ backbone = FPN(
149
+ bottom_up=bottom_up,
150
+ in_features=in_features,
151
+ out_channels=out_channels,
152
+ norm=cfg.MODEL.FPN.NORM,
153
+ top_block=LastLevelMaxPool(),
154
+ fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
155
+ )
156
+ return backbone
DiT_Extractor/dit_object_detection/ditod/beit.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+
3
+ A PyTorch implement of Vision Transformers as described in
4
+ 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
5
+
6
+ The official jax code is released and available at https://github.com/google-research/vision_transformer
7
+
8
+ Status/TODO:
9
+ * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
10
+ * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
11
+ * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
12
+ * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
13
+
14
+ Acknowledgments:
15
+ * The paper authors for releasing code and weights, thanks!
16
+ * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
17
+ for some einops/einsum fun
18
+ * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
19
+ * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
20
+
21
+ Hacked together by / Copyright 2020 Ross Wightman
22
+ """
23
+ import warnings
24
+ import math
25
+ import torch
26
+ from functools import partial
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint as checkpoint
30
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
31
+
32
+
33
+ def _cfg(url='', **kwargs):
34
+ return {
35
+ 'url': url,
36
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
37
+ 'crop_pct': .9, 'interpolation': 'bicubic',
38
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
39
+ **kwargs
40
+ }
41
+
42
+
43
+ class DropPath(nn.Module):
44
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
45
+ """
46
+
47
+ def __init__(self, drop_prob=None):
48
+ super(DropPath, self).__init__()
49
+ self.drop_prob = drop_prob
50
+
51
+ def forward(self, x):
52
+ return drop_path(x, self.drop_prob, self.training)
53
+
54
+ def extra_repr(self) -> str:
55
+ return 'p={}'.format(self.drop_prob)
56
+
57
+
58
+ class Mlp(nn.Module):
59
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
60
+ super().__init__()
61
+ out_features = out_features or in_features
62
+ hidden_features = hidden_features or in_features
63
+ self.fc1 = nn.Linear(in_features, hidden_features)
64
+ self.act = act_layer()
65
+ self.fc2 = nn.Linear(hidden_features, out_features)
66
+ self.drop = nn.Dropout(drop)
67
+
68
+ def forward(self, x):
69
+ x = self.fc1(x)
70
+ x = self.act(x)
71
+ # x = self.drop(x)
72
+ # commit this for the orignal BERT implement
73
+ x = self.fc2(x)
74
+ x = self.drop(x)
75
+ return x
76
+
77
+
78
+ class Attention(nn.Module):
79
+ def __init__(
80
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
81
+ proj_drop=0., window_size=None, attn_head_dim=None):
82
+ super().__init__()
83
+ self.num_heads = num_heads
84
+ head_dim = dim // num_heads
85
+ if attn_head_dim is not None:
86
+ head_dim = attn_head_dim
87
+ all_head_dim = head_dim * self.num_heads
88
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
89
+ self.scale = qk_scale or head_dim ** -0.5
90
+
91
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
92
+ if qkv_bias:
93
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
94
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
95
+ else:
96
+ self.q_bias = None
97
+ self.v_bias = None
98
+
99
+ if window_size:
100
+ self.window_size = window_size
101
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
102
+ self.relative_position_bias_table = nn.Parameter(
103
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
104
+ # cls to token & token 2 cls & cls to cls
105
+
106
+ # get pair-wise relative position index for each token inside the window
107
+ coords_h = torch.arange(window_size[0])
108
+ coords_w = torch.arange(window_size[1])
109
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
110
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
111
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
112
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
113
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
114
+ relative_coords[:, :, 1] += window_size[1] - 1
115
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
116
+ relative_position_index = \
117
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
118
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
119
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
120
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
121
+ relative_position_index[0, 0] = self.num_relative_distance - 1
122
+
123
+ self.register_buffer("relative_position_index", relative_position_index)
124
+
125
+ # trunc_normal_(self.relative_position_bias_table, std=.0)
126
+ else:
127
+ self.window_size = None
128
+ self.relative_position_bias_table = None
129
+ self.relative_position_index = None
130
+
131
+ self.attn_drop = nn.Dropout(attn_drop)
132
+ self.proj = nn.Linear(all_head_dim, dim)
133
+ self.proj_drop = nn.Dropout(proj_drop)
134
+
135
+ def forward(self, x, rel_pos_bias=None, training_window_size=None):
136
+ B, N, C = x.shape
137
+ qkv_bias = None
138
+ if self.q_bias is not None:
139
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
140
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
141
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
142
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
143
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
144
+
145
+ q = q * self.scale
146
+ attn = (q @ k.transpose(-2, -1))
147
+
148
+ if self.relative_position_bias_table is not None:
149
+ if training_window_size == self.window_size:
150
+ relative_position_bias = \
151
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
152
+ self.window_size[0] * self.window_size[1] + 1,
153
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
154
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
155
+ attn = attn + relative_position_bias.unsqueeze(0)
156
+ else:
157
+ training_window_size = tuple(training_window_size.tolist())
158
+ new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
159
+ # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
160
+ new_relative_position_bias_table = F.interpolate(
161
+ self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
162
+ 2 * self.window_size[0] - 1,
163
+ 2 * self.window_size[1] - 1),
164
+ size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
165
+ align_corners=False)
166
+ new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
167
+ new_num_relative_distance - 3).permute(
168
+ 1, 0)
169
+ new_relative_position_bias_table = torch.cat(
170
+ [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
171
+
172
+ # get pair-wise relative position index for each token inside the window
173
+ coords_h = torch.arange(training_window_size[0])
174
+ coords_w = torch.arange(training_window_size[1])
175
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
176
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
177
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
178
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
179
+ relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
180
+ relative_coords[:, :, 1] += training_window_size[1] - 1
181
+ relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
182
+ relative_position_index = \
183
+ torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
184
+ dtype=relative_coords.dtype)
185
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
186
+ relative_position_index[0, 0:] = new_num_relative_distance - 3
187
+ relative_position_index[0:, 0] = new_num_relative_distance - 2
188
+ relative_position_index[0, 0] = new_num_relative_distance - 1
189
+
190
+ relative_position_bias = \
191
+ new_relative_position_bias_table[relative_position_index.view(-1)].view(
192
+ training_window_size[0] * training_window_size[1] + 1,
193
+ training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
194
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
195
+ attn = attn + relative_position_bias.unsqueeze(0)
196
+
197
+ if rel_pos_bias is not None:
198
+ attn = attn + rel_pos_bias
199
+
200
+ attn = attn.softmax(dim=-1)
201
+ attn = self.attn_drop(attn)
202
+
203
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
204
+ x = self.proj(x)
205
+ x = self.proj_drop(x)
206
+ return x
207
+
208
+
209
+ class Block(nn.Module):
210
+
211
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
212
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
213
+ window_size=None, attn_head_dim=None):
214
+ super().__init__()
215
+ self.norm1 = norm_layer(dim)
216
+ self.attn = Attention(
217
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
218
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
219
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
220
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
221
+ self.norm2 = norm_layer(dim)
222
+ mlp_hidden_dim = int(dim * mlp_ratio)
223
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
224
+
225
+ if init_values is not None:
226
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
227
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
228
+ else:
229
+ self.gamma_1, self.gamma_2 = None, None
230
+
231
+ def forward(self, x, rel_pos_bias=None, training_window_size=None):
232
+ if self.gamma_1 is None:
233
+ x = x + self.drop_path(
234
+ self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, training_window_size=training_window_size))
235
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
236
+ else:
237
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias,
238
+ training_window_size=training_window_size))
239
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
240
+ return x
241
+
242
+
243
+ class PatchEmbed(nn.Module):
244
+ """ Image to Patch Embedding
245
+ """
246
+
247
+ def __init__(self, img_size=[224, 224], patch_size=16, in_chans=3, embed_dim=768):
248
+ super().__init__()
249
+ img_size = to_2tuple(img_size)
250
+ patch_size = to_2tuple(patch_size)
251
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
252
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
253
+ self.num_patches_w = self.patch_shape[0]
254
+ self.num_patches_h = self.patch_shape[1]
255
+ # the so-called patch_shape is the patch shape during pre-training
256
+ self.img_size = img_size
257
+ self.patch_size = patch_size
258
+ self.num_patches = num_patches
259
+
260
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
261
+
262
+ def forward(self, x, position_embedding=None, **kwargs):
263
+ # FIXME look at relaxing size constraints
264
+ # assert H == self.img_size[0] and W == self.img_size[1], \
265
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
266
+ x = self.proj(x)
267
+ Hp, Wp = x.shape[2], x.shape[3]
268
+
269
+ if position_embedding is not None:
270
+ # interpolate the position embedding to the corresponding size
271
+ position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(0, 3,
272
+ 1, 2)
273
+ position_embedding = F.interpolate(position_embedding, size=(Hp, Wp), mode='bicubic')
274
+ x = x + position_embedding
275
+
276
+ x = x.flatten(2).transpose(1, 2)
277
+ return x, (Hp, Wp)
278
+
279
+
280
+ class HybridEmbed(nn.Module):
281
+ """ CNN Feature Map Embedding
282
+ Extract feature map from CNN, flatten, project to embedding dim.
283
+ """
284
+
285
+ def __init__(self, backbone, img_size=[224, 224], feature_size=None, in_chans=3, embed_dim=768):
286
+ super().__init__()
287
+ assert isinstance(backbone, nn.Module)
288
+ img_size = to_2tuple(img_size)
289
+ self.img_size = img_size
290
+ self.backbone = backbone
291
+ if feature_size is None:
292
+ with torch.no_grad():
293
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
294
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
295
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
296
+ training = backbone.training
297
+ if training:
298
+ backbone.eval()
299
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
300
+ feature_size = o.shape[-2:]
301
+ feature_dim = o.shape[1]
302
+ backbone.train(training)
303
+ else:
304
+ feature_size = to_2tuple(feature_size)
305
+ feature_dim = self.backbone.feature_info.channels()[-1]
306
+ self.num_patches = feature_size[0] * feature_size[1]
307
+ self.proj = nn.Linear(feature_dim, embed_dim)
308
+
309
+ def forward(self, x):
310
+ x = self.backbone(x)[-1]
311
+ x = x.flatten(2).transpose(1, 2)
312
+ x = self.proj(x)
313
+ return x
314
+
315
+
316
+ class RelativePositionBias(nn.Module):
317
+
318
+ def __init__(self, window_size, num_heads):
319
+ super().__init__()
320
+ self.window_size = window_size
321
+ self.num_heads = num_heads
322
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
323
+ self.relative_position_bias_table = nn.Parameter(
324
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
325
+ # cls to token & token 2 cls & cls to cls
326
+
327
+ # get pair-wise relative position index for each token inside the window
328
+ coords_h = torch.arange(window_size[0])
329
+ coords_w = torch.arange(window_size[1])
330
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
331
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
332
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
333
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
334
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
335
+ relative_coords[:, :, 1] += window_size[1] - 1
336
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
337
+ relative_position_index = \
338
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
339
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
340
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
341
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
342
+ relative_position_index[0, 0] = self.num_relative_distance - 1
343
+
344
+ self.register_buffer("relative_position_index", relative_position_index)
345
+
346
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
347
+
348
+ def forward(self, training_window_size):
349
+ if training_window_size == self.window_size:
350
+ relative_position_bias = \
351
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
352
+ self.window_size[0] * self.window_size[1] + 1,
353
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
354
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
355
+ else:
356
+ training_window_size = tuple(training_window_size.tolist())
357
+ new_num_relative_distance = (2 * training_window_size[0] - 1) * (2 * training_window_size[1] - 1) + 3
358
+ # new_num_relative_dis 为 所有可能的相对位置选项,包含cls-cls,tok-cls,与cls-tok
359
+ new_relative_position_bias_table = F.interpolate(
360
+ self.relative_position_bias_table[:-3, :].permute(1, 0).view(1, self.num_heads,
361
+ 2 * self.window_size[0] - 1,
362
+ 2 * self.window_size[1] - 1),
363
+ size=(2 * training_window_size[0] - 1, 2 * training_window_size[1] - 1), mode='bicubic',
364
+ align_corners=False)
365
+ new_relative_position_bias_table = new_relative_position_bias_table.view(self.num_heads,
366
+ new_num_relative_distance - 3).permute(
367
+ 1, 0)
368
+ new_relative_position_bias_table = torch.cat(
369
+ [new_relative_position_bias_table, self.relative_position_bias_table[-3::]], dim=0)
370
+
371
+ # get pair-wise relative position index for each token inside the window
372
+ coords_h = torch.arange(training_window_size[0])
373
+ coords_w = torch.arange(training_window_size[1])
374
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
375
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
376
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
377
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
378
+ relative_coords[:, :, 0] += training_window_size[0] - 1 # shift to start from 0
379
+ relative_coords[:, :, 1] += training_window_size[1] - 1
380
+ relative_coords[:, :, 0] *= 2 * training_window_size[1] - 1
381
+ relative_position_index = \
382
+ torch.zeros(size=(training_window_size[0] * training_window_size[1] + 1,) * 2,
383
+ dtype=relative_coords.dtype)
384
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
385
+ relative_position_index[0, 0:] = new_num_relative_distance - 3
386
+ relative_position_index[0:, 0] = new_num_relative_distance - 2
387
+ relative_position_index[0, 0] = new_num_relative_distance - 1
388
+
389
+ relative_position_bias = \
390
+ new_relative_position_bias_table[relative_position_index.view(-1)].view(
391
+ training_window_size[0] * training_window_size[1] + 1,
392
+ training_window_size[0] * training_window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
393
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
394
+
395
+ return relative_position_bias
396
+
397
+
398
+ class BEiT(nn.Module):
399
+ """ Vision Transformer with support for patch or hybrid CNN input stage
400
+ """
401
+
402
+ def __init__(self,
403
+ img_size=[224, 224],
404
+ patch_size=16,
405
+ in_chans=3,
406
+ num_classes=80,
407
+ embed_dim=768,
408
+ depth=12,
409
+ num_heads=12,
410
+ mlp_ratio=4.,
411
+ qkv_bias=False,
412
+ qk_scale=None,
413
+ drop_rate=0.,
414
+ attn_drop_rate=0.,
415
+ drop_path_rate=0.,
416
+ hybrid_backbone=None,
417
+ norm_layer=None,
418
+ init_values=None,
419
+ use_abs_pos_emb=False,
420
+ use_rel_pos_bias=False,
421
+ use_shared_rel_pos_bias=False,
422
+ use_checkpoint=True,
423
+ pretrained=None,
424
+ out_features=None,
425
+ ):
426
+
427
+ super(BEiT, self).__init__()
428
+
429
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
430
+ self.num_classes = num_classes
431
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
432
+ self.use_checkpoint = use_checkpoint
433
+
434
+ if hybrid_backbone is not None:
435
+ self.patch_embed = HybridEmbed(
436
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
437
+ else:
438
+ self.patch_embed = PatchEmbed(
439
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
440
+ num_patches = self.patch_embed.num_patches
441
+ self.out_features = out_features
442
+ self.out_indices = [int(name[5:]) for name in out_features]
443
+
444
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
445
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
446
+ if use_abs_pos_emb:
447
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
448
+ else:
449
+ self.pos_embed = None
450
+ self.pos_drop = nn.Dropout(p=drop_rate)
451
+
452
+ self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
453
+ if use_shared_rel_pos_bias:
454
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
455
+ else:
456
+ self.rel_pos_bias = None
457
+
458
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
459
+ self.use_rel_pos_bias = use_rel_pos_bias
460
+ self.blocks = nn.ModuleList([
461
+ Block(
462
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
463
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
464
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
465
+ for i in range(depth)])
466
+
467
+ # trunc_normal_(self.mask_token, std=.02)
468
+
469
+ if patch_size == 16:
470
+ self.fpn1 = nn.Sequential(
471
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
472
+ # nn.SyncBatchNorm(embed_dim),
473
+ nn.BatchNorm2d(embed_dim),
474
+ nn.GELU(),
475
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
476
+ )
477
+
478
+ self.fpn2 = nn.Sequential(
479
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
480
+ )
481
+
482
+ self.fpn3 = nn.Identity()
483
+
484
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
485
+ elif patch_size == 8:
486
+ self.fpn1 = nn.Sequential(
487
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
488
+ )
489
+
490
+ self.fpn2 = nn.Identity()
491
+
492
+ self.fpn3 = nn.Sequential(
493
+ nn.MaxPool2d(kernel_size=2, stride=2),
494
+ )
495
+
496
+ self.fpn4 = nn.Sequential(
497
+ nn.MaxPool2d(kernel_size=4, stride=4),
498
+ )
499
+
500
+ if self.pos_embed is not None:
501
+ trunc_normal_(self.pos_embed, std=.02)
502
+ trunc_normal_(self.cls_token, std=.02)
503
+ self.apply(self._init_weights)
504
+ self.fix_init_weight()
505
+
506
+ def fix_init_weight(self):
507
+ def rescale(param, layer_id):
508
+ param.div_(math.sqrt(2.0 * layer_id))
509
+
510
+ for layer_id, layer in enumerate(self.blocks):
511
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
512
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
513
+
514
+ def _init_weights(self, m):
515
+ if isinstance(m, nn.Linear):
516
+ trunc_normal_(m.weight, std=.02)
517
+ if isinstance(m, nn.Linear) and m.bias is not None:
518
+ nn.init.constant_(m.bias, 0)
519
+ elif isinstance(m, nn.LayerNorm):
520
+ nn.init.constant_(m.bias, 0)
521
+ nn.init.constant_(m.weight, 1.0)
522
+
523
+ '''
524
+ def init_weights(self):
525
+ """Initialize the weights in backbone.
526
+
527
+ Args:
528
+ pretrained (str, optional): Path to pre-trained weights.
529
+ Defaults to None.
530
+ """
531
+ logger = get_root_logger()
532
+
533
+ if self.pos_embed is not None:
534
+ trunc_normal_(self.pos_embed, std=.02)
535
+ trunc_normal_(self.cls_token, std=.02)
536
+ self.apply(self._init_weights)
537
+ self.fix_init_weight()
538
+
539
+ if self.init_cfg is None:
540
+ logger.warn(f'No pre-trained weights for '
541
+ f'{self.__class__.__name__}, '
542
+ f'training start from scratch')
543
+ else:
544
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
545
+ f'specify `Pretrained` in ' \
546
+ f'`init_cfg` in ' \
547
+ f'{self.__class__.__name__} '
548
+ logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
549
+ load_checkpoint(self,
550
+ filename=self.init_cfg['checkpoint'],
551
+ strict=False,
552
+ logger=logger,
553
+ beit_spec_expand_rel_pos = self.use_rel_pos_bias,
554
+ )
555
+ '''
556
+
557
+ def get_num_layers(self):
558
+ return len(self.blocks)
559
+
560
+ @torch.jit.ignore
561
+ def no_weight_decay(self):
562
+ return {'pos_embed', 'cls_token'}
563
+
564
+ def forward_features(self, x):
565
+ B, C, H, W = x.shape
566
+ x, (Hp, Wp) = self.patch_embed(x, self.pos_embed[:, 1:, :] if self.pos_embed is not None else None)
567
+ # Hp, Wp are HW for patches
568
+ batch_size, seq_len, _ = x.size()
569
+
570
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
571
+ if self.pos_embed is not None:
572
+ cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
573
+ x = torch.cat((cls_tokens, x), dim=1)
574
+ x = self.pos_drop(x)
575
+
576
+ features = []
577
+ training_window_size = torch.tensor([Hp, Wp])
578
+
579
+ rel_pos_bias = self.rel_pos_bias(training_window_size) if self.rel_pos_bias is not None else None
580
+
581
+ for i, blk in enumerate(self.blocks):
582
+ if self.use_checkpoint:
583
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias, training_window_size)
584
+ else:
585
+ x = blk(x, rel_pos_bias=rel_pos_bias, training_window_size=training_window_size)
586
+ if i in self.out_indices:
587
+ xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
588
+ features.append(xp.contiguous())
589
+
590
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
591
+ for i in range(len(features)):
592
+ features[i] = ops[i](features[i])
593
+
594
+ feat_out = {}
595
+
596
+ for name, value in zip(self.out_features, features):
597
+ feat_out[name] = value
598
+
599
+ return feat_out
600
+
601
+ def forward(self, x):
602
+ x = self.forward_features(x)
603
+ return x
604
+
605
+
606
+ def beit_base_patch16(pretrained=False, **kwargs):
607
+ model = BEiT(
608
+ patch_size=16,
609
+ embed_dim=768,
610
+ depth=12,
611
+ num_heads=12,
612
+ mlp_ratio=4,
613
+ qkv_bias=True,
614
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
615
+ init_values=None,
616
+ **kwargs)
617
+ model.default_cfg = _cfg()
618
+ return model
619
+
620
+ def beit_large_patch16(pretrained=False, **kwargs):
621
+ model = BEiT(
622
+ patch_size=16,
623
+ embed_dim=1024,
624
+ depth=24,
625
+ num_heads=16,
626
+ mlp_ratio=4,
627
+ qkv_bias=True,
628
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
629
+ init_values=None,
630
+ **kwargs)
631
+ model.default_cfg = _cfg()
632
+ return model
633
+
634
+ def dit_base_patch16(pretrained=False, **kwargs):
635
+ model = BEiT(
636
+ patch_size=16,
637
+ embed_dim=768,
638
+ depth=12,
639
+ num_heads=12,
640
+ mlp_ratio=4,
641
+ qkv_bias=True,
642
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
643
+ init_values=0.1,
644
+ **kwargs)
645
+ model.default_cfg = _cfg()
646
+ return model
647
+
648
+ def dit_large_patch16(pretrained=False, **kwargs):
649
+ model = BEiT(
650
+ patch_size=16,
651
+ embed_dim=1024,
652
+ depth=24,
653
+ num_heads=16,
654
+ mlp_ratio=4,
655
+ qkv_bias=True,
656
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
657
+ init_values=1e-5,
658
+ **kwargs)
659
+ model.default_cfg = _cfg()
660
+ return model
661
+
662
+ if __name__ == '__main__':
663
+ model = BEiT(use_checkpoint=True, use_shared_rel_pos_bias=True)
664
+ model = model.to("cuda:0")
665
+ input1 = torch.rand(2, 3, 512, 762).to("cuda:0")
666
+ input2 = torch.rand(2, 3, 800, 1200).to("cuda:0")
667
+ input3 = torch.rand(2, 3, 720, 1000).to("cuda:0")
668
+ output1 = model(input1)
669
+ output2 = model(input2)
670
+ output3 = model(input3)
671
+ print("all done")
DiT_Extractor/dit_object_detection/ditod/config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2.config import CfgNode as CN
2
+
3
+
4
+ def add_vit_config(cfg):
5
+ """
6
+ Add config for VIT.
7
+ """
8
+ _C = cfg
9
+
10
+ _C.MODEL.VIT = CN()
11
+
12
+ # CoaT model name.
13
+ _C.MODEL.VIT.NAME = ""
14
+
15
+ # Output features from CoaT backbone.
16
+ _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
17
+
18
+ _C.MODEL.VIT.IMG_SIZE = [224, 224]
19
+
20
+ _C.MODEL.VIT.POS_TYPE = "shared_rel"
21
+
22
+ _C.MODEL.VIT.DROP_PATH = 0.
23
+
24
+ _C.MODEL.VIT.MODEL_KWARGS = "{}"
25
+
26
+ _C.SOLVER.OPTIMIZER = "ADAMW"
27
+
28
+ _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
29
+
30
+ _C.AUG = CN()
31
+
32
+ _C.AUG.DETR = False
DiT_Extractor/dit_object_detection/ditod/deit.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly copy-paste from DINO and timm library:
3
+ https://github.com/facebookresearch/dino
4
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
5
+ """
6
+ import warnings
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import trunc_normal_, drop_path, to_2tuple
13
+ from functools import partial
14
+
15
+ def _cfg(url='', **kwargs):
16
+ return {
17
+ 'url': url,
18
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
19
+ 'crop_pct': .9, 'interpolation': 'bicubic',
20
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
21
+ **kwargs
22
+ }
23
+
24
+ class DropPath(nn.Module):
25
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26
+ """
27
+
28
+ def __init__(self, drop_prob=None):
29
+ super(DropPath, self).__init__()
30
+ self.drop_prob = drop_prob
31
+
32
+ def forward(self, x):
33
+ return drop_path(x, self.drop_prob, self.training)
34
+
35
+ def extra_repr(self) -> str:
36
+ return 'p={}'.format(self.drop_prob)
37
+
38
+
39
+ class Mlp(nn.Module):
40
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
41
+ super().__init__()
42
+ out_features = out_features or in_features
43
+ hidden_features = hidden_features or in_features
44
+ self.fc1 = nn.Linear(in_features, hidden_features)
45
+ self.act = act_layer()
46
+ self.fc2 = nn.Linear(hidden_features, out_features)
47
+ self.drop = nn.Dropout(drop)
48
+
49
+ def forward(self, x):
50
+ x = self.fc1(x)
51
+ x = self.act(x)
52
+ x = self.drop(x)
53
+ x = self.fc2(x)
54
+ x = self.drop(x)
55
+ return x
56
+
57
+
58
+ class Attention(nn.Module):
59
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ head_dim = dim // num_heads
63
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
64
+ self.scale = qk_scale or head_dim ** -0.5
65
+
66
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
67
+ self.attn_drop = nn.Dropout(attn_drop)
68
+ self.proj = nn.Linear(dim, dim)
69
+ self.proj_drop = nn.Dropout(proj_drop)
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads,
74
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
75
+
76
+ attn = (q @ k.transpose(-2, -1)) * self.scale
77
+ attn = attn.softmax(dim=-1)
78
+ attn = self.attn_drop(attn)
79
+
80
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
81
+ x = self.proj(x)
82
+ x = self.proj_drop(x)
83
+ return x
84
+
85
+
86
+ class Block(nn.Module):
87
+
88
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
89
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
90
+ super().__init__()
91
+ self.norm1 = norm_layer(dim)
92
+ self.attn = Attention(
93
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
94
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
95
+ self.drop_path = DropPath(
96
+ drop_path) if drop_path > 0. else nn.Identity()
97
+ self.norm2 = norm_layer(dim)
98
+ mlp_hidden_dim = int(dim * mlp_ratio)
99
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
100
+ act_layer=act_layer, drop=drop)
101
+
102
+ def forward(self, x):
103
+ x = x + self.drop_path(self.attn(self.norm1(x)))
104
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
105
+ return x
106
+
107
+
108
+ class PatchEmbed(nn.Module):
109
+ """ Image to Patch Embedding
110
+ """
111
+
112
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
113
+ super().__init__()
114
+ img_size = to_2tuple(img_size)
115
+ patch_size = to_2tuple(patch_size)
116
+
117
+ self.window_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
118
+
119
+ self.num_patches_w, self.num_patches_h = self.window_size
120
+
121
+ self.num_patches = self.window_size[0] * self.window_size[1]
122
+ self.img_size = img_size
123
+ self.patch_size = patch_size
124
+
125
+ self.proj = nn.Conv2d(in_chans, embed_dim,
126
+ kernel_size=patch_size, stride=patch_size)
127
+
128
+ def forward(self, x):
129
+ x = self.proj(x)
130
+ return x
131
+
132
+
133
+ class HybridEmbed(nn.Module):
134
+ """ CNN Feature Map Embedding
135
+ Extract feature map from CNN, flatten, project to embedding dim.
136
+ """
137
+
138
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
139
+ super().__init__()
140
+ assert isinstance(backbone, nn.Module)
141
+ img_size = to_2tuple(img_size)
142
+ self.img_size = img_size
143
+ self.backbone = backbone
144
+ if feature_size is None:
145
+ with torch.no_grad():
146
+ # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
147
+ # map for all networks, the feature metadata has reliable channel and stride info, but using
148
+ # stride to calc feature dim requires info about padding of each stage that isn't captured.
149
+ training = backbone.training
150
+ if training:
151
+ backbone.eval()
152
+ o = self.backbone(torch.zeros(
153
+ 1, in_chans, img_size[0], img_size[1]))[-1]
154
+ feature_size = o.shape[-2:]
155
+ feature_dim = o.shape[1]
156
+ backbone.train(training)
157
+ else:
158
+ feature_size = to_2tuple(feature_size)
159
+ feature_dim = self.backbone.feature_info.channels()[-1]
160
+ self.num_patches = feature_size[0] * feature_size[1]
161
+ self.proj = nn.Linear(feature_dim, embed_dim)
162
+
163
+ def forward(self, x):
164
+ x = self.backbone(x)[-1]
165
+ x = x.flatten(2).transpose(1, 2)
166
+ x = self.proj(x)
167
+ return x
168
+
169
+
170
+ class ViT(nn.Module):
171
+ """ Vision Transformer with support for patch or hybrid CNN input stage
172
+ """
173
+
174
+ def __init__(self,
175
+ model_name='vit_base_patch16_224',
176
+ img_size=384,
177
+ patch_size=16,
178
+ in_chans=3,
179
+ embed_dim=1024,
180
+ depth=24,
181
+ num_heads=16,
182
+ num_classes=19,
183
+ mlp_ratio=4.,
184
+ qkv_bias=True,
185
+ qk_scale=None,
186
+ drop_rate=0.1,
187
+ attn_drop_rate=0.,
188
+ drop_path_rate=0.,
189
+ hybrid_backbone=None,
190
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
191
+ norm_cfg=None,
192
+ pos_embed_interp=False,
193
+ random_init=False,
194
+ align_corners=False,
195
+ use_checkpoint=False,
196
+ num_extra_tokens=1,
197
+ out_features=None,
198
+ **kwargs,
199
+ ):
200
+
201
+ super(ViT, self).__init__()
202
+ self.model_name = model_name
203
+ self.img_size = img_size
204
+ self.patch_size = patch_size
205
+ self.in_chans = in_chans
206
+ self.embed_dim = embed_dim
207
+ self.depth = depth
208
+ self.num_heads = num_heads
209
+ self.num_classes = num_classes
210
+ self.mlp_ratio = mlp_ratio
211
+ self.qkv_bias = qkv_bias
212
+ self.qk_scale = qk_scale
213
+ self.drop_rate = drop_rate
214
+ self.attn_drop_rate = attn_drop_rate
215
+ self.drop_path_rate = drop_path_rate
216
+ self.hybrid_backbone = hybrid_backbone
217
+ self.norm_layer = norm_layer
218
+ self.norm_cfg = norm_cfg
219
+ self.pos_embed_interp = pos_embed_interp
220
+ self.random_init = random_init
221
+ self.align_corners = align_corners
222
+ self.use_checkpoint = use_checkpoint
223
+ self.num_extra_tokens = num_extra_tokens
224
+ self.out_features = out_features
225
+ self.out_indices = [int(name[5:]) for name in out_features]
226
+
227
+ # self.num_stages = self.depth
228
+ # self.out_indices = tuple(range(self.num_stages))
229
+
230
+ if self.hybrid_backbone is not None:
231
+ self.patch_embed = HybridEmbed(
232
+ self.hybrid_backbone, img_size=self.img_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
233
+ else:
234
+ self.patch_embed = PatchEmbed(
235
+ img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
236
+ self.num_patches = self.patch_embed.num_patches
237
+
238
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
239
+
240
+ if self.num_extra_tokens == 2:
241
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
242
+
243
+ self.pos_embed = nn.Parameter(torch.zeros(
244
+ 1, self.num_patches + self.num_extra_tokens, self.embed_dim))
245
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
246
+
247
+ # self.num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
248
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate,
249
+ self.depth)] # stochastic depth decay rule
250
+ self.blocks = nn.ModuleList([
251
+ Block(
252
+ dim=self.embed_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=self.qkv_bias,
253
+ qk_scale=self.qk_scale,
254
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=self.norm_layer)
255
+ for i in range(self.depth)])
256
+
257
+ # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
258
+ # self.repr = nn.Linear(embed_dim, representation_size)
259
+ # self.repr_act = nn.Tanh()
260
+
261
+ if patch_size == 16:
262
+ self.fpn1 = nn.Sequential(
263
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
264
+ nn.SyncBatchNorm(embed_dim),
265
+ nn.GELU(),
266
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
267
+ )
268
+
269
+ self.fpn2 = nn.Sequential(
270
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
271
+ )
272
+
273
+ self.fpn3 = nn.Identity()
274
+
275
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
276
+ elif patch_size == 8:
277
+ self.fpn1 = nn.Sequential(
278
+ nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2),
279
+ )
280
+
281
+ self.fpn2 = nn.Identity()
282
+
283
+ self.fpn3 = nn.Sequential(
284
+ nn.MaxPool2d(kernel_size=2, stride=2),
285
+ )
286
+
287
+ self.fpn4 = nn.Sequential(
288
+ nn.MaxPool2d(kernel_size=4, stride=4),
289
+ )
290
+
291
+ trunc_normal_(self.pos_embed, std=.02)
292
+ trunc_normal_(self.cls_token, std=.02)
293
+ if self.num_extra_tokens==2:
294
+ trunc_normal_(self.dist_token, std=0.2)
295
+ self.apply(self._init_weights)
296
+ # self.fix_init_weight()
297
+
298
+ def fix_init_weight(self):
299
+ def rescale(param, layer_id):
300
+ param.div_(math.sqrt(2.0 * layer_id))
301
+
302
+ for layer_id, layer in enumerate(self.blocks):
303
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
304
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
305
+
306
+ def _init_weights(self, m):
307
+ if isinstance(m, nn.Linear):
308
+ trunc_normal_(m.weight, std=.02)
309
+ if isinstance(m, nn.Linear) and m.bias is not None:
310
+ nn.init.constant_(m.bias, 0)
311
+ elif isinstance(m, nn.LayerNorm):
312
+ nn.init.constant_(m.bias, 0)
313
+ nn.init.constant_(m.weight, 1.0)
314
+
315
+ '''
316
+ def init_weights(self):
317
+ logger = get_root_logger()
318
+
319
+ trunc_normal_(self.pos_embed, std=.02)
320
+ trunc_normal_(self.cls_token, std=.02)
321
+ self.apply(self._init_weights)
322
+
323
+ if self.init_cfg is None:
324
+ logger.warn(f'No pre-trained weights for '
325
+ f'{self.__class__.__name__}, '
326
+ f'training start from scratch')
327
+ else:
328
+ assert 'checkpoint' in self.init_cfg, f'Only support ' \
329
+ f'specify `Pretrained` in ' \
330
+ f'`init_cfg` in ' \
331
+ f'{self.__class__.__name__} '
332
+ logger.info(f"Will load ckpt from {self.init_cfg['checkpoint']}")
333
+ load_checkpoint(self, filename=self.init_cfg['checkpoint'], strict=False, logger=logger)
334
+ '''
335
+
336
+ def get_num_layers(self):
337
+ return len(self.blocks)
338
+
339
+ @torch.jit.ignore
340
+ def no_weight_decay(self):
341
+ return {'pos_embed', 'cls_token'}
342
+
343
+ def _conv_filter(self, state_dict, patch_size=16):
344
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
345
+ out_dict = {}
346
+ for k, v in state_dict.items():
347
+ if 'patch_embed.proj.weight' in k:
348
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
349
+ out_dict[k] = v
350
+ return out_dict
351
+
352
+ def to_2D(self, x):
353
+ n, hw, c = x.shape
354
+ h = w = int(math.sqrt(hw))
355
+ x = x.transpose(1, 2).reshape(n, c, h, w)
356
+ return x
357
+
358
+ def to_1D(self, x):
359
+ n, c, h, w = x.shape
360
+ x = x.reshape(n, c, -1).transpose(1, 2)
361
+ return x
362
+
363
+ def interpolate_pos_encoding(self, x, w, h):
364
+ npatch = x.shape[1] - self.num_extra_tokens
365
+ N = self.pos_embed.shape[1] - self.num_extra_tokens
366
+ if npatch == N and w == h:
367
+ return self.pos_embed
368
+
369
+ class_ORdist_pos_embed = self.pos_embed[:, 0:self.num_extra_tokens]
370
+
371
+ patch_pos_embed = self.pos_embed[:, self.num_extra_tokens:]
372
+
373
+ dim = x.shape[-1]
374
+ w0 = w // self.patch_embed.patch_size[0]
375
+ h0 = h // self.patch_embed.patch_size[1]
376
+ # we add a small number to avoid floating point error in the interpolation
377
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
378
+ w0, h0 = w0 + 0.1, h0 + 0.1
379
+ patch_pos_embed = nn.functional.interpolate(
380
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
381
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
382
+ mode='bicubic',
383
+ )
384
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
385
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
386
+
387
+ return torch.cat((class_ORdist_pos_embed, patch_pos_embed), dim=1)
388
+
389
+ def prepare_tokens(self, x, mask=None):
390
+ B, nc, w, h = x.shape
391
+ # patch linear embedding
392
+ x = self.patch_embed(x)
393
+
394
+ # mask image modeling
395
+ if mask is not None:
396
+ x = self.mask_model(x, mask)
397
+ x = x.flatten(2).transpose(1, 2)
398
+
399
+ # add the [CLS] token to the embed patch tokens
400
+ all_tokens = [self.cls_token.expand(B, -1, -1)]
401
+
402
+ if self.num_extra_tokens == 2:
403
+ dist_tokens = self.dist_token.expand(B, -1, -1)
404
+ all_tokens.append(dist_tokens)
405
+ all_tokens.append(x)
406
+
407
+ x = torch.cat(all_tokens, dim=1)
408
+
409
+ # add positional encoding to each token
410
+ x = x + self.interpolate_pos_encoding(x, w, h)
411
+
412
+ return self.pos_drop(x)
413
+
414
+ def forward_features(self, x):
415
+ # print(f"==========shape of x is {x.shape}==========")
416
+ B, _, H, W = x.shape
417
+ Hp, Wp = H // self.patch_size, W // self.patch_size
418
+ x = self.prepare_tokens(x)
419
+
420
+ features = []
421
+ for i, blk in enumerate(self.blocks):
422
+ if self.use_checkpoint:
423
+ x = checkpoint.checkpoint(blk, x)
424
+ else:
425
+ x = blk(x)
426
+ if i in self.out_indices:
427
+ xp = x[:, self.num_extra_tokens:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp)
428
+ features.append(xp.contiguous())
429
+
430
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
431
+ for i in range(len(features)):
432
+ features[i] = ops[i](features[i])
433
+
434
+ feat_out = {}
435
+
436
+ for name, value in zip(self.out_features, features):
437
+ feat_out[name] = value
438
+
439
+ return feat_out
440
+
441
+ def forward(self, x):
442
+ x = self.forward_features(x)
443
+ return x
444
+
445
+
446
+ def deit_base_patch16(pretrained=False, **kwargs):
447
+ model = ViT(
448
+ patch_size=16,
449
+ drop_rate=0.,
450
+ embed_dim=768,
451
+ depth=12,
452
+ num_heads=12,
453
+ num_classes=1000,
454
+ mlp_ratio=4.,
455
+ qkv_bias=True,
456
+ use_checkpoint=True,
457
+ num_extra_tokens=2,
458
+ **kwargs)
459
+ model.default_cfg = _cfg()
460
+ return model
461
+
462
+ def mae_base_patch16(pretrained=False, **kwargs):
463
+ model = ViT(
464
+ patch_size=16,
465
+ drop_rate=0.,
466
+ embed_dim=768,
467
+ depth=12,
468
+ num_heads=12,
469
+ num_classes=1000,
470
+ mlp_ratio=4.,
471
+ qkv_bias=True,
472
+ use_checkpoint=True,
473
+ num_extra_tokens=1,
474
+ **kwargs)
475
+ model.default_cfg = _cfg()
476
+ return model
DiT_Extractor/dit_object_detection/publaynet_configs/Base-RCNN-FPN.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ MASK_ON: True
3
+ META_ARCHITECTURE: "GeneralizedRCNN"
4
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
5
+ PIXEL_STD: [58.395, 57.120, 57.375]
6
+ BACKBONE:
7
+ NAME: "build_vit_fpn_backbone"
8
+ VIT:
9
+ OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
10
+ DROP_PATH: 0.1
11
+ IMG_SIZE: [224,224]
12
+ POS_TYPE: "abs"
13
+ FPN:
14
+ IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
15
+ ANCHOR_GENERATOR:
16
+ SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
17
+ ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
18
+ RPN:
19
+ IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
20
+ PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
21
+ PRE_NMS_TOPK_TEST: 1000 # Per FPN level
22
+ # Detectron1 uses 2000 proposals per-batch,
23
+ # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
24
+ # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
25
+ POST_NMS_TOPK_TRAIN: 1000
26
+ POST_NMS_TOPK_TEST: 1000
27
+ ROI_HEADS:
28
+ NAME: "StandardROIHeads"
29
+ IN_FEATURES: ["p2", "p3", "p4", "p5"]
30
+ NUM_CLASSES: 5
31
+ ROI_BOX_HEAD:
32
+ NAME: "FastRCNNConvFCHead"
33
+ NUM_FC: 2
34
+ POOLER_RESOLUTION: 7
35
+ ROI_MASK_HEAD:
36
+ NAME: "MaskRCNNConvUpsampleHead"
37
+ NUM_CONV: 4
38
+ POOLER_RESOLUTION: 14
39
+ DATASETS:
40
+ TRAIN: ("publaynet_train",)
41
+ TEST: ("publaynet_val",)
42
+ SOLVER:
43
+ LR_SCHEDULER_NAME: "WarmupCosineLR"
44
+ AMP:
45
+ ENABLED: True
46
+ OPTIMIZER: "ADAMW"
47
+ BACKBONE_MULTIPLIER: 1.0
48
+ CLIP_GRADIENTS:
49
+ ENABLED: True
50
+ CLIP_TYPE: "full_model"
51
+ CLIP_VALUE: 1.0
52
+ NORM_TYPE: 2.0
53
+ WARMUP_FACTOR: 0.01
54
+ BASE_LR: 0.0004
55
+ WEIGHT_DECAY: 0.05
56
+ IMS_PER_BATCH: 32
57
+ INPUT:
58
+ CROP:
59
+ ENABLED: True
60
+ TYPE: "absolute_range"
61
+ SIZE: (384, 600)
62
+ MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
63
+ FORMAT: "RGB"
64
+ DATALOADER:
65
+ FILTER_EMPTY_ANNOTATIONS: False
66
+ VERSION: 2
67
+ AUG:
68
+ DETR: True
69
+ SEED: 42
DiT_Extractor/dit_object_detection/publaynet_configs/cascade/cascade_dit_base.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "../Base-RCNN-FPN.yaml"
2
+ MODEL:
3
+ PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
4
+ PIXEL_STD: [ 127.5, 127.5, 127.5 ]
5
+ WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth"
6
+ VIT:
7
+ NAME: "dit_base_patch16"
8
+ ROI_HEADS:
9
+ NAME: CascadeROIHeads
10
+ ROI_BOX_HEAD:
11
+ CLS_AGNOSTIC_BBOX_REG: True
12
+ RPN:
13
+ POST_NMS_TOPK_TRAIN: 2000
14
+ SOLVER:
15
+ WARMUP_ITERS: 1000
16
+ IMS_PER_BATCH: 16
17
+ MAX_ITER: 60000
18
+ CHECKPOINT_PERIOD: 2000
19
+ TEST:
20
+ EVAL_PERIOD: 2000
DiT_Extractor/dit_object_detection/publaynet_configs/cascade/cascade_dit_large.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "../Base-RCNN-FPN.yaml"
2
+ MODEL:
3
+ PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
4
+ PIXEL_STD: [ 127.5, 127.5, 127.5 ]
5
+ WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-large-224-p16-500k-d7a2fb.pth"
6
+ VIT:
7
+ NAME: "dit_large_patch16"
8
+ OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
9
+ DROP_PATH: 0.2
10
+ FPN:
11
+ IN_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
12
+ ROI_HEADS:
13
+ NAME: CascadeROIHeads
14
+ ROI_BOX_HEAD:
15
+ CLS_AGNOSTIC_BBOX_REG: True
16
+ RPN:
17
+ POST_NMS_TOPK_TRAIN: 2000
18
+ SOLVER:
19
+ WARMUP_ITERS: 1000
20
+ IMS_PER_BATCH: 16
21
+ MAX_ITER: 60000
22
+ CHECKPOINT_PERIOD: 2000
23
+ BASE_LR: 0.0001
24
+ STEPS: (40000, 53333)
25
+ AMP:
26
+ ENABLED: False
27
+ TEST:
28
+ EVAL_PERIOD: 2000
DiT_Extractor/dit_object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_base.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "../Base-RCNN-FPN.yaml"
2
+ MODEL:
3
+ PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
4
+ PIXEL_STD: [ 127.5, 127.5, 127.5 ]
5
+ WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth"
6
+ VIT:
7
+ NAME: "dit_base_patch16"
8
+ SOLVER:
9
+ WARMUP_ITERS: 1000
10
+ IMS_PER_BATCH: 16
11
+ MAX_ITER: 60000
12
+ CHECKPOINT_PERIOD: 2000
13
+ TEST:
14
+ EVAL_PERIOD: 2000
15
+ OUTPUT_DIR: $AMLT_OUTPUT_DIR
DiT_Extractor/dit_object_detection/publaynet_configs/maskrcnn/maskrcnn_dit_large.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "../Base-RCNN-FPN.yaml"
2
+ MODEL:
3
+ PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
4
+ PIXEL_STD: [ 127.5, 127.5, 127.5 ]
5
+ WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-large-224-p16-500k-d7a2fb.pth"
6
+ VIT:
7
+ NAME: "dit_large_patch16"
8
+ OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
9
+ DROP_PATH: 0.2
10
+ FPN:
11
+ IN_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
12
+ SOLVER:
13
+ WARMUP_ITERS: 1000
14
+ IMS_PER_BATCH: 16
15
+ MAX_ITER: 60000
16
+ CHECKPOINT_PERIOD: 2000
17
+ BASE_LR: 0.0001
18
+ AMP:
19
+ ENABLED: False
20
+ TEST:
21
+ EVAL_PERIOD: 2000
22
+ OUTPUT_DIR: "output/publaynet/mask_rcnn/dit_base_multistep_3x_ms"
DiT_Extractor/dit_runner.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Lawrence Livermore National Security, LLC.
2
+ # All rights reserved.
3
+ # See the top-level LICENSE and NOTICE files for details.
4
+ # LLNL-CODE-838964
5
+
6
+ # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
7
+
8
+ import cv2
9
+ from pathlib import Path
10
+ import torch
11
+ import json
12
+
13
+ from detectron2.config import CfgNode as CN
14
+ from detectron2.config import get_cfg
15
+ from detectron2.utils.visualizer import ColorMode, Visualizer
16
+ from detectron2.data import MetadataCatalog
17
+ from detectron2.engine import DefaultPredictor
18
+
19
+ from pdf2image import convert_from_path
20
+
21
+ from PIL import Image
22
+ import numpy as np
23
+
24
+ from dit_object_detection.ditod import add_vit_config
25
+ import base_utils
26
+ from pdfminer.layout import LTTextLineHorizontal, LTTextBoxHorizontal, LTAnno, LTChar
27
+
28
+ from tokenizers.pre_tokenizers import Whitespace
29
+
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+
33
+ dit_path = Path('DiT_Extractor/dit_object_detection')
34
+
35
+ cfg = get_cfg()
36
+ add_vit_config(cfg)
37
+ cfg.merge_from_file(dit_path / "publaynet_configs/cascade/cascade_dit_base.yaml")
38
+
39
+ cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth"
40
+ cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ predictor = DefaultPredictor(cfg)
43
+
44
+ thing_classes = ["text","title","list","table","figure"]
45
+ thing_map = dict(map(reversed, enumerate(thing_classes)))
46
+ md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
47
+ md.set(thing_classes=thing_classes)
48
+
49
+
50
+ def get_pdf_image(pdf_file, page):
51
+ image = convert_from_path(pdf_file, dpi=200, first_page=page, last_page=page)
52
+ return image
53
+
54
+ def get_characters(subelement):
55
+ all_chars = []
56
+ if isinstance(subelement, LTTextLineHorizontal):
57
+ for char in subelement:
58
+ if isinstance(char, LTChar):
59
+ all_chars.append((char.bbox, char.get_text()))
60
+ if isinstance(char, LTAnno):
61
+ # No bbox, just a space, so make a thin slice after previous text
62
+ bbox = all_chars[-1][0]
63
+ bbox = (bbox[2],bbox[1],bbox[2],bbox[3])
64
+ all_chars.append((bbox, char.get_text()))
65
+ return all_chars
66
+
67
+
68
+ def get_dit_preds(pdf, score_threshold=0.5):
69
+
70
+ page_count = base_utils.get_pdf_page_count(pdf)
71
+
72
+ # Input is numpy array of PIL image
73
+ page_sizes = base_utils.get_page_sizes(pdf)
74
+
75
+ sections = {}
76
+ viz_images = []
77
+ page_words = base_utils.get_pdf_words(pdf)
78
+ for page in range(1, page_count+1): #range(2, page_count + 1):
79
+ image = get_pdf_image(pdf, page)
80
+ image = np.array(image[0])
81
+ # Get prediction
82
+ output = predictor(image)["instances"]
83
+ output = output.to('cpu')
84
+
85
+ # Visualize predictions
86
+ v = Visualizer(image[:, :, ::-1],
87
+ md,
88
+ scale=1.0,
89
+ instance_mode=ColorMode.SEGMENTATION)
90
+ result = v.draw_instance_predictions(output)
91
+ result_image = result.get_image()[:, :, ::-1]
92
+ viz_img = Image.fromarray(result_image)
93
+ viz_images.append(viz_img)
94
+
95
+ words = page_words[page-1]
96
+
97
+ # Convert from image_size to page size
98
+ pdf_dimensions = page_sizes[page-1][2:]
99
+ # Swap height/width
100
+ pdf_image_size = (output.image_size[1], output.image_size[0])
101
+
102
+ scale = np.array(pdf_dimensions) / np.array(pdf_image_size)
103
+ scale_box = np.hstack((scale,scale))
104
+ # Words are in page coordinates
105
+
106
+ id = 0
107
+ sections[page-1] = []
108
+ draw = image.copy()
109
+ for box_t, clazz, score in zip(output.get('pred_boxes'), output.get('pred_classes'), output.get('scores')):
110
+
111
+ if score < score_threshold:
112
+ continue
113
+
114
+ box = box_t.numpy()
115
+ # Flip along Y axis
116
+ box[1] = pdf_image_size[1] - box[1]
117
+ box[3] = pdf_image_size[1] - box[3]
118
+ # Scale
119
+ scaled = box * scale_box
120
+ # This is the correct order
121
+ scaled = [scaled[0], scaled[3], scaled[2], scaled[1]]
122
+ if clazz != thing_map['text']:
123
+ continue
124
+
125
+ start = box[0:2].tolist()
126
+ end = box[2:4].tolist()
127
+ start = [int(x) for x in start]
128
+ end = [int(x) for x in end]
129
+
130
+ out = {}
131
+
132
+ for word in words.copy():
133
+ if base_utils.partial_overlaps(word[0:4], scaled):
134
+ if out == {}:
135
+ id += 1
136
+ out['coord'] = word[0:4]
137
+ out['subelements'] = []
138
+ out['type'] = 'content_block'
139
+ out['id']= id
140
+ out['text'] = ''
141
+
142
+ out['coord'] = base_utils.union(out['coord'], word[0:4])
143
+ out['text'] = out['text'] + word[4].get_text()
144
+
145
+ characters = get_characters(word[4])
146
+ out['subelements'].append(characters)
147
+ words.remove(word)
148
+
149
+ if len(out) != 0:
150
+ sections[page-1].append(out)
151
+
152
+ # Write final annotation
153
+
154
+ out_name = Path(pdf).name[:-4] + ".json"
155
+ with open(out_name, 'w', encoding='utf8') as json_out:
156
+ json.dump(sections, json_out, ensure_ascii=False, indent=4)
157
+
158
+ return viz_images
DiT_Extractor/sentence_extractor.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Lawrence Livermore National Security, LLC.
2
+ # All rights reserved.
3
+ # See the top-level LICENSE and NOTICE files for details.
4
+ # LLNL-CODE-838964
5
+
6
+ # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
7
+
8
+ import json
9
+ from tokenizers.pre_tokenizers import Whitespace
10
+ import base_utils
11
+ import spacy
12
+
13
+ def guess_sentences(tokens, text):
14
+ sentence_delems = ('.', '?', ').', '!')
15
+ sentences = []
16
+ sentence = []
17
+ maybe_delem = None
18
+ for token in tokens:
19
+ # check next token to see if there is space after prev delem
20
+ if maybe_delem != None:
21
+ if maybe_delem[1][1] < token[1][0]:
22
+ sentences.append(sentence)
23
+ sentence = []
24
+ maybe_delem = None
25
+
26
+ sentence.append(token)
27
+ if token[0] in sentence_delems:
28
+ maybe_delem = token
29
+ if sentence != []:
30
+ sentences.append(sentence)
31
+ return sentences
32
+
33
+ def spacey_sentences(text):
34
+ nlp = spacy.blank('en')
35
+ nlp.add_pipe('sentencizer')
36
+ sentences = [s.text for s in nlp(text).sents]
37
+ return sentences
38
+
39
+ def add_coords(sentences, all_coords):
40
+ sentences_out = []
41
+ for sentence in sentences:
42
+ new_sentence = []
43
+ for token in sentence:
44
+ indexes = token[1]
45
+ bbox = all_coords[indexes[0]]
46
+ for i in range(indexes[0]+1, indexes[1]):
47
+ bbox = base_utils.union(bbox, all_coords[i])
48
+ new_sentence.append((token[0],token[1],bbox))
49
+ sentences_out.append(new_sentence)
50
+ return sentences_out
51
+
52
+ def sentence_extract(document):
53
+ """
54
+ Convert extract .PDF result .pkl into tokens with max length of 384 tokens, seperated
55
+ on sentence delimiter boundaries such as .!?
56
+ """
57
+ max_tokens = 384
58
+ document_tree = json.load(open(document,'r'))
59
+ sections_per_page = {}
60
+ for page_num, page in document_tree.items():
61
+ # Tokenize per section (rectangular block that was detected by DIT)
62
+ word_sections = []
63
+ text_sections = []
64
+ for section in page:
65
+ text_sections.append(section['text'])
66
+ all_text = ''
67
+ all_coord = []
68
+ if 'subelements' not in section:
69
+ continue
70
+ for subelement in section['subelements']:
71
+ for char in subelement:
72
+ all_text += char[1]
73
+ all_coord.append(char[0])
74
+ # check for weird characters, e.g. "(cid:206)", "ff", "fi", etc
75
+ # if string isn't just 1 character, it's an irregular LTChar (character) from pdfminer.
76
+ # instead of skipping them, we can just create extra duplicate coordinates for the additional characters.
77
+ if len(char[1]) > 1:
78
+ bad_char_len = len(char[1])
79
+ dupe_coord_amt = (bad_char_len - 1)
80
+ for dupe_i in range(dupe_coord_amt):
81
+ all_coord.append(char[0])
82
+
83
+ pre_tokenizer = Whitespace()
84
+
85
+ sentences_pre_tok = spacey_sentences(all_text)
86
+ sentences = []
87
+ for sentence in sentences_pre_tok:
88
+ tokenized = pre_tokenizer.pre_tokenize_str(sentence)
89
+ sentences.append(tokenized)
90
+
91
+ sentences = add_coords(sentences, all_coord)
92
+
93
+ word_section = []
94
+ t = 0
95
+ for sentence in sentences:
96
+ t += len(sentence)
97
+ if t <= max_tokens:
98
+ word_section += sentence
99
+ else:
100
+ word_sections.append(word_section)
101
+ word_section = sentence
102
+ t = len(sentence)
103
+ word_sections.append(word_section)
104
+ sections = {'text_sections':text_sections, 'word_sections':word_sections}
105
+ sections_per_page[page_num] = sections
106
+ return sections_per_page
107
+
108
+ def format_output_contexts(sections_per_page):
109
+
110
+ all_contexts = {}
111
+
112
+ for page_idx in sections_per_page.keys():
113
+
114
+ text_sections = sections_per_page[page_idx]['text_sections']
115
+ word_sections = sections_per_page[page_idx]['word_sections']
116
+
117
+ for text_section, word_section in zip(text_sections, word_sections):
118
+ whitespaced_text = ' '.join([word[0] for word in word_section])
119
+ words_info = []
120
+ for word in word_section:
121
+ words_info.append({'word_text:':word[0], 'char_indices':word[1], 'word_bbox':word[2]})
122
+
123
+ context_row = {'text':text_section, 'whitespaced_text':whitespaced_text, 'page_idx':int(page_idx), 'words_info':words_info}
124
+ context_id = 'context_{0}'.format(len(all_contexts))
125
+ all_contexts[context_id] = context_row
126
+
127
+ return all_contexts
128
+
129
+ def get_contexts(json_input):
130
+ json_output = 'contexts_{0}'.format(json_input)
131
+ sections_per_page = sentence_extract(json_input)
132
+
133
+ all_contexts = format_output_contexts(sections_per_page)
134
+
135
+ with open(json_output, 'w', encoding='utf8') as json_out:
136
+ json.dump(all_contexts, json_out, ensure_ascii=False, indent=4)
LICENSE ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, August 2022
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2018, Lawrence Livermore National Security, LLC
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
202
+
203
+ ---- LLVM Exceptions to the Apache 2.0 License ----
204
+
205
+ As an exception, if, as a result of your compiling your source code, portions of this Software are embedded into an Object form of such source code, you may redistribute such embedded portions in such Object form without complying with the conditions of Sections 4(a), 4(b) and 4(d) of the License.
206
+
207
+ In addition, if you combine or link compiled forms of this Software with software that is licensed under the GPLv2 ("Combined Software") and if a court of competent jurisdiction determines that the patent provision (Section 3), the indemnity provision (Section 9) or other Section of the License conflicts with the conditions of the GPLv2, you may retroactively and prospectively choose to deem waived or otherwise exclude such Section(s) of the License, but only in their entirety and only with respect to the Combined Software.
NOTICE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This work was produced under the auspices of the U.S. Department of
2
+ Energy by Lawrence Livermore National Laboratory under Contract
3
+ DE-AC52-07NA27344.
4
+
5
+ This work was prepared as an account of work sponsored by an agency of
6
+ the United States Government. Neither the United States Government nor
7
+ Lawrence Livermore National Security, LLC, nor any of their employees
8
+ makes any warranty, expressed or implied, or assumes any legal liability
9
+ or responsibility for the accuracy, completeness, or usefulness of any
10
+ information, apparatus, product, or process disclosed, or represents that
11
+ its use would not infringe privately owned rights.
12
+
13
+ Reference herein to any specific commercial product, process, or service
14
+ by trade name, trademark, manufacturer, or otherwise does not necessarily
15
+ constitute or imply its endorsement, recommendation, or favoring by the
16
+ United States Government or Lawrence Livermore National Security, LLC.
17
+
18
+ The views and opinions of authors expressed herein do not necessarily
19
+ state or reflect those of the United States Government or Lawrence
20
+ Livermore National Security, LLC, and shall not be used for advertising
21
+ or product endorsement purposes.
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Detect Retrieve Comprehend
3
- emoji: 👀
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 3.1.7
8
  app_file: app.py
@@ -10,4 +10,14 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Detect Retrieve Comprehend
3
+ emoji: 📚
4
+ colorFrom: green
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.1.7
8
  app_file: app.py
 
10
  license: apache-2.0
11
  ---
12
 
13
+ # Release
14
+
15
+ ---
16
+
17
+ **Detect, Retrieve, Comprehend** is distributed under the terms of Apache 2.0 license with LLVM exception.
18
+
19
+ See [LICENSE]() and [NOTICE]() for details.
20
+
21
+ SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
22
+
23
+ LLNL-CODE-838964
UnifiedQA/demo_QA.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Lawrence Livermore National Security, LLC.
2
+ # All rights reserved.
3
+ # See the top-level LICENSE and NOTICE files for details.
4
+ # LLNL-CODE-838964
5
+
6
+ # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
7
+
8
+ import sys
9
+ import json
10
+ from math import ceil
11
+
12
+ import torch
13
+ import numpy as np
14
+ from torch import tensor
15
+ from torch.nn.functional import log_softmax
16
+ from torch.distributions.categorical import Categorical
17
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
18
+
19
+ # load UnifiedQA onto device
20
+ model_name = "allenai/unifiedqa-v2-t5-large-1363200"
21
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
22
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
23
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
24
+ model.to(device)
25
+
26
+ def get_inputs(contexts_json, ranked_contexts_json):
27
+ with open(contexts_json, 'rt') as fp:
28
+ contexts = json.load(fp)
29
+
30
+ with open(ranked_contexts_json, 'rt') as fp:
31
+ ranked_contexts = json.load(fp)
32
+
33
+ question_id = list(ranked_contexts.keys())[0]
34
+ # assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
35
+ question = ranked_contexts[question_id]['text']
36
+ context_ids_sorted = ranked_contexts[question_id]['ranks']
37
+ context_scores = ranked_contexts[question_id]['scores']
38
+ contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]
39
+
40
+ # returns the question (str) and its contexts (sequence)
41
+ return question, contexts, context_scores
42
+
43
+ def get_tokens(text, tokenizer, max_tokens):
44
+ return tokenizer.encode_plus(text, return_tensors='pt', max_length=max_tokens, padding='max_length', truncation=True)['input_ids']
45
+
46
+ def prepare_inputs(tokenizer, max_tokens, context, question):
47
+ input_str = f'{question} \\n {context}'
48
+ inputs = get_tokens(input_str, tokenizer, max_tokens)
49
+ return inputs
50
+
51
+ def get_outputs(model, tokenizer, input_tokens, max_tokens):
52
+ output_dict = model.generate(input_tokens, output_scores=True, return_dict_in_generate=True, **{'max_length': max_tokens})
53
+ pred_tokens = output_dict['sequences'].squeeze().tolist()
54
+
55
+ # initialize metrics
56
+ logit_entropy = []
57
+ sentence_probs = []
58
+
59
+ # accumulate metrics over logit_sequence
60
+ logit_sequence = output_dict['scores'][:-1] # discard end token
61
+ for logit in logit_sequence:
62
+ log_probs = log_softmax(logit, dim=-1)
63
+
64
+ # update metrics
65
+ logit_entropy.append(Categorical(log_probs.exp()).entropy())
66
+ sentence_probs.append(log_probs.max())
67
+
68
+ # finish metrics calculation
69
+ logit_entropy = tensor(logit_entropy)
70
+ sentence_probs = tensor(sentence_probs)
71
+ entropy = logit_entropy.mean()
72
+ sentence_std = 0 if len(logit_sequence) == 1 else sentence_probs.std(unbiased=True).exp()
73
+
74
+ # use entropy * sentence_std as uncertainty
75
+ uncertainty = (entropy * sentence_std).item()
76
+
77
+ # convert answer tokens to str
78
+ pred_str = tokenizer.decode(pred_tokens, skip_special_tokens=True).lower()
79
+
80
+ return pred_str, uncertainty
81
+
82
+ # k_percent: percentage of contexts to use, cannot be less than min_k or greater than max_k
83
+ # min_k: minimum number of contexts to use, if possible. Setting this too small reduces recall
84
+ # max_k: maximum number of contexts to use. Setting this too big reduces precision
85
+ # recommended uncertainty thresholds are 2,3,4, and 5. The lower the threshold, the more aggressive the filtering
86
+ def run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=0.1, min_k=10, max_k=25, uncertainty_thresh=3):
87
+ k = min(max(ceil(k_percent * len(contexts)), min_k), max_k)
88
+ contexts = contexts[:k]
89
+ context_scores = context_scores[:k]
90
+
91
+ # iterate through top-k contexts
92
+ answers = []
93
+ uncertainty = []
94
+ for context in contexts:
95
+ input_tokens = prepare_inputs(tokenizer, 512, context, question).to(device)
96
+ pred_str, uncertainty_1 = get_outputs(model, tokenizer, input_tokens, 512)
97
+ answers.append(pred_str)
98
+ uncertainty.append(uncertainty_1)
99
+
100
+ # contexts = np.array(contexts)
101
+ # answers = np.array(answers)
102
+ # uncertainty = np.array(uncertainty)
103
+
104
+ # sort by uncertainty, ascending order
105
+ # order = np.argsort(uncertainty)
106
+ # contexts = contexts[order]
107
+ # answers = answers[order]
108
+ # uncertainty = uncertainty[order]
109
+
110
+ # init lists for threshed answers
111
+ # weak_contexts = []
112
+ # weak_answers = []
113
+ # weak_uncertainty = []
114
+
115
+ # filter by uncertainty
116
+ # if len(answers) > min_k:
117
+ # weak = np.argwhere(uncertainty > uncertainty_thresh) # exceeds threshold
118
+ # weak_contexts = contexts[weak].tolist()
119
+ # weak_answers = answers[weak].tolist()
120
+ # weak_uncertainty = uncertainty[weak].tolist()
121
+
122
+ # strong = np.argwhere(uncertainty <= uncertainty_thresh) # within threshold
123
+ # contexts = contexts[strong]
124
+ # answers = answers[strong]
125
+ # uncertainty = uncertainty[strong]
126
+
127
+ # contexts = contexts.tolist()
128
+ # answers = answers.tolist()
129
+ # uncertainty = uncertainty.tolist()
130
+
131
+ # return {'contexts': contexts, 'answers': answers, 'uncertainty': uncertainty}, \
132
+ # {'contexts': weak_contexts, 'answers': weak_answers, 'uncertainty': weak_uncertainty}
133
+
134
+ return {'contexts': contexts, 'answers': answers, 'context_scores':context_scores, 'uncertainty': uncertainty}
135
+
136
+ def get_qa_results(contexts_json, ranked_contexts_json, topk):
137
+
138
+ # extract question and contexts from json
139
+ question, contexts, context_scores = get_inputs(contexts_json, ranked_contexts_json)
140
+
141
+ # infer answers
142
+ with torch.inference_mode(True):
143
+ # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
144
+ qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
145
+
146
+ return qa_results
147
+
148
+ def get_qa_results_in_memory(contexts, ranked_contexts, topk):
149
+
150
+ question_id = list(ranked_contexts.keys())[0]
151
+ # assert len(questions) == 1, f'JSON should only have 1 question but found {len(questions)}: {questions}'
152
+ question = ranked_contexts[question_id]['text']
153
+ context_ids_sorted = ranked_contexts[question_id]['ranks']
154
+ context_scores = ranked_contexts[question_id]['scores']
155
+ contexts = [contexts[context_id]['text'] for context_id in context_ids_sorted]
156
+
157
+ # infer answers
158
+ with torch.inference_mode(True):
159
+ # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
160
+ qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
161
+
162
+ return qa_results
163
+
164
+ def load_custom_model(finetuned_model_path):
165
+ global tokenizer
166
+ global model
167
+
168
+ # load UnifiedQA onto device
169
+ tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
170
+ model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path)
171
+ model.to(device)
172
+
173
+ def get_qa_results_in_memory_finetuned_unifiedqa(question, context_scores, contexts, topk):
174
+
175
+ # infer answers
176
+ with torch.inference_mode(True):
177
+ # strong_answers, weak_answers = run_model(model, tokenizer, device, question, contexts, k_percent=k_percent)
178
+ qa_results = run_model(model, tokenizer, device, question, contexts, context_scores, k_percent=1.0, min_k=1, max_k=topk)
179
+
180
+ return qa_results
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Lawrence Livermore National Security, LLC.
2
+ # All rights reserved.
3
+ # See the top-level LICENSE and NOTICE files for details.
4
+ # LLNL-CODE-838964
5
+
6
+ # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
7
+
8
+ import torch
9
+ import gradio as gr
10
+ from pathlib import Path
11
+
12
+ from torchvision.transforms import ToPILImage, ToTensor
13
+ tensor_to_image = ToPILImage()
14
+ image_to_tensor = ToTensor()
15
+
16
+ import sys
17
+ sys.path.append('DiT_Extractor/')
18
+ sys.path.append('CrossEncoder/')
19
+ sys.path.append('UnifiedQA/')
20
+
21
+ import dit_runner
22
+ import sentence_extractor
23
+ import cross_encoder
24
+ import demo_QA
25
+
26
+ from torchvision.transforms import ToPILImage
27
+ tensor_to_image = ToPILImage()
28
+
29
+ def run_fn(pdf_file_obj, question_text, input_topk):
30
+
31
+ pdf = pdf_file_obj.name
32
+ viz_images = dit_runner.get_dit_preds(pdf, score_threshold=0.5)
33
+ entity_json = '{0}.json'.format(Path(pdf).name[:-4])
34
+
35
+ sentence_extractor.get_contexts(entity_json)
36
+
37
+ contexts_json = 'contexts_{0}'.format(entity_json)
38
+ # contexts_json = 'contexts_2105u2iwiwxh.03011.json'
39
+
40
+ cross_encoder.get_ranked_contexts(contexts_json, question_text)
41
+
42
+ ranked_contexts_json = 'ranked_{0}'.format(contexts_json)
43
+ # ranked_contexts_json = 'ranked_contexts_2105u2iwiwxh.03011.json'
44
+
45
+ input_topk = int(input_topk)
46
+
47
+ # viz_images = [tensor_to_image(x) for x in torch.randn(4, 3, 256, 256)]
48
+
49
+ qa_results = demo_QA.get_qa_results(contexts_json, ranked_contexts_json, input_topk)
50
+
51
+ history = [('<<< [Retrieval Score: {0:.02f}] >>> {1}'.format(s, c), a) for c, s, a in zip(qa_results['contexts'], qa_results['context_scores'], qa_results['answers'])]
52
+
53
+ # Show in ascending order of score, since results box is already scrolled down.
54
+ history = history[::-1]
55
+
56
+ return viz_images, contexts_json, ranked_contexts_json, history
57
+
58
+ demo = gr.Blocks()
59
+
60
+ with demo:
61
+
62
+ gr.Markdown("<h1><center>Document-based Question Answering</center></h1>")
63
+ gr.Markdown("<center>This is a supplemental demo for our publication, [Document-based Question Answering](https://www.google.com). In this system, our input is a PDF file with a specific question of interest. The output is a set of most probable answers. There are 4 main components in our deployed pipeline: (1) DiT Layout Analysis (2) Context Extraction (3) Cross-Encoder Retrieval (4) UnifiedQA. See below for example uses with further explanation.</center>")
64
+
65
+ with gr.Row():
66
+ with gr.Column():
67
+ with gr.Row():
68
+ input_pdf_file = gr.File(file_count='single', label='PDF File')
69
+ with gr.Row():
70
+ input_question_text = gr.Textbox(label='Question')
71
+ with gr.Row():
72
+ input_k_percent = gr.Slider(minimum=1, maximum=24, step=1, value=8, label='Top K')
73
+ with gr.Row():
74
+ button_run = gr.Button('Run QA on Document')
75
+
76
+ gr.Markdown("<h3><center>Summary</center></h3>")
77
+ with gr.Row():
78
+ gr.Markdown('''
79
+ - <u>**DiT - Document Image Transformer**</u>: PDF -> converted into a list of images -> each image receives Entity Predictions
80
+ - Note that using this computer vision approach allows us to ignore things like *page numbers, footnotes, references*, etc
81
+ - <u>**Paragraph-based Text Extraction**</u>: DiT Bounding Boxes -> Convert into PDF-Space Coordinates -> Text Extraction using PDFMiner6 -> Tokenize & Sentence Split if tokenizer max length is exceeded
82
+ - <u>**CrossEncoder Context Retrieval**</u>: All Contexts + Question -> Top K Relevant Contexts best suited for answering question
83
+ - <u>**UnifiedQA**</u>: Most Relevant Contexts + Supplied Question -> Predict Set of Probable Answers
84
+ ''')
85
+
86
+ with gr.Column():
87
+ with gr.Row():
88
+ output_gallery = gr.Gallery(label='DiT Predicted Entities')
89
+ with gr.Row():
90
+ gr.Markdown('''
91
+ - The `DiT predicted Entities` output box is scrollable! Scroll to see different page predictions. Note that predictions with confidence scores < 0.5 are not passed forward for text extraction.
92
+ - If an image is clicked, the output box will switch to a gallery view. To view these outputs in much higher resolution, right-click and choose "open image in new tab"
93
+ ''')
94
+ with gr.Row():
95
+ output_contexts = gr.File(label='Detected Contexts', interactive=False)
96
+ output_ranked_contexts = gr.File(label='CrossEncoder Ranked Contexts', interactive=False)
97
+ with gr.Row():
98
+ output_qa_results = gr.Chatbot(color_map=['blue', 'green'], label='UnifiedQA Results').style()
99
+
100
+ gr.Markdown("<h3><center>Related Work & Code</center></h3>")
101
+ gr.Markdown("<center>DiT (Document Image Transformer) - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>")
102
+ gr.Markdown("<center>CrossEncoder - <a href=https://arxiv.org/abs/2203.02378>Arxiv Page</a> | <a href=https://github.com/microsoft/unilm/tree/master/dit>Github Repo</a></center>")
103
+ gr.Markdown("<center>UnifiedQA - <a href=https://arxiv.org/abs/2005.00700>Arxiv Page</a> | <a href=https://github.com/allenai/unifiedqa>Github Repo</a></center>")
104
+
105
+ button_run.click(fn=run_fn, inputs=[input_pdf_file, input_question_text, input_k_percent], outputs=[output_gallery, output_contexts, output_ranked_contexts, output_qa_results])
106
+
107
+ examples = [
108
+ ['examples/1909.00694.pdf', 'What is the seed lexicon?', 5],
109
+ ['examples/1909.00694.pdf', 'How big is seed lexicon used for training?', 5],
110
+ ['examples/1810.04805.pdf', 'What is this paper about?', 5],
111
+ ['examples/1810.04805.pdf', 'What is the model size?', 5],
112
+ ['examples/2105.03011.pdf', 'How many questions are in this dataset?', 5],
113
+ ['examples/1909.00694.pdf', 'How are relations used to propagate polarity?', 5],
114
+
115
+ ]
116
+ gr.Examples(examples=examples,
117
+ inputs=[input_pdf_file, input_question_text, input_k_percent])
118
+
119
+ # examples = gr.Dataset(components=[input_pdf_file, input_question_text], samples=[[open('examples/1810.04805.pdf', mode='rb'), 'How many parameters are in the model?']])
120
+ demo.launch(enable_queue=True)
env_setup.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conda create --name llnl_actici_env python=3.9
2
+ conda activate llnl_actici_env
3
+
4
+ conda install pytorch=1.10 torchvision torchaudio cudatoolkit=11.3 -c pytorch
5
+
6
+ # For DiT
7
+ python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
8
+
9
+
10
+ # For DiT
11
+ pip install opencv-python
12
+ pip install timm
13
+ pip install pdfminer.six
14
+ conda install -c conda-forge poppler
15
+ pip install pdf2image
16
+ pip install pypdf2
17
+ pip install spacy
18
+ # pytesseract, in case we need in future
19
+ pip install pytesseract
20
+
21
+ # For Retrieval & QA
22
+ pip install transformers==4.20
23
+ pip install sentence-transformers
24
+
25
+ # For Demo
26
+ pip install gradio
27
+
28
+ # If Jupyter is allowed
29
+ pip install jupyter
30
+
31
+ # (Optional, adding this custom env to the base environment's jupyter)
32
+ python -m ipykernel install --user --name llnl_actici_env --display-name "Python (llnl_actici_env)"
examples/1810.04805.pdf ADDED
Binary file (775 kB). View file
 
examples/1909.00694.pdf ADDED
Binary file (540 kB). View file
 
examples/2105.03011.pdf ADDED
Binary file (507 kB). View file
 
ms-marco-electra-base/CEBinaryClassificationEvaluator_MS-Marco_results.csv ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,steps,Accuracy,Accuracy_Threshold,F1,F1_Threshold,Precision,Recall,Average_Precision
2
+ 0,5000,0.9297070292970703,0.25256121158599854,0.8307839388145314,0.19771124422550201,0.7957875457875457,0.869,0.8904110467492587
3
+ 0,10000,0.939006099390061,0.5306986570358276,0.8460807600950118,0.28808051347732544,0.8058823529411765,0.8905,0.910544278892506
4
+ 0,15000,0.9393060693930607,0.5750397443771362,0.8566081871345029,0.48249387741088867,0.8048351648351648,0.9155,0.9132147986720082
5
+ 0,20000,0.9405059494050595,0.591253936290741,0.8546298558514537,0.570050835609436,0.8356426182513139,0.8745,0.9073685536522613
6
+ 0,25000,0.9436056394360564,0.5074090957641602,0.8603960396039605,0.5057582855224609,0.8519607843137255,0.869,0.9167379821993755
7
+ 0,30000,0.9396060393960604,0.8262588381767273,0.8542471042471043,0.7406325340270996,0.8255597014925373,0.885,0.8979176130668384
8
+ 0,35000,0.9425057494250575,0.46686679124832153,0.8596070915189268,0.28302955627441406,0.8252069917203312,0.897,0.9163289965092976
9
+ 0,40000,0.9417058294170583,0.6763133406639099,0.8575602629656682,0.6603987216949463,0.8357854769814903,0.8805,0.9173776247925393
10
+ 0,45000,0.9426057394260574,0.4643915295600891,0.8605042016806723,0.29147765040397644,0.8277136258660508,0.896,0.9120726077810245
11
+ 0,50000,0.945005499450055,0.5493776798248291,0.8624535315985131,0.4713650643825531,0.855036855036855,0.87,0.9209400105864155
12
+ 0,55000,0.9454054594540546,0.6156725287437439,0.864585893339887,0.5604670643806458,0.8501691638472693,0.8795,0.9206262233464874
13
+ 0,60000,0.9421057894210579,0.39554399251937866,0.8605827112930412,0.3811936378479004,0.8300046446818393,0.8935,0.9193948306076224
14
+ 0,65000,0.9428057194280572,0.5363738536834717,0.8629682313892841,0.32784485816955566,0.8205590622182146,0.91,0.9227492855045069
15
+ 0,70000,0.9438056194380562,0.38333064317703247,0.8628501827040195,0.3524332344532013,0.8413301662707838,0.8855,0.9236299441431376
16
+ 0,75000,0.9468053194680532,0.48936331272125244,0.8696717295443409,0.48936331272125244,0.8525456292026897,0.8875,0.9254413650794524
17
+ 0,80000,0.9454054594540546,0.3127445578575134,0.8651851851851852,0.3127445578575134,0.8546341463414634,0.876,0.9213706944185774
18
+ 0,85000,0.9443055694430557,0.31547677516937256,0.8655280250180418,0.21403872966766357,0.8340287436254057,0.8995,0.9237103419372517
19
+ 0,90000,0.9465053494650535,0.3857932686805725,0.8702401164200824,0.3761560022830963,0.8450306170513424,0.897,0.9258501989030058
20
+ 0,95000,0.9453054694530547,0.3604514002799988,0.8669713735867213,0.29048818349838257,0.8354195642095503,0.901,0.9226658871253511
21
+ 0,100000,0.9453054694530547,0.6748594045639038,0.8686288585786074,0.4552273154258728,0.8329508949059201,0.9075,0.9252677323330876
22
+ 0,105000,0.9435056494350565,0.40062007308006287,0.8639551192145862,0.1210024282336235,0.8112379280070237,0.924,0.9237990563267019
23
+ 0,110000,0.944905509449055,0.4197750985622406,0.8656429942418427,0.27975988388061523,0.8321033210332104,0.902,0.9247201058651281
24
+ 0,115000,0.9464053594640536,0.4172205924987793,0.8698167791706846,0.2961992919445038,0.839851024208566,0.902,0.927117403879296
25
+ 0,120000,0.9474052594740526,0.44686269760131836,0.8712047012732614,0.4383932948112488,0.8536468330134357,0.8895,0.9279628711835812
26
+ 0,125000,0.945005499450055,0.4358792304992676,0.8655339805825243,0.28539055585861206,0.8410377358490566,0.8915,0.9268525722856882
27
+ 0,130000,0.9462053794620537,0.21194982528686523,0.8703747911195989,0.16292141377925873,0.8328003654636821,0.9115,0.925512309638313
28
+ 0,135000,0.9454054594540546,0.2292814701795578,0.8678621991505427,0.11477036774158478,0.82171581769437,0.9195,0.9268551457216524
29
+ 0,140000,0.9482051794820517,0.31556186079978943,0.8758076094759513,0.26744428277015686,0.8398347865993575,0.915,0.9275073681003255
30
+ 0,145000,0.9478052194780522,0.3485147953033447,0.8719556305763203,0.12995882332324982,0.8421052631578947,0.904,0.9278250006342896
31
+ 0,150000,0.9483051694830517,0.32228657603263855,0.8726037369570493,0.21710461378097534,0.8477133427628477,0.899,0.9259328370035781
32
+ 0,155000,0.9474052594740526,0.1903868019580841,0.8731307284129282,0.18298938870429993,0.8434296365330848,0.905,0.9261096325445609
33
+ 0,160000,0.9473052694730527,0.5740681886672974,0.872194660996929,0.17134147882461548,0.8266905508284819,0.923,0.927973529121574
34
+ 0,165000,0.9495050494950505,0.38968273997306824,0.87591956841589,0.34622055292129517,0.8594802694898941,0.893,0.9241440163389828
35
+ 0,170000,0.9459054094590541,0.47478723526000977,0.8706669854171647,0.11328981816768646,0.8341731562070546,0.9105,0.9289979858500923
36
+ 0,175000,0.9473052694730527,0.5903739929199219,0.8703747911195989,0.15506823360919952,0.8328003654636821,0.9115,0.9305074303915251
37
+ 0,180000,0.9463053694630537,0.23235449194908142,0.8702585165498912,0.23235449194908142,0.841982234689107,0.9005,0.9291547676197442
38
+ 0,185000,0.9478052194780522,0.174373060464859,0.8734852157052836,0.171615868806839,0.8476011288805269,0.901,0.9280170204346545
39
+ 0,190000,0.949005099490051,0.5715193748474121,0.8747241971071341,0.5108739137649536,0.8581048581048581,0.892,0.9271410745170057
40
+ 0,195000,0.9461053894610539,0.5194154977798462,0.8679334916864608,0.170893132686615,0.8266968325791855,0.9135,0.9271023702066649
41
+ 0,200000,0.9468053194680532,0.3094758987426758,0.8707931277947754,0.11578939855098724,0.82258781680747,0.925,0.9290083868621436
42
+ 0,205000,0.9461053894610539,0.6028298139572144,0.8679067577113257,0.13052904605865479,0.8202047174009791,0.9215,0.9276186176796931
43
+ 0,210000,0.9459054094590541,0.49049288034439087,0.8694616484040019,0.16249723732471466,0.8303002729754322,0.9125,0.9285170114050436
ms-marco-electra-base/README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # Cross-Encoder for MS Marco
5
+
6
+ This model was trained on the [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
7
+
8
+ The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/ms_marco)
9
+
10
+
11
+ ## Usage with Transformers
12
+
13
+ ```python
14
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
15
+ import torch
16
+
17
+ model = AutoModelForSequenceClassification.from_pretrained('model_name')
18
+ tokenizer = AutoTokenizer.from_pretrained('model_name')
19
+
20
+ features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
21
+
22
+ model.eval()
23
+ with torch.no_grad():
24
+ scores = model(**features).logits
25
+ print(scores)
26
+ ```
27
+
28
+
29
+ ## Usage with SentenceTransformers
30
+
31
+ The usage becomes easier when you have [SentenceTransformers](https://www.sbert.net/) installed. Then, you can use the pre-trained models like this:
32
+ ```python
33
+ from sentence_transformers import CrossEncoder
34
+ model = CrossEncoder('model_name', max_length=512)
35
+ scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
36
+ ```
37
+
38
+
39
+ ## Performance
40
+ In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
41
+
42
+
43
+ | Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec |
44
+ | ------------- |:-------------| -----| --- |
45
+ | **Version 2 models** | | |
46
+ | cross-encoder/ms-marco-TinyBERT-L-2-v2 | 69.84 | 32.56 | 9000
47
+ | cross-encoder/ms-marco-MiniLM-L-2-v2 | 71.01 | 34.85 | 4100
48
+ | cross-encoder/ms-marco-MiniLM-L-4-v2 | 73.04 | 37.70 | 2500
49
+ | cross-encoder/ms-marco-MiniLM-L-6-v2 | 74.30 | 39.01 | 1800
50
+ | cross-encoder/ms-marco-MiniLM-L-12-v2 | 74.31 | 39.02 | 960
51
+ | **Version 1 models** | | |
52
+ | cross-encoder/ms-marco-TinyBERT-L-2 | 67.43 | 30.15 | 9000
53
+ | cross-encoder/ms-marco-TinyBERT-L-4 | 68.09 | 34.50 | 2900
54
+ | cross-encoder/ms-marco-TinyBERT-L-6 | 69.57 | 36.13 | 680
55
+ | cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340
56
+ | **Other models** | | |
57
+ | nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900
58
+ | nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340
59
+ | nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100
60
+ | Capreolus/electra-base-msmarco | 71.23 | 36.89 | 340
61
+ | amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | 35.54 | 330
62
+ | sebastian-hofstaetter/distilbert-cat-margin_mse-T2-msmarco | 72.82 | 37.88 | 720
63
+
64
+ Note: Runtime was computed on a V100 GPU.
ms-marco-electra-base/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/electra-base-discriminator",
3
+ "architectures": [
4
+ "ElectraForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "embedding_size": 768,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "LABEL_0"
13
+ },
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "label2id": {
17
+ "LABEL_0": 0
18
+ },
19
+ "layer_norm_eps": 1e-12,
20
+ "max_position_embeddings": 512,
21
+ "model_type": "electra",
22
+ "num_attention_heads": 12,
23
+ "num_hidden_layers": 12,
24
+ "pad_token_id": 0,
25
+ "summary_activation": "gelu",
26
+ "summary_last_dropout": 0.1,
27
+ "summary_type": "first",
28
+ "summary_use_proj": true,
29
+ "type_vocab_size": 2,
30
+ "vocab_size": 30522
31
+ }
ms-marco-electra-base/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c554473d61458bf2969566b1bb464eb280ef7de9cacb6ec787b4fe7f0a9a80d9
3
+ size 438022601
ms-marco-electra-base/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
ms-marco-electra-base/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "name_or_path": "google/electra-base-discriminator"}
ms-marco-electra-base/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.0
2
+ torchvision
3
+ opencv-python
4
+ timm
5
+ pdfminer.six
6
+ pdf2image
7
+ pypdf2
8
+ spacy
9
+ pytesseract
10
+ transformers==4.20
11
+ sentence-transformers
12
+ https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/detectron2-0.6%2Bcpu-cp38-cp38-linux_x86_64.whl
13
+ gradio