Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
from transformers import AutoTokenizer | |
models = { | |
"RUPunct-small": "RUPunct/RUPunct_small", | |
"RUPunct-big": "RUPunct/RUPunct_big", | |
"RUPunct-medium": "RUPunct/RUPunct_medium" | |
} | |
pipelines = {} | |
for model_name, model_path in models.items(): | |
tokenizer = AutoTokenizer.from_pretrained(model_path, strip_accents=False, add_prefix_space=True) | |
pipelines[model_name] = pipeline("ner", model=model_path, tokenizer=tokenizer, aggregation_strategy="first") | |
def process_token(token, label): | |
if label == "LOWER_O": | |
return token | |
if label == "LOWER_PERIOD": | |
return token + "." | |
if label == "LOWER_COMMA": | |
return token + "," | |
if label == "LOWER_QUESTION": | |
return token + "?" | |
if label == "LOWER_TIRE": | |
return token + "—" | |
if label == "LOWER_DVOETOCHIE": | |
return token + ":" | |
if label == "LOWER_VOSKL": | |
return token + "!" | |
if label == "LOWER_PERIODCOMMA": | |
return token + ";" | |
if label == "LOWER_DEFIS": | |
return token + "-" | |
if label == "LOWER_MNOGOTOCHIE": | |
return token + "..." | |
if label == "LOWER_QUESTIONVOSKL": | |
return token + "?!" | |
if label == "UPPER_O": | |
return token.capitalize() | |
if label == "UPPER_PERIOD": | |
return token.capitalize() + "." | |
if label == "UPPER_COMMA": | |
return token.capitalize() + "," | |
if label == "UPPER_QUESTION": | |
return token.capitalize() + "?" | |
if label == "UPPER_TIRE": | |
return token.capitalize() + " —" | |
if label == "UPPER_DVOETOCHIE": | |
return token.capitalize() + ":" | |
if label == "UPPER_VOSKL": | |
return token.capitalize() + "!" | |
if label == "UPPER_PERIODCOMMA": | |
return token.capitalize() + ";" | |
if label == "UPPER_DEFIS": | |
return token.capitalize() + "-" | |
if label == "UPPER_MNOGOTOCHIE": | |
return token.capitalize() + "..." | |
if label == "UPPER_QUESTIONVOSKL": | |
return token.capitalize() + "?!" | |
if label == "UPPER_TOTAL_O": | |
return token.upper() | |
if label == "UPPER_TOTAL_PERIOD": | |
return token.upper() + "." | |
if label == "UPPER_TOTAL_COMMA": | |
return token.upper() + "," | |
if label == "UPPER_TOTAL_QUESTION": | |
return token.upper() + "?" | |
if label == "UPPER_TOTAL_TIRE": | |
return token.upper() + " —" | |
if label == "UPPER_TOTAL_DVOETOCHIE": | |
return token.upper() + ":" | |
if label == "UPPER_TOTAL_VOSKL": | |
return token.upper() + "!" | |
if label == "UPPER_TOTAL_PERIODCOMMA": | |
return token.upper() + ";" | |
if label == "UPPER_TOTAL_DEFIS": | |
return token.upper() + "-" | |
if label == "UPPER_TOTAL_MNOGOTOCHIE": | |
return token.upper() + "..." | |
if label == "UPPER_TOTAL_QUESTIONVOSKL": | |
return token.upper() + "?!" | |
def punctuate(input_text, model_name): | |
classifier = pipelines[model_name] | |
preds = classifier(input_text) | |
output = "" | |
for item in preds: | |
if item["word"] == ".": | |
item["entity_group"] = "LOWER_O" | |
output += " " + process_token(item['word'].strip(), item['entity_group']) | |
return output.strip() | |
iface = gr.Interface( | |
fn=punctuate, | |
inputs=[ | |
gr.components.Textbox(lines=5, placeholder="Введите текст"), | |
gr.components.Radio(list(models.keys()), label="Модель") | |
], | |
outputs="text", | |
title="RUPunct", | |
description="Демо RUPunct - модели для автоматической расстановки знаков препинания в русском тексте.", | |
) | |
iface.launch() |