import os import joblib from copy import deepcopy import pandas as pd import plotly.express as px from huggingface_hub import hf_hub_download, snapshot_download import streamlit as st import streamlit_analytics from utils import add_logo_to_sidebar, add_footer, add_email_signup_form HF_TOKEN = os.environ.get("HF_TOKEN") MODEL_REPO_ID = "simplexico/cuad-sklearn-contract-clustering" DATA_REPO_ID = "simplexico/cuad-top-ten" MODEL_FILENAME = "cuad_tfidf_umap_kmeans.pkl" DATA_FILENAME = "cuad_top_ten_popular_contract_types.json" streamlit_analytics.start_tracking() st.set_page_config( page_title="Organise Demo", page_icon="🗂", layout="wide", initial_sidebar_state="expanded", menu_items={ 'Get Help': 'mailto:hello@simplexico.ai', 'Report a bug': None, 'About': "## This a demo showcasing different Legal AI Actions" } ) add_logo_to_sidebar() st.sidebar.success("👆 Select a demo above.") st.title('🗂 Organise Demo') st.write(""" This demo shows how AI can be used to organise contracts. We've trained a model to group contracts into similar types. The plot below shows a sample set of contracts that have been automatically grouped together. Each point in the plot represents how the model interprets a contract, the closer together a pair of points are, the more similar they appear to the model. Similar documents are grouped by color. \n**TIP:** Hover over each point to see the filename of the contract. Groups can be added or removed by clicking on the symbol in the plot legend. """) st.write("**👈 Upload your own contracts on the left (as .txt files)** and hit the button **Organise Data** to see how your own contracts can be grouped together") @st.cache(allow_output_mutation=True) def load_model(): model = joblib.load( hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, token=HF_TOKEN) ) return model @st.cache(allow_output_mutation=True) def load_dataset(): snapshot_download(repo_id=DATA_REPO_ID, token=HF_TOKEN, local_dir='./', repo_type='dataset') df = pd.read_json(DATA_FILENAME) return df def get_transform_and_predictions(model, X): y = model.predict(X) X_transform = model[:2].transform(X) return X_transform, y def generate_plot(X, y, filenames): fig = px.scatter_3d( x=X[:,0], y=X[:,1], z=X[:,2], color=[str(y_i) for y_i in y], hover_name=filenames) fig.update_traces( marker_size=8, marker_line=dict(width=2), selector=dict(mode='markers') ) fig.update_layout( legend=dict( title='grouping', yanchor="top", y=0.99, xanchor="left", x=0.01 ), width=1100, height=900 ) return fig uploaded_files = st.sidebar.file_uploader("Select contracts to organise ", accept_multiple_files=True) button = st.sidebar.button('Organise Contracts', type='primary', use_container_width=True) with st.container(): with st.spinner('⚙️ Loading model...'): cuad_tfidf_umap_kmeans = load_model() cuad_df = load_dataset() X = [text[:500] for text in cuad_df['text'].to_list()] filenames = cuad_df['filename'].to_list() X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X) fig = generate_plot(X_transform, y, filenames) figure = st.plotly_chart(fig, use_container_width=True) if button: figure.empty() with st.spinner('⚙️ Training model...'): if not uploaded_files or not len(uploaded_files) > 1: st.write( "Please add at least two contracts" ) else: if len(uploaded_files) < 10: n_clusters = 3 else: n_clusters = 8 X_train = [uploaded_file.read()[:500] for uploaded_file in uploaded_files] filenames = [uploaded_file.name for uploaded_file in uploaded_files] tfidf_umap_kmeans = deepcopy(cuad_tfidf_umap_kmeans) tfidf_umap_kmeans.set_params(kmeans__n_clusters=4) tfidf_umap_kmeans.fit(X_train) X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X_train) fig = generate_plot(X_transform, y, filenames) st.write("**Your organised contracts:**") st.plotly_chart(fig, use_container_width=True) add_email_signup_form() add_footer() streamlit_analytics.stop_tracking(unsafe_password=os.environ["ANALYTICS_PASSWORD"])