rsp-test / app.py
momo's picture
test
fc3a84b
raw
history blame
1.68 kB
"""
baseline_interactive.py
"""
import gradio as gr
from transformers import MBartForConditionalGeneration, MBartTokenizer
from transformers import pipeline
model_name = "momo/rsp-sum"
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer = MBartTokenizer.from_pretrained(model_name, src_lang="ko_KR", tgt_lang="ko_KR")
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
def summarization(News):
result = summarizer(News, min_length=50, max_length=150)
return result[0]["summary_text"]
if __name__ == '__main__':
app = gr.Interface(
fn=summarization,
inputs=gr.inputs.Textbox(lines=10, label="News"),
outputs=gr.outputs.Textbox(label="Summary"),
title="한국어 뉴스 요약 생성기",
description="Korean News Summary Generator"
)
app.launch()
# with torch.no_grad():
# while True:
# t = input("\nDocument: ")
# tokens = tokenizer(
# t,
# return_tensors="pt",
# truncation=True,
# padding=True,
# max_length=600
# )
# input_ids = tokens.input_ids.cuda()
# attention_mask = tokens.attention_mask.cuda()
# sample_output = model.generate(
# input_ids,
# max_length=150,
# num_beams=5,
# early_stopping=True,
# no_repeat_ngram_size=8,
# )
# # print("token:" + str(input_ids.detach().cpu()))
# # print("token:" + tokenizer.convert_ids_to_tokens(str(input_ids.detach().cpu())))
# print("Summary: " + tokenizer.decode(sample_output[0], skip_special_tokens=True))