htr_demo / tabs /htr_tool.py
Gabriel's picture
test new scheduler
14d4a0b
raw
history blame
8.84 kB
import gradio as gr
from helper.examples.examples import DemoImages
from src.htr_pipeline.gradio_backend import FastTrack, SingletonModelLoader
model_loader = SingletonModelLoader()
fast_track = FastTrack(model_loader)
images_for_demo = DemoImages()
terminate = False
with gr.Blocks() as htr_tool_tab:
with gr.Row(equal_height=True):
with gr.Column(scale=2):
with gr.Row():
fast_track_input_region_image = gr.Image(
label="Image to run HTR on", type="numpy", tool="editor", elem_id="image_upload", height=395
)
with gr.Row():
with gr.Tab("Output and Settings") as tab_output_and_setting_selector:
with gr.Row():
stop_htr_button = gr.Button(
value="Stop HTR",
variant="stop",
)
htr_pipeline_button = gr.Button(
"Run HTR",
variant="primary",
visible=True,
elem_id="run_pipeline_button",
)
htr_pipeline_button_api = gr.Button("Run pipeline", variant="primary", visible=False, scale=1)
fast_file_downlod = gr.File(
label="Download output file", visible=True, scale=1, height=100, elem_id="download_file"
)
with gr.Tab("Image Viewer") as tab_image_viewer_selector:
with gr.Row():
gr.Button(
value="External Image Viewer",
variant="secondary",
link="https://huggingface.co/spaces/Riksarkivet/Viewer_demo",
interactive=True,
)
run_image_visualizer_button = gr.Button(
value="Visualize results", variant="primary", interactive=True
)
selection_text_from_image_viewer = gr.Textbox(
interactive=False, label="Text Selector", info="Select a mask on Image Viewer to return text"
)
with gr.Column(scale=4):
with gr.Box():
with gr.Row(visible=True) as output_and_setting_tab:
with gr.Column(scale=3):
with gr.Row():
with gr.Group():
gr.Markdown("   ⚙️ Settings ")
with gr.Row():
radio_file_input = gr.CheckboxGroup(
choices=["Txt", "XML"],
value=["Txt", "XML"],
label="Output file extension",
# info="Only txt and page xml is supported for now!",
scale=1,
)
with gr.Row():
gr.Checkbox(
value=True,
label="Binarize image",
info="Binarize image to reduce background noise",
)
gr.Checkbox(
value=True,
label="Output prediction threshold",
info="Output XML with prediction score",
)
with gr.Row():
gr.Slider(
value=0.6,
minimum=0.5,
maximum=1,
label="HTR threshold",
info="Prediction score threshold for transcribed lines",
scale=1,
)
gr.Slider(
value=0.7,
minimum=0.6,
maximum=1,
label="Avg threshold",
info="Average prediction score for a region",
scale=1,
)
htr_tool_region_segment_model_dropdown = gr.Dropdown(
choices=["Riksarkivet/rtmdet_region"],
value="Riksarkivet/rtmdet_region",
label="Region Segment models",
info="Will add more models later!",
)
# with gr.Accordion("Transcribe settings:", open=False):
htr_tool_line_segment_model_dropdown = gr.Dropdown(
choices=["Riksarkivet/rtmdet_lines"],
value="Riksarkivet/rtmdet_lines",
label="Line Segment models",
info="Will add more models later!",
)
htr_tool_transcriber_model_dropdown = gr.Dropdown(
choices=["Riksarkivet/satrn_htr", "microsoft/trocr-base-handwritten"],
value="Riksarkivet/satrn_htr",
label="Transcriber models",
info="Models will be continuously updated with future additions for specific cases.",
)
with gr.Column(scale=2):
fast_name_files_placeholder = gr.Markdown(visible=False)
gr.Examples(
examples=images_for_demo.examples_list,
inputs=[fast_name_files_placeholder, fast_track_input_region_image],
label="Example images",
examples_per_page=5,
)
with gr.Row(visible=False) as image_viewer_tab:
text_polygon_dict = gr.Variable()
fast_track_output_image = gr.Image(
label="Image Viewer", type="numpy", height=600, interactive=False
)
xml_rendered_placeholder_for_api = gr.Textbox(visible=False)
htr_event_click_event = htr_pipeline_button.click(
fast_track.segment_to_xml,
inputs=[fast_track_input_region_image, radio_file_input],
outputs=[fast_file_downlod, fast_file_downlod],
)
htr_pipeline_button_api.click(
fast_track.segment_to_xml_api,
inputs=[fast_track_input_region_image],
outputs=[xml_rendered_placeholder_for_api],
api_name="predict",
)
def dummy_update_htr_tool_transcriber_model_dropdown(htr_tool_transcriber_model_dropdown):
return gr.update(value="Riksarkivet/satrn_htr")
htr_tool_transcriber_model_dropdown.change(
fn=dummy_update_htr_tool_transcriber_model_dropdown,
inputs=htr_tool_transcriber_model_dropdown,
outputs=htr_tool_transcriber_model_dropdown,
)
def update_selected_tab_output_and_setting():
return gr.update(visible=True), gr.update(visible=False)
def update_selected_tab_image_viewer():
return gr.update(visible=False), gr.update(visible=True)
tab_output_and_setting_selector.select(
fn=update_selected_tab_output_and_setting, outputs=[output_and_setting_tab, image_viewer_tab]
)
tab_image_viewer_selector.select(
fn=update_selected_tab_image_viewer, outputs=[output_and_setting_tab, image_viewer_tab]
)
def stop_function():
from src.htr_pipeline.utils import pipeline_inferencer
pipeline_inferencer.terminate = True
gr.Info("The HTR execution was halted")
stop_htr_button.click(fn=stop_function, inputs=None, outputs=None, cancels=[htr_event_click_event])
run_image_visualizer_button.click(
fn=fast_track.visualize_image_viewer,
inputs=fast_track_input_region_image,
outputs=[fast_track_output_image, text_polygon_dict],
)
fast_track_output_image.select(
fast_track.get_text_from_coords, inputs=text_polygon_dict, outputs=selection_text_from_image_viewer
)