deepsh2207 commited on
Commit
d812fea
β€’
1 Parent(s): cdd9cb4

Updates in readme, parameters

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +23 -15
  3. backend/pytorch.py +16 -16
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: docTR
3
  emoji: πŸ“‘
4
  colorFrom: purple
5
  colorTo: pink
 
1
  ---
2
+ title: Text Extractor
3
  emoji: πŸ“‘
4
  colorFrom: purple
5
  colorTo: pink
app.py CHANGED
@@ -1,5 +1,5 @@
1
- import cv2
2
- import matplotlib.pyplot as plt
3
  import numpy as np
4
  import streamlit as st
5
  import torch
@@ -7,7 +7,7 @@ import torch
7
  from doctr.io import DocumentFile
8
  from doctr.utils.visualization import visualize_page
9
 
10
- from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
11
 
12
  forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
 
@@ -50,18 +50,18 @@ def main(det_archs, reco_archs):
50
  det_arch = st.sidebar.selectbox("Text detection model", det_archs)
51
  reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
52
 
53
- # For newline
54
- st.sidebar.write("\n")
55
- # Only straight pages or possible rotation
56
- st.sidebar.title("Parameters")
57
- assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
58
- st.sidebar.write("\n")
59
- # Straighten pages
60
- straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
61
- st.sidebar.write("\n")
62
- # Binarization threshold
63
- bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
64
- st.sidebar.write("\n")
65
 
66
  if st.sidebar.button("Analyze page"):
67
  if uploaded_file is None:
@@ -69,6 +69,9 @@ def main(det_archs, reco_archs):
69
 
70
  else:
71
  with st.spinner("Loading model..."):
 
 
 
72
  predictor = load_predictor(
73
  det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
74
  )
@@ -96,6 +99,11 @@ def main(det_archs, reco_archs):
96
  img = out.pages[0].synthesize()
97
  cols[2].image(img, clamp=True)
98
 
 
 
 
 
 
99
  # Display JSON
100
  st.markdown("\nHere are your analysis results in JSON format:")
101
  st.json(page_export, expanded=False)
 
1
+ # import cv2
2
+ # import matplotlib.pyplot as plt
3
  import numpy as np
4
  import streamlit as st
5
  import torch
 
7
  from doctr.io import DocumentFile
8
  from doctr.utils.visualization import visualize_page
9
 
10
+ from backend.pytorch import DET_ARCHS, RECO_ARCHS, load_predictor #forward_image
11
 
12
  forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
 
 
50
  det_arch = st.sidebar.selectbox("Text detection model", det_archs)
51
  reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
52
 
53
+ # # For newline
54
+ # st.sidebar.write("\n")
55
+ # # Only straight pages or possible rotation
56
+ # st.sidebar.title("Parameters")
57
+ # assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
58
+ # st.sidebar.write("\n")
59
+ # # Straighten pages
60
+ # straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
61
+ # st.sidebar.write("\n")
62
+ # # Binarization threshold
63
+ # bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
64
+ # st.sidebar.write("\n")
65
 
66
  if st.sidebar.button("Analyze page"):
67
  if uploaded_file is None:
 
69
 
70
  else:
71
  with st.spinner("Loading model..."):
72
+ # Default Values
73
+ assume_straight_pages, straighten_pages, bin_thresh = True, False, 0.3
74
+
75
  predictor = load_predictor(
76
  det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
77
  )
 
99
  img = out.pages[0].synthesize()
100
  cols[2].image(img, clamp=True)
101
 
102
+ print('out',out)
103
+ print('\n')
104
+ print('page_export',page_export)
105
+ print('\n')
106
+
107
  # Display JSON
108
  st.markdown("\nHere are your analysis results in JSON format:")
109
  st.json(page_export, expanded=False)
backend/pytorch.py CHANGED
@@ -60,22 +60,22 @@ def load_predictor(
60
  return predictor
61
 
62
 
63
- def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
64
- """Forward an image through the predictor
65
 
66
- Args:
67
- ----
68
- predictor: instance of OCRPredictor
69
- image: image to process
70
- device: torch.device, the device to process the image on
71
 
72
- Returns:
73
- -------
74
- segmentation map
75
- """
76
- with torch.no_grad():
77
- processed_batches = predictor.det_predictor.pre_processor([image])
78
- out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
79
- seg_map = out["out_map"].to("cpu").numpy()
80
 
81
- return seg_map
 
60
  return predictor
61
 
62
 
63
+ # def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray:
64
+ # """Forward an image through the predictor
65
 
66
+ # Args:
67
+ # ----
68
+ # predictor: instance of OCRPredictor
69
+ # image: image to process
70
+ # device: torch.device, the device to process the image on
71
 
72
+ # Returns:
73
+ # -------
74
+ # segmentation map
75
+ # """
76
+ # with torch.no_grad():
77
+ # processed_batches = predictor.det_predictor.pre_processor([image])
78
+ # out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True)
79
+ # seg_map = out["out_map"].to("cpu").numpy()
80
 
81
+ # return seg_map