Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from transformers import BertTokenizer, TFBertForSequenceClassification | |
from transformers import TextClassificationPipeline | |
from transformers import pipeline | |
from stqdm import stqdm | |
from simplet5 import SimpleT5 | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from transformers import BertTokenizer, TFBertForSequenceClassification | |
import logging | |
from datasets import load_dataset | |
import gc | |
from typing import List | |
from collections import OrderedDict | |
from datetime import datetime | |
tokenizer_kwargs = dict(max_length=128, truncation=True, padding=True) | |
flan_t5_kwargs = dict(repetition_penalty=1.2) | |
SLEEP = 2 | |
date = datetime.now().strftime(r"%Y-%m-%d") | |
def clean_memory(obj: TextClassificationPipeline): | |
del obj | |
gc.collect() | |
def get_all_cats(): | |
data = load_dataset("ashhadahsan/amazon_theme") | |
data = data["train"].to_pandas() | |
labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"] | |
del data | |
return labels | |
def get_all_subcats(): | |
data = load_dataset("ashhadahsan/amazon_subtheme") | |
data = data["train"].to_pandas() | |
labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"] | |
del data | |
return labels | |
def load_zero_shot_classification_large(): | |
classifier_zero = pipeline( | |
"zero-shot-classification", | |
model="facebook/bart-large-mnli", | |
) | |
return classifier_zero | |
def assign_label_zeroshot(zero, to: str, old: List): | |
assigned = zero(to, old) | |
assigned_dict = dict(zip(assigned["labels"], assigned["scores"])) | |
od = OrderedDict(sorted(assigned_dict.items(), key=lambda x: x[1], reverse=True)) | |
print(list(od.keys())[0]) | |
print(type(list(od.keys())[0])) | |
return list(od.keys())[0] | |
def assign_labels_flant5(pipe, what: str, to: str, old: List): | |
old = ", ".join(old) | |
return pipe( | |
f"""'Generate a new one word {what} to this summary of the text of a review | |
{to} for context | |
already assigned {what} are , {themes} | |
theme:""" | |
)[0]["generated_text"] | |
def load_t5() -> (AutoModelForSeq2SeqLM, AutoTokenizer): | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"t5-base", | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path="t5-base", | |
) | |
return model, tokenizer | |
def load_flan_t5_large(): | |
return pipeline( | |
task="text2text-generation", | |
model="google/flan-t5-large", | |
model_kwargs=flan_t5_kwargs, | |
) | |
def summarizationModel(): | |
return pipeline( | |
task="summarization", | |
model="my_awesome_sum/", | |
) | |
def convert_df(df: pd.DataFrame): | |
return df.to_csv(index=False).encode("utf-8") | |
def load_one_line_summarizer(model): | |
return model.load_model( | |
"t5", | |
"snrspeaks/t5-one-line-summary", | |
) | |
def classify_theme() -> TextClassificationPipeline: | |
tokenizer = BertTokenizer.from_pretrained( | |
"ashhadahsan/amazon-theme-bert-base-finetuned", | |
) | |
model = TFBertForSequenceClassification.from_pretrained( | |
"ashhadahsan/amazon-theme-bert-base-finetuned", | |
) | |
pipeline = TextClassificationPipeline( | |
model=model, | |
tokenizer=tokenizer, | |
**tokenizer_kwargs, | |
) | |
return pipeline | |
def classify_sub_theme() -> TextClassificationPipeline: | |
tokenizer = BertTokenizer.from_pretrained( | |
"ashhadahsan/amazon-subtheme-bert-base-finetuned", | |
) | |
model = TFBertForSequenceClassification.from_pretrained( | |
"ashhadahsan/amazon-subtheme-bert-base-finetuned", | |
) | |
pipeline = TextClassificationPipeline( | |
model=model, tokenizer=tokenizer, **tokenizer_kwargs | |
) | |
return pipeline | |
st.set_page_config(layout="wide", page_title="Amazon Review | Summarizer") | |
st.title(body="Amazon Review Summarizer") | |
uploaded_file = st.file_uploader(label="Choose a file", type=["xlsx", "xls", "csv"]) | |
summarizer_option = st.selectbox( | |
label="Select Summarizer", | |
options=("Custom trained on the dataset", "t5-base", "t5-one-line-summary"), | |
) | |
col1, col2, col3 = st.columns(spec=[1, 1, 1]) | |
with col1: | |
summary_yes = st.checkbox(label="Summrization", value=False) | |
with col2: | |
classification = st.checkbox(label="Classify Category", value=True) | |
with col3: | |
sub_theme = st.checkbox(label="Sub theme classification", value=True) | |
treshold = st.slider( | |
label="Model Confidence value", | |
min_value=0.1, | |
max_value=0.8, | |
step=0.1, | |
value=0.6, | |
help="If the model has a confidence score below this number , then a new label is assigned (0.6) means 60 percent and so on", | |
) | |
ps = st.empty() | |
if st.button("Process", type="primary"): | |
themes = get_all_cats() | |
subthemes = get_all_subcats() | |
oneline = SimpleT5() | |
load_one_line_summarizer(model=oneline) | |
zeroline = load_zero_shot_classification_large() | |
bot = load_flan_t5_large() | |
cancel_button = st.empty() | |
cancel_button2 = st.empty() | |
cancel_button3 = st.empty() | |
if uploaded_file is not None: | |
if uploaded_file.name.split(".")[-1] in ["xls", "xlsx"]: | |
df = pd.read_excel(io=uploaded_file, engine="openpyxl") | |
if uploaded_file.name.split(".")[-1] in [".csv"]: | |
df = pd.read_csv(filepath_or_buffer=uploaded_file) | |
columns = df.columns.values.tolist() | |
columns = [x.lower() for x in columns] | |
df.columns = columns | |
print(summarizer_option) | |
outputdf = pd.DataFrame() | |
try: | |
text = df["text"].values.tolist() | |
outputdf["text"] = text | |
if summarizer_option == "Custom trained on the dataset": | |
if summary_yes: | |
model = summarizationModel() | |
progress_text = "Summarization in progress. Please wait." | |
summary = [] | |
for x in stqdm(iterable=range(len(text))): | |
if cancel_button.button("Cancel", key=x): | |
del model | |
break | |
try: | |
summary.append( | |
model( | |
f"summarize: {text[x]}", | |
max_length=50, | |
early_stopping=True, | |
)[0]["summary_text"] | |
) | |
except: | |
pass | |
outputdf["summary"] = summary | |
del model | |
if classification: | |
themePipe = classify_theme() | |
classes = [] | |
classesUnlabel = [] | |
classesUnlabelZero = [] | |
for x in stqdm( | |
iterable=text, | |
desc="Assigning Themes ...", | |
total=len(text), | |
colour="#BF1A1A", | |
): | |
output = themePipe(x)[0]["label"] | |
classes.append(output) | |
score = round(number=themePipe(x)[0]["score"], ndigits=2) | |
if score <= treshold: | |
onelineoutput = oneline.predict(source_text=x)[0] | |
print("hit") | |
classesUnlabel.append( | |
assign_labels_flant5( | |
bot, | |
what="theme", | |
to=onelineoutput, | |
old=themes, | |
) | |
) | |
classesUnlabelZero.append( | |
assign_label_zeroshot( | |
zero=zeroline, to=onelineoutput, old=themes | |
) | |
) | |
else: | |
classesUnlabel.append("") | |
classesUnlabelZero.append("") | |
outputdf["Review Theme"] = classes | |
outputdf["Review Theme-issue-new"] = classesUnlabel | |
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero | |
clean_memory(themePipe) | |
if sub_theme: | |
subThemePipe = classify_sub_theme() | |
classes = [] | |
classesUnlabel = [] | |
classesUnlabelZero = [] | |
for x in stqdm( | |
iterable=text, | |
desc="Assigning Subthemes ...", | |
total=len(text), | |
colour="green", | |
): | |
output = subThemePipe(x)[0]["label"] | |
classes.append(output) | |
score = round(subThemePipe(x)[0]["score"], 2) | |
if score <= treshold: | |
onelineoutput = oneline.predict(x)[0] | |
print("hit") | |
classesUnlabel.append( | |
assign_labels_flant5( | |
bot, | |
what="subtheme", | |
to=onelineoutput, | |
old=subthemes, | |
) | |
) | |
classesUnlabelZero.append( | |
assign_label_zeroshot( | |
zero=zeroline, | |
to=onelineoutput, | |
old=subthemes, | |
) | |
) | |
else: | |
classesUnlabel.append("") | |
classesUnlabelZero.append("") | |
outputdf["Review SubTheme"] = classes | |
outputdf["Review SubTheme-issue-new"] = classesUnlabel | |
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero | |
clean_memory(subThemePipe) | |
csv = convert_df(outputdf) | |
st.download_button( | |
label="Download output as CSV", | |
data=csv, | |
file_name=f"{summarizer_option}_{date}_df.csv", | |
mime="text/csv", | |
use_container_width=True, | |
) | |
if summarizer_option == "t5-base": | |
if summary_yes: | |
model, tokenizer = load_t5() | |
summary = [] | |
for x in stqdm(range(len(text))): | |
if cancel_button2.button("Cancel", key=x): | |
del model, tokenizer | |
break | |
tokens_input = tokenizer.encode( | |
"summarize: " + text[x], | |
return_tensors="pt", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
) | |
summary_ids = model.generate( | |
tokens_input, | |
min_length=80, | |
max_length=150, | |
length_penalty=20, | |
num_beams=2, | |
) | |
summary_gen = tokenizer.decode( | |
summary_ids[0], skip_special_tokens=True | |
) | |
summary.append(summary_gen) | |
del model, tokenizer | |
outputdf["summary"] = summary | |
if classification: | |
themePipe = classify_theme() | |
classes = [] | |
classesUnlabel = [] | |
classesUnlabelZero = [] | |
for x in stqdm( | |
text, desc="Assigning Themes ...", total=len(text), colour="red" | |
): | |
output = themePipe(x)[0]["label"] | |
classes.append(output) | |
score = round(themePipe(x)[0]["score"], 2) | |
if score <= treshold: | |
onelineoutput = oneline.predict(x)[0] | |
print("hit") | |
classesUnlabel.append( | |
assign_labels_flant5( | |
bot, | |
what="theme", | |
to=onelineoutput, | |
old=themes, | |
) | |
) | |
classesUnlabelZero.append( | |
assign_label_zeroshot( | |
zero=zeroline, to=onelineoutput, old=themes | |
) | |
) | |
else: | |
classesUnlabel.append("") | |
classesUnlabelZero.append("") | |
outputdf["Review Theme"] = classes | |
outputdf["Review Theme-issue-new"] = classesUnlabel | |
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero | |
clean_memory(themePipe) | |
if sub_theme: | |
subThemePipe = classify_sub_theme() | |
classes = [] | |
classesUnlabelZero = [] | |
for x in stqdm( | |
text, | |
desc="Assigning Subthemes ...", | |
total=len(text), | |
colour="green", | |
): | |
output = subThemePipe(x)[0]["label"] | |
classes.append(output) | |
score = round(subThemePipe(x)[0]["score"], 2) | |
if score <= treshold: | |
onelineoutput = oneline.predict(x)[0] | |
print("hit") | |
classesUnlabel.append( | |
assign_labels_flant5( | |
bot, | |
what="subtheme", | |
to=onelineoutput, | |
old=subthemes, | |
) | |
) | |
classesUnlabelZero.append( | |
assign_label_zeroshot( | |
zero=zeroline, | |
to=onelineoutput, | |
old=subthemes, | |
) | |
) | |
else: | |
classesUnlabel.append("") | |
classesUnlabelZero.append("") | |
outputdf["Review SubTheme"] = classes | |
outputdf["Review SubTheme-issue-new"] = classesUnlabel | |
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero | |
clean_memory(subThemePipe) | |
csv = convert_df(outputdf) | |
st.download_button( | |
label="Download output as CSV", | |
data=csv, | |
file_name=f"{summarizer_option}_{date}_df.csv", | |
mime="text/csv", | |
use_container_width=True, | |
) | |
if summarizer_option == "t5-one-line-summary": | |
if summary_yes: | |
model = SimpleT5() | |
load_one_line_summarizer(model=model) | |
summary = [] | |
for x in stqdm(iterable=range(len(text))): | |
if cancel_button3.button(label="Cancel", key=x): | |
del model | |
break | |
try: | |
summary.append(model.predict(source_text=text[x])[0]) | |
except: | |
pass | |
outputdf["summary"] = summary | |
del model | |
if classification: | |
themePipe = classify_theme() | |
classes = [] | |
classesUnlabel = [] | |
classesUnlabelZero = [] | |
for x in stqdm( | |
iterable=text, | |
desc="Assigning Themes ...", | |
total=len(text), | |
colour="red", | |
): | |
output = themePipe(x)[0]["label"] | |
classes.append(output) | |
score = round(number=themePipe(x)[0]["score"], ndigits=2) | |
if score <= treshold: | |
onelineoutput = oneline.predict(x)[0] | |
print("hit") | |
classesUnlabel.append( | |
assign_labels_flant5( | |
bot, | |
what="theme", | |
to=onelineoutput, | |
old=themes, | |
) | |
) | |
classesUnlabelZero.append( | |
assign_label_zeroshot( | |
zero=zeroline, to=onelineoutput, old=themes | |
) | |
) | |
else: | |
classesUnlabel.append("") | |
classesUnlabelZero.append("") | |
outputdf["Review Theme"] = classes | |
outputdf["Review Theme-issue-new"] = classesUnlabel | |
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero | |
if sub_theme: | |
subThemePipe = classify_sub_theme() | |
classes = [] | |
classesUnlabelZero = [] | |
for x in stqdm( | |
iterable=text, | |
desc="Assigning Subthemes ...", | |
total=len(text), | |
colour="green", | |
): | |
output = subThemePipe(x)[0]["label"] | |
classes.append(output) | |
score = round(subThemePipe(x)[0]["score"], 2) | |
if score <= treshold: | |
print("hit") | |
onelineoutput = oneline.predict(source_text=x)[0] | |
classesUnlabel.append( | |
assign_labels_flant5( | |
bot, | |
what="subtheme", | |
to=onelineoutput, | |
old=subthemes, | |
) | |
) | |
classesUnlabelZero.append( | |
assign_label_zeroshot( | |
zero=zeroline, | |
to=onelineoutput, | |
old=subthemes, | |
) | |
) | |
else: | |
classesUnlabel.append("") | |
classesUnlabelZero.append("") | |
outputdf["Review SubTheme"] = classes | |
outputdf["Review SubTheme-issue-new"] = classesUnlabel | |
outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero | |
clean_memory(subThemePipe) | |
csv = convert_df(outputdf) | |
st.download_button( | |
label="Download output as CSV", | |
data=csv, | |
file_name=f"{summarizer_option}_{date}_df.csv", | |
mime="text/csv", | |
use_container_width=True, | |
) | |
except KeyError as e: | |
st.error( | |
body="Please Make sure that your data must have a column named text", | |
icon="π¨", | |
) | |
st.info(body="Text column must have amazon reviews", icon="βΉοΈ") | |
st.exception(e) | |
except BaseException as e: | |
logging.exception(msg="An exception was occurred") | |