Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- .gitattributes +2 -0
- app.py +307 -0
- contexts-emb.txt +3 -0
- requirements.txt +5 -0
- synthetic-dataset.csv +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
contexts-emb.txt filter=lfs diff=lfs merge=lfs -text
|
36 |
+
synthetic-dataset.csv filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
# import faiss
|
7 |
+
from sentence_transformers import util, LoggingHandler
|
8 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
9 |
+
import streamlit as st
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def get_embeddings_from_contexts(model, contexts): # for embeddings
|
14 |
+
"""
|
15 |
+
It takes a list of contexts and returns a list of embeddings
|
16 |
+
|
17 |
+
:param model: the model you want to use to get the embeddings
|
18 |
+
:param contexts: a list of strings, each string is a context
|
19 |
+
:return: The embeddings of the contexts
|
20 |
+
"""
|
21 |
+
return model.encode(contexts)
|
22 |
+
|
23 |
+
def load_semantic_search_model(model_name):
|
24 |
+
"""
|
25 |
+
It loads the model
|
26 |
+
|
27 |
+
:param model_name: The name of the model to load
|
28 |
+
:return: A sentence transformer object
|
29 |
+
"""
|
30 |
+
from sentence_transformers import SentenceTransformer
|
31 |
+
|
32 |
+
return SentenceTransformer(model_name)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def convert_embeddings_to_faiss_index(embeddings, context_ids):
|
37 |
+
"""
|
38 |
+
We take in a list of embeddings and a list of context IDs, convert the embeddings to a numpy array,
|
39 |
+
instantiate a flat index, pass the index to IndexIDMap, add the embeddings and their IDs to the
|
40 |
+
index, instantiate the resources, and move the index to the GPU
|
41 |
+
|
42 |
+
:param embeddings: The embeddings you want to convert to a faiss index
|
43 |
+
:param context_ids: The IDs of the contexts
|
44 |
+
:return: A GPU index
|
45 |
+
"""
|
46 |
+
embeddings = np.array(embeddings).astype("float32") # Step 1: Change data type
|
47 |
+
|
48 |
+
index = faiss.IndexFlatIP(embeddings.shape[1]) # Step 2: Instantiate the index
|
49 |
+
index = faiss.IndexIDMap(index) # Step 3: Pass the index to IndexIDMap
|
50 |
+
|
51 |
+
index.add_with_ids(embeddings, context_ids) # Step 4: Add vectors and their IDs
|
52 |
+
|
53 |
+
res = faiss.StandardGpuResources() # Step 5: Instantiate the resources
|
54 |
+
gpu_index = faiss.index_cpu_to_gpu(
|
55 |
+
res, 0, index
|
56 |
+
) # Step 6: Move the index to the GPU
|
57 |
+
return gpu_index
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
def vector_search(query, model, index, num_results=20):
|
62 |
+
"""Tranforms query to vector using a pretrained, sentence-level
|
63 |
+
model and finds similar vectors using FAISS.
|
64 |
+
"""
|
65 |
+
vector = model.encode(list(query))
|
66 |
+
D, I = index.search(np.array(vector).astype("float32"), k=num_results)
|
67 |
+
return D, I
|
68 |
+
|
69 |
+
|
70 |
+
def id2details(df, I, column):
|
71 |
+
"""Returns the paper titles based on the paper index."""
|
72 |
+
return [list(df[df.index.values == idx][column])[0] for idx in I[0]]
|
73 |
+
|
74 |
+
|
75 |
+
def combine(user_query, model, index, df, column, num_results=10):
|
76 |
+
"""
|
77 |
+
It takes a user query, a model, an index, a dataframe, and a column name, and returns the top 5
|
78 |
+
results from the dataframe
|
79 |
+
|
80 |
+
:param user_query: the query you want to search for
|
81 |
+
:param model: the model we trained above
|
82 |
+
:param index: the index of the vectorized dataframe
|
83 |
+
:param df: the dataframe containing the data
|
84 |
+
:param column: the column in the dataframe that contains the text you want to search
|
85 |
+
:param num_results: the number of results to return, defaults to 5 (optional)
|
86 |
+
:return: the top 5 results from the vector search.
|
87 |
+
"""
|
88 |
+
D, I = vector_search([user_query], model, index, num_results=num_results)
|
89 |
+
return id2details(df, I, column)
|
90 |
+
|
91 |
+
|
92 |
+
def get_context(model, query, contexts, contexts_emb, top_k=100):
|
93 |
+
"""
|
94 |
+
Given a query, a list of contexts, and their embeddings, return the top k contexts with the highest
|
95 |
+
similarity score.
|
96 |
+
|
97 |
+
:param model: the model we trained in the previous section
|
98 |
+
:param query: the query string
|
99 |
+
:param contexts: list of contexts
|
100 |
+
:param contexts_emb: the embeddings of the contexts
|
101 |
+
:param top_k: the number of contexts to return, defaults to 3 (optional)
|
102 |
+
:return: The top_context is a list of the top 3 contexts that are most similar to the query.
|
103 |
+
"""
|
104 |
+
# Encode query and contexts with the encode function
|
105 |
+
query_emb = model.encode(query)
|
106 |
+
query_emb = torch.from_numpy(query_emb.reshape(1, -1))
|
107 |
+
contexts_emb = torch.from_numpy(contexts_emb)
|
108 |
+
# Compute similiarity score between query and all contexts embeddings
|
109 |
+
scores = util.cos_sim(query_emb, contexts_emb)[0].cpu().tolist()
|
110 |
+
# Combine contexts & scores
|
111 |
+
# print(contexts)
|
112 |
+
contexts_score_pairs = list(zip(contexts.premise.tolist(), scores))
|
113 |
+
|
114 |
+
result = sorted(contexts_score_pairs, key=lambda x: x[1], reverse=True)[:top_k]
|
115 |
+
# print(result)
|
116 |
+
top_context = []
|
117 |
+
for c, s in result:
|
118 |
+
top_context.append(c)
|
119 |
+
return top_context
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
def get_answer(model, query, context):
|
124 |
+
"""
|
125 |
+
> Given a model, a query, and a context, return the answer
|
126 |
+
|
127 |
+
:param model: the model we just loaded
|
128 |
+
:param query: The question you want to ask
|
129 |
+
:param context: The context of the question
|
130 |
+
:return: A string
|
131 |
+
"""
|
132 |
+
|
133 |
+
formatted_query = f"{query}\n{context}"
|
134 |
+
res = model(formatted_query)
|
135 |
+
return res[0]["generated_text"]
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
def evaluate_semantic_model(model, question, contexts, contexts_emb, index=None):
|
140 |
+
|
141 |
+
"""
|
142 |
+
For each question, we use the model to find the most similar context.
|
143 |
+
|
144 |
+
:param model: the model we're using to evaluate
|
145 |
+
:param questions: a list of questions
|
146 |
+
:param contexts: the list of contexts
|
147 |
+
:param contexts_emb: the embeddings of the contexts
|
148 |
+
:param index: the index of the context embeddings
|
149 |
+
:return: The predictions are being returned.
|
150 |
+
"""
|
151 |
+
predictions = combine(question, model, index, contexts, "premise") if index else get_context(model, question, contexts, contexts_emb) #for cosine
|
152 |
+
|
153 |
+
|
154 |
+
return predictions
|
155 |
+
|
156 |
+
@st.experimental_singleton
|
157 |
+
def load_models():
|
158 |
+
|
159 |
+
semantic_search_model = load_semantic_search_model("distiluse-base-multilingual-cased-v1")
|
160 |
+
|
161 |
+
model_nli_stsb = CrossEncoder('ssilwal/nli-stsb-fr', max_length=512, device='cpu')
|
162 |
+
|
163 |
+
model_nli = CrossEncoder('ssilwal/CASS-civile-nli', max_length=512, device='cpu')
|
164 |
+
|
165 |
+
model_baseline = CrossEncoder('amberoad/bert-multilingual-passage-reranking-msmarco', max_length=512, device='cpu')
|
166 |
+
|
167 |
+
df = pd.read_csv('synthetic-dataset.csv')
|
168 |
+
contexts = df.premise.unique()
|
169 |
+
contexts = pd.DataFrame(contexts, columns = ['premise'])
|
170 |
+
context_emb = np.loadtxt('contexts-emb.txt', dtype=np.float32)
|
171 |
+
|
172 |
+
return semantic_search_model, model_nli, model_nli_stsb, model_baseline, contexts, context_emb
|
173 |
+
|
174 |
+
|
175 |
+
def callback(state, object):
|
176 |
+
return
|
177 |
+
# st.session_state[f'{state}']
|
178 |
+
|
179 |
+
|
180 |
+
if 'slider' not in st.session_state:
|
181 |
+
st.session_state['slider'] = 0
|
182 |
+
|
183 |
+
if 'radio' not in st.session_state:
|
184 |
+
st.session_state['radio'] = 'Model 1'
|
185 |
+
|
186 |
+
if 'show' not in st.session_state:
|
187 |
+
st.session_state['show'] = False
|
188 |
+
|
189 |
+
if 'results' not in st.session_state:
|
190 |
+
st.session_state['results'] = None
|
191 |
+
|
192 |
+
# if 'run' not in st.session_state:
|
193 |
+
# st.session_state['run'] = True
|
194 |
+
|
195 |
+
# if 'radio' not in st.session_state:
|
196 |
+
# st.session_state['radio'] = 'Model 1'
|
197 |
+
|
198 |
+
|
199 |
+
semantic_search_model, model_nli, model_nli_stsb, model_baseline, contexts, context_emb = load_models()
|
200 |
+
|
201 |
+
@st.cache(suppress_st_warning=True)
|
202 |
+
def run_inference(model_name, query):
|
203 |
+
|
204 |
+
|
205 |
+
pred = evaluate_semantic_model(
|
206 |
+
semantic_search_model,
|
207 |
+
query,
|
208 |
+
contexts,
|
209 |
+
context_emb,
|
210 |
+
# index,
|
211 |
+
# #if u want to use faiss
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
# So we create the respective sentence combinations
|
216 |
+
sentence_combinations = [[query, corpus_sentence] for corpus_sentence in pred]
|
217 |
+
|
218 |
+
# Compute the similarity scores for these combinations
|
219 |
+
|
220 |
+
if model_name=='Model 1':
|
221 |
+
similarity_scores = model_nli.predict(sentence_combinations)
|
222 |
+
scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
|
223 |
+
sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
|
224 |
+
results = [pred[idx] for _,idx in list(sim_scores_argsort)[:int(top_K)]]
|
225 |
+
|
226 |
+
if model_name=='Model 2':
|
227 |
+
similarity_scores = model_nli_stsb.predict(sentence_combinations)
|
228 |
+
sim_scores_argsort = reversed(np.argsort(similarity_scores))
|
229 |
+
results = [pred[idx] for idx in list(sim_scores_argsort)[:int(top_K)]]
|
230 |
+
|
231 |
+
if model_name=='Model 3':
|
232 |
+
similarity_scores = model_baseline.predict(sentence_combinations)
|
233 |
+
scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
|
234 |
+
sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
|
235 |
+
results = [pred[idx] for _,idx in list(sim_scores_argsort)[:int(top_K)]]
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
return results
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
# only need for faiss index
|
247 |
+
# index = convert_embeddings_to_faiss_index(context_emb, contexts.index.values)
|
248 |
+
|
249 |
+
|
250 |
+
# query = ['Quelles protections la Loi sur la protection du consommateur accorde-t-elle aux individus?']
|
251 |
+
query = st.text_input('Civil Legal Query', 'Quelles protections la Loi sur la protection du consommateur accorde-t-elle aux individus?')
|
252 |
+
top_K = st.text_input('Choose Number of Result: ','10')
|
253 |
+
|
254 |
+
|
255 |
+
model_name = st.radio(
|
256 |
+
"Choose Model",
|
257 |
+
("Model 1", "Model 2", "Model 3"),
|
258 |
+
key='radio', on_change=callback, args=('radio','Model 1')
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
if st.button('Run', key='run'):
|
263 |
+
|
264 |
+
results= run_inference(model_name, query)
|
265 |
+
|
266 |
+
st.session_state['show'] = True
|
267 |
+
st.session_state['results'] = results
|
268 |
+
st.session_state['query'] = query
|
269 |
+
model_dict = {'Model 1': 'NLI-Syn', 'Model 2': 'NLI-stsb', 'Model 3': 'NLI-baseline'}
|
270 |
+
st.session_state['model'] = model_dict[model_name]
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
if st.session_state['show'] and st.session_state['results']!=None:
|
276 |
+
st.write("-"*50)
|
277 |
+
for result in st.session_state['results']:
|
278 |
+
|
279 |
+
line = f'Context: {result}\n\n'
|
280 |
+
|
281 |
+
st.write(line)
|
282 |
+
|
283 |
+
rate = st.slider('Please rate this output', min_value= 0, max_value=5, key='slider', on_change=callback, args=('slider','0'))
|
284 |
+
|
285 |
+
if st.session_state['slider'] !=0:
|
286 |
+
rate = st.session_state['slider']
|
287 |
+
st.write(f'You rated {rate}')
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
if st.button('Submit', key='rate'):
|
292 |
+
if st.session_state['results']!=None:
|
293 |
+
item = {'query': st.session_state['query'], 'results': st.session_state['results'], 'model': st.session_state['model'],'rating': st.session_state['slider']}
|
294 |
+
try:
|
295 |
+
with open('human.json','r') as file:
|
296 |
+
import json
|
297 |
+
archive = json.load(file)
|
298 |
+
archive.append(item)
|
299 |
+
with open('human.json','w') as file:
|
300 |
+
json.dump(archive, file)
|
301 |
+
except FileNotFoundError:
|
302 |
+
import json
|
303 |
+
data = [item]
|
304 |
+
print(data)
|
305 |
+
with open('human.json','w') as file:
|
306 |
+
json.dump(data, file)
|
307 |
+
|
contexts-emb.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d46f89976c4ea5e8c573950b51c94db43b03d231c557096a8273d75cb506576
|
3 |
+
size 76331907
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pandas
|
3 |
+
torch
|
4 |
+
sentence_transformers
|
5 |
+
streamlit
|
synthetic-dataset.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b05dffb7e3b522a85fe20263c22ab91430f6e9c535705515dd6bf869a20199d
|
3 |
+
size 38688491
|