Upload LayoutLMv2Main_cord2_gOcr_folder.py
Browse files
LayoutLMv2Main_cord2_gOcr_folder.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""inference with LayoutLMv2ForTokenClassification .ipynb
|
3 |
+
|
4 |
+
Automatically generated by Colaboratory.
|
5 |
+
|
6 |
+
Original file is located at
|
7 |
+
https://colab.research.google.com/drive/1nhfx6XRncq2XsOBREZGJI7tRByt_TJIa
|
8 |
+
|
9 |
+
## Inference with LayoutLMv2ForTokenClassification + Gradio demo
|
10 |
+
|
11 |
+
In this notebook, we are going to perform inference with `LayoutLMv2ForTokenClassification` on new document images, when no label information is accessible. At the end, we will also make a cool [Gradio](https://gradio.app/) demo, that turns our inference code into a cool web interface.
|
12 |
+
|
13 |
+
## Install libraries
|
14 |
+
|
15 |
+
Let's first install the required libraries:
|
16 |
+
* HuggingFace Transformers + Detectron2 (for the model)
|
17 |
+
* HuggingFace Datasets (for getting the data)
|
18 |
+
* PyTesseract (for OCR)
|
19 |
+
"""
|
20 |
+
|
21 |
+
# !pip install -q transformers
|
22 |
+
# !pip install -q gradio
|
23 |
+
|
24 |
+
# !pip install 'git+https://github.com/facebookresearch/detectron2.git'
|
25 |
+
|
26 |
+
# !pip install -q datasets
|
27 |
+
|
28 |
+
# !sudo apt install tesseract-ocr
|
29 |
+
# !pip install -q pytesseract
|
30 |
+
# pip install torchvision
|
31 |
+
|
32 |
+
# import gradio as gr
|
33 |
+
import os
|
34 |
+
import time
|
35 |
+
import numpy as np
|
36 |
+
from transformers import LayoutLMv2Processor, LayoutLMv2ForTokenClassification
|
37 |
+
from datasets import load_dataset
|
38 |
+
import torch
|
39 |
+
from transformers import LayoutLMv2ForTokenClassification
|
40 |
+
from PIL import Image, ImageDraw, ImageFont
|
41 |
+
import json
|
42 |
+
from GoogleVisionService import GoogleVisionService
|
43 |
+
from getTextHelper import cord_label_to_color, get_word_boxes_google, get_word_boxes_tesseract, getImg, getImgAndPath, normalize_bbox, unnormalize_box
|
44 |
+
from datasets import load_dataset
|
45 |
+
|
46 |
+
import pytesseract
|
47 |
+
import cv2
|
48 |
+
|
49 |
+
|
50 |
+
class labelCounter:
|
51 |
+
lbl_i = 0
|
52 |
+
lbl = None
|
53 |
+
|
54 |
+
|
55 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
56 |
+
|
57 |
+
apply_ocr = False
|
58 |
+
# apply_ocr=True
|
59 |
+
processor = LayoutLMv2Processor.from_pretrained(
|
60 |
+
"microsoft/layoutlmv2-base-uncased", apply_ocr=apply_ocr)
|
61 |
+
|
62 |
+
workingPath = "/Users/eliaweiss/Documents/doc2txt/lineCv/"
|
63 |
+
docNumberList = [6, 7, 10, 21, 25, 29, 48, 50]
|
64 |
+
# load the fine-tuned model from the hub
|
65 |
+
model = LayoutLMv2ForTokenClassification.from_pretrained(
|
66 |
+
"doc2txt/layoutlmv2-finetuned-cord")
|
67 |
+
model.to(device)
|
68 |
+
|
69 |
+
|
70 |
+
# datasets = load_dataset("MarkusDressel/cord")
|
71 |
+
|
72 |
+
"""Let's create a list containing all unique labels, as well as dictionaries mapping integers to their label names and vice versa. This will be useful to convert the model's predictions to actual label names."""
|
73 |
+
|
74 |
+
# labels = datasets['train'].features['ner_tags'].feature.names
|
75 |
+
labels = ['I-menu.cnt', 'I-menu.discountprice', 'I-menu.etc', 'I-menu.itemsubtotal', 'I-menu.nm', 'I-menu.num', 'I-menu.price', 'I-menu.sub_cnt', 'I-menu.sub_etc', 'I-menu.sub_nm', 'I-menu.sub_price', 'I-menu.sub_unitprice', 'I-menu.unitprice', 'I-menu.vatyn', 'I-sub_total.discount_price', 'I-sub_total.etc',
|
76 |
+
'I-sub_total.othersvc_price', 'I-sub_total.service_price', 'I-sub_total.subtotal_price', 'I-sub_total.tax_price', 'I-total.cashprice', 'I-total.changeprice', 'I-total.creditcardprice', 'I-total.emoneyprice', 'I-total.menuqty_cnt', 'I-total.menutype_cnt', 'I-total.total_etc', 'I-total.total_price', 'I-void_menu.nm', 'I-void_menu.price']
|
77 |
+
# print(labels)
|
78 |
+
|
79 |
+
id2label = {v: k for v, k in enumerate(labels)}
|
80 |
+
label2id = {k: v for v, k in enumerate(labels)}
|
81 |
+
|
82 |
+
|
83 |
+
"""## Inference
|
84 |
+
|
85 |
+
"""
|
86 |
+
# font = ImageFont.load_default()
|
87 |
+
font = ImageFont.truetype(
|
88 |
+
"/Users/eliaweiss/work/ocrPlus/ocrPlus/DejaVuSans.ttf", 50)
|
89 |
+
|
90 |
+
|
91 |
+
# print(example.keys())
|
92 |
+
# for nn in docNumberList:
|
93 |
+
def getNextLabel(labels, true_predictions):
|
94 |
+
labelCounter.lbl = None
|
95 |
+
i = 0
|
96 |
+
while labelCounter.lbl not in true_predictions:
|
97 |
+
labelCounter.lbl_i += 1
|
98 |
+
if labelCounter.lbl_i >= len(labels) - 1:
|
99 |
+
labelCounter.lbl_i = 0
|
100 |
+
labelCounter.lbl = labels[labelCounter.lbl_i]
|
101 |
+
i += 1
|
102 |
+
if i >= len(labels):
|
103 |
+
break
|
104 |
+
return labelCounter.lbl
|
105 |
+
|
106 |
+
|
107 |
+
def iob_to_label(label):
|
108 |
+
label = label[2:]
|
109 |
+
if not label:
|
110 |
+
return 'other'
|
111 |
+
return label
|
112 |
+
|
113 |
+
|
114 |
+
def drawLabels(image, true_predictions, true_boxes):
|
115 |
+
image_tmp = image.copy()
|
116 |
+
draw = ImageDraw.Draw(image_tmp)
|
117 |
+
|
118 |
+
color = cord_label_to_color(labelCounter.lbl)
|
119 |
+
|
120 |
+
draw.text((10, 10), text=labelCounter.lbl, fill=color, font=font)
|
121 |
+
|
122 |
+
for prediction, box in zip(true_predictions, true_boxes):
|
123 |
+
predicted_label = iob_to_label(prediction).lower()
|
124 |
+
color = cord_label_to_color(prediction)
|
125 |
+
|
126 |
+
if not labelCounter.lbl in prediction:
|
127 |
+
continue
|
128 |
+
# color = label2color[predicted_label] if predicted_label in label2color else 'black'
|
129 |
+
draw.rectangle(box, outline=color, width=5)
|
130 |
+
return image_tmp
|
131 |
+
|
132 |
+
# folder = "/Users/eliaweiss/.cache/huggingface/datasets/downloads/extracted/87634c2ab68012df3def8353986bcb092170ef7341c69e1a9cd97be52e513079/CORD/test/image/"
|
133 |
+
# folder = "/Users/eliaweiss/Documents/doc2txt/en_invoice_printed"
|
134 |
+
folder = "/Users/eliaweiss/ai/ICDAR-2019-SROIE/data/img"
|
135 |
+
for img_name in os.listdir(folder):
|
136 |
+
labelCounter.lbl_i = 0
|
137 |
+
start_time = time.time()
|
138 |
+
img_path = os.path.join(folder, img_name)
|
139 |
+
image = Image.open(img_path)
|
140 |
+
# image = Image.open(example['image_path'])
|
141 |
+
|
142 |
+
# pathOcr = workingPath + docNumber+".json"
|
143 |
+
# with open(pathOcr, encoding="utf-8") as f:
|
144 |
+
# gOcrJson = json.load(f)
|
145 |
+
gOcr = GoogleVisionService(img_path)
|
146 |
+
gOcrJson = gOcr.googleOcr()
|
147 |
+
|
148 |
+
|
149 |
+
image = image.convert("RGB")
|
150 |
+
width, height = image.size
|
151 |
+
|
152 |
+
"""We prepare it for the model using `LayoutLMv2Processor`."""
|
153 |
+
|
154 |
+
# Extract words and bounding boxes
|
155 |
+
words = []
|
156 |
+
boxes = []
|
157 |
+
|
158 |
+
words, boxes = get_word_boxes_google(gOcrJson)
|
159 |
+
boxes = [normalize_bbox(box, width, height) for box in boxes]
|
160 |
+
|
161 |
+
# # Use pytesseract to perform OCR on the image
|
162 |
+
# cv_image = cv2.imread(image_path)
|
163 |
+
# gray = cv2.cvtColor(cv_image, cv2.COLOR_BGR2GRAY)
|
164 |
+
|
165 |
+
# # Get word-level bounding boxes using pytesseract
|
166 |
+
# data = pytesseract.image_to_data(gray, output_type=pytesseract.Output.DICT)
|
167 |
+
# words, boxes = get_word_boxes_tesseract(data)
|
168 |
+
# boxes = [normalize_bbox(box, width, height) for box in boxes]
|
169 |
+
|
170 |
+
if not apply_ocr:
|
171 |
+
encoding = processor(image, words, boxes=boxes,
|
172 |
+
return_offsets_mapping=True, return_tensors="pt")
|
173 |
+
else:
|
174 |
+
encoding = processor(
|
175 |
+
image, return_offsets_mapping=True, return_tensors="pt")
|
176 |
+
offset_mapping = encoding.pop('offset_mapping')
|
177 |
+
print(encoding.keys())
|
178 |
+
|
179 |
+
"""Next, let's move everything to the GPU, if it's available."""
|
180 |
+
|
181 |
+
for k, v in encoding.items():
|
182 |
+
encoding[k] = v.to(device)
|
183 |
+
|
184 |
+
# forward pass
|
185 |
+
outputs = model(**encoding)
|
186 |
+
# print(outputs.logits.shape)
|
187 |
+
print("Time: " + str(time.time() - start_time))
|
188 |
+
|
189 |
+
"""Let's create the true predictions as well as the true boxes. With "true", I mean only taking into account tokens that are at the start of a given word. We can use the `offset_mapping` returned by the processor to determine which tokens are a subword."""
|
190 |
+
|
191 |
+
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
192 |
+
token_boxes = encoding.bbox.squeeze().tolist()
|
193 |
+
|
194 |
+
import numpy as np
|
195 |
+
|
196 |
+
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
|
197 |
+
|
198 |
+
true_predictions = [id2label[pred]
|
199 |
+
for idx, pred in enumerate(predictions) if not is_subword[idx]]
|
200 |
+
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(
|
201 |
+
token_boxes) if not is_subword[idx]]
|
202 |
+
|
203 |
+
# print(true_predictions)
|
204 |
+
# print(true_boxes)
|
205 |
+
|
206 |
+
"""Let's visualize the result!"""
|
207 |
+
|
208 |
+
labelCounter.lbl = "I-total.total_price" # getNextLabel(labels, true_predictions)
|
209 |
+
|
210 |
+
image_tmp = drawLabels(image, true_predictions, true_boxes)
|
211 |
+
|
212 |
+
# display the image with cv2
|
213 |
+
# Display the image
|
214 |
+
image_np = np.array(image_tmp)
|
215 |
+
# cv2.imwrite('output_image_v2.jpg', image_np)
|
216 |
+
cv2.imshow('Window Name ', image_np)
|
217 |
+
|
218 |
+
while (1):
|
219 |
+
k = cv2.waitKey(0) & 0xFF
|
220 |
+
if k == 255:
|
221 |
+
continue
|
222 |
+
if k == 126: # shift + `
|
223 |
+
break
|
224 |
+
print("k", k)
|
225 |
+
if k == 9: # tab - change direction
|
226 |
+
labelCounter.lbl = getNextLabel(labels, true_predictions)
|
227 |
+
|
228 |
+
image_tmp = drawLabels(image, true_predictions, true_boxes)
|
229 |
+
|
230 |
+
# display the image with cv2
|
231 |
+
# Display the image
|
232 |
+
image_np = np.array(image_tmp)
|
233 |
+
# cv2.imwrite('output_image_v2.jpg', image_np)
|
234 |
+
cv2.imshow('Window Name ', image_np)
|