import sys import time from importlib.metadata import version import torch import gradio as gr from transformers import MBartForConditionalGeneration, AutoTokenizer # Config model_name = "/home/user/app/mbart-large-50-verbalization" concurrency_limit = 5 device = "cuda" if torch.cuda.is_available() else "cpu" # Load the model model = MBartForConditionalGeneration.from_pretrained( model_name, low_cpu_mem_usage=True, device_map=device, ) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.src_lang = "uk_XX" tokenizer.tgt_lang = "uk_XX" examples = [ "WP: F-16 навряд чи значно змінять ситуацію на полі бою", "Над Україною збили ракету та 7 з 8 Шахедів", "Олімпійські ігри-2024. Розклад змагань українських спортсменів на 28 липня", "Кампанія Гарріс менш як за тиждень зібрала понад $200 млн", "За тиждень Нацбанк продав майже 800 мільйонів доларів на міжбанку", "Париж-2024. День 2. Текстова трансляція", ] title = "Normalize Text for Ukrainian" # https://www.tablesgenerator.com/markdown_tables authors_table = """ ## Authors Follow them on social networks and **contact** if you need any help or have any questions: | **Yehor Smoliakov** | |-------------------------------------------------------------------------------------------------| | https://t.me/smlkw in Telegram | | https://x.com/yehor_smoliakov at X | | https://github.com/egorsmkv at GitHub | | https://huggingface.co/Yehor at Hugging Face | | or use egorsmkv@gmail.com | """.strip() description_head = f""" # {title} ## Overview This space uses https://huggingface.co/skypro1111/mbart-large-50-verbalization model. Paste the text you want to enhance. """.strip() description_foot = f""" {authors_table} """.strip() normalized_text_value = """ Normalized text will appear here. Choose **an example** below the Normalize button or paste **your text**. """.strip() tech_env = f""" #### Environment - Python: {sys.version} """.strip() tech_libraries = f""" #### Libraries - torch: {version('torch')} - gradio: {version('gradio')} - transformers: {version('transformers')} """.strip() def inference(text, progress=gr.Progress()): if not text: raise gr.Error("Please paste your text.") gr.Info("Starting normalizing", duration=2) progress(0, desc="Normalizing...") results = [] sentences = [ text, ] for sentence in progress.tqdm(sentences, desc="Normalizing...", unit="sentence"): sentence = sentence.strip() if len(sentence) == 0: continue t0 = time.time() input_text = ":" + sentence encoded_input = tokenizer( input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024, ).to(device) output_ids = model.generate( **encoded_input, max_length=1024, num_beams=5, early_stopping=True ) normalized_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) if not normalized_text: normalized_text = "-" elapsed_time = round(time.time() - t0, 2) normalized_text = normalized_text.strip() results.append( { "sentence": sentence, "normalized_text": normalized_text, "elapsed_time": elapsed_time, } ) gr.Info("Finished!", duration=2) result_texts = [] for result in results: result_texts.append(f'> {result["normalized_text"]}') result_texts.append("\n") sum_elapsed_text = sum([result["elapsed_time"] for result in results]) result_texts.append(f"Elapsed time: {sum_elapsed_text} seconds") return "\n".join(result_texts) demo = gr.Blocks( title=title, analytics_enabled=False, # theme="huggingface", theme=gr.themes.Base(), ) with demo: gr.Markdown(description_head) gr.Markdown("## Usage") with gr.Row(): text = gr.Textbox(label="Text", autofocus=True, max_lines=1) normalized_text = gr.Textbox( label="Normalized text", placeholder=normalized_text_value, show_copy_button=True, ) gr.Button("Normalize").click( inference, concurrency_limit=concurrency_limit, inputs=text, outputs=normalized_text, ) with gr.Row(): gr.Examples(label="Choose an example", inputs=text, examples=examples) gr.Markdown(description_foot) gr.Markdown("### Gradio app uses:") gr.Markdown(tech_env) gr.Markdown(tech_libraries) if __name__ == "__main__": demo.queue() demo.launch()