rizgiak commited on
Commit
737c08b
β€’
1 Parent(s): 81861fc

initial commit

Browse files
AUTHORS.rst ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ =======
2
+ Credits
3
+ =======
4
+
5
+ Development Lead
6
+ ----------------
7
+
8
+ * Full name of the author <Email of the author>
9
+
10
+ Contributors
11
+ ------------
12
+
13
+ None yet. Why not be the first?
README.md CHANGED
@@ -1,13 +1,23 @@
1
  ---
2
- title: Table Caption Extraction
3
- emoji: πŸŒ–
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: streamlit
7
- sdk_version: 1.31.1
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Table Extraction (Table Transformer + PaddleOCR)
3
+ emoji: πŸš€
4
+ colorFrom: blue
5
+ colorTo: yellow
6
  sdk: streamlit
7
+ sdk_version: 1.21.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
+ # huggingface-space
12
 
13
+ Imported from https://huggingface.co/spaces/jurgendn/table-extraction with some adjustment.
14
+
15
+ Current pipeline:
16
+
17
+ Table detection: https://huggingface.co/microsoft/table-transformer-detection
18
+
19
+ Table recognition: https://huggingface.co/microsoft/table-transformer-structure-recognition
20
+
21
+ OCR: https://github.com/pbcquoc/vietocr
22
+
23
+ OCR-new: https://github.com/PaddlePaddle/PaddleOCR
app.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import string
3
+ import random
4
+ from collections import Counter
5
+ from itertools import count, tee
6
+ import base64
7
+
8
+ import cv2
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import pandas as pd
12
+ import streamlit as st
13
+ import torch
14
+ from PIL import Image
15
+ from transformers import DetrImageProcessor, TableTransformerForObjectDetection
16
+ from paddleocr import PaddleOCR
17
+
18
+ ocr = PaddleOCR(use_angle_cls=True, lang="en", use_gpu=False, ocr_version='PP-OCRv3')
19
+
20
+ st.set_option('deprecation.showPyplotGlobalUse', False)
21
+ st.set_page_config(layout='wide')
22
+ st.title("Table Detection and Table Structure Recognition")
23
+ st.write(
24
+ "Implemented by MSFT team: https://github.com/microsoft/table-transformer")
25
+
26
+ table_detection_model = TableTransformerForObjectDetection.from_pretrained(
27
+ "microsoft/table-transformer-detection")
28
+
29
+ table_recognition_model = TableTransformerForObjectDetection.from_pretrained(
30
+ "microsoft/table-transformer-structure-recognition")
31
+
32
+ def reload_ocr(vlang):
33
+ global ocr
34
+ ocr = PaddleOCR(use_angle_cls=True, lang=vlang, use_gpu=False, ocr_version='PP-OCRv4')
35
+
36
+
37
+ def PIL_to_cv(pil_img):
38
+ return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
39
+
40
+
41
+ def cv_to_PIL(cv_img):
42
+ return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
43
+
44
+
45
+ async def pytess(cell_pil_img, threshold: float = 0.5):
46
+ cell_pil_img = TableExtractionPipeline.add_padding(pil_img=cell_pil_img, top=20, right=10, bottom=20, left=10, color=(255, 255, 255))
47
+ result = ocr.ocr(np.asarray(cell_pil_img), cls=True)[0]
48
+
49
+ #Debug
50
+ # filename = str(random.random())
51
+ # cell_pil_img.save("dump/" + filename + ".png")
52
+ # print(filename)
53
+ # print(result)
54
+
55
+ text = ""
56
+ if result != None:
57
+ txts = [line[1][0] for line in result]
58
+ text = " ".join(txts)
59
+ return text
60
+
61
+
62
+ def sharpen_image(pil_img):
63
+
64
+ img = PIL_to_cv(pil_img)
65
+ sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
66
+
67
+ sharpen = cv2.filter2D(img, -1, sharpen_kernel)
68
+ pil_img = cv_to_PIL(sharpen)
69
+ return pil_img
70
+
71
+
72
+ def uniquify(seq, suffs=count(1)):
73
+ """Make all the items unique by adding a suffix (1, 2, etc).
74
+ Credit: https://stackoverflow.com/questions/30650474/python-rename-duplicates-in-list-with-progressive-numbers-without-sorting-list
75
+ `seq` is mutable sequence of strings.
76
+ `suffs` is an optional alternative suffix iterable.
77
+ """
78
+ not_unique = [k for k, v in Counter(seq).items() if v > 1]
79
+
80
+ suff_gens = dict(zip(not_unique, tee(suffs, len(not_unique))))
81
+ for idx, s in enumerate(seq):
82
+ try:
83
+ suffix = str(next(suff_gens[s]))
84
+ except KeyError:
85
+ continue
86
+ else:
87
+ seq[idx] += suffix
88
+
89
+ return seq
90
+
91
+
92
+ def binarizeBlur_image(pil_img):
93
+ image = PIL_to_cv(pil_img)
94
+ thresh = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY_INV)[1]
95
+
96
+ result = cv2.GaussianBlur(thresh, (5, 5), 0)
97
+ result = 255 - result
98
+ return cv_to_PIL(result)
99
+
100
+
101
+ def td_postprocess(pil_img):
102
+ '''
103
+ Removes gray background from tables
104
+ '''
105
+ img = PIL_to_cv(pil_img)
106
+
107
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
108
+ mask = cv2.inRange(hsv, (0, 0, 100),
109
+ (255, 5, 255)) # (0, 0, 100), (255, 5, 255)
110
+ nzmask = cv2.inRange(hsv, (0, 0, 5),
111
+ (255, 255, 255)) # (0, 0, 5), (255, 255, 255))
112
+ nzmask = cv2.erode(nzmask, np.ones((3, 3))) # (3,3)
113
+ mask = mask & nzmask
114
+
115
+ new_img = img.copy()
116
+ new_img[np.where(mask)] = 255
117
+
118
+ return cv_to_PIL(new_img)
119
+
120
+
121
+ # def super_res(pil_img):
122
+ # # requires opencv-contrib-python installed without the opencv-python
123
+ # sr = dnn_superres.DnnSuperResImpl_create()
124
+ # image = PIL_to_cv(pil_img)
125
+ # model_path = "./LapSRN_x8.pb"
126
+ # model_name = model_path.split('/')[1].split('_')[0].lower()
127
+ # model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1])
128
+
129
+ # sr.readModel(model_path)
130
+ # sr.setModel(model_name, model_scale)
131
+ # final_img = sr.upsample(image)
132
+ # final_img = cv_to_PIL(final_img)
133
+
134
+ # return final_img
135
+
136
+
137
+ def table_detector(image, THRESHOLD_PROBA):
138
+ '''
139
+ Table detection using DEtect-object TRansformer pre-trained on 1 million tables
140
+
141
+ '''
142
+
143
+ feature_extractor = DetrImageProcessor(do_resize=True,
144
+ size=800,
145
+ max_size=800)
146
+ encoding = feature_extractor(image, return_tensors="pt")
147
+
148
+ with torch.no_grad():
149
+ outputs = table_detection_model(**encoding)
150
+
151
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
152
+ keep = probas.max(-1).values > THRESHOLD_PROBA
153
+
154
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
155
+ postprocessed_outputs = feature_extractor.post_process(
156
+ outputs, target_sizes)
157
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
158
+
159
+ return (probas[keep], bboxes_scaled)
160
+
161
+
162
+ def table_struct_recog(image, THRESHOLD_PROBA):
163
+ '''
164
+ Table structure recognition using DEtect-object TRansformer pre-trained on 1 million tables
165
+ '''
166
+
167
+ feature_extractor = DetrImageProcessor(do_resize=True,
168
+ size=1000,
169
+ max_size=1000)
170
+ encoding = feature_extractor(image, return_tensors="pt")
171
+
172
+ with torch.no_grad():
173
+ outputs = table_recognition_model(**encoding)
174
+
175
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
176
+ keep = probas.max(-1).values > THRESHOLD_PROBA
177
+
178
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
179
+ postprocessed_outputs = feature_extractor.post_process(
180
+ outputs, target_sizes)
181
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
182
+
183
+ return (probas[keep], bboxes_scaled)
184
+
185
+
186
+ class TableExtractionPipeline():
187
+
188
+ colors = ["red", "blue", "green", "yellow", "orange", "violet"]
189
+
190
+ # colors = ["red", "blue", "green", "red", "red", "red"]
191
+
192
+ @staticmethod
193
+ def add_padding(pil_img,
194
+ top,
195
+ right,
196
+ bottom,
197
+ left,
198
+ color=(255, 255, 255)):
199
+ '''
200
+ Image padding as part of TSR pre-processing to prevent missing table edges
201
+ '''
202
+ width, height = pil_img.size
203
+ new_width = width + right + left
204
+ new_height = height + top + bottom
205
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
206
+ result.paste(pil_img, (left, top))
207
+ return result
208
+
209
+ @staticmethod
210
+ def dynamic_delta(xmin, ymin, xmax, ymax, delta_xmin, delta_ymin, delta_xmax, delta_ymax, pil_img):
211
+ offset_x = (xmax - xmin) * 0.05
212
+ offset_y = (ymax - ymin) * 0.05
213
+
214
+ w_img, h_img = pil_img.size
215
+
216
+ doxmin = xmin - (delta_xmin + offset_x)
217
+ if (doxmin < 0):
218
+ doxmin = 0
219
+
220
+ doymin = ymin - (delta_ymin + offset_y)
221
+ if (doymin < 0):
222
+ doymin = 0
223
+
224
+ doxmax = xmax + (delta_xmax + offset_x)
225
+ if (doxmax > w_img):
226
+ doxmax = w_img
227
+
228
+ doymax = ymax + (delta_ymax + offset_y)
229
+ if (doymax > h_img):
230
+ doymax = h_img
231
+
232
+
233
+ return doxmin, doymin, doxmax, doymax
234
+
235
+ @staticmethod
236
+ def get_cxy(pil_img, xmin, ymin, xmax, ymax, offset):
237
+ '''
238
+ get the possible position of table caption
239
+ '''
240
+ w_img, h_img = pil_img.size
241
+ c_xmin = xmin
242
+ c_xmax = xmax
243
+
244
+ delta_x = xmax - xmin
245
+
246
+ if delta_x / w_img > 0.5: # full page
247
+ c_xmin = 0
248
+ c_xmax = w_img
249
+ else:
250
+ cx_dist = c_xmax-c_xmin
251
+ delta_dist = w_img * 0.4 #0.4 is from 0.5 assumed that paper has padding around 0.1 of total width. In assumption that the paper is 2 column
252
+ print("cx_dist: " + str(cx_dist))
253
+ print("delta_dist: " + str(delta_dist))
254
+ if cx_dist < delta_dist:
255
+ d_off = int((delta_dist - cx_dist) / 2)
256
+ print("d_off: " + str(d_off))
257
+ c_xmin = c_xmin - d_off
258
+ if c_xmin < 0:
259
+ c_xmin = 0
260
+ c_xmax = c_xmax + d_off
261
+ if c_xmax > w_img:
262
+ c_xmax = w_img
263
+
264
+
265
+ if offset < 0:
266
+ c_ymin = ymin + offset
267
+ c_ymax = ymin
268
+ if c_ymin < 0:
269
+ c_ymin = 0
270
+ else:
271
+ c_ymin = ymax
272
+ c_ymax = ymax + offset
273
+ if c_ymax > h_img:
274
+ c_ymax = h_img
275
+
276
+ return c_xmin, c_ymin, c_xmax, c_ymax
277
+
278
+ def plot_results_detection(self, c1, model, pil_img, prob, boxes,
279
+ delta_xmin, delta_ymin, delta_xmax, delta_ymax):
280
+ '''
281
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
282
+ '''
283
+ # st.write('img_obj')
284
+ # st.write(pil_img)
285
+ plt.imshow(pil_img)
286
+ ax = plt.gca()
287
+
288
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
289
+ cl = p.argmax()
290
+ xmin, ymin, xmax, ymax = self.dynamic_delta(xmin, ymin, xmax, ymax, delta_xmin, delta_ymin, delta_xmax, delta_ymax, pil_img)
291
+ ax.add_patch(
292
+ plt.Rectangle((xmin, ymin),
293
+ xmax - xmin,
294
+ ymax - ymin,
295
+ fill=False,
296
+ color='red',
297
+ linewidth=3))
298
+ text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
299
+ ax.text(xmin - 20,
300
+ ymin - 50,
301
+ text,
302
+ fontsize=10,
303
+ bbox=dict(facecolor='yellow', alpha=0.5))
304
+
305
+ # Caption possibility (bottom)
306
+ offset = 200
307
+ c_xmin, c_ymin, c_xmax, c_ymax = self.get_cxy(pil_img, xmin, ymin, xmax, ymax, offset)
308
+
309
+ ax.add_patch(
310
+ plt.Rectangle((c_xmin, c_ymin),
311
+ c_xmax - c_xmin,
312
+ c_ymax - c_ymin,
313
+ fill=False,
314
+ color='blue',
315
+ linewidth=1))
316
+
317
+ # Caption possibility (top)
318
+ offset = -200
319
+ c_xmin, c_ymin, c_xmax, c_ymax = self.get_cxy(pil_img, xmin, ymin, xmax, ymax, offset)
320
+
321
+ ax.add_patch(
322
+ plt.Rectangle((c_xmin, c_ymin),
323
+ c_xmax - c_xmin,
324
+ c_ymax - c_ymin,
325
+ fill=False,
326
+ color='green',
327
+ linewidth=1))
328
+
329
+ plt.axis('off')
330
+ c1.pyplot()
331
+
332
+ def crop_tables(self, pil_img, prob, boxes, delta_xmin, delta_ymin,
333
+ delta_xmax, delta_ymax):
334
+ '''
335
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
336
+ '''
337
+ cropped_img_list = []
338
+
339
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
340
+ xmin, ymin, xmax, ymax = self.dynamic_delta(xmin, ymin, xmax, ymax, delta_xmin, delta_ymin, delta_xmax, delta_ymax, pil_img)
341
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
342
+ cropped_img_list.append(cropped_img)
343
+
344
+ return cropped_img_list
345
+
346
+ def crop_caption(self, pil_img, prob, boxes, delta_xmin, delta_ymin,
347
+ delta_xmax, delta_ymax):
348
+ '''
349
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
350
+ '''
351
+
352
+ cropped_caption_list = []
353
+
354
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
355
+ xmin, ymin, xmax, ymax = self.dynamic_delta(xmin, ymin, xmax, ymax, delta_xmin, delta_ymin, delta_xmax, delta_ymax, pil_img)
356
+
357
+ # Caption possibility (top)
358
+ offset = -200
359
+ c_xmin, c_ymin, c_xmax, c_ymax = self.get_cxy(pil_img, xmin, ymin, xmax, ymax, offset)
360
+ cropped_caption = pil_img.crop((c_xmin, c_ymin, c_xmax, c_ymax))
361
+ cropped_caption_list.append(cropped_caption)
362
+
363
+ # Caption possibility (bottom)
364
+ offset = 200
365
+ c_xmin, c_ymin, c_xmax, c_ymax = self.get_cxy(pil_img, xmin, ymin, xmax, ymax, offset)
366
+ cropped_caption = pil_img.crop((c_xmin, c_ymin, c_xmax, c_ymax))
367
+ cropped_caption_list.append(cropped_caption)
368
+
369
+ return cropped_caption_list
370
+
371
+ def generate_structure(self, c2, model, pil_img, prob, boxes,
372
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom):
373
+ '''
374
+ Co-ordinates are adjusted here by 3 'pixels'
375
+ To plot table pillow image and the TSR bounding boxes on the table
376
+ '''
377
+ # st.write('img_obj')
378
+ # st.write(pil_img)
379
+ plt.figure(figsize=(32, 20))
380
+ plt.imshow(pil_img)
381
+ ax = plt.gca()
382
+ rows = {}
383
+ cols = {}
384
+ idx = 0
385
+
386
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
387
+
388
+ xmin, ymin, xmax, ymax = xmin, ymin, xmax, ymax
389
+ cl = p.argmax()
390
+ class_text = model.config.id2label[cl.item()]
391
+ text = f'{class_text}: {p[cl]:0.2f}'
392
+ # or (class_text == 'table column')
393
+ if (class_text
394
+ == 'table row') or (class_text
395
+ == 'table projected row header') or (
396
+ class_text == 'table column'):
397
+ ax.add_patch(
398
+ plt.Rectangle((xmin, ymin),
399
+ xmax - xmin,
400
+ ymax - ymin,
401
+ fill=False,
402
+ color=self.colors[cl.item()],
403
+ linewidth=2))
404
+ ax.text(xmin - 10,
405
+ ymin - 10,
406
+ text,
407
+ fontsize=5,
408
+ bbox=dict(facecolor='yellow', alpha=0.5))
409
+
410
+ if class_text == 'table row':
411
+ rows['table row.' +
412
+ str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
413
+ ymax + expand_rowcol_bbox_bottom)
414
+ if class_text == 'table column':
415
+ cols['table column.' +
416
+ str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
417
+ ymax + expand_rowcol_bbox_bottom)
418
+
419
+ idx += 1
420
+
421
+ plt.axis('on')
422
+ c2.pyplot()
423
+ return rows, cols
424
+
425
+ def sort_table_featuresv2(self, rows: dict, cols: dict):
426
+ # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
427
+ rows_ = {
428
+ table_feature: (xmin, ymin, xmax, ymax)
429
+ for table_feature, (
430
+ xmin, ymin, xmax,
431
+ ymax) in sorted(rows.items(), key=lambda tup: tup[1][1])
432
+ }
433
+ cols_ = {
434
+ table_feature: (xmin, ymin, xmax, ymax)
435
+ for table_feature, (
436
+ xmin, ymin, xmax,
437
+ ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])
438
+ }
439
+
440
+ return rows_, cols_
441
+
442
+ def individual_table_featuresv2(self, pil_img, rows: dict, cols: dict):
443
+
444
+ for k, v in rows.items():
445
+ xmin, ymin, xmax, ymax = v
446
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
447
+ rows[k] = xmin, ymin, xmax, ymax, cropped_img
448
+
449
+ for k, v in cols.items():
450
+ xmin, ymin, xmax, ymax = v
451
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
452
+ cols[k] = xmin, ymin, xmax, ymax, cropped_img
453
+
454
+ return rows, cols
455
+
456
+ def object_to_cellsv2(self, master_row: dict, cols: dict,
457
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom,
458
+ padd_left):
459
+ '''Removes redundant bbox for rows&columns and divides each row into cells from columns
460
+ Args:
461
+
462
+ Returns:
463
+
464
+
465
+ '''
466
+ cells_img = {}
467
+ header_idx = 0
468
+ row_idx = 0
469
+ previous_xmax_col = 0
470
+ new_cols = {}
471
+ new_master_row = {}
472
+ previous_ymin_row = 0
473
+ new_cols = cols
474
+ new_master_row = master_row
475
+ ## Below 2 for loops remove redundant bounding boxes ###
476
+ # for k_col, v_col in cols.items():
477
+ # xmin_col, _, xmax_col, _, col_img = v_col
478
+ # if (np.isclose(previous_xmax_col, xmax_col, atol=5)) or (xmin_col >= xmax_col):
479
+ # print('Found a column with double bbox')
480
+ # continue
481
+ # previous_xmax_col = xmax_col
482
+ # new_cols[k_col] = v_col
483
+
484
+ # for k_row, v_row in master_row.items():
485
+ # _, ymin_row, _, ymax_row, row_img = v_row
486
+ # if (np.isclose(previous_ymin_row, ymin_row, atol=5)) or (ymin_row >= ymax_row):
487
+ # print('Found a row with double bbox')
488
+ # continue
489
+ # previous_ymin_row = ymin_row
490
+ # new_master_row[k_row] = v_row
491
+ ######################################################
492
+ for k_row, v_row in new_master_row.items():
493
+
494
+ _, _, _, _, row_img = v_row
495
+ xmax, ymax = row_img.size
496
+ xa, ya, xb, yb = 0, 0, 0, ymax
497
+ row_img_list = []
498
+ # plt.imshow(row_img)
499
+ # st.pyplot()
500
+ for idx, kv in enumerate(new_cols.items()):
501
+ k_col, v_col = kv
502
+ xmin_col, _, xmax_col, _, col_img = v_col
503
+ xmin_col, xmax_col = xmin_col - padd_left - 10, xmax_col - padd_left
504
+ xa = xmin_col
505
+ xb = xmax_col
506
+ if idx == 0:
507
+ xa = 0
508
+ if idx == len(new_cols) - 1:
509
+ xb = xmax
510
+ xa, ya, xb, yb = xa, ya, xb, yb
511
+
512
+ row_img_cropped = row_img.crop((xa, ya, xb, yb))
513
+ row_img_list.append(row_img_cropped)
514
+
515
+ cells_img[k_row + '.' + str(row_idx)] = row_img_list
516
+ row_idx += 1
517
+
518
+ return cells_img, len(new_cols), len(new_master_row) - 1
519
+
520
+ def clean_dataframe(self, df):
521
+ '''
522
+ Remove irrelevant symbols that appear with tesseractOCR
523
+ '''
524
+ # df.columns = [col.replace('|', '') for col in df.columns]
525
+
526
+ for col in df.columns:
527
+
528
+ df[col] = df[col].str.replace("'", '', regex=True)
529
+ df[col] = df[col].str.replace('"', '', regex=True)
530
+ df[col] = df[col].str.replace(']', '', regex=True)
531
+ df[col] = df[col].str.replace('[', '', regex=True)
532
+ df[col] = df[col].str.replace('{', '', regex=True)
533
+ df[col] = df[col].str.replace('}', '', regex=True)
534
+ return df
535
+
536
+ @st.cache
537
+ def convert_df(self, df):
538
+ csv = df.to_csv(index=False, encoding='utf-8-sig') # utf-8-sig to handle BOM for Excel
539
+ return csv.encode('utf-8')
540
+
541
+ def create_dataframe(self, c3, cell_ocr_res: list, max_cols: int,
542
+ max_rows: int):
543
+ '''Create dataframe using list of cell values of the table, also checks for valid header of dataframe
544
+ Args:
545
+ cell_ocr_res: list of strings, each element representing a cell in a table
546
+ max_cols, max_rows: number of columns and rows
547
+ Returns:
548
+ dataframe : final dataframe after all pre-processing
549
+ '''
550
+
551
+ headers = cell_ocr_res[:max_cols]
552
+ new_headers = uniquify(headers,
553
+ (f' {x!s}' for x in string.ascii_lowercase))
554
+ counter = 0
555
+
556
+ cells_list = cell_ocr_res[max_cols:]
557
+ df = pd.DataFrame("", index=range(0, max_rows), columns=new_headers)
558
+
559
+ cell_idx = 0
560
+ for nrows in range(max_rows):
561
+ for ncols in range(max_cols):
562
+ df.iat[nrows, ncols] = str(cells_list[cell_idx])
563
+ cell_idx += 1
564
+
565
+ ## To check if there are duplicate headers if result of uniquify+col == col
566
+ ## This check removes headers when all headers are empty or if median of header word count is less than 6
567
+ for x, col in zip(string.ascii_lowercase, new_headers):
568
+ if f' {x!s}' == col:
569
+ counter += 1
570
+ header_char_count = [len(col) for col in new_headers]
571
+
572
+ # if (counter == len(new_headers)) or (statistics.median(header_char_count) < 6):
573
+ # st.write('woooot')
574
+ # df.columns = uniquify(df.iloc[0], (f' {x!s}' for x in string.ascii_lowercase))
575
+ # df = df.iloc[1:,:]
576
+
577
+ df = self.clean_dataframe(df)
578
+
579
+ c3.dataframe(df)
580
+ csv = self.convert_df(df)
581
+
582
+ try:
583
+ numkey = str(df.iloc[0, 0])
584
+ except IndexError:
585
+ numkey = str(0)
586
+
587
+ # Create a download link with filename and extension
588
+ filename = f"table_{numkey}.csv" # Adjust the filename as needed
589
+ b64_csv = base64.b64encode(csv).decode() # Encode CSV data to base64
590
+ href = f'<a href="data:file/csv;base64,{b64_csv}" download="{filename}">Download {filename}</a>'
591
+ c3.markdown(href, unsafe_allow_html=True)
592
+
593
+ return df
594
+
595
+ async def start_process(self, image_path: str, TD_THRESHOLD, TSR_THRESHOLD,
596
+ OCR_THRESHOLD, padd_top, padd_left, padd_bottom,
597
+ padd_right, delta_xmin, delta_ymin, delta_xmax,
598
+ delta_ymax, expand_rowcol_bbox_top,
599
+ expand_rowcol_bbox_bottom):
600
+ '''
601
+ Initiates process of generating pandas dataframes from raw pdf-page images
602
+
603
+ '''
604
+ image = Image.open(image_path).convert("RGB")
605
+ probas, bboxes_scaled = table_detector(image,
606
+ THRESHOLD_PROBA=TD_THRESHOLD)
607
+
608
+ if bboxes_scaled.nelement() == 0:
609
+ st.write('No table found in the pdf-page image')
610
+ return ''
611
+
612
+ # try:
613
+ # st.write('Document: '+image_path.split('/')[-1])
614
+ c1, c2, c3 = st.columns((1, 1, 1))
615
+
616
+ self.plot_results_detection(c1, table_detection_model, image, probas,
617
+ bboxes_scaled, delta_xmin, delta_ymin,
618
+ delta_xmax, delta_ymax)
619
+ cropped_img_list = self.crop_tables(image, probas, bboxes_scaled,
620
+ delta_xmin, delta_ymin, delta_xmax,
621
+ delta_ymax)
622
+
623
+ cropped_caption_list = self.crop_caption(image, probas, bboxes_scaled,
624
+ delta_xmin, delta_ymin, delta_xmax,
625
+ delta_ymax)
626
+
627
+ # for p, (xmin, ymin, xmax, ymax) in zip(probas, bboxes_scaled.tolist()):
628
+ # print(p.argmax())
629
+ # print(xmin, ymin, xmax, ymax)
630
+
631
+ sequential_caption_img_list = []
632
+ for idx, caption_img in enumerate(cropped_caption_list):
633
+ if idx%2 == 0: # top
634
+ print("top")
635
+ else: # bottom
636
+ print("bottom")
637
+ plt.imshow(caption_img)
638
+ c2.pyplot()
639
+ sequential_caption_img_list.append(pytess(cell_pil_img=caption_img, threshold=OCR_THRESHOLD))
640
+
641
+ caption_ocr_res = await asyncio.gather(*sequential_caption_img_list)
642
+ flag_caption_pos = 0 # 0=top, 1=bottom
643
+ for idx, caption_text in enumerate(caption_ocr_res):
644
+ if caption_text == "" or "table" not in caption_text.lower():
645
+ if idx%2==0:
646
+ flag_caption_pos=1
647
+ break
648
+
649
+ for idx, caption_text in enumerate(caption_ocr_res):
650
+ if idx%2==flag_caption_pos:
651
+ c3.text(str(idx) + "_" + caption_text)
652
+
653
+
654
+ # for idx, unpadded_table in enumerate(cropped_img_list):
655
+
656
+ # table = self.add_padding(unpadded_table, padd_top, padd_right,
657
+ # padd_bottom, padd_left)
658
+ # # table = super_res(table)
659
+ # # table = binarizeBlur_image(table)
660
+ # # table = sharpen_image(table) # Test sharpen image next
661
+ # # table = td_postprocess(table)
662
+
663
+ # # table.save("result"+str(idx)+".png")
664
+
665
+ # probas, bboxes_scaled = table_struct_recog(
666
+ # table, THRESHOLD_PROBA=TSR_THRESHOLD)
667
+ # rows, cols = self.generate_structure(c2, table_recognition_model,
668
+ # table, probas, bboxes_scaled,
669
+ # expand_rowcol_bbox_top,
670
+ # expand_rowcol_bbox_bottom)
671
+ # # st.write(len(rows), len(cols))
672
+ # rows, cols = self.sort_table_featuresv2(rows, cols)
673
+ # master_row, cols = self.individual_table_featuresv2(
674
+ # table, rows, cols)
675
+
676
+ # cells_img, max_cols, max_rows = self.object_to_cellsv2(
677
+ # master_row, cols, expand_rowcol_bbox_top,
678
+ # expand_rowcol_bbox_bottom, padd_left)
679
+
680
+ # sequential_cell_img_list = []
681
+ # for k, img_list in cells_img.items():
682
+ # for img in img_list:
683
+ # # img = super_res(img)
684
+ # # img = sharpen_image(img) # Test sharpen image next
685
+ # # img = binarizeBlur_image(img)
686
+ # # img = self.add_padding(img, 10,10,10,10)
687
+ # # plt.imshow(img)
688
+ # # c3.pyplot()
689
+ # sequential_cell_img_list.append(
690
+ # pytess(cell_pil_img=img, threshold=OCR_THRESHOLD))
691
+
692
+ # cell_ocr_res = await asyncio.gather(*sequential_cell_img_list)
693
+
694
+ # self.create_dataframe(c3, cell_ocr_res, max_cols, max_rows)
695
+ # st.write(
696
+ # 'Errors in OCR is due to either quality of the image or performance of the OCR'
697
+ # )
698
+ # except:
699
+ # st.write('Either incorrectly identified table or no table, to debug remove try/except')
700
+ # break
701
+ # break
702
+
703
+
704
+ if __name__ == "__main__":
705
+
706
+ st_up, st_lang = st.columns((1, 1))
707
+ img_name = st_up.file_uploader("Upload an image with table(s)")
708
+ lang = st_lang.selectbox('Language', ('en', 'japan'))
709
+ reload_ocr(lang)
710
+
711
+ st1, st2, st3 = st.columns((1, 1, 1))
712
+ TD_th = st1.slider('Table detection threshold', 0.0, 1.0, 0.8)
713
+ TSR_th = st2.slider('Table structure recognition threshold', 0.0, 1.0, 0.7)
714
+ OCR_th = st3.slider("Text Probs Threshold", 0.0, 1.0, 0.5)
715
+
716
+ st1, st2, st3, st4 = st.columns((1, 1, 1, 1))
717
+
718
+ padd_top = st1.slider('Padding top', 0, 200, 90)
719
+ padd_left = st2.slider('Padding left', 0, 200, 40)
720
+ padd_right = st3.slider('Padding right', 0, 200, 40)
721
+ padd_bottom = st4.slider('Padding bottom', 0, 200, 90)
722
+
723
+ te = TableExtractionPipeline()
724
+ # for img in image_list:
725
+ if img_name is not None:
726
+ asyncio.run(
727
+ te.start_process(img_name,
728
+ TD_THRESHOLD=TD_th,
729
+ TSR_THRESHOLD=TSR_th,
730
+ OCR_THRESHOLD=OCR_th,
731
+ padd_top=padd_top,
732
+ padd_left=padd_left,
733
+ padd_bottom=padd_bottom,
734
+ padd_right=padd_right,
735
+ delta_xmin=10, # add offset to the left of the table
736
+ delta_ymin=3, # add offset to the bottom of the table
737
+ delta_xmax=10, # add offset to the right of the table
738
+ delta_ymax=3, # add offset to the top of the table
739
+ expand_rowcol_bbox_top=0,
740
+ expand_rowcol_bbox_bottom=0))
components/callbacks.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Define callbacks here
2
+ from pytorch_lightning.callbacks import EarlyStopping
3
+
4
+ early_stopping = EarlyStopping(monitor="loss", min_delta=0, patience=3)
components/data_module.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Union
2
+
3
+ import torch
4
+ from pytorch_lightning import LightningDataModule
5
+ from torch.utils.data import DataLoader, Dataset
6
+
7
+
8
+ class SampleDataset(Dataset):
9
+
10
+ def __init__(self,
11
+ x: Union[List, torch.Tensor],
12
+ y: Union[List, torch.Tensor],
13
+ transforms: Optional[Callable] = None) -> None:
14
+ super(SampleDataset, self).__init__()
15
+ self.x = x
16
+ self.y = y
17
+
18
+ if transforms is None:
19
+ # Replace None with some default transforms
20
+ # If image, could be an Resize and ToTensor
21
+ self.transforms = lambda x: x
22
+ else:
23
+ self.transforms = transforms
24
+
25
+ def __len__(self):
26
+ return len(self.x)
27
+
28
+ def __getitem__(self, index: int):
29
+ x = self.x[index]
30
+ y = self.y[index]
31
+
32
+ x = self.transforms(x)
33
+ return x, y
34
+
35
+
36
+ class SampleDataModule(LightningDataModule):
37
+
38
+ def __init__(self,
39
+ x: Union[List, torch.Tensor],
40
+ y: Union[List, torch.Tensor],
41
+ transforms: Optional[Callable] = None,
42
+ val_ratio: float = 0,
43
+ batch_size: int = 32) -> None:
44
+ super(SampleDataModule, self).__init__()
45
+ assert 0 <= val_ratio < 1
46
+ assert isinstance(batch_size, int)
47
+ self.x = x
48
+ self.y = y
49
+
50
+ self.transforms = transforms
51
+ self.val_ratio = val_ratio
52
+ self.batch_size = batch_size
53
+
54
+ self.setup()
55
+ self.prepare_data()
56
+
57
+ def setup(self, stage: Optional[str] = None) -> None:
58
+ pass
59
+
60
+ def prepare_data(self) -> None:
61
+ n_samples: int = len(self.x)
62
+ train_size: int = n_samples - int(n_samples * self.val_ratio)
63
+
64
+ self.train_dataset = SampleDataset(x=self.x[:train_size],
65
+ y=self.y[:train_size],
66
+ transforms=self.transforms)
67
+ if train_size < n_samples:
68
+ self.val_dataset = SampleDataset(x=self.x[train_size:],
69
+ y=self.y[train_size:],
70
+ transforms=self.transforms)
71
+ else:
72
+ self.val_dataset = SampleDataset(x=self.x[-self.batch_size:],
73
+ y=self.y[-self.batch_size:],
74
+ transforms=self.transforms)
75
+
76
+ def train_dataloader(self) -> DataLoader:
77
+ return DataLoader(dataset=self.train_dataset,
78
+ batch_size=self.batch_size)
79
+
80
+ def val_dataloader(self) -> DataLoader:
81
+ return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size)
config.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from dynaconf import Dynaconf
2
+
3
+ CFG = Dynaconf(envvar_prefix="DYNACONF", settings_files=["config/config.yaml"])
config/config.yaml ADDED
File without changes
data/.gitkeep ADDED
File without changes
docker-compose.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.7"
2
+
3
+ services:
4
+ model_name:
5
+ build:
6
+ context: .
7
+ dockerfile: .docker/Dockerfile
8
+ container_name: model_name
9
+ ports:
10
+ - "8996:8996"
11
+ env_file:
12
+ - ./.env
13
+ volumes:
14
+ - ./data:/home/working/data:ro
15
+
16
+ # This part is used to enable GPU support
17
+ deploy:
18
+ resources:
19
+ reservations:
20
+ devices:
21
+ - driver: nvidia
22
+ count: 1
23
+ capabilities: [ gpu ]
models/__init__.py ADDED
File without changes
models/base_model/classification.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Dict, List
3
+
4
+ import torch
5
+ from pytorch_lightning import LightningModule
6
+ from torch import Tensor
7
+
8
+
9
+ class LightningClassification(LightningModule):
10
+
11
+ @abstractmethod
12
+ def __init__(self, *args, **kwargs) -> None:
13
+ super(LightningClassification, self).__init__(*args, **kwargs)
14
+ self.train_batch_output: List[Dict] = []
15
+ self.validation_batch_output: List[Dict] = []
16
+ self.log_value_list: List[str] = ['loss', 'f1', 'precision', 'recall']
17
+
18
+ @abstractmethod
19
+ def forward(self, *args, **kwargs) -> Any:
20
+ pass
21
+
22
+ @abstractmethod
23
+ def configure_optimizers(self):
24
+ pass
25
+
26
+ @abstractmethod
27
+ def loss(self, input: Tensor, target: Tensor, **kwargs) -> Tensor:
28
+ pass
29
+
30
+ @abstractmethod
31
+ def training_step(self, batch, batch_idx):
32
+ pass
33
+
34
+ def __average(self, key: str, outputs: List[Dict]) -> Tensor:
35
+ target_arr = torch.Tensor([val[key] for val in outputs]).float()
36
+ return target_arr.mean()
37
+
38
+ @torch.no_grad()
39
+ def on_train_epoch_start(self) -> None:
40
+ self.train_batch_output = []
41
+
42
+ @torch.no_grad()
43
+ def on_train_epoch_end(self) -> None:
44
+ for key in self.log_value_list:
45
+ val = self.__average(key=key, outputs=self.train_batch_output)
46
+ log_name = f"training/{key}"
47
+ self.log(name=log_name, value=val)
48
+
49
+ @abstractmethod
50
+ @torch.no_grad()
51
+ def validation_step(self, batch, batch_idx):
52
+ pass
53
+
54
+ @torch.no_grad()
55
+ def on_validation_epoch_start(self) -> None:
56
+ self.validation_batch_output = []
57
+
58
+ @torch.no_grad()
59
+ def on_validation_epoch_end(self) -> None:
60
+ for key in self.log_value_list:
61
+ val = self.__average(key=key, outputs=self.validation_batch_output)
62
+ log_name = f"val/{key}"
63
+ self.log(name=log_name, value=val)
models/base_model/gan.py ADDED
File without changes
models/base_model/regression.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Dict, List
3
+
4
+ import torch
5
+ from pytorch_lightning import LightningModule
6
+ from torch import Tensor
7
+
8
+
9
+ class LightningRegression(LightningModule):
10
+
11
+ @abstractmethod
12
+ def __init__(self, *args, **kwargs) -> None:
13
+ super(LightningRegression, self).__init__(*args, **kwargs)
14
+ self.train_step_output: List[Dict] = []
15
+ self.validation_step_output: List[Dict] = []
16
+ self.log_value_list: List[str] = ['loss', 'mse', 'mape']
17
+
18
+ @abstractmethod
19
+ def forward(self, *args, **kwargs) -> Any:
20
+ pass
21
+
22
+ @abstractmethod
23
+ def configure_optimizers(self):
24
+ pass
25
+
26
+ @abstractmethod
27
+ def loss(self, input: Tensor, output: Tensor, **kwargs):
28
+ return 0
29
+
30
+ @abstractmethod
31
+ def training_step(self, batch, batch_idx):
32
+ pass
33
+
34
+ def __average(self, key: str, outputs: List[Dict]) -> Tensor:
35
+ target_arr = torch.Tensor([val[key] for val in outputs]).float()
36
+ return target_arr.mean()
37
+
38
+ @torch.no_grad()
39
+ def on_train_epoch_end(self) -> None:
40
+ for key in self.log_value_list:
41
+ val = self.__average(key=key, outputs=self.train_step_output)
42
+ log_name = f"training/{key}"
43
+ self.log(name=log_name, value=val)
44
+
45
+ @torch.no_grad()
46
+ @abstractmethod
47
+ def validation_step(self, batch, batch_idx):
48
+ pass
49
+
50
+ @torch.no_grad()
51
+ def validation_epoch_end(self, outputs):
52
+ for key in self.log_value_list:
53
+ val = self.__average(key=key, outputs=self.validation_step_output)
54
+ log_name = f"training/{key}"
55
+ self.log(name=log_name, value=val)
models/metrics/classification.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ from torchmetrics import functional as FM
5
+
6
+
7
+ def classification_metrics(
8
+ preds: torch.Tensor,
9
+ target: torch.Tensor,
10
+ num_classes: int,
11
+ average: str = 'macro',
12
+ task: str = 'multiclass') -> Dict[str, torch.Tensor]:
13
+ """
14
+ get_classification_metrics
15
+ Return some metrics evaluation the classification task
16
+
17
+ Parameters
18
+ ----------
19
+ preds : torch.Tensor
20
+ logits, probs
21
+ target : torch.Tensor
22
+ targets label
23
+
24
+ Returns
25
+ -------
26
+ Dict[str, torch.Tensor]
27
+ _description_
28
+ """
29
+ f1 = FM.f1_score(preds=preds,
30
+ target=target,
31
+ num_classes=num_classes,
32
+ task=task,
33
+ average=average)
34
+ recall = FM.recall(preds=preds,
35
+ target=target,
36
+ num_classes=num_classes,
37
+ task=task,
38
+ average=average)
39
+ precision = FM.precision(preds=preds,
40
+ target=target,
41
+ num_classes=num_classes,
42
+ task=task,
43
+ average=average)
44
+ return dict(f1=f1, precision=precision, recall=recall)
models/metrics/regression.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import torch
4
+ from torchmetrics import functional as FM
5
+
6
+
7
+ def regression_metrics(preds: torch.Tensor,
8
+ target: torch.Tensor) -> Dict[str, torch.Tensor]:
9
+ """
10
+ get_classification_metrics
11
+ Return some metrics evaluation the classification task
12
+
13
+ Parameters
14
+ ----------
15
+ preds : torch.Tensor
16
+ logits, probs
17
+ target : torch.Tensor
18
+ targets label
19
+
20
+ Returns
21
+ -------
22
+ Dict[str, torch.Tensor]
23
+ _description_
24
+ """
25
+ mse: torch.Tensor = FM.mean_squared_error(preds=preds, target=target)
26
+ mape: torch.Tensor = FM.mean_absolute_percentage_error(preds=preds,
27
+ target=target)
28
+ return dict(mse=mse, mape=mape)
models/model_lit.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn, optim
2
+ from torch.nn import functional as F
3
+
4
+ from .base_model.classification import LightningClassification
5
+ from .metrics.classification import classification_metrics
6
+ from .modules.sample_torch_module import UselessLayer
7
+
8
+
9
+ class UselessClassification(LightningClassification):
10
+
11
+ def __init__(self, n_classes: int, lr: float, **kwargs) -> None:
12
+ super(UselessClassification).__init__()
13
+ self.save_hyperparameters()
14
+ self.n_classes = n_classes
15
+ self.lr = lr
16
+ self.main = nn.Sequential(UselessLayer(), nn.GELU())
17
+
18
+ def forward(self, x: Tensor) -> Tensor:
19
+ return self.main(x)
20
+
21
+ def loss(self, input: Tensor, target: Tensor) -> Tensor:
22
+ return F.mse_loss(input=input, target=target)
23
+
24
+ def configure_optimizers(self):
25
+ optimizer = optim.Adam(params=self.parameters(), lr=self.lr)
26
+ return optimizer
27
+
28
+ def training_step(self, batch, batch_idx):
29
+ x, y = batch
30
+
31
+ logits = self.forward(x)
32
+ loss = self.loss(input=x, target=y)
33
+ metrics = classification_metrics(preds=logits,
34
+ target=y,
35
+ num_classes=self.n_classes)
36
+
37
+ self.train_batch_output.append({'loss': loss, **metrics})
38
+ return loss
39
+
40
+ def validation_step(self, batch, batch_idx):
41
+ x, y = batch
42
+
43
+ logits = self.forward(x)
44
+ loss = self.loss(input=x, target=y)
45
+ metrics = classification_metrics(preds=logits,
46
+ target=y,
47
+ num_classes=self.n_classes)
48
+
49
+ self.validation_batch_output.append({'loss': loss, **metrics})
50
+ return loss
models/modules/sample_torch_module.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+
3
+
4
+ class UselessLayer(nn.Module):
5
+
6
+ def __init__(self) -> None:
7
+ super(UselessLayer, self).__init__()
8
+ self.seq = nn.Identity()
9
+
10
+ def forward(self, x: Tensor) -> Tensor:
11
+ x = self.seq(x)
12
+ return x
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ timm==0.9.2
2
+ torch --index-url https://download.pytorch.org/whl/cpu
3
+ torchvision --index-url https://download.pytorch.org/whl/cpu
4
+ torchaudio --index-url https://download.pytorch.org/whl/cpu
5
+ streamlit==1.21.0
6
+ pandas
7
+ transformers==4.29.1
8
+ Pillow==10.0.1
9
+ paddlepaddle
10
+ paddleocr
test_pdf2img.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pdf2image import convert_from_path
3
+
4
+ # Set the PDF file path
5
+ pdf_path = 'test.pdf'
6
+
7
+ # Convert the first page of the PDF to a JPEG image
8
+ first = 14
9
+ last = 14
10
+ images = convert_from_path(pdf_path, dpi=300, first_page=first, last_page=last, poppler_path=r"C:\poppler-23.07.0\Library\bin")
11
+
12
+ # Save the image file
13
+ image_path = os.path.splitext(pdf_path)[0]
14
+
15
+ for index, image in enumerate(images):
16
+ image.save(image_path + "p" + str(index+first) + '.jpg', 'JPEG')
tests/test_resource.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ def test_cuda():
2
+ from torch.cuda import is_available
3
+ assert is_available()
4
+
utils/.gitkeep ADDED
File without changes