Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from transformers import pipeline | |
from stqdm import stqdm | |
from simplet5 import SimpleT5 | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
def load_t5(): | |
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
return model, tokenizer | |
def custom_model(): | |
return pipeline("summarization", model="my_awesome_sum/") | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv().encode("utf-8") | |
def load_one_line_summarizer(model): | |
return model.load_model("t5", "snrspeaks/t5-one-line-summary") | |
st.set_page_config(layout="wide", page_title="Amazon Review Summarizer") | |
st.title("Amazon Review Summarizer") | |
uploaded_file = st.file_uploader("Choose a file", type=["xlsx", "xls", "csv"]) | |
summarizer_option = st.selectbox( | |
"Select Summarizer", | |
("Custom trained on the dataset", "t5-base", "t5-one-line-summary"), | |
) | |
ps = st.empty() | |
if st.button("Process"): | |
if uploaded_file is not None: | |
df = pd.read_excel(uploaded_file) | |
columns = df.columns.values.tolist() | |
columns = [x.lower() for x in columns] | |
df.columns = columns | |
print(summarizer_option) | |
if summarizer_option == "Custom trained on the dataset": | |
model = custom_model() | |
print(summarizer_option) | |
text = df["text"].values.tolist() | |
progress_text = "Summarization in progress. Please wait." | |
summary = [] | |
for x in stqdm(range(len(text))): | |
try: | |
summary.append( | |
model( | |
f"summarize: {text[x]}", max_length=50, early_stopping=True | |
)[0]["summary_text"] | |
) | |
except: | |
pass | |
output = pd.DataFrame( | |
{"text": df["text"].values.tolist(), "summary": summary} | |
) | |
csv = convert_df(output) | |
st.download_button( | |
label="Download data as CSV", | |
data=csv, | |
file_name=f"{summarizer_option}_df.csv", | |
mime="text/csv", | |
) | |
if summarizer_option == "t5-base": | |
model, tokenizer = load_t5() | |
text = df["text"].values.tolist() | |
summary = [] | |
for x in stqdm(range(10)): | |
tokens_input = tokenizer.encode( | |
"summarize: " + text[x], | |
return_tensors="pt", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
) | |
summary_ids = model.generate( | |
tokens_input, | |
min_length=80, | |
max_length=150, | |
length_penalty=20, | |
num_beams=2, | |
) | |
summary_gen = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
summary.append(summary_gen) | |
output = pd.DataFrame( | |
{"text": df["text"].values.tolist()[0:10], "summary": summary} | |
) | |
csv = convert_df(output) | |
st.download_button( | |
label="Download data as CSV", | |
data=csv, | |
file_name=f"{summarizer_option}_df.csv", | |
mime="text/csv", | |
) | |
if summarizer_option == "t5-one-line-summary": | |
model = SimpleT5() | |
text = df["text"].values.tolist() | |
load_one_line_summarizer(model=model) | |
summary = [] | |
for x in stqdm(range(10)): | |
try: | |
summary.append(model.predict(text[x])[0]) | |
except: | |
pass | |
output = pd.DataFrame( | |
{"text": df["text"].values.tolist()[0:10], "summary": summary} | |
) | |
csv = convert_df(output) | |
st.download_button( | |
label="Download data as CSV", | |
data=csv, | |
file_name=f"{summarizer_option}_df.csv", | |
mime="text/csv", | |
) | |