File size: 8,166 Bytes
99e744f
c293aab
 
99e744f
0528be1
 
 
c293aab
99e744f
 
0528be1
99e744f
0528be1
 
 
 
 
 
 
c293aab
99e744f
 
 
 
 
c293aab
 
1df85e0
c293aab
 
 
 
 
99e744f
c293aab
 
 
99e744f
c293aab
99e744f
c293aab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1df85e0
c293aab
 
 
 
 
 
 
 
052ff21
 
 
c293aab
 
 
 
 
052ff21
c293aab
 
052ff21
 
c293aab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1df85e0
 
 
c293aab
 
 
 
 
 
1df85e0
 
 
99e744f
 
 
0528be1
 
052ff21
 
 
 
 
 
 
0528be1
c293aab
0528be1
 
 
 
 
 
 
 
 
99e744f
0528be1
 
99e744f
0528be1
 
 
99e744f
0528be1
 
 
 
 
 
 
 
99e744f
c293aab
99e744f
c293aab
 
 
 
 
99e744f
c293aab
99e744f
c293aab
 
 
99e744f
 
c293aab
 
 
 
99e744f
c293aab
 
99e744f
c293aab
 
 
 
 
 
 
 
 
 
 
99e744f
 
 
 
 
 
4cd6173
 
99e744f
 
c293aab
 
 
 
 
 
99e744f
 
 
c293aab
 
 
 
99e744f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import datetime
import os
import time
import logging
import nltk
import validators
import streamlit as st
from summarizer import summarizer_init, summarizer_summarize
from config import MODELS
from warnings import filterwarnings

filterwarnings("ignore")
from utils import (
    clean_text,
    fetch_article_text,
    preprocess_text_for_abstractive_summarization,
    read_text_from_file,
)

# from rouge import Rouge

logger = logging.getLogger(__name__)

def initialize_app():
    nltk.download("punkt")
    SESSION_DEFAULTS = {
        "model_type": "local",
        "model_name": "Boardpac summarizer v1",
        "summarizer_type": "Map Reduce",
        "is_parameters_changed":False,
        # "user_question":'',
        'openai_api_key':'',
    }

    for k, v in SESSION_DEFAULTS.items():
        if k not in st.session_state:
            st.session_state[k] = v

    # init_summarizer(st.session_state.model_name,api_key=None)

@st.cache_resource
def init_summarizer(model_name,api_key=None):
    with st.spinner(
            text="initialising the summarizer. This might take a few seconds ..."
        ):
        model_type = "local"
        if model_name == "OpenAI":
            model_type = "openai"

        model_path = MODELS[model_name]
        if model_type == "openai":
            #validation logic
            api_key = st.session_state.openai_api_key
            tokenizer,base_summarizer = summarizer_init(model_path,model_type,api_key)
        else:
            logger.info(f"Model for summarization : {model_path}")
            tokenizer,base_summarizer = summarizer_init(model_path, model_type)

        alert =  st.success("summarizer initialised")
        time.sleep(1) # Wait for 1 seconds
        alert.empty() # Clear the alert
        return model_type, tokenizer, base_summarizer

def update_parameters_change():
    st.session_state.is_parameters_changed = True


def parameters_change_button(model_name, summarizer_type):
    st.session_state.model_name = model_name
    st.session_state.summarizer_type = summarizer_type
    st.session_state.is_parameters_changed = False
    # init_summarizer(model_name,api_key=None)
    alert =  st.success("chat parameters updated")
    time.sleep(2) # Wait for 1 seconds
    alert.empty() # Clear the alert

import re
def is_valid_open_ai_api_key(secretKey):
    if re.search("^sk-[a-zA-Z0-9]{32,}$", secretKey ): 
        return True
    else: return False

def side_bar():
    with st.sidebar:
        st.subheader("Model parameters")

        with st.form('param_form'):
            # st.info('Info: use openai chat model for best results')
            model_name = st.selectbox(
                "Summary model",
                MODELS,
                #  options=["long-t5 v0", "long-t5 v1",  "pegasus-x-large v1", "OpenAI"],
                key="Model Name",
                help="Select the LLM model for summarization",
                # on_change=update_parameters_change,
            )

            summarizer_type = st.selectbox(
                "Summarizer Type for Long Text", 
                # options=["Map Reduce", "Refine"]
                options=["Map Reduce"]
            )

            submitted = st.form_submit_button(
                "Save Parameters",
                # on_click=update_parameters_change
                disabled = True
                )
       
            # if submitted:
            #     parameters_change_button(model_name, summarizer_type)


        st.markdown("\n")
        if st.session_state.model_name == 'openai':
            with st.form('openai api key'):
                api_key = st.text_input(
                    "Enter openai api key", 
                    type="password",
                    value=st.session_state.openai_api_key,
                    help="enter an openai api key created from 'https://platform.openai.com/account/api-keys'",
                )

                submit_key = st.form_submit_button(
                    "Save key",
                    # on_click=update_parameters_change
                    )
        
                if submit_key:
                    st.session_state.openai_api_key = api_key
                    # st.text(st.session_state.openai_api_key)
                    alert =  st.success("openai api key updated")
                    time.sleep(1) # Wait for 3 seconds
                    alert.empty() # Clear the alert
        st.markdown(
            "### How to use\n"
            "1. Select the Summarization model\n"  # noqa: E501
            # "1. If selected model asks for a api key enter a valid api key.\n"  # noqa: E501
            "1. Enter the text to get the summary."
        )
        st.markdown("---")
        st.markdown("""
           This app supports text in the following formats:
            - Raw text in text box 
            - .txt, .pdf, .docx file formats
        """
            #  - URL of article/news to be summarized 
        )


def load_app():
    st.title("Text Summarizer 📝")

    # inp_text = st.text_input("Enter text or a url here")
    # inp_text = st.text_input(
    #     "Enter text or a url here"
    # )
    inp_text = st.text_area(
        "Enter text here"
    )
    st.markdown(
        "<h4 style='text-align: center; color: green;'>OR</h4>",
        unsafe_allow_html=True,
    )
    uploaded_file = st.file_uploader(
        "Upload a .txt, .pdf, .docx file for summarization"
    )

    is_url = validators.url(inp_text)
    if is_url:
        # complete text, chunks to summarize (list of sentences for long docs)
        logger.info("Text Input Type: URL")
        text, cleaned_txt = fetch_article_text(url=inp_text)
    elif uploaded_file:
        logger.info("Text Input Type: FILE")
        cleaned_txt = read_text_from_file(uploaded_file)
        cleaned_txt = clean_text(cleaned_txt)
    else:
        logger.info("Text Input Type: INPUT TEXT")
        cleaned_txt = clean_text(inp_text)

    # view summarized text (expander)
    with st.expander("View input text"):
        if is_url:
            st.write(cleaned_txt[0])
        else:
            st.write(cleaned_txt)

    submitted = st.button("Summarize")

    if submitted:
        if is_url:
            text_to_summarize = " ".join([txt for txt in cleaned_txt])
        else:
            text_to_summarize = cleaned_txt

        submit_text_to_summarize(text_to_summarize)

def submit_text_to_summarize(text_to_summarize):
    summarized_text, time = get_summary(text_to_summarize)
    display_output(summarized_text,time)


def get_summary(text_to_summarize):
    model_name = st.session_state.model_name
    summarizer_type = st.session_state.summarizer_type
    model_type, tokenizer, base_summarizer = init_summarizer(model_name,api_key=None)

    logger.info(f"Model Name: {model_name}")
    logger.info(f"Summarization Type for Long Text: {summarizer_type}")

    with st.spinner(
        text="Creating summary. This might take a few seconds ..."
    ):
        if summarizer_type == "Refine":
            # summarized_text, time = summarizer.summarize(text_to_summarize,"refine")
            summarized_text, time = summarizer_summarize(model_type,tokenizer, base_summarizer, text_to_summarize ,summarizer_type = "refine")
            return summarized_text, time
        else : 
            # summarized_text, time = summarizer.summarize(text_to_summarize,"map_reduce")
            summarized_text, time = summarizer_summarize(model_type,tokenizer, base_summarizer, text_to_summarize ,summarizer_type = "map_reduce")
            return summarized_text, time


def display_output(summarized_text,time):
    logger.info(f"SUMMARY: {summarized_text}")
    logger.info(f"Summary took {time}s")
    st.subheader("Summarized text")
    st.info(f"{summarized_text}")
    st.markdown(f"Time: {time}s")


def main():
   
    initialize_app()
    side_bar()
    load_app()
    # chat_body()


if __name__ == "__main__":
    main()
    # text_to_summarize, model_name, summarizer_type, summarize = load_app()
    # summarized_text,time = get_summary(text_to_summarize, model_name, summarizer_type, summarize)
    # display_output(summarized_text,time)