Spaces:
Runtime error
Runtime error
karida
commited on
Commit
•
d4a6a10
1
Parent(s):
8b5a1c6
Add gradio
Browse files
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
|
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.15.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: other
|
11 |
---
|
|
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.15.0
|
8 |
+
app_file: main.py
|
9 |
pinned: false
|
10 |
license: other
|
11 |
---
|
main.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import ImageDraw, ImageFont, Image
|
3 |
+
from transformers import AutoModelForTokenClassification, AutoProcessor
|
4 |
+
import fitz # PyMuPDF
|
5 |
+
import io
|
6 |
+
|
7 |
+
|
8 |
+
def extract_data_from_pdf(pdf_path, page_number=0):
|
9 |
+
"""
|
10 |
+
Extracts image, words, and bounding boxes from a specified page of a PDF.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
- pdf_path (str): Path to the PDF file.
|
14 |
+
- page_number (int): Page number to extract data from (0-indexed).
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
- image: An image of the specified page.
|
18 |
+
- words: A list of words found on the page.
|
19 |
+
- boxes: A list of bounding boxes corresponding to the words.
|
20 |
+
"""
|
21 |
+
# Open the PDF
|
22 |
+
doc = fitz.open(pdf_path)
|
23 |
+
page = doc.load_page(page_number)
|
24 |
+
|
25 |
+
# Extract image of the page
|
26 |
+
pix = page.get_pixmap()
|
27 |
+
image_bytes = pix.tobytes("png")
|
28 |
+
image = Image.open(io.BytesIO(image_bytes))
|
29 |
+
|
30 |
+
# Extract words and their bounding boxes
|
31 |
+
words = []
|
32 |
+
boxes = []
|
33 |
+
for word in page.get_text("words"):
|
34 |
+
words.append(word[4])
|
35 |
+
boxes.append(word[:4]) # (x0, y0, x1, y1)
|
36 |
+
|
37 |
+
doc.close()
|
38 |
+
return image, words, boxes
|
39 |
+
|
40 |
+
|
41 |
+
def merge_pairs_v2(pairs):
|
42 |
+
if not pairs:
|
43 |
+
return []
|
44 |
+
|
45 |
+
merged = [pairs[0]]
|
46 |
+
for current in pairs[1:]:
|
47 |
+
last = merged[-1]
|
48 |
+
if last[0] == current[0]:
|
49 |
+
# Merge 'y' values (as strings) if 'x' values are the same
|
50 |
+
merged[-1] = [last[0], last[1] + " " + current[1]]
|
51 |
+
else:
|
52 |
+
merged.append(current)
|
53 |
+
|
54 |
+
return merged
|
55 |
+
|
56 |
+
|
57 |
+
def create_pretty_table(data):
|
58 |
+
table = "<div>"
|
59 |
+
for row in data:
|
60 |
+
color = (
|
61 |
+
"blue"
|
62 |
+
if row[0] == "Heder"
|
63 |
+
else "green"
|
64 |
+
if row[0] == "Section"
|
65 |
+
else "black"
|
66 |
+
)
|
67 |
+
table += "<p style='color:{};'>---{}---</p>{}".format(
|
68 |
+
color, row[0], row[1]
|
69 |
+
)
|
70 |
+
table += "</div>"
|
71 |
+
return table
|
72 |
+
|
73 |
+
|
74 |
+
# When using this function in Gradio, set the output type to 'html'
|
75 |
+
|
76 |
+
|
77 |
+
def interference(example, page_number=0):
|
78 |
+
image, words, boxes = extract_data_from_pdf(example, page_number)
|
79 |
+
boxes = [list(map(int, box)) for box in boxes]
|
80 |
+
|
81 |
+
# Process the image and words
|
82 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
83 |
+
"karida/LayoutLMv3_RFP"
|
84 |
+
)
|
85 |
+
processor = AutoProcessor.from_pretrained(
|
86 |
+
"microsoft/layoutlmv3-base", apply_ocr=False
|
87 |
+
)
|
88 |
+
encoding = processor(image, words, boxes=boxes, return_tensors="pt")
|
89 |
+
|
90 |
+
# Prediction
|
91 |
+
with torch.no_grad():
|
92 |
+
outputs = model(**encoding)
|
93 |
+
|
94 |
+
logits = outputs.logits
|
95 |
+
predictions = logits.argmax(-1).squeeze().tolist()
|
96 |
+
model_words = encoding.word_ids()
|
97 |
+
|
98 |
+
# Process predictions
|
99 |
+
token_boxes = encoding.bbox.squeeze().tolist()
|
100 |
+
width, height = image.size
|
101 |
+
|
102 |
+
true_predictions = [model.config.id2label[pred] for pred in predictions]
|
103 |
+
true_boxes = token_boxes
|
104 |
+
# Draw annotations on the image
|
105 |
+
draw = ImageDraw.Draw(image)
|
106 |
+
font = ImageFont.load_default()
|
107 |
+
|
108 |
+
def iob_to_label(label):
|
109 |
+
label = label[2:]
|
110 |
+
return "other" if not label else label.lower()
|
111 |
+
|
112 |
+
label2color = {
|
113 |
+
"question": "blue",
|
114 |
+
"answer": "green",
|
115 |
+
"header": "orange",
|
116 |
+
"other": "violet",
|
117 |
+
}
|
118 |
+
|
119 |
+
# print(len(true_predictions), len(true_boxes), len(model_words))
|
120 |
+
|
121 |
+
table = []
|
122 |
+
ids = set()
|
123 |
+
|
124 |
+
for prediction, box, model_word in zip(
|
125 |
+
true_predictions, true_boxes, model_words
|
126 |
+
):
|
127 |
+
predicted_label = iob_to_label(prediction)
|
128 |
+
draw.rectangle(box, outline=label2color[predicted_label], width=2)
|
129 |
+
# draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
|
130 |
+
if model_word and model_word not in ids and predicted_label != "other":
|
131 |
+
ids.add(model_word)
|
132 |
+
table.append([predicted_label[0], words[model_word]])
|
133 |
+
|
134 |
+
values = merge_pairs_v2(table)
|
135 |
+
values = [
|
136 |
+
["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] for x in values
|
137 |
+
]
|
138 |
+
table = create_pretty_table(values)
|
139 |
+
return image, table
|
140 |
+
|
141 |
+
|
142 |
+
import gradio as gr
|
143 |
+
|
144 |
+
description_text = """
|
145 |
+
<p>
|
146 |
+
Heading - <span style='color: blue;'>shown in blue</span><br>
|
147 |
+
Section - <span style='color: green;'>shown in green</span><br>
|
148 |
+
other - (ignored)<span style='color: violet;'>shown in violet</span>
|
149 |
+
</p>
|
150 |
+
"""
|
151 |
+
|
152 |
+
flagging_options = ["great example", "bad example"]
|
153 |
+
|
154 |
+
|
155 |
+
iface = gr.Interface(
|
156 |
+
fn=interference,
|
157 |
+
inputs=["file", "number"],
|
158 |
+
outputs=["image", "html"],
|
159 |
+
# examples=[["output.pdf", 1]],
|
160 |
+
description=description_text,
|
161 |
+
flagging_options=flagging_options,
|
162 |
+
)
|
163 |
+
# iface.save(".")
|
164 |
+
if __name__ == "__main__":
|
165 |
+
iface.launch()
|