TURNA / apps /summarization.py
gokceuludogan's picture
Fix summarization
04ab1d5 verified
import requests
import streamlit as st
import time
from transformers import pipeline
import os
from .utils import query
HF_AUTH_TOKEN = os.getenv('HF_AUTH_TOKEN')
headers = {"Authorization": f"Bearer {HF_AUTH_TOKEN}"}
def write():
#st.set_page_config(page_title="Text Summarization", page_icon="📈")
st.markdown("# Text Summarization")
st.sidebar.header("Text Summarization")
st.write(
"""Here, you can summarize your text using the fine-tuned TURNA summarization models. """
)
# Sidebar
# Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
st.sidebar.subheader("Configurable parameters")
model_name = st.sidebar.selectbox(
"Model Selector",
options=[
"turna_summarization_tr_news",
"turna_summarization_mlsum"
],
index=0,
)
max_new_tokens = st.sidebar.number_input(
"Maximum length",
min_value=0,
max_value=128,
value=128,
help="The maximum length of the sequence to be generated.",
)
length_penalty = st.sidebar.number_input(
"Length penalty",
value=2.0,
help=" length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. ",
)
"""do_sample = st.sidebar.selectbox(
"Sampling?",
(True, False),
help="Whether or not to use sampling; use greedy decoding otherwise.",
)
num_beams = st.sidebar.number_input(
"Number of beams",
min_value=1,
max_value=10,
value=3,
help="The number of beams to use for beam search.",
)
repetition_penalty = st.sidebar.number_input(
"Repetition Penalty",
min_value=0.0,
value=3.0,
step=0.1,
help="The parameter for repetition penalty. 1.0 means no penalty",
)"""
no_repeat_ngram_size = st.sidebar.number_input(
"No Repeat N-Gram Size",
min_value=0,
value=3,
help="If set to int > 0, all ngrams of that size can only occur once.",
)
input_text = st.text_area(label='Enter a text: ', height=200,
value="Kalp krizi geçirenlerin yaklaşık üçte birinin kısa bir süre önce grip atlattığı düşünülüyor. Peki grip virüsü ne yapıyor da kalp krizine yol açıyor? Karpuz şöyle açıkladı: Grip virüsü kanın yapışkanlığını veya pıhtılaşmasını artırıyor.")
url = ("https://api-inference.huggingface.co/models/boun-tabi-LMG/" + model_name.lower())
params = {"length_penalty": length_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "max_new_tokens": max_new_tokens, "decoder_start_token_id": 0, "eos_token_id": 1, "pad_token_id": 0 }
if st.button("Generate"):
with st.spinner('Generating...'):
output = query(input_text, url, params)
st.success(output)