legal-ai-actions / pages /5_πŸ—‚_Organise_Demo.py
JMuscatello
Add custom model+display
6d70e63
raw
history blame
4.68 kB
import os
import joblib
from copy import deepcopy
import pandas as pd
import plotly.express as px
from huggingface_hub import hf_hub_download, snapshot_download
import streamlit as st
import streamlit_analytics
from utils import add_logo_to_sidebar, add_footer, add_email_signup_form
HF_TOKEN = os.environ.get("HF_TOKEN")
MODEL_REPO_ID = "simplexico/cuad-sklearn-contract-clustering"
DATA_REPO_ID = "simplexico/cuad-top-ten"
MODEL_FILENAME = "cuad_tfidf_umap_kmeans.pkl"
DATA_FILENAME = "cuad_top_ten_popular_contract_types.json"
streamlit_analytics.start_tracking()
st.set_page_config(
page_title="Organise Demo",
page_icon="πŸ—‚",
layout="wide",
initial_sidebar_state="expanded",
menu_items={
'Get Help': 'mailto:[email protected]',
'Report a bug': None,
'About': "## This a demo showcasing different Legal AI Actions"
}
)
add_logo_to_sidebar()
st.sidebar.success("πŸ‘† Select a demo above.")
st.title('πŸ—‚ Organise Demo')
st.write("""
This demo shows how AI can be used to organise contracts.
We've trained a model to group contracts into similar types.
The plot below shows a sample set of contracts that have been automatically grouped together.
Each point in the plot represents how the model interprets a contract, the closer together a pair of points are, the more similar they appear to the model.
Similar documents are grouped by color.
\n**TIP:** Hover over each point to see the filename of the contract. Groups can be added or removed by clicking on the symbol in the plot legend.
""")
st.write("**πŸ‘ˆ Upload your own contracts on the left (as .txt files)** and hit the button **Organise Data** to see how your own contracts can be grouped together")
@st.cache(allow_output_mutation=True)
def load_model():
model = joblib.load(
hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME, token=HF_TOKEN)
)
return model
@st.cache(allow_output_mutation=True)
def load_dataset():
snapshot_download(repo_id=DATA_REPO_ID, token=HF_TOKEN, local_dir='./', repo_type='dataset')
df = pd.read_json(DATA_FILENAME)
return df
def get_transform_and_predictions(model, X):
y = model.predict(X)
X_transform = model[:2].transform(X)
return X_transform, y
def generate_plot(X, y, filenames):
fig = px.scatter_3d(
x=X[:,0],
y=X[:,1],
z=X[:,2],
color=[str(y_i) for y_i in y], hover_name=filenames)
fig.update_traces(
marker_size=8,
marker_line=dict(width=2),
selector=dict(mode='markers')
)
fig.update_layout(
legend=dict(
title='grouping',
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
),
width=1100,
height=900
)
return fig
uploaded_files = st.sidebar.file_uploader("Select contracts to organise ", accept_multiple_files=True)
button = st.sidebar.button('Organise Contracts', type='primary', use_container_width=True)
with st.container():
with st.spinner('βš™οΈ Loading model...'):
cuad_tfidf_umap_kmeans = load_model()
cuad_df = load_dataset()
X = [text[:500] for text in cuad_df['text'].to_list()]
filenames = cuad_df['filename'].to_list()
X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X)
fig = generate_plot(X_transform, y, filenames)
figure = st.plotly_chart(fig, use_container_width=True)
if button:
figure.empty()
with st.spinner('βš™οΈ Training model...'):
if not uploaded_files or not len(uploaded_files) > 1:
st.write(
"Please add at least two contracts"
)
else:
if len(uploaded_files) < 10:
n_clusters = 3
else:
n_clusters = 8
X_train = [uploaded_file.read()[:500] for uploaded_file in uploaded_files]
filenames = [uploaded_file.name for uploaded_file in uploaded_files]
tfidf_umap_kmeans = deepcopy(cuad_tfidf_umap_kmeans)
tfidf_umap_kmeans.set_params(kmeans__n_clusters=4)
tfidf_umap_kmeans.fit(X_train)
X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X_train)
fig = generate_plot(X_transform, y, filenames)
st.write("**Your organised contracts:**")
st.plotly_chart(fig, use_container_width=True)
add_email_signup_form()
add_footer()
streamlit_analytics.stop_tracking(unsafe_password=os.environ["ANALYTICS_PASSWORD"])