File size: 5,544 Bytes
d812fea
 
640c986
 
 
d840e96
63f1335
640c986
 
 
 
d812fea
640c986
 
 
 
 
 
 
 
 
d06a6f3
 
 
640c986
383e64a
640c986
 
 
 
 
383e64a
 
640c986
383e64a
cdd9cb4
 
640c986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d812fea
 
 
 
 
 
 
 
 
 
 
 
640c986
 
 
 
 
 
 
d812fea
 
 
640c986
 
 
 
 
cdd9cb4
 
 
 
 
 
 
 
 
 
640c986
 
 
 
cdd9cb4
640c986
 
 
 
 
cdd9cb4
640c986
d812fea
 
 
 
84f3cde
 
 
 
 
 
9bc57af
84f3cde
 
 
 
 
63f1335
84f3cde
d812fea
640c986
63f1335
 
 
ae0acef
d6bde52
63f1335
640c986
 
c00e3bb
 
fd26dfb
640c986
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# import cv2
# import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import torch
import json
import base64

from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page

from backend.pytorch import DET_ARCHS, RECO_ARCHS, load_predictor #forward_image

forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def main(det_archs, reco_archs):
    """Build a streamlit layout"""
    # Wide mode
    st.set_page_config(layout="wide")

    st.markdown("Used Github Actions to automatically build the app on any updates on this [github repo link](https://github.com/deepanshu2207/imgtotxt_using_DocTR)")
    st.caption("Made with ❤️ by Deepanshu. Credits to 🤗 Spaces for Hosting this.")
    
    # Designing the interface
    st.title("Document Text Extraction")
    # For newline
    st.write("\n")
    # Instructions
    st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
    # Set the columns
    # cols = st.columns((1, 1, 1, 1))
    cols = st.columns((1, 1, 1))
    cols[0].subheader("Input page")
    # cols[1].subheader("Segmentation heatmap")
    cols[1].subheader("OCR output")
    cols[2].subheader("Page reconstitution")

    # Sidebar
    # File selection
    st.sidebar.title("Document selection")
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
    if uploaded_file is not None:
        if uploaded_file.name.endswith(".pdf"):
            doc = DocumentFile.from_pdf(uploaded_file.read())
        else:
            doc = DocumentFile.from_images(uploaded_file.read())
        page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
        page = doc[page_idx]
        cols[0].image(page)

    # Model selection
    st.sidebar.title("Model selection")
    det_arch = st.sidebar.selectbox("Text detection model", det_archs)
    reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)

    # # For newline
    # st.sidebar.write("\n")
    # # Only straight pages or possible rotation
    # st.sidebar.title("Parameters")
    # assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
    # st.sidebar.write("\n")
    # # Straighten pages
    # straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
    # st.sidebar.write("\n")
    # # Binarization threshold
    # bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
    # st.sidebar.write("\n")

    if st.sidebar.button("Analyze page"):
        if uploaded_file is None:
            st.sidebar.write("Please upload a document")

        else:
            with st.spinner("Loading model..."):
                # Default Values
                assume_straight_pages, straighten_pages, bin_thresh = True, False, 0.3

                predictor = load_predictor(
                    det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
                )

            with st.spinner("Analyzing..."):
                # # Forward the image to the model
                # seg_map = forward_image(predictor, page, forward_device)
                # seg_map = np.squeeze(seg_map)
                # seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)

                # # Plot the raw heatmap
                # fig, ax = plt.subplots()
                # ax.imshow(seg_map)
                # ax.axis("off")
                # cols[1].pyplot(fig)

                # Plot OCR output
                out = predictor([page])
                fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
                cols[1].pyplot(fig)

                # Page reconsitution under input page
                page_export = out.pages[0].export()
                if assume_straight_pages or (not assume_straight_pages and straighten_pages):
                    img = out.pages[0].synthesize()
                    cols[2].image(img, clamp=True)

                print('out',out)
                print('\n')
                print('page_export',page_export)
                print('\n')
                all_text = ''
                for i in page_export['blocks']:
                    for line in i['lines']:
                        for word in line['words']:
                            all_text+=word['value']
                            all_text+=' '
                        all_text+='\n'    
                
                print('all_text', all_text)
                print('\n')

                # Display Text
                st.markdown("\n### **Here is your text:**")
                st.write(all_text)

                # Display JSON
                # json_string = json.dumps(page_export)
                st.markdown("\n### **Here is your document structure in JSON format:**")
                encoded_data = base64.b64encode(json.dumps(page_export).encode("utf-8")).decode("utf-8")
                download_link = f"data:file/txt;base64,{encoded_data}"
                st.markdown(f"[Download JSON]( {download_link} )", unsafe_allow_html=True)
                # st.download_button(label="Download JSON", data=json_string, file_name='data.json', mime='application/json')
                st.json(page_export, expanded=False)

            st.success('Done!')
            st.balloons()
    

if __name__ == "__main__":
    main(DET_ARCHS, RECO_ARCHS)