rsp-test / app.py
momo's picture
test
d60334d
raw
history blame
1.84 kB
"""
baseline_interactive.py
"""
import gradio as gr
from transformers import MBartForConditionalGeneration, MBartTokenizer
from transformers import pipeline
model_name = "momo/rsp-sum"
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer = MBartTokenizer.from_pretrained(model_name, src_lang="ko_KR", tgt_lang="ko_KR")
# prefix = "translate English to German: "
def summarization(News, Summary):
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
summarizer(News, min_length=50, max_length=150)
for result in summarizer(News):
print(result)
return result
if __name__ == '__main__':
#Create a gradio app with a button that calls predict()
app = gr.Interface(
fn=summarization,
inputs=gr.inputs.Textbox(lines=10, label="News"),
outputs=gr.outputs.Textbox(lines=10, label="Summary"),
title="한국어 뉴스 요약 생성기",
description="Korean News Summary Generator"
)
app.launch()
# with torch.no_grad():
# while True:
# t = input("\nDocument: ")
# tokens = tokenizer(
# t,
# return_tensors="pt",
# truncation=True,
# padding=True,
# max_length=600
# )
# input_ids = tokens.input_ids.cuda()
# attention_mask = tokens.attention_mask.cuda()
# sample_output = model.generate(
# input_ids,
# max_length=150,
# num_beams=5,
# early_stopping=True,
# no_repeat_ngram_size=8,
# )
# # print("token:" + str(input_ids.detach().cpu()))
# # print("token:" + tokenizer.convert_ids_to_tokens(str(input_ids.detach().cpu())))
# print("Summary: " + tokenizer.decode(sample_output[0], skip_special_tokens=True))