import os import nltk from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch import streamlit as st from src.doc2vec import inference from src.abstractive_sum import summarize_text_with_model from src.textrank import custom_textrank_summarizer, get_labels_for_license from src.clean import clean_license_text from src.read_data import read_license_text_data from src.diff import strikethrough_diff from src.parameters import help_messages, captions, options nltk.download('punkt') if __name__ == "__main__": CUSTOM_MODEL_NAME = "utkarshsaboo45/ClearlyDefinedLicenseSummarizer" SIMILARITY_THRESHOLD = 0.8 os.environ["TOKENIZERS_PARALLELISM"] = "false" device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") with st.spinner(captions.LOADING): model = AutoModelForSeq2SeqLM.from_pretrained(CUSTOM_MODEL_NAME).to(device) tokenizer = AutoTokenizer.from_pretrained(CUSTOM_MODEL_NAME) summarization_type = st.sidebar.selectbox( captions.SELECT_SUMMARIZATION_TYPE, (options.EXTRACTIVE, options.ABSTRACTIVE, options.BOTH), help=help_messages.SUMMARIZATION_TYPE ) cleaned_view = None exceptions = "" definitions = "" if summarization_type == options.ABSTRACTIVE: st.sidebar.caption(captions.SUMMARY_BY_T5) st.sidebar.caption(captions.WARNING_ABSTRACTIVE) elif summarization_type == options.EXTRACTIVE: st.sidebar.caption(captions.SUMMARY_BY_TEXTRANK) summary_len = st.sidebar.slider( captions.SUMMARY_LENGTH_PERCENTAGE, 1, 100, 30, help=help_messages.SLIDER ) summary_view = st.sidebar.selectbox( captions.SUMMARY_VIEW, ( options.DISPLAY_SUMMARY_ONLY, options.DISPLAY_HIGHLIGHTED_SUMMARY ), help=help_messages.SUMMARY_VIEW ) if summary_view == options.DISPLAY_SUMMARY_ONLY: st.sidebar.caption(captions.DISPLAY_SUMMARY_ONLY_DESC) elif summary_view == options.DISPLAY_HIGHLIGHTED_SUMMARY: st.sidebar.caption(captions.DISPLAY_HIGHLIGHTED_SUMMARY_DESC) cleaned_view = st.sidebar.selectbox( captions.CLEANED_LICENSE_VIEW, ( options.HIDE_CLEANED_LICENSE, options.DISPLAY_CLEANED_LICENSE, options.DISPLAY_CLEANED_DIFF ), help=help_messages.CLEANED_LICENSE_VIEW ) if cleaned_view == options.DISPLAY_CLEANED_LICENSE: st.sidebar.caption(captions.CLEANED_LICENSE_ONLY) elif cleaned_view == options.DISPLAY_CLEANED_DIFF: st.sidebar.caption(captions.CLEANED_LICENSE_WITH_DIFF) elif cleaned_view == options.HIDE_CLEANED_LICENSE: st.sidebar.caption(captions.HIDE_CLEANED_LICENSE) elif summarization_type == options.BOTH: st.sidebar.caption(captions.SUMMARY_BY_BOTH) st.sidebar.caption(captions.WARNING_BOTH) st.title(captions.APP_TITLE) st.caption(captions.APP_DISCLAIMER) license_input = st.text_area( captions.LICENSE_TEXT, placeholder=captions.ENTER_LICENSE_CONTENT ) if len(license_input) > 0: cleaned_modified_license_text = clean_license_text(license_input)[0] with st.spinner(captions.LOADING): if summarization_type == options.ABSTRACTIVE: summary, definitions = summarize_text_with_model( license_input, model, tokenizer ) if summarization_type == options.EXTRACTIVE: if summary_view == options.DISPLAY_SUMMARY_ONLY: summary, definitions, exceptions = custom_textrank_summarizer( license_input, summary_len=summary_len / 100 ) elif summary_view == options.DISPLAY_HIGHLIGHTED_SUMMARY: summary, definitions, exceptions = custom_textrank_summarizer( license_input, summary_len=summary_len / 100, return_summary_only=False ) if summarization_type == options.BOTH: summary, definitions = summarize_text_with_model( license_input, model, tokenizer ) summary, definitions, exceptions = custom_textrank_summarizer( summary, summary_len=1 ) st.header(captions.SUMMARY) st.markdown(summary, unsafe_allow_html=True) prediction_scores = inference(license_input) top1_result = prediction_scores.loc[0, :] st.header(captions.SIMILARITY_INDEX) st.caption(captions.SIMILARITY_INDEX_DISCLAIMER) st.dataframe(prediction_scores) if cleaned_view == options.DISPLAY_CLEANED_DIFF: st.header(captions.CLEANED_LICENSE_DIFF) if top1_result["Similarity Scores"] > SIMILARITY_THRESHOLD: st.caption("Comparing against the official " + " ".join( top1_result["License"].split("-") ) + " license") top_license_name = top1_result["License"].lower() original_license_text = read_license_text_data( top_license_name ) cleaned_original_license_text = clean_license_text( original_license_text )[0] st.markdown( strikethrough_diff( cleaned_original_license_text, cleaned_modified_license_text ), unsafe_allow_html=True ) else: st.caption(captions.NO_SIMILAR_LICENSE_FOUND) elif cleaned_view == options.DISPLAY_CLEANED_LICENSE: st.header(captions.CLEANED_LICENSE_TEXT) st.write(cleaned_modified_license_text) if st.sidebar.checkbox( options.SHOW_LICENSE_PROPERTIES, disabled = False if top1_result["Similarity Scores"] > SIMILARITY_THRESHOLD else True, value=False, help=help_messages.PROPERTIES_CHECKBOX): license_properties = get_labels_for_license(top1_result["License"].lower()) st.header(captions.PROPERTIES) st.caption(captions.PROPERTIES_DISCLAIMER) st.dataframe(license_properties) if st.sidebar.checkbox( options.SHOW_LICENSE_DEFINITIONS, disabled=False if len(definitions.strip()) > 10 else True, value=False, help=help_messages.DEFINITIONS_CHECKBOX ): if len(definitions.strip()) > 10: st.header(captions.DEFINITIONS) st.write(definitions) if st.sidebar.checkbox( options.SHOW_LICENSE_EXCEPTIONS, disabled=False if len(exceptions.strip()) > 10 else True, value=False, help=help_messages.EXCEPTIONS_CHECKBOX ): if len(exceptions.strip()) > 10: st.header(captions.EXCEPTIONS) st.write(exceptions)