Spaces:
Runtime error
Runtime error
import os | |
import nltk | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
import streamlit as st | |
from src.doc2vec import inference | |
from src.abstractive_sum import summarize_text_with_model | |
from src.textrank import custom_textrank_summarizer, get_labels_for_license | |
from src.clean import clean_license_text | |
from src.read_data import read_license_text_data | |
from src.diff import strikethrough_diff | |
from src.parameters import help_messages, captions, options | |
nltk.download('punkt') | |
if __name__ == "__main__": | |
CUSTOM_MODEL_NAME = "utkarshsaboo45/ClearlyDefinedLicenseSummarizer" | |
SIMILARITY_THRESHOLD = 0.8 | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
with st.spinner(captions.LOADING): | |
model = AutoModelForSeq2SeqLM.from_pretrained(CUSTOM_MODEL_NAME).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(CUSTOM_MODEL_NAME) | |
summarization_type = st.sidebar.selectbox( | |
captions.SELECT_SUMMARIZATION_TYPE, | |
(options.EXTRACTIVE, options.ABSTRACTIVE, options.BOTH), | |
help=help_messages.SUMMARIZATION_TYPE | |
) | |
cleaned_view = None | |
exceptions = "" | |
definitions = "" | |
if summarization_type == options.ABSTRACTIVE: | |
st.sidebar.caption(captions.SUMMARY_BY_T5) | |
st.sidebar.caption(captions.WARNING_ABSTRACTIVE) | |
elif summarization_type == options.EXTRACTIVE: | |
st.sidebar.caption(captions.SUMMARY_BY_TEXTRANK) | |
summary_len = st.sidebar.slider( | |
captions.SUMMARY_LENGTH_PERCENTAGE, | |
1, | |
100, | |
30, | |
help=help_messages.SLIDER | |
) | |
summary_view = st.sidebar.selectbox( | |
captions.SUMMARY_VIEW, ( | |
options.DISPLAY_SUMMARY_ONLY, | |
options.DISPLAY_HIGHLIGHTED_SUMMARY | |
), | |
help=help_messages.SUMMARY_VIEW | |
) | |
if summary_view == options.DISPLAY_SUMMARY_ONLY: | |
st.sidebar.caption(captions.DISPLAY_SUMMARY_ONLY_DESC) | |
elif summary_view == options.DISPLAY_HIGHLIGHTED_SUMMARY: | |
st.sidebar.caption(captions.DISPLAY_HIGHLIGHTED_SUMMARY_DESC) | |
cleaned_view = st.sidebar.selectbox( | |
captions.CLEANED_LICENSE_VIEW, ( | |
options.HIDE_CLEANED_LICENSE, | |
options.DISPLAY_CLEANED_LICENSE, | |
options.DISPLAY_CLEANED_DIFF | |
), | |
help=help_messages.CLEANED_LICENSE_VIEW | |
) | |
if cleaned_view == options.DISPLAY_CLEANED_LICENSE: | |
st.sidebar.caption(captions.CLEANED_LICENSE_ONLY) | |
elif cleaned_view == options.DISPLAY_CLEANED_DIFF: | |
st.sidebar.caption(captions.CLEANED_LICENSE_WITH_DIFF) | |
elif cleaned_view == options.HIDE_CLEANED_LICENSE: | |
st.sidebar.caption(captions.HIDE_CLEANED_LICENSE) | |
elif summarization_type == options.BOTH: | |
st.sidebar.caption(captions.SUMMARY_BY_BOTH) | |
st.sidebar.caption(captions.WARNING_BOTH) | |
st.title(captions.APP_TITLE) | |
st.caption(captions.APP_DISCLAIMER) | |
license_input = st.text_area( | |
captions.LICENSE_TEXT, | |
placeholder=captions.ENTER_LICENSE_CONTENT | |
) | |
if len(license_input) > 0: | |
cleaned_modified_license_text = clean_license_text(license_input)[0] | |
with st.spinner(captions.LOADING): | |
if summarization_type == options.ABSTRACTIVE: | |
summary, definitions = summarize_text_with_model( | |
license_input, | |
model, | |
tokenizer | |
) | |
if summarization_type == options.EXTRACTIVE: | |
if summary_view == options.DISPLAY_SUMMARY_ONLY: | |
summary, definitions, exceptions = custom_textrank_summarizer( | |
license_input, | |
summary_len=summary_len / 100 | |
) | |
elif summary_view == options.DISPLAY_HIGHLIGHTED_SUMMARY: | |
summary, definitions, exceptions = custom_textrank_summarizer( | |
license_input, | |
summary_len=summary_len / 100, | |
return_summary_only=False | |
) | |
if summarization_type == options.BOTH: | |
summary, definitions = summarize_text_with_model( | |
license_input, | |
model, | |
tokenizer | |
) | |
summary, definitions, exceptions = custom_textrank_summarizer( | |
summary, | |
summary_len=1 | |
) | |
st.header(captions.SUMMARY) | |
st.markdown(summary, unsafe_allow_html=True) | |
prediction_scores = inference(license_input) | |
top1_result = prediction_scores.loc[0, :] | |
st.header(captions.SIMILARITY_INDEX) | |
st.caption(captions.SIMILARITY_INDEX_DISCLAIMER) | |
st.dataframe(prediction_scores) | |
if cleaned_view == options.DISPLAY_CLEANED_DIFF: | |
st.header(captions.CLEANED_LICENSE_DIFF) | |
if top1_result["Similarity Scores"] > SIMILARITY_THRESHOLD: | |
st.caption("Comparing against the official " + " ".join( | |
top1_result["License"].split("-") | |
) + " license") | |
top_license_name = top1_result["License"].lower() | |
original_license_text = read_license_text_data( | |
top_license_name | |
) | |
cleaned_original_license_text = clean_license_text( | |
original_license_text | |
)[0] | |
st.markdown( | |
strikethrough_diff( | |
cleaned_original_license_text, | |
cleaned_modified_license_text | |
), | |
unsafe_allow_html=True | |
) | |
else: | |
st.caption(captions.NO_SIMILAR_LICENSE_FOUND) | |
elif cleaned_view == options.DISPLAY_CLEANED_LICENSE: | |
st.header(captions.CLEANED_LICENSE_TEXT) | |
st.write(cleaned_modified_license_text) | |
if st.sidebar.checkbox( | |
options.SHOW_LICENSE_PROPERTIES, | |
disabled = False if top1_result["Similarity Scores"] > SIMILARITY_THRESHOLD else True, | |
value=False, | |
help=help_messages.PROPERTIES_CHECKBOX): | |
license_properties = get_labels_for_license(top1_result["License"].lower()) | |
st.header(captions.PROPERTIES) | |
st.caption(captions.PROPERTIES_DISCLAIMER) | |
st.dataframe(license_properties) | |
if st.sidebar.checkbox( | |
options.SHOW_LICENSE_DEFINITIONS, | |
disabled=False if len(definitions.strip()) > 10 else True, | |
value=False, | |
help=help_messages.DEFINITIONS_CHECKBOX | |
): | |
if len(definitions.strip()) > 10: | |
st.header(captions.DEFINITIONS) | |
st.write(definitions) | |
if st.sidebar.checkbox( | |
options.SHOW_LICENSE_EXCEPTIONS, | |
disabled=False if len(exceptions.strip()) > 10 else True, | |
value=False, | |
help=help_messages.EXCEPTIONS_CHECKBOX | |
): | |
if len(exceptions.strip()) > 10: | |
st.header(captions.EXCEPTIONS) | |
st.write(exceptions) | |