JMuscatello commited on
Commit
6d70e63
β€’
1 Parent(s): 916cbfe

Add custom model+display

Browse files
Files changed (1) hide show
  1. pages/5_πŸ—‚_Organise_Demo.py +65 -13
pages/5_πŸ—‚_Organise_Demo.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import joblib
3
 
 
 
4
  import pandas as pd
5
  import plotly.express as px
6
 
@@ -39,6 +41,7 @@ This demo shows how AI can be used to organise contracts.
39
  We've trained a model to group contracts into similar types.
40
  The plot below shows a sample set of contracts that have been automatically grouped together.
41
  Each point in the plot represents how the model interprets a contract, the closer together a pair of points are, the more similar they appear to the model.
 
42
  \n**TIP:** Hover over each point to see the filename of the contract. Groups can be added or removed by clicking on the symbol in the plot legend.
43
  """)
44
  st.write("**πŸ‘ˆ Upload your own contracts on the left (as .txt files)** and hit the button **Organise Data** to see how your own contracts can be grouped together")
@@ -56,23 +59,23 @@ def load_dataset():
56
  df = pd.read_json(DATA_FILENAME)
57
  return df
58
 
59
- def get_transform_and_predictions(model, df):
60
- X = [text[:500] for text in df['text'].to_list()]
61
  y = model.predict(X)
62
  X_transform = model[:2].transform(X)
63
  return X_transform, y
64
 
65
- with st.spinner('βš™οΈ Loading model...'):
66
- cuad_tfidf_umap_kmeans = load_model()
67
- cuad_df = load_dataset()
68
-
69
- X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, cuad_df)
70
-
71
  fig = px.scatter_3d(
72
- x=X_transform[:,0],
73
- y=X_transform[:,1],
74
- z=X_transform[:,2],
75
- color=[str(y_i) for y_i in y], hover_name=cuad_df['filename'].to_list())
 
 
 
 
 
 
76
 
77
  fig.update_layout(
78
  legend=dict(
@@ -85,8 +88,57 @@ with st.spinner('βš™οΈ Loading model...'):
85
  width=1100,
86
  height=900
87
  )
88
- st.plotly_chart(fig, use_container_width=True, height=1600)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  add_email_signup_form()
91
 
92
  add_footer()
 
1
  import os
2
  import joblib
3
 
4
+ from copy import deepcopy
5
+
6
  import pandas as pd
7
  import plotly.express as px
8
 
 
41
  We've trained a model to group contracts into similar types.
42
  The plot below shows a sample set of contracts that have been automatically grouped together.
43
  Each point in the plot represents how the model interprets a contract, the closer together a pair of points are, the more similar they appear to the model.
44
+ Similar documents are grouped by color.
45
  \n**TIP:** Hover over each point to see the filename of the contract. Groups can be added or removed by clicking on the symbol in the plot legend.
46
  """)
47
  st.write("**πŸ‘ˆ Upload your own contracts on the left (as .txt files)** and hit the button **Organise Data** to see how your own contracts can be grouped together")
 
59
  df = pd.read_json(DATA_FILENAME)
60
  return df
61
 
62
+ def get_transform_and_predictions(model, X):
 
63
  y = model.predict(X)
64
  X_transform = model[:2].transform(X)
65
  return X_transform, y
66
 
67
+ def generate_plot(X, y, filenames):
 
 
 
 
 
68
  fig = px.scatter_3d(
69
+ x=X[:,0],
70
+ y=X[:,1],
71
+ z=X[:,2],
72
+ color=[str(y_i) for y_i in y], hover_name=filenames)
73
+
74
+ fig.update_traces(
75
+ marker_size=8,
76
+ marker_line=dict(width=2),
77
+ selector=dict(mode='markers')
78
+ )
79
 
80
  fig.update_layout(
81
  legend=dict(
 
88
  width=1100,
89
  height=900
90
  )
 
91
 
92
+ return fig
93
+
94
+ uploaded_files = st.sidebar.file_uploader("Select contracts to organise ", accept_multiple_files=True)
95
+
96
+ button = st.sidebar.button('Organise Contracts', type='primary', use_container_width=True)
97
+
98
+ with st.container():
99
+ with st.spinner('βš™οΈ Loading model...'):
100
+ cuad_tfidf_umap_kmeans = load_model()
101
+ cuad_df = load_dataset()
102
+
103
+ X = [text[:500] for text in cuad_df['text'].to_list()]
104
+ filenames = cuad_df['filename'].to_list()
105
+
106
+ X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X)
107
+
108
+ fig = generate_plot(X_transform, y, filenames)
109
+
110
+ figure = st.plotly_chart(fig, use_container_width=True)
111
+
112
+ if button:
113
+ figure.empty()
114
+
115
+ with st.spinner('βš™οΈ Training model...'):
116
+
117
+ if not uploaded_files or not len(uploaded_files) > 1:
118
+ st.write(
119
+ "Please add at least two contracts"
120
+ )
121
+ else:
122
+ if len(uploaded_files) < 10:
123
+ n_clusters = 3
124
+ else:
125
+ n_clusters = 8
126
+
127
+ X_train = [uploaded_file.read()[:500] for uploaded_file in uploaded_files]
128
+ filenames = [uploaded_file.name for uploaded_file in uploaded_files]
129
+
130
+ tfidf_umap_kmeans = deepcopy(cuad_tfidf_umap_kmeans)
131
+ tfidf_umap_kmeans.set_params(kmeans__n_clusters=4)
132
+ tfidf_umap_kmeans.fit(X_train)
133
+
134
+ X_transform, y = get_transform_and_predictions(cuad_tfidf_umap_kmeans, X_train)
135
+
136
+ fig = generate_plot(X_transform, y, filenames)
137
+
138
+ st.write("**Your organised contracts:**")
139
+
140
+ st.plotly_chart(fig, use_container_width=True)
141
+
142
  add_email_signup_form()
143
 
144
  add_footer()