Spaces:
Runtime error
Runtime error
feat: application ready
Browse files
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import streamlit as st
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from io import BytesIO
|
6 |
+
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor
|
7 |
+
|
8 |
+
|
9 |
+
def run_prediction(sample):
|
10 |
+
global pretrained_model, processor, task_prompt
|
11 |
+
if isinstance(sample, dict):
|
12 |
+
# prepare inputs
|
13 |
+
pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
|
14 |
+
else: # sample is an image
|
15 |
+
# prepare encoder inputs
|
16 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
17 |
+
|
18 |
+
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
19 |
+
|
20 |
+
# run inference
|
21 |
+
outputs = pretrained_model.generate(
|
22 |
+
pixel_values.to(device),
|
23 |
+
decoder_input_ids=decoder_input_ids.to(device),
|
24 |
+
max_length=pretrained_model.decoder.config.max_position_embeddings,
|
25 |
+
early_stopping=True,
|
26 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
27 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
28 |
+
use_cache=True,
|
29 |
+
num_beams=1,
|
30 |
+
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
31 |
+
return_dict_in_generate=True,
|
32 |
+
)
|
33 |
+
|
34 |
+
# process output
|
35 |
+
prediction = processor.batch_decode(outputs.sequences)[0]
|
36 |
+
|
37 |
+
# post-processing
|
38 |
+
if "cord" in task_prompt:
|
39 |
+
prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
40 |
+
# prediction = re.sub(r"<.*?>", "", prediction, count=1).strip() # remove first task start token
|
41 |
+
prediction = processor.token2json(prediction)
|
42 |
+
|
43 |
+
# load reference target
|
44 |
+
if isinstance(sample, dict):
|
45 |
+
target = processor.token2json(sample["target_sequence"])
|
46 |
+
else:
|
47 |
+
target = "<not_provided>"
|
48 |
+
|
49 |
+
return prediction, target
|
50 |
+
|
51 |
+
|
52 |
+
task_prompt = f"<s>"
|
53 |
+
|
54 |
+
# logo = Image.open("./img/rsz_unstructured_logo.png")
|
55 |
+
# st.image(logo)
|
56 |
+
|
57 |
+
st.markdown('''
|
58 |
+
### Donut Common Crawl
|
59 |
+
Experimental OCR-free Document Understanding Vision Transformer nicknamed π©, fine-tuned with few samples of the common-crawl with some specific document elements.
|
60 |
+
''')
|
61 |
+
|
62 |
+
with st.sidebar:
|
63 |
+
information = st.radio(
|
64 |
+
"Choose one predictor:?",
|
65 |
+
('Base Common-Crawl π©', 'Hierarchical Common-Crawl π©'))
|
66 |
+
image_choice = st.selectbox('Pick one π', ['1', '2', '3'], index=1)
|
67 |
+
|
68 |
+
st.text(f'{information} mode is ON!\nTarget π: {image_choice}') # \n(opening image @:./img/receipt-{receipt}.png)')
|
69 |
+
|
70 |
+
col1, col2 = st.columns(2)
|
71 |
+
|
72 |
+
image_choice_map = {
|
73 |
+
'1': 'commoncrawl_amandalacombznewspolice-bust-man-sawed-oal_1.jpg',
|
74 |
+
'2': 'commoncrawl_canyonhillschroniclecomtagwomens-basketbll_0.png',
|
75 |
+
'3': 'commoncrawl_celstuttgartdeideaa-different-stort-of-nfe_0.png'
|
76 |
+
}
|
77 |
+
image = Image.open(image_choice_map[image_choice])
|
78 |
+
with col1:
|
79 |
+
st.image(image, caption='Your target sample')
|
80 |
+
|
81 |
+
if st.button('Parse sample! π'):
|
82 |
+
image = image.convert('RGB')
|
83 |
+
image.save('./target_image.jpg')
|
84 |
+
image = Image.open('./target_image.jpg')
|
85 |
+
with st.spinner(f'baking the π©s...'):
|
86 |
+
if information == 'Base Common-Crawl π©':
|
87 |
+
processor = DonutProcessor.from_pretrained(".") # laverdes/donut-commoncrawl
|
88 |
+
pretrained_model = VisionEncoderDecoderModel.from_pretrained("checkpoints") # laverdes/donut-commoncrawl
|
89 |
+
task_prompt = f"<s>"
|
90 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
+
pretrained_model.to(device)
|
92 |
+
|
93 |
+
elif information == 'Hierarchical Common-Crawl π©':
|
94 |
+
st.info("Not implemented yet...")
|
95 |
+
|
96 |
+
with col2:
|
97 |
+
st.info(f'parsing π...')
|
98 |
+
parsed_info, _ = run_prediction(image)
|
99 |
+
st.text(f'\n{information}')
|
100 |
+
st.json(parsed_info)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
sentencepiece
|
samples/commoncrawl_amandalacombznewspolice-bust-man-sawed-oal_1.jpg
ADDED
samples/commoncrawl_canyonhillschroniclecomtagwomens-basketbll_0.png
ADDED
samples/commoncrawl_celstuttgartdeideaa-different-stort-of-nfe_0.png
ADDED