Spaces:
Running
on
T4
Running
on
T4
refactored redid the ui of fast track
Browse files- app.py +1 -1
- helper/gradio_config.py +1 -0
- src/htr_pipeline/gradio_backend.py +30 -13
- src/htr_pipeline/pipeline.py +2 -2
- src/htr_pipeline/utils/pipeline_inferencer.py +11 -1
- src/htr_pipeline/utils/visualize_xml.py +11 -45
- tabs/htr_tool.py +145 -34
app.py
CHANGED
@@ -118,7 +118,7 @@ print(job.result())
|
|
118 |
# demo.load(None, None, None, _js=js)
|
119 |
|
120 |
|
121 |
-
demo.queue(concurrency_count=
|
122 |
|
123 |
|
124 |
if __name__ == "__main__":
|
|
|
118 |
# demo.load(None, None, None, _js=js)
|
119 |
|
120 |
|
121 |
+
demo.queue(concurrency_count=2, max_size=2)
|
122 |
|
123 |
|
124 |
if __name__ == "__main__":
|
helper/gradio_config.py
CHANGED
@@ -22,6 +22,7 @@ class GradioConfig:
|
|
22 |
#gallery {height: 400px}
|
23 |
.fixed-height.svelte-g4rw9.svelte-g4rw9 {min-height: 400px;}
|
24 |
|
|
|
25 |
#gallery_lines > div.preview.svelte-1b19cri > div.thumbnails.scroll-hide.svelte-1b19cri {display: none;}
|
26 |
|
27 |
"""
|
|
|
22 |
#gallery {height: 400px}
|
23 |
.fixed-height.svelte-g4rw9.svelte-g4rw9 {min-height: 400px;}
|
24 |
|
25 |
+
#download_file > div.empty.svelte-lk9eg8.large.unpadded_box {min-height: 100px;}
|
26 |
#gallery_lines > div.preview.svelte-1b19cri > div.thumbnails.scroll-hide.svelte-1b19cri {display: none;}
|
27 |
|
28 |
"""
|
src/htr_pipeline/gradio_backend.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
2 |
|
|
|
3 |
import gradio as gr
|
|
|
4 |
import pandas as pd
|
5 |
|
6 |
from src.htr_pipeline.inferencer import Inferencer, InferencerInterface
|
@@ -23,13 +25,21 @@ class SingletonModelLoader:
|
|
23 |
self.pipeline = Pipeline(self.inferencer)
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# fast track
|
27 |
class FastTrack:
|
28 |
def __init__(self, model_loader):
|
29 |
self.pipeline: PipelineInterface = model_loader.pipeline
|
30 |
|
31 |
def segment_to_xml(self, image, radio_button_choices):
|
32 |
-
|
|
|
|
|
33 |
xml_xml = "page_xml.xml"
|
34 |
xml_txt = "page_txt.txt"
|
35 |
|
@@ -41,10 +51,18 @@ class FastTrack:
|
|
41 |
with open(xml_xml, "w") as f:
|
42 |
f.write(rendered_xml)
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
returned_file_extension = self.file_extenstion_to_return(radio_button_choices, xml_xml, xml_txt)
|
46 |
|
47 |
-
return
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def file_extenstion_to_return(self, radio_button_choices, xml_xml, xml_txt):
|
50 |
if len(radio_button_choices) < 2:
|
@@ -56,20 +74,19 @@ class FastTrack:
|
|
56 |
returned_file_extension = [xml_txt, xml_xml]
|
57 |
return returned_file_extension
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def segment_to_xml_api(self, image):
|
60 |
rendered_xml = self.pipeline.running_htr_pipeline(image)
|
61 |
return rendered_xml
|
62 |
|
63 |
-
def visualize_xml_and_return_txt(self, img, xml_txt):
|
64 |
-
xml_img = self.pipeline.visualize_xml(img)
|
65 |
-
|
66 |
-
if os.path.exists(f"./{xml_txt}"):
|
67 |
-
os.remove(f"./{xml_txt}")
|
68 |
-
|
69 |
-
self.pipeline.parse_xml_to_txt()
|
70 |
-
|
71 |
-
return xml_img
|
72 |
-
|
73 |
|
74 |
# Custom track
|
75 |
class CustomTrack:
|
|
|
1 |
import os
|
2 |
|
3 |
+
import cv2
|
4 |
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
import pandas as pd
|
7 |
|
8 |
from src.htr_pipeline.inferencer import Inferencer, InferencerInterface
|
|
|
25 |
self.pipeline = Pipeline(self.inferencer)
|
26 |
|
27 |
|
28 |
+
def handling_callback_stop_inferencer():
|
29 |
+
from src.htr_pipeline.utils import pipeline_inferencer
|
30 |
+
|
31 |
+
pipeline_inferencer.terminate = False
|
32 |
+
|
33 |
+
|
34 |
# fast track
|
35 |
class FastTrack:
|
36 |
def __init__(self, model_loader):
|
37 |
self.pipeline: PipelineInterface = model_loader.pipeline
|
38 |
|
39 |
def segment_to_xml(self, image, radio_button_choices):
|
40 |
+
handling_callback_stop_inferencer()
|
41 |
+
|
42 |
+
gr.Info("Excuting HTR on image")
|
43 |
xml_xml = "page_xml.xml"
|
44 |
xml_txt = "page_txt.txt"
|
45 |
|
|
|
51 |
with open(xml_xml, "w") as f:
|
52 |
f.write(rendered_xml)
|
53 |
|
54 |
+
if os.path.exists(f"./{xml_txt}"):
|
55 |
+
os.remove(f"./{xml_txt}")
|
56 |
+
|
57 |
+
self.pipeline.parse_xml_to_txt()
|
58 |
+
|
59 |
returned_file_extension = self.file_extenstion_to_return(radio_button_choices, xml_xml, xml_txt)
|
60 |
|
61 |
+
return returned_file_extension, gr.update(visible=True)
|
62 |
+
|
63 |
+
def visualize_image_viewer(self, image):
|
64 |
+
xml_img, text_polygon_dict = self.pipeline.visualize_xml(image)
|
65 |
+
return xml_img, text_polygon_dict
|
66 |
|
67 |
def file_extenstion_to_return(self, radio_button_choices, xml_xml, xml_txt):
|
68 |
if len(radio_button_choices) < 2:
|
|
|
74 |
returned_file_extension = [xml_txt, xml_xml]
|
75 |
return returned_file_extension
|
76 |
|
77 |
+
def get_text_from_coords(self, text_polygon_dict, evt: gr.SelectData):
|
78 |
+
x, y = evt.index[0], evt.index[1]
|
79 |
+
|
80 |
+
for text, polygon_coords in text_polygon_dict.items():
|
81 |
+
if (
|
82 |
+
cv2.pointPolygonTest(np.array(polygon_coords), (x, y), False) >= 0
|
83 |
+
): # >= 0 means on the polygon or inside
|
84 |
+
return text
|
85 |
+
|
86 |
def segment_to_xml_api(self, image):
|
87 |
rendered_xml = self.pipeline.running_htr_pipeline(image)
|
88 |
return rendered_xml
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# Custom track
|
92 |
class CustomTrack:
|
src/htr_pipeline/pipeline.py
CHANGED
@@ -40,8 +40,8 @@ class Pipeline:
|
|
40 |
def visualize_xml(self, input_image: np.ndarray) -> np.ndarray:
|
41 |
xml_viz = XmlViz()
|
42 |
bin_input_image = self.preprocess_img.binarize_img(input_image)
|
43 |
-
xml_image = xml_viz.visualize_xml(bin_input_image)
|
44 |
-
return xml_image
|
45 |
|
46 |
@timer_func
|
47 |
def parse_xml_to_txt(self) -> None:
|
|
|
40 |
def visualize_xml(self, input_image: np.ndarray) -> np.ndarray:
|
41 |
xml_viz = XmlViz()
|
42 |
bin_input_image = self.preprocess_img.binarize_img(input_image)
|
43 |
+
xml_image, text_polygon_dict = xml_viz.visualize_xml(bin_input_image)
|
44 |
+
return xml_image, text_polygon_dict
|
45 |
|
46 |
@timer_func
|
47 |
def parse_xml_to_txt(self) -> None:
|
src/htr_pipeline/utils/pipeline_inferencer.py
CHANGED
@@ -4,6 +4,8 @@ from tqdm import tqdm
|
|
4 |
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
|
5 |
from src.htr_pipeline.utils.xml_helper import XMLHelper
|
6 |
|
|
|
|
|
7 |
|
8 |
class PipelineInferencer:
|
9 |
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
|
@@ -29,6 +31,8 @@ class PipelineInferencer:
|
|
29 |
containments_threshold,
|
30 |
htr_threshold=0.7,
|
31 |
):
|
|
|
|
|
32 |
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
|
33 |
image,
|
34 |
pred_score_threshold=pred_score_threshold_regions,
|
@@ -38,6 +42,8 @@ class PipelineInferencer:
|
|
38 |
gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse")
|
39 |
region_data_list = []
|
40 |
for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
|
|
|
|
|
41 |
region_data = self._create_region_data(
|
42 |
data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
|
43 |
)
|
@@ -68,7 +74,7 @@ class PipelineInferencer:
|
|
68 |
region_data["textLines"] = text_lines
|
69 |
mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0
|
70 |
|
71 |
-
return region_data if mean_htr_score > htr_threshold else None
|
72 |
|
73 |
def _process_lines(
|
74 |
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.7
|
@@ -90,7 +96,11 @@ class PipelineInferencer:
|
|
90 |
|
91 |
gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.")
|
92 |
|
|
|
|
|
93 |
for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
|
|
|
|
|
94 |
line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)
|
95 |
|
96 |
if line_data:
|
|
|
4 |
from src.htr_pipeline.utils.process_segmask import SegMaskHelper
|
5 |
from src.htr_pipeline.utils.xml_helper import XMLHelper
|
6 |
|
7 |
+
terminate = False
|
8 |
+
|
9 |
|
10 |
class PipelineInferencer:
|
11 |
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
|
|
|
31 |
containments_threshold,
|
32 |
htr_threshold=0.7,
|
33 |
):
|
34 |
+
global terminate
|
35 |
+
|
36 |
_, regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered = inferencer.predict_regions(
|
37 |
image,
|
38 |
pred_score_threshold=pred_score_threshold_regions,
|
|
|
42 |
gr.Info(f"Found {len(regions_cropped_ordered)} Regions to parse")
|
43 |
region_data_list = []
|
44 |
for i, data in tqdm(enumerate(zip(regions_cropped_ordered, reg_polygons_ordered, reg_masks_ordered))):
|
45 |
+
if terminate:
|
46 |
+
break
|
47 |
region_data = self._create_region_data(
|
48 |
data, i, inferencer, pred_score_threshold_lines, containments_threshold, htr_threshold
|
49 |
)
|
|
|
74 |
region_data["textLines"] = text_lines
|
75 |
mean_htr_score = sum(htr_scores) / len(htr_scores) if htr_scores else 0
|
76 |
|
77 |
+
return region_data if mean_htr_score > htr_threshold + 0.1 else None
|
78 |
|
79 |
def _process_lines(
|
80 |
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.7
|
|
|
96 |
|
97 |
gr.Info(f" Region {id_number}, found {total_lines_len} lines to parse and transcribe.")
|
98 |
|
99 |
+
global terminate
|
100 |
+
|
101 |
for index, (line, line_pol) in enumerate(zip(lines_cropped_ordered, line_polygons_ordered_trans)):
|
102 |
+
if terminate:
|
103 |
+
break
|
104 |
line_data, htr_score = self._create_line_data(line, line_pol, index, region_id, inferencer, htr_threshold)
|
105 |
|
106 |
if line_data:
|
src/htr_pipeline/utils/visualize_xml.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import random
|
2 |
import xml.etree.ElementTree as ET
|
3 |
|
4 |
-
import
|
5 |
-
|
6 |
|
7 |
|
8 |
class XmlViz:
|
@@ -11,58 +11,24 @@ class XmlViz:
|
|
11 |
self.root = self.tree.getroot()
|
12 |
self.namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
|
13 |
|
14 |
-
def visualize_xml(
|
15 |
-
|
16 |
-
|
17 |
-
font_size=9,
|
18 |
-
text_offset=10,
|
19 |
-
font_path_tff="./src/htr_pipeline/utils/templates/arial.ttf",
|
20 |
-
):
|
21 |
-
image = Image.fromarray(background_image).convert("RGBA")
|
22 |
-
|
23 |
-
text_offset = -text_offset
|
24 |
-
base_font_size = font_size
|
25 |
-
font_path = font_path_tff
|
26 |
-
|
27 |
-
max_bbox_width = 0 # Initialize maximum bounding box width
|
28 |
-
gr.Info("Parsing XML to visualize the data.")
|
29 |
-
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
30 |
-
coords = textregion.find(f"{self.namespace}Coords").attrib["points"].split()
|
31 |
-
points = [tuple(map(int, point.split(","))) for point in coords]
|
32 |
-
x_coords, y_coords = zip(*points)
|
33 |
-
min_x, max_x = min(x_coords), max(x_coords)
|
34 |
-
bbox_width = max_x - min_x # Width of the current bounding box
|
35 |
-
max_bbox_width = max(max_bbox_width, bbox_width) # Update maximum bounding box width
|
36 |
-
|
37 |
-
scaling_factor = max_bbox_width / 400.0 # Use maximum bounding box width for scaling
|
38 |
-
font_size_scaled = int(base_font_size * scaling_factor)
|
39 |
-
font = ImageFont.truetype(font_path, font_size_scaled)
|
40 |
|
41 |
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
42 |
-
fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)
|
43 |
for textline in textregion.findall(f".//{self.namespace}TextLine"):
|
44 |
coords = textline.find(f"{self.namespace}Coords").attrib["points"].split()
|
45 |
points = [tuple(map(int, point.split(","))) for point in coords]
|
46 |
-
|
47 |
-
poly_image = Image.new("RGBA", image.size)
|
48 |
-
poly_draw = ImageDraw.Draw(poly_image)
|
49 |
-
poly_draw.polygon(points, fill=fill_color)
|
50 |
|
51 |
text = textline.find(f"{self.namespace}TextEquiv").find(f"{self.namespace}Unicode").text
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
min_y = min(y_coords)
|
56 |
-
text_width, text_height = poly_draw.textsize(text, font=font) # Get text size
|
57 |
-
text_position = (
|
58 |
-
(min_x + max_x) // 2 - text_width // 2,
|
59 |
-
min_y + text_offset,
|
60 |
-
) # Center text horizontally
|
61 |
-
|
62 |
-
poly_draw.text(text_position, text, fill=(0, 0, 0), font=font)
|
63 |
-
image = Image.alpha_composite(image, poly_image)
|
64 |
|
65 |
-
return
|
66 |
|
67 |
|
68 |
if __name__ == "__main__":
|
|
|
1 |
import random
|
2 |
import xml.etree.ElementTree as ET
|
3 |
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
|
7 |
|
8 |
class XmlViz:
|
|
|
11 |
self.root = self.tree.getroot()
|
12 |
self.namespace = "{http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15}"
|
13 |
|
14 |
+
def visualize_xml(self, background_image):
|
15 |
+
overlay = background_image.copy()
|
16 |
+
text_polygon_dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
for textregion in self.root.findall(f".//{self.namespace}TextRegion"):
|
19 |
+
fill_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
20 |
for textline in textregion.findall(f".//{self.namespace}TextLine"):
|
21 |
coords = textline.find(f"{self.namespace}Coords").attrib["points"].split()
|
22 |
points = [tuple(map(int, point.split(","))) for point in coords]
|
23 |
+
cv2.fillPoly(overlay, [np.array(points)], fill_color)
|
|
|
|
|
|
|
24 |
|
25 |
text = textline.find(f"{self.namespace}TextEquiv").find(f"{self.namespace}Unicode").text
|
26 |
+
text_polygon_dict[text] = points
|
27 |
|
28 |
+
# Blend the overlay with the original image
|
29 |
+
cv2.addWeighted(overlay, 0.5, background_image, 0.5, 0, background_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
return background_image, text_polygon_dict
|
32 |
|
33 |
|
34 |
if __name__ == "__main__":
|
tabs/htr_tool.py
CHANGED
@@ -9,6 +9,8 @@ fast_track = FastTrack(model_loader)
|
|
9 |
|
10 |
images_for_demo = DemoImages()
|
11 |
|
|
|
|
|
12 |
|
13 |
with gr.Blocks() as htr_tool_tab:
|
14 |
with gr.Row(equal_height=True):
|
@@ -19,54 +21,131 @@ with gr.Blocks() as htr_tool_tab:
|
|
19 |
)
|
20 |
|
21 |
with gr.Row():
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
|
|
|
|
40 |
gr.Examples(
|
41 |
examples=images_for_demo.examples_list,
|
42 |
inputs=[fast_name_files_placeholder, fast_track_input_region_image],
|
43 |
label="Example images",
|
44 |
examples_per_page=5,
|
45 |
)
|
46 |
-
with gr.Row():
|
47 |
-
gr.Markdown(
|
48 |
-
"""
|
49 |
-
Image viewer for xml output:
|
50 |
-
<p align="center">
|
51 |
-
<a href="https://huggingface.co/spaces/Riksarkivet/Viewer_demo">
|
52 |
-
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-xl-dark.svg" alt="Badge 1">
|
53 |
-
</a>
|
54 |
-
</p>
|
55 |
-
|
56 |
-
"""
|
57 |
-
)
|
58 |
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
xml_rendered_placeholder_for_api = gr.Textbox(visible=False)
|
66 |
-
|
|
|
67 |
fast_track.segment_to_xml,
|
68 |
inputs=[fast_track_input_region_image, radio_file_input],
|
69 |
-
outputs=[
|
70 |
)
|
71 |
|
72 |
htr_pipeline_button_api.click(
|
@@ -75,3 +154,35 @@ with gr.Blocks() as htr_tool_tab:
|
|
75 |
outputs=[xml_rendered_placeholder_for_api],
|
76 |
api_name="predict",
|
77 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
images_for_demo = DemoImages()
|
11 |
|
12 |
+
terminate = False
|
13 |
+
|
14 |
|
15 |
with gr.Blocks() as htr_tool_tab:
|
16 |
with gr.Row(equal_height=True):
|
|
|
21 |
)
|
22 |
|
23 |
with gr.Row():
|
24 |
+
with gr.Tab("Output and Settings") as tab_output_and_setting_selector:
|
25 |
+
with gr.Row():
|
26 |
+
stop_htr_button = gr.Button(
|
27 |
+
value="Stop HTR",
|
28 |
+
variant="stop",
|
29 |
+
)
|
|
|
30 |
|
31 |
+
htr_pipeline_button = gr.Button(
|
32 |
+
"Run HTR",
|
33 |
+
variant="primary",
|
34 |
+
visible=True,
|
35 |
+
elem_id="run_pipeline_button",
|
36 |
+
)
|
37 |
+
|
38 |
+
htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1)
|
39 |
+
|
40 |
+
fast_file_downlod = gr.File(
|
41 |
+
label="Download output file", visible=True, scale=1, height=100, elem_id="download_file"
|
42 |
+
)
|
43 |
+
|
44 |
+
with gr.Tab("Image Viewer") as tab_image_viewer_selector:
|
45 |
+
with gr.Row():
|
46 |
+
gr.Button(
|
47 |
+
value="External Image Viewer",
|
48 |
+
variant="secondary",
|
49 |
+
link="https://huggingface.co/spaces/Riksarkivet/Viewer_demo",
|
50 |
+
interactive=True,
|
51 |
+
)
|
52 |
+
|
53 |
+
run_image_visualizer_button = gr.Button(
|
54 |
+
value="Visualize results", variant="primary", interactive=True
|
55 |
+
)
|
56 |
+
|
57 |
+
selection_text_from_image_viewer = gr.Textbox(
|
58 |
+
interactive=False, label="Text Selector", info="Select a mask on Image Viewer to return text"
|
59 |
+
)
|
60 |
+
|
61 |
+
with gr.Column(scale=4):
|
62 |
+
with gr.Box():
|
63 |
+
with gr.Row(visible=True) as output_and_setting_tab:
|
64 |
+
with gr.Column(scale=3):
|
65 |
+
with gr.Row():
|
66 |
+
with gr.Group():
|
67 |
+
gr.Markdown(" ⚙️ Settings ")
|
68 |
+
with gr.Row():
|
69 |
+
radio_file_input = gr.CheckboxGroup(
|
70 |
+
choices=["Txt", "XML"],
|
71 |
+
value=["XML"],
|
72 |
+
label="Output file extension",
|
73 |
+
# info="Only txt and page xml is supported for now!",
|
74 |
+
scale=1,
|
75 |
+
)
|
76 |
+
with gr.Row():
|
77 |
+
gr.Checkbox(
|
78 |
+
value=True,
|
79 |
+
label="Binarize image",
|
80 |
+
info="Binarize image to reduce background noise",
|
81 |
+
)
|
82 |
+
gr.Checkbox(
|
83 |
+
value=True,
|
84 |
+
label="Output prediction threshold",
|
85 |
+
info="Output XML with prediction score",
|
86 |
+
)
|
87 |
+
with gr.Row():
|
88 |
+
gr.Slider(
|
89 |
+
value=0.7,
|
90 |
+
minimum=0.5,
|
91 |
+
maximum=1,
|
92 |
+
label="HTR threshold",
|
93 |
+
info="Prediction score threshold for transcribed lines",
|
94 |
+
scale=1,
|
95 |
+
)
|
96 |
+
gr.Slider(
|
97 |
+
value=0.8,
|
98 |
+
minimum=0.6,
|
99 |
+
maximum=1,
|
100 |
+
label="Avg threshold",
|
101 |
+
info="Average prediction score for a region",
|
102 |
+
scale=1,
|
103 |
+
)
|
104 |
+
|
105 |
+
htr_tool_region_segment_model_dropdown = gr.Dropdown(
|
106 |
+
choices=["Riksarkivet/RmtDet_region"],
|
107 |
+
value="Riksarkivet/RmtDet_region",
|
108 |
+
label="Region segment model",
|
109 |
+
info="Will add more models later!",
|
110 |
+
)
|
111 |
+
|
112 |
+
# with gr.Accordion("Transcribe settings:", open=False):
|
113 |
+
htr_tool_line_segment_model_dropdown = gr.Dropdown(
|
114 |
+
choices=["Riksarkivet/RmtDet_lines"],
|
115 |
+
value="Riksarkivet/RmtDet_lines",
|
116 |
+
label="Line segment model",
|
117 |
+
info="Will add more models later!",
|
118 |
+
)
|
119 |
+
|
120 |
+
htr_tool_transcriber_model_dropdown = gr.Dropdown(
|
121 |
+
choices=["Riksarkivet/SATRN_transcriber", "microsoft/trocr-base-handwritten"],
|
122 |
+
value="Riksarkivet/SATRN_transcriber",
|
123 |
+
label="Transcribe model",
|
124 |
+
info="Will add more models later!",
|
125 |
+
)
|
126 |
|
127 |
+
with gr.Column(scale=2):
|
128 |
+
fast_name_files_placeholder = gr.Markdown(visible=False)
|
129 |
gr.Examples(
|
130 |
examples=images_for_demo.examples_list,
|
131 |
inputs=[fast_name_files_placeholder, fast_track_input_region_image],
|
132 |
label="Example images",
|
133 |
examples_per_page=5,
|
134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
with gr.Row(visible=False) as image_viewer_tab:
|
137 |
+
text_polygon_dict = gr.Variable()
|
138 |
|
139 |
+
fast_track_output_image = gr.Image(
|
140 |
+
label="Image Viewer", type="numpy", height=600, interactive=False
|
141 |
+
)
|
142 |
|
143 |
xml_rendered_placeholder_for_api = gr.Textbox(visible=False)
|
144 |
+
|
145 |
+
htr_event_click_event = htr_pipeline_button.click(
|
146 |
fast_track.segment_to_xml,
|
147 |
inputs=[fast_track_input_region_image, radio_file_input],
|
148 |
+
outputs=[fast_file_downlod, fast_file_downlod],
|
149 |
)
|
150 |
|
151 |
htr_pipeline_button_api.click(
|
|
|
154 |
outputs=[xml_rendered_placeholder_for_api],
|
155 |
api_name="predict",
|
156 |
)
|
157 |
+
|
158 |
+
def update_selected_tab_output_and_setting():
|
159 |
+
return gr.update(visible=True), gr.update(visible=False)
|
160 |
+
|
161 |
+
def update_selected_tab_image_viewer():
|
162 |
+
return gr.update(visible=False), gr.update(visible=True)
|
163 |
+
|
164 |
+
tab_output_and_setting_selector.select(
|
165 |
+
fn=update_selected_tab_output_and_setting, outputs=[output_and_setting_tab, image_viewer_tab]
|
166 |
+
)
|
167 |
+
|
168 |
+
tab_image_viewer_selector.select(
|
169 |
+
fn=update_selected_tab_image_viewer, outputs=[output_and_setting_tab, image_viewer_tab]
|
170 |
+
)
|
171 |
+
|
172 |
+
def stop_function():
|
173 |
+
from src.htr_pipeline.utils import pipeline_inferencer
|
174 |
+
|
175 |
+
pipeline_inferencer.terminate = True
|
176 |
+
gr.Info("The HTR execution was halted")
|
177 |
+
|
178 |
+
stop_htr_button.click(fn=stop_function, inputs=None, outputs=None, cancels=[htr_event_click_event])
|
179 |
+
|
180 |
+
run_image_visualizer_button.click(
|
181 |
+
fn=fast_track.visualize_image_viewer,
|
182 |
+
inputs=fast_track_input_region_image,
|
183 |
+
outputs=[fast_track_output_image, text_polygon_dict],
|
184 |
+
)
|
185 |
+
|
186 |
+
fast_track_output_image.select(
|
187 |
+
fast_track.get_text_from_coords, inputs=text_polygon_dict, outputs=selection_text_from_image_viewer
|
188 |
+
)
|