Spaces:
Running
on
T4
Running
on
T4
debug print
Browse files
src/htr_pipeline/inferencer.py
CHANGED
@@ -26,25 +26,51 @@ class Inferencer:
|
|
26 |
|
27 |
@timer_func
|
28 |
def predict_regions(self, input_image, pred_score_threshold=0.5, containments_threshold=0.5, visualize=True):
|
|
|
|
|
|
|
|
|
29 |
input_image = self.preprocess_img.binarize_img(input_image)
|
30 |
|
31 |
image = mmcv.imread(input_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
result = self.seg_model(image, return_datasample=True)
|
33 |
result_pred = result["predictions"][0]
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
filtered_result_pred = self.postprocess_seg_mask.filter_on_pred_threshold(
|
36 |
result_pred, pred_score_threshold=pred_score_threshold
|
37 |
)
|
38 |
|
|
|
|
|
|
|
|
|
39 |
if len(filtered_result_pred.pred_instances.masks) == 0:
|
40 |
raise gr.Error("No Regions were predicted by the model")
|
41 |
|
42 |
else:
|
|
|
|
|
43 |
result_align = self.process_seg_mask.align_masks_with_image(filtered_result_pred, image)
|
44 |
result_clean = self.postprocess_seg_mask.remove_overlapping_masks(
|
45 |
predicted_mask=result_align, containments_threshold=containments_threshold
|
46 |
)
|
47 |
|
|
|
|
|
|
|
|
|
48 |
if visualize:
|
49 |
result_viz = self.seg_model.visualize(
|
50 |
inputs=[image], preds=[result_clean], return_vis=True, no_save_vis=True
|
@@ -52,6 +78,8 @@ class Inferencer:
|
|
52 |
else:
|
53 |
result_viz = None
|
54 |
|
|
|
|
|
55 |
regions_cropped, polygons = self.process_seg_mask.crop_masks(result_clean, image)
|
56 |
order = self.ordering.order_regions_marginalia(result_clean)
|
57 |
|
@@ -59,6 +87,10 @@ class Inferencer:
|
|
59 |
polygons_ordered = [polygons[i] for i in order]
|
60 |
masks_ordered = [result_clean.pred_instances.masks[i] for i in order]
|
61 |
|
|
|
|
|
|
|
|
|
62 |
return result_viz, regions_cropped_ordered, polygons_ordered, masks_ordered
|
63 |
|
64 |
@timer_func
|
|
|
26 |
|
27 |
@timer_func
|
28 |
def predict_regions(self, input_image, pred_score_threshold=0.5, containments_threshold=0.5, visualize=True):
|
29 |
+
import time
|
30 |
+
|
31 |
+
t1 = time.time()
|
32 |
+
|
33 |
input_image = self.preprocess_img.binarize_img(input_image)
|
34 |
|
35 |
image = mmcv.imread(input_image)
|
36 |
+
|
37 |
+
t2 = time.time()
|
38 |
+
|
39 |
+
print(f"Function executed bin and read in {(t2-t1):.4f}s")
|
40 |
+
|
41 |
+
t1 = time.time()
|
42 |
+
|
43 |
result = self.seg_model(image, return_datasample=True)
|
44 |
result_pred = result["predictions"][0]
|
45 |
+
t2 = time.time()
|
46 |
+
|
47 |
+
print(f"Function executed predict in {(t2-t1):.4f}s")
|
48 |
+
|
49 |
+
t1 = time.time()
|
50 |
|
51 |
filtered_result_pred = self.postprocess_seg_mask.filter_on_pred_threshold(
|
52 |
result_pred, pred_score_threshold=pred_score_threshold
|
53 |
)
|
54 |
|
55 |
+
t2 = time.time()
|
56 |
+
|
57 |
+
print(f"Function executed filter in {(t2-t1):.4f}s")
|
58 |
+
|
59 |
if len(filtered_result_pred.pred_instances.masks) == 0:
|
60 |
raise gr.Error("No Regions were predicted by the model")
|
61 |
|
62 |
else:
|
63 |
+
t1 = time.time()
|
64 |
+
|
65 |
result_align = self.process_seg_mask.align_masks_with_image(filtered_result_pred, image)
|
66 |
result_clean = self.postprocess_seg_mask.remove_overlapping_masks(
|
67 |
predicted_mask=result_align, containments_threshold=containments_threshold
|
68 |
)
|
69 |
|
70 |
+
t2 = time.time()
|
71 |
+
|
72 |
+
print(f"Function executed align and remove in {(t2-t1):.4f}s")
|
73 |
+
|
74 |
if visualize:
|
75 |
result_viz = self.seg_model.visualize(
|
76 |
inputs=[image], preds=[result_clean], return_vis=True, no_save_vis=True
|
|
|
78 |
else:
|
79 |
result_viz = None
|
80 |
|
81 |
+
t1 = time.time()
|
82 |
+
|
83 |
regions_cropped, polygons = self.process_seg_mask.crop_masks(result_clean, image)
|
84 |
order = self.ordering.order_regions_marginalia(result_clean)
|
85 |
|
|
|
87 |
polygons_ordered = [polygons[i] for i in order]
|
88 |
masks_ordered = [result_clean.pred_instances.masks[i] for i in order]
|
89 |
|
90 |
+
t2 = time.time()
|
91 |
+
|
92 |
+
print(f"Function executed crop and margin in {(t2-t1):.4f}s")
|
93 |
+
|
94 |
return result_viz, regions_cropped_ordered, polygons_ordered, masks_ordered
|
95 |
|
96 |
@timer_func
|
src/htr_pipeline/utils/pipeline_inferencer.py
CHANGED
@@ -6,6 +6,8 @@ from src.htr_pipeline.utils.xml_helper import XMLHelper
|
|
6 |
|
7 |
terminate = False
|
8 |
|
|
|
|
|
9 |
|
10 |
class PipelineInferencer:
|
11 |
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
|
@@ -29,7 +31,7 @@ class PipelineInferencer:
|
|
29 |
pred_score_threshold_regions,
|
30 |
pred_score_threshold_lines,
|
31 |
containments_threshold,
|
32 |
-
htr_threshold=0.
|
33 |
):
|
34 |
global terminate
|
35 |
|
@@ -77,7 +79,7 @@ class PipelineInferencer:
|
|
77 |
return region_data if mean_htr_score > htr_threshold + 0.1 else None
|
78 |
|
79 |
def _process_lines(
|
80 |
-
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.
|
81 |
):
|
82 |
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
|
83 |
text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
|
|
|
6 |
|
7 |
terminate = False
|
8 |
|
9 |
+
# TODO check why region is so slow to start.. Is their error with loading the model?
|
10 |
+
|
11 |
|
12 |
class PipelineInferencer:
|
13 |
def __init__(self, process_seg_mask: SegMaskHelper, xml_helper: XMLHelper):
|
|
|
31 |
pred_score_threshold_regions,
|
32 |
pred_score_threshold_lines,
|
33 |
containments_threshold,
|
34 |
+
htr_threshold=0.6,
|
35 |
):
|
36 |
global terminate
|
37 |
|
|
|
79 |
return region_data if mean_htr_score > htr_threshold + 0.1 else None
|
80 |
|
81 |
def _process_lines(
|
82 |
+
self, text_region, inferencer, pred_score_threshold, containments_threshold, mask, region_id, htr_threshold=0.6
|
83 |
):
|
84 |
_, lines_cropped_ordered, line_polygons_ordered = inferencer.predict_lines(
|
85 |
text_region, pred_score_threshold, containments_threshold, visualize=False, custom_track=False
|
tabs/htr_tool.py
CHANGED
@@ -86,7 +86,7 @@ with gr.Blocks() as htr_tool_tab:
|
|
86 |
)
|
87 |
with gr.Row():
|
88 |
gr.Slider(
|
89 |
-
value=0.
|
90 |
minimum=0.5,
|
91 |
maximum=1,
|
92 |
label="HTR threshold",
|
@@ -94,7 +94,7 @@ with gr.Blocks() as htr_tool_tab:
|
|
94 |
scale=1,
|
95 |
)
|
96 |
gr.Slider(
|
97 |
-
value=0.
|
98 |
minimum=0.6,
|
99 |
maximum=1,
|
100 |
label="Avg threshold",
|
@@ -105,7 +105,7 @@ with gr.Blocks() as htr_tool_tab:
|
|
105 |
htr_tool_region_segment_model_dropdown = gr.Dropdown(
|
106 |
choices=["Riksarkivet/RmtDet_region"],
|
107 |
value="Riksarkivet/RmtDet_region",
|
108 |
-
label="Region
|
109 |
info="Will add more models later!",
|
110 |
)
|
111 |
|
@@ -113,15 +113,15 @@ with gr.Blocks() as htr_tool_tab:
|
|
113 |
htr_tool_line_segment_model_dropdown = gr.Dropdown(
|
114 |
choices=["Riksarkivet/RmtDet_lines"],
|
115 |
value="Riksarkivet/RmtDet_lines",
|
116 |
-
label="Line
|
117 |
info="Will add more models later!",
|
118 |
)
|
119 |
|
120 |
htr_tool_transcriber_model_dropdown = gr.Dropdown(
|
121 |
choices=["Riksarkivet/SATRN_transcriber", "microsoft/trocr-base-handwritten"],
|
122 |
value="Riksarkivet/SATRN_transcriber",
|
123 |
-
label="
|
124 |
-
info="
|
125 |
)
|
126 |
|
127 |
with gr.Column(scale=2):
|
|
|
86 |
)
|
87 |
with gr.Row():
|
88 |
gr.Slider(
|
89 |
+
value=0.6,
|
90 |
minimum=0.5,
|
91 |
maximum=1,
|
92 |
label="HTR threshold",
|
|
|
94 |
scale=1,
|
95 |
)
|
96 |
gr.Slider(
|
97 |
+
value=0.7,
|
98 |
minimum=0.6,
|
99 |
maximum=1,
|
100 |
label="Avg threshold",
|
|
|
105 |
htr_tool_region_segment_model_dropdown = gr.Dropdown(
|
106 |
choices=["Riksarkivet/RmtDet_region"],
|
107 |
value="Riksarkivet/RmtDet_region",
|
108 |
+
label="Region Segment models",
|
109 |
info="Will add more models later!",
|
110 |
)
|
111 |
|
|
|
113 |
htr_tool_line_segment_model_dropdown = gr.Dropdown(
|
114 |
choices=["Riksarkivet/RmtDet_lines"],
|
115 |
value="Riksarkivet/RmtDet_lines",
|
116 |
+
label="Line Segment models",
|
117 |
info="Will add more models later!",
|
118 |
)
|
119 |
|
120 |
htr_tool_transcriber_model_dropdown = gr.Dropdown(
|
121 |
choices=["Riksarkivet/SATRN_transcriber", "microsoft/trocr-base-handwritten"],
|
122 |
value="Riksarkivet/SATRN_transcriber",
|
123 |
+
label="Transcriber models",
|
124 |
+
info="Models will be continuously updated with future additions for specific cases.",
|
125 |
)
|
126 |
|
127 |
with gr.Column(scale=2):
|