import requests import streamlit as st import time from transformers import pipeline import os from .utils import query HF_AUTH_TOKEN = os.getenv('HF_AUTH_TOKEN') headers = {"Authorization": f"Bearer {HF_AUTH_TOKEN}"} def write(): #st.set_page_config(page_title="Text Summarization", page_icon="📈") st.markdown("# Text Summarization") st.sidebar.header("Text Summarization") st.write( """Here, you can summarize your text using the fine-tuned TURNA summarization models. """ ) # Sidebar # Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py st.sidebar.subheader("Configurable parameters") model_name = st.sidebar.selectbox( "Model Selector", options=[ "turna_summarization_tr_news", "turna_summarization_mlsum" ], index=0, ) max_new_tokens = st.sidebar.number_input( "Maximum length", min_value=0, max_value=128, value=128, help="The maximum length of the sequence to be generated.", ) length_penalty = st.sidebar.number_input( "Length penalty", value=2.0, help=" length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. ", ) """do_sample = st.sidebar.selectbox( "Sampling?", (True, False), help="Whether or not to use sampling; use greedy decoding otherwise.", ) num_beams = st.sidebar.number_input( "Number of beams", min_value=1, max_value=10, value=3, help="The number of beams to use for beam search.", ) repetition_penalty = st.sidebar.number_input( "Repetition Penalty", min_value=0.0, value=3.0, step=0.1, help="The parameter for repetition penalty. 1.0 means no penalty", )""" no_repeat_ngram_size = st.sidebar.number_input( "No Repeat N-Gram Size", min_value=0, value=3, help="If set to int > 0, all ngrams of that size can only occur once.", ) input_text = st.text_area(label='Enter a text: ', height=200, value="Kalp krizi geçirenlerin yaklaşık üçte birinin kısa bir süre önce grip atlattığı düşünülüyor. Peki grip virüsü ne yapıyor da kalp krizine yol açıyor? Karpuz şöyle açıkladı: Grip virüsü kanın yapışkanlığını veya pıhtılaşmasını artırıyor.") url = ("https://api-inference.huggingface.co/models/boun-tabi-LMG/" + model_name.lower()) params = {"length_penalty": length_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "max_new_tokens": max_new_tokens, "decoder_start_token_id": 0, "eos_token_id": 1, "pad_token_id": 0 } if st.button("Generate"): with st.spinner('Generating...'): output = query(input_text, url, params) st.success(output)