summarizer-space / pages /1_πŸ“ˆ_predict.py
ashhadahsan's picture
Update pages/1_πŸ“ˆ_predict.py
c1dd675
import streamlit as st
import pandas as pd
from transformers import BertTokenizer, TFBertForSequenceClassification
from transformers import TextClassificationPipeline
from transformers import pipeline
from stqdm import stqdm
from simplet5 import SimpleT5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import BertTokenizer, TFBertForSequenceClassification
import logging
from datasets import load_dataset
import gc
from typing import List
from collections import OrderedDict
from datetime import datetime
tokenizer_kwargs = dict(max_length=128, truncation=True, padding=True)
flan_t5_kwargs = dict(repetition_penalty=1.2)
SLEEP = 2
date = datetime.now().strftime(r"%Y-%m-%d")
def clean_memory(obj: TextClassificationPipeline):
del obj
gc.collect()
@st.cache_data
def get_all_cats():
data = load_dataset("ashhadahsan/amazon_theme")
data = data["train"].to_pandas()
labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"]
del data
return labels
@st.cache_data
def get_all_subcats():
data = load_dataset("ashhadahsan/amazon_subtheme")
data = data["train"].to_pandas()
labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"]
del data
return labels
@st.cache_resource
def load_zero_shot_classification_large():
classifier_zero = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
)
return classifier_zero
def assign_label_zeroshot(zero, to: str, old: List):
assigned = zero(to, old)
assigned_dict = dict(zip(assigned["labels"], assigned["scores"]))
od = OrderedDict(sorted(assigned_dict.items(), key=lambda x: x[1], reverse=True))
print(list(od.keys())[0])
print(type(list(od.keys())[0]))
return list(od.keys())[0]
def assign_labels_flant5(pipe, what: str, to: str, old: List):
old = ", ".join(old)
return pipe(
f"""'Generate a new one word {what} to this summary of the text of a review
{to} for context
already assigned {what} are , {themes}
theme:"""
)[0]["generated_text"]
@st.cache_resource
def load_t5() -> (AutoModelForSeq2SeqLM, AutoTokenizer):
model = AutoModelForSeq2SeqLM.from_pretrained(
"t5-base",
)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path="t5-base",
)
return model, tokenizer
@st.cache_resource
def load_flan_t5_large():
return pipeline(
task="text2text-generation",
model="google/flan-t5-large",
model_kwargs=flan_t5_kwargs,
)
@st.cache_resource
def summarizationModel():
return pipeline(
task="summarization",
model="my_awesome_sum/",
)
@st.cache_resource
def convert_df(df: pd.DataFrame):
return df.to_csv(index=False).encode("utf-8")
def load_one_line_summarizer(model):
return model.load_model(
"t5",
"snrspeaks/t5-one-line-summary",
)
@st.cache_resource
def classify_theme() -> TextClassificationPipeline:
tokenizer = BertTokenizer.from_pretrained(
"ashhadahsan/amazon-theme-bert-base-finetuned",
)
model = TFBertForSequenceClassification.from_pretrained(
"ashhadahsan/amazon-theme-bert-base-finetuned",
)
pipeline = TextClassificationPipeline(
model=model,
tokenizer=tokenizer,
**tokenizer_kwargs,
)
return pipeline
@st.cache_resource
def classify_sub_theme() -> TextClassificationPipeline:
tokenizer = BertTokenizer.from_pretrained(
"ashhadahsan/amazon-subtheme-bert-base-finetuned",
)
model = TFBertForSequenceClassification.from_pretrained(
"ashhadahsan/amazon-subtheme-bert-base-finetuned",
)
pipeline = TextClassificationPipeline(
model=model, tokenizer=tokenizer, **tokenizer_kwargs
)
return pipeline
st.set_page_config(layout="wide", page_title="Amazon Review | Summarizer")
st.title(body="Amazon Review Summarizer")
uploaded_file = st.file_uploader(label="Choose a file", type=["xlsx", "xls", "csv"])
summarizer_option = st.selectbox(
label="Select Summarizer",
options=("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
)
col1, col2, col3 = st.columns(spec=[1, 1, 1])
with col1:
summary_yes = st.checkbox(label="Summrization", value=False)
with col2:
classification = st.checkbox(label="Classify Category", value=True)
with col3:
sub_theme = st.checkbox(label="Sub theme classification", value=True)
treshold = st.slider(
label="Model Confidence value",
min_value=0.1,
max_value=0.8,
step=0.1,
value=0.6,
help="If the model has a confidence score below this number , then a new label is assigned (0.6) means 60 percent and so on",
)
ps = st.empty()
if st.button("Process", type="primary"):
themes = get_all_cats()
subthemes = get_all_subcats()
oneline = SimpleT5()
load_one_line_summarizer(model=oneline)
zeroline = load_zero_shot_classification_large()
bot = load_flan_t5_large()
cancel_button = st.empty()
cancel_button2 = st.empty()
cancel_button3 = st.empty()
if uploaded_file is not None:
if uploaded_file.name.split(".")[-1] in ["xls", "xlsx"]:
df = pd.read_excel(io=uploaded_file, engine="openpyxl")
if uploaded_file.name.split(".")[-1] in [".csv"]:
df = pd.read_csv(filepath_or_buffer=uploaded_file)
columns = df.columns.values.tolist()
columns = [x.lower() for x in columns]
df.columns = columns
print(summarizer_option)
outputdf = pd.DataFrame()
try:
text = df["text"].values.tolist()
outputdf["text"] = text
if summarizer_option == "Custom trained on the dataset":
if summary_yes:
model = summarizationModel()
progress_text = "Summarization in progress. Please wait."
summary = []
for x in stqdm(iterable=range(len(text))):
if cancel_button.button("Cancel", key=x):
del model
break
try:
summary.append(
model(
f"summarize: {text[x]}",
max_length=50,
early_stopping=True,
)[0]["summary_text"]
)
except:
pass
outputdf["summary"] = summary
del model
if classification:
themePipe = classify_theme()
classes = []
classesUnlabel = []
classesUnlabelZero = []
for x in stqdm(
iterable=text,
desc="Assigning Themes ...",
total=len(text),
colour="#BF1A1A",
):
output = themePipe(x)[0]["label"]
classes.append(output)
score = round(number=themePipe(x)[0]["score"], ndigits=2)
if score <= treshold:
onelineoutput = oneline.predict(source_text=x)[0]
print("hit")
classesUnlabel.append(
assign_labels_flant5(
bot,
what="theme",
to=onelineoutput,
old=themes,
)
)
classesUnlabelZero.append(
assign_label_zeroshot(
zero=zeroline, to=onelineoutput, old=themes
)
)
else:
classesUnlabel.append("")
classesUnlabelZero.append("")
outputdf["Review Theme"] = classes
outputdf["Review Theme-issue-new"] = classesUnlabel
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
clean_memory(themePipe)
if sub_theme:
subThemePipe = classify_sub_theme()
classes = []
classesUnlabel = []
classesUnlabelZero = []
for x in stqdm(
iterable=text,
desc="Assigning Subthemes ...",
total=len(text),
colour="green",
):
output = subThemePipe(x)[0]["label"]
classes.append(output)
score = round(subThemePipe(x)[0]["score"], 2)
if score <= treshold:
onelineoutput = oneline.predict(x)[0]
print("hit")
classesUnlabel.append(
assign_labels_flant5(
bot,
what="subtheme",
to=onelineoutput,
old=subthemes,
)
)
classesUnlabelZero.append(
assign_label_zeroshot(
zero=zeroline,
to=onelineoutput,
old=subthemes,
)
)
else:
classesUnlabel.append("")
classesUnlabelZero.append("")
outputdf["Review SubTheme"] = classes
outputdf["Review SubTheme-issue-new"] = classesUnlabel
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
clean_memory(subThemePipe)
csv = convert_df(outputdf)
st.download_button(
label="Download output as CSV",
data=csv,
file_name=f"{summarizer_option}_{date}_df.csv",
mime="text/csv",
use_container_width=True,
)
if summarizer_option == "t5-base":
if summary_yes:
model, tokenizer = load_t5()
summary = []
for x in stqdm(range(len(text))):
if cancel_button2.button("Cancel", key=x):
del model, tokenizer
break
tokens_input = tokenizer.encode(
"summarize: " + text[x],
return_tensors="pt",
max_length=tokenizer.model_max_length,
truncation=True,
)
summary_ids = model.generate(
tokens_input,
min_length=80,
max_length=150,
length_penalty=20,
num_beams=2,
)
summary_gen = tokenizer.decode(
summary_ids[0], skip_special_tokens=True
)
summary.append(summary_gen)
del model, tokenizer
outputdf["summary"] = summary
if classification:
themePipe = classify_theme()
classes = []
classesUnlabel = []
classesUnlabelZero = []
for x in stqdm(
text, desc="Assigning Themes ...", total=len(text), colour="red"
):
output = themePipe(x)[0]["label"]
classes.append(output)
score = round(themePipe(x)[0]["score"], 2)
if score <= treshold:
onelineoutput = oneline.predict(x)[0]
print("hit")
classesUnlabel.append(
assign_labels_flant5(
bot,
what="theme",
to=onelineoutput,
old=themes,
)
)
classesUnlabelZero.append(
assign_label_zeroshot(
zero=zeroline, to=onelineoutput, old=themes
)
)
else:
classesUnlabel.append("")
classesUnlabelZero.append("")
outputdf["Review Theme"] = classes
outputdf["Review Theme-issue-new"] = classesUnlabel
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
clean_memory(themePipe)
if sub_theme:
subThemePipe = classify_sub_theme()
classes = []
classesUnlabelZero = []
for x in stqdm(
text,
desc="Assigning Subthemes ...",
total=len(text),
colour="green",
):
output = subThemePipe(x)[0]["label"]
classes.append(output)
score = round(subThemePipe(x)[0]["score"], 2)
if score <= treshold:
onelineoutput = oneline.predict(x)[0]
print("hit")
classesUnlabel.append(
assign_labels_flant5(
bot,
what="subtheme",
to=onelineoutput,
old=subthemes,
)
)
classesUnlabelZero.append(
assign_label_zeroshot(
zero=zeroline,
to=onelineoutput,
old=subthemes,
)
)
else:
classesUnlabel.append("")
classesUnlabelZero.append("")
outputdf["Review SubTheme"] = classes
outputdf["Review SubTheme-issue-new"] = classesUnlabel
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
clean_memory(subThemePipe)
csv = convert_df(outputdf)
st.download_button(
label="Download output as CSV",
data=csv,
file_name=f"{summarizer_option}_{date}_df.csv",
mime="text/csv",
use_container_width=True,
)
if summarizer_option == "t5-one-line-summary":
if summary_yes:
model = SimpleT5()
load_one_line_summarizer(model=model)
summary = []
for x in stqdm(iterable=range(len(text))):
if cancel_button3.button(label="Cancel", key=x):
del model
break
try:
summary.append(model.predict(source_text=text[x])[0])
except:
pass
outputdf["summary"] = summary
del model
if classification:
themePipe = classify_theme()
classes = []
classesUnlabel = []
classesUnlabelZero = []
for x in stqdm(
iterable=text,
desc="Assigning Themes ...",
total=len(text),
colour="red",
):
output = themePipe(x)[0]["label"]
classes.append(output)
score = round(number=themePipe(x)[0]["score"], ndigits=2)
if score <= treshold:
onelineoutput = oneline.predict(x)[0]
print("hit")
classesUnlabel.append(
assign_labels_flant5(
bot,
what="theme",
to=onelineoutput,
old=themes,
)
)
classesUnlabelZero.append(
assign_label_zeroshot(
zero=zeroline, to=onelineoutput, old=themes
)
)
else:
classesUnlabel.append("")
classesUnlabelZero.append("")
outputdf["Review Theme"] = classes
outputdf["Review Theme-issue-new"] = classesUnlabel
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
if sub_theme:
subThemePipe = classify_sub_theme()
classes = []
classesUnlabelZero = []
for x in stqdm(
iterable=text,
desc="Assigning Subthemes ...",
total=len(text),
colour="green",
):
output = subThemePipe(x)[0]["label"]
classes.append(output)
score = round(subThemePipe(x)[0]["score"], 2)
if score <= treshold:
print("hit")
onelineoutput = oneline.predict(source_text=x)[0]
classesUnlabel.append(
assign_labels_flant5(
bot,
what="subtheme",
to=onelineoutput,
old=subthemes,
)
)
classesUnlabelZero.append(
assign_label_zeroshot(
zero=zeroline,
to=onelineoutput,
old=subthemes,
)
)
else:
classesUnlabel.append("")
classesUnlabelZero.append("")
outputdf["Review SubTheme"] = classes
outputdf["Review SubTheme-issue-new"] = classesUnlabel
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
clean_memory(subThemePipe)
csv = convert_df(outputdf)
st.download_button(
label="Download output as CSV",
data=csv,
file_name=f"{summarizer_option}_{date}_df.csv",
mime="text/csv",
use_container_width=True,
)
except KeyError as e:
st.error(
body="Please Make sure that your data must have a column named text",
icon="🚨",
)
st.info(body="Text column must have amazon reviews", icon="ℹ️")
st.exception(e)
except BaseException as e:
logging.exception(msg="An exception was occurred")