asr / main.py
HoneyTian's picture
update
7e1376c
raw
history blame
5.22 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
import platform
import gradio as gr
from examples import examples
from models import model_map
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
args = parser.parse_args()
return args
def update_model_dropdown(language: str):
if language not in model_map.keys():
raise ValueError(f"Unsupported language: {language}")
choices = model_map[language]
choices = [c["repo_id"] for c in choices]
return gr.Dropdown(
choices=choices,
value=choices[0],
interactive=True,
)
def build_html_output(s: str, style: str = "result_item_success"):
return f"""
<div class='result'>
<div class='result_item {style}'>
{s}
</div>
</div>
"""
def process_uploaded_file(language: str,
repo_id: str,
decoding_method: str,
num_active_paths: int,
add_punctuation: str,
in_filename: str,
):
return "Dummy", build_html_output("Dummy")
# css style is copied from
# https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
css = """
.result {display:flex;flex-direction:column}
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
.result_item_error {background-color:#ff7070;color:white;align-self:start}
"""
def main():
title = "# Automatic Speech Recognition with Next-gen Kaldi"
language_choices = list(model_map.keys())
language_to_models = defaultdict(list)
for k, v in model_map.items():
for m in v:
repo_id = m["repo_id"]
language_to_models[k].append(repo_id)
# blocks
with gr.Blocks(css=css) as blocks:
gr.Markdown(value=title)
with gr.Tabs():
with gr.TabItem("Upload from disk"):
language_radio = gr.Radio(
label="Language",
choices=language_choices,
value=language_choices[0],
)
model_dropdown = gr.Dropdown(
choices=language_to_models[language_choices[0]],
label="Select a model",
value=language_to_models[language_choices[0]][0],
)
decoding_method_radio = gr.Radio(
label="Decoding method",
choices=["greedy_search", "modified_beam_search"],
value="greedy_search",
)
num_active_paths_slider = gr.Slider(
minimum=1,
value=4,
step=1,
label="Number of active paths for modified_beam_search",
)
punct_radio = gr.Radio(
label="Whether to add punctuation (Only for Chinese and English)",
choices=["Yes", "No"],
value="Yes",
)
uploaded_file = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload from disk",
)
upload_button = gr.Button("Submit for recognition")
uploaded_output = gr.Textbox(label="Recognized speech from uploaded file")
uploaded_html_info = gr.HTML(label="Info")
gr.Examples(
examples=examples,
inputs=[
language_radio,
model_dropdown,
decoding_method_radio,
num_active_paths_slider,
punct_radio,
uploaded_file,
],
outputs=[uploaded_output, uploaded_html_info],
fn=process_uploaded_file,
)
upload_button.click(
process_uploaded_file,
inputs=[
language_radio,
model_dropdown,
decoding_method_radio,
num_active_paths_slider,
punct_radio,
uploaded_file,
],
outputs=[uploaded_output, uploaded_html_info],
)
language_radio.change(
update_model_dropdown,
inputs=language_radio,
outputs=model_dropdown,
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=7860
)
return
if __name__ == "__main__":
main()