luis-mi commited on
Commit
c5685a6
β€’
1 Parent(s): b2285b0

Upload 24 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ static/images/screen_recording_busqueda_final_2.gif filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.11-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install any needed packages specified in requirements.txt
8
+ COPY requirements.txt /app/
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the rest of your application's code
12
+ COPY . /app
13
+
14
+ # Make port 8501 available to the world outside this container
15
+ EXPOSE 7860
16
+
17
+ # Run app.py when the container launches, use environment variables
18
+ CMD ["streamlit", "run", "Home.py", "--server.address=0.0.0.0", "--server.port=8501"]
Home.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import base64
3
+
4
+ ## PAGE CONFIGURATION
5
+ st.set_page_config(page_title="BΓΊsqueda Aumentada MSM para Impuestos Especiales",
6
+ page_icon="πŸ”",
7
+ layout="centered",
8
+ initial_sidebar_state="auto",
9
+ menu_items=None)
10
+
11
+ st.image('./static/images/cervezas-mahou.jpeg', width=700,)
12
+
13
+ # Mensaje de bienvenida
14
+ st.markdown(
15
+ """
16
+ # Β‘Bienvenido a BΓΊsqueda Aumentada MSM para Impuestos Especiales! πŸ”πŸ“š
17
+
18
+ Esta aplicaciΓ³n es una herramienta diseΓ±ada especΓ­ficamente para la exploraciΓ³n y anΓ‘lisis de datos en el Γ‘mbito de Impuestos Especiales utilizando el poder de la Inteligencia Artificial.
19
+
20
+ **πŸ‘ˆ Selecciona una opciΓ³n en la barra lateral** para comenzar a explorar las diferentes funcionalidades que ofrece la aplicaciΓ³n.
21
+ """)
22
+ file_ = open('./static/images/screen_recording_busqueda_final_2.gif', "rb")
23
+ contents = file_.read()
24
+ data_url = base64.b64encode(contents).decode("utf-8")
25
+ file_.close()
26
+
27
+ st.subheader("Uso de la AplicaciΓ³n: πŸ” Busqueda Aumentada")
28
+ st.caption("Observa en acciΓ³n cΓ³mo la busqueda aumentada con una potente IA simplifica la bΓΊsqueda de informaciΓ³n, todo con una interfaz de usuario facΓ­l de usar.")
29
+ st.markdown(
30
+ f'<div style="text-align: center;"><img src="data:image/gif;base64,{data_url}" alt="demo gif" style="max-width: 100%; height: auto;"></div>',
31
+ unsafe_allow_html=True,
32
+ )
33
+
34
+ st.markdown("""
35
+
36
+ ### ΒΏQuieres aprender mΓ‘s?
37
+ - Visita nuestra [pΓ‘gina web](https://tupagina.com)
38
+ - SumΓ©rgete en nuestra [documentaciΓ³n](https://tudocumentacion.com)
39
+ - Participa y pregunta en nuestros [foros comunitarios](https://tucomunidad.com)
40
+
41
+ ### Explora demos mΓ‘s complejos
42
+ - Descubre cΓ³mo aplicamos la IA para [analizar datasets especializados](https://tulinkdedataset.com)
43
+ - Explora [bases de datos de acceso pΓΊblico](https://tulinkdedatasetpublico.com) y ve la IA en acciΓ³n
44
+ """,
45
+ unsafe_allow_html=True
46
+ )
data/1_IIEE_1_json_data_19_02_2024_22-17-49.json ADDED
The diff for this file is too large to render. See raw diff
 
pages/1_πŸ”_Busqueda_Aumentada.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tiktoken import get_encoding, encoding_for_model
2
+ from utils.weaviate_interface_v3_spa import WeaviateClient, WhereFilter
3
+ from templates.prompt_templates_spa import question_answering_prompt_series_spa
4
+ from utils.openai_interface_spa import GPT_Turbo
5
+ from openai import BadRequestError
6
+ from utils.app_features_spa import (convert_seconds, generate_prompt_series, search_result,
7
+ validate_token_threshold, load_content_cache, load_data, expand_content)
8
+ from utils.reranker_spa import ReRanker
9
+ from loguru import logger
10
+ import streamlit as st
11
+ import os
12
+
13
+ # load environment variables
14
+ from dotenv import load_dotenv
15
+ load_dotenv('.env', override=True)
16
+
17
+ ## PAGE CONFIGURATION
18
+ st.set_page_config(page_title="Busqueda Aumentada",
19
+ page_icon="πŸ”",
20
+ layout="wide",
21
+ initial_sidebar_state="auto",
22
+ menu_items=None)
23
+
24
+ ## DATA + CACHE
25
+ data_path = 'data/1_IIEE_1_json_data_19_02_2024_22-17-49.json'
26
+ cache_path = ''
27
+ data = load_data(data_path)
28
+ cache = None # Initialize cache as None
29
+
30
+ # Check if the cache file exists before attempting to load it
31
+ if os.path.exists(cache_path):
32
+ cache = load_content_cache(cache_path)
33
+ else:
34
+ logger.warning(f"Cache file {cache_path} not found. Proceeding without cache.")
35
+
36
+ #creates list of guests for sidebar
37
+ guest_list = sorted(list(set([d['document_title'] for d in data])))
38
+
39
+ with st.sidebar:
40
+ st.subheader("Selecciona tu Base de datos πŸ—ƒοΈ")
41
+ client_type = st.radio(
42
+ "Selecciona el modo de acceso:",
43
+ ('Cloud', 'Local'),
44
+ help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarΓ‘s tu bΓΊsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu mΓ‘quina.'
45
+ )
46
+ if client_type == 'Cloud':
47
+ api_key = st.secrets['WEAVIATE_CLOUD_API_KEY']
48
+ url = st.secrets['WEAVIATE_CLOUD_ENDPOINT']
49
+
50
+ weaviate_client = WeaviateClient(
51
+ endpoint=url,
52
+ api_key=api_key,
53
+ # model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
54
+ model_name_or_path="intfloat/multilingual-e5-small",
55
+ # openai_api_key=os.environ['OPENAI_API_KEY']
56
+ )
57
+ available_classes=sorted(weaviate_client.show_classes())
58
+ logger.info(available_classes)
59
+ logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
60
+ elif client_type == 'Local':
61
+ url = st.secrets['WEAVIATE_LOCAL_ENDPOINT']
62
+ weaviate_client = WeaviateClient(
63
+ endpoint=url,
64
+ # api_key=api_key,
65
+ # model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
66
+ model_name_or_path="intfloat/multilingual-e5-small",
67
+ # openai_api_key=os.environ['OPENAI_API_KEY']
68
+ )
69
+ available_classes=sorted(weaviate_client.show_classes())
70
+ logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
71
+
72
+ def main():
73
+
74
+ # Define the available user selected options
75
+ available_models = ['gpt-3.5-turbo', 'gpt-4-1106-preview']
76
+ # Define system prompts
77
+
78
+ # Initialize selected options in session state
79
+ if "openai_data_model" not in st.session_state:
80
+ st.session_state["openai_data_model"] = available_models[0]
81
+
82
+ if 'class_name' not in st.session_state:
83
+ st.session_state['class_name'] = None
84
+
85
+ with st.sidebar:
86
+ st.session_state['class_name'] = st.selectbox(
87
+ label='Repositorio:',
88
+ options=available_classes,
89
+ index=None,
90
+ placeholder='Repositorio',
91
+ help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarΓ‘s tu bΓΊsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu mΓ‘quina.'
92
+ )
93
+ # Check if the collection name has been selected
94
+ class_name = st.session_state['class_name']
95
+ if class_name:
96
+ st.success(f"Repositorio seleccionado βœ…: {st.session_state['class_name']}")
97
+
98
+ else:
99
+ st.warning("πŸŽ—οΈ No olvides seleccionar el repositorio πŸ‘† a consultar πŸ—„οΈ.")
100
+ st.stop() # Stop execution of the script
101
+
102
+ model_choice = st.selectbox(
103
+ label="Elige un modelo de OpenAI",
104
+ options=available_models,
105
+ index= available_models.index(st.session_state["openai_data_model"]),
106
+ help='Escoge entre diferentes modelos de OpenAI para generar respuestas a tus consultas. Cada modelo tiene distintas capacidades y limitaciones.'
107
+ )
108
+ st.sidebar.make_llm_call = st.checkbox(
109
+ label="Activar GPT",
110
+ help='Marca esta casilla para activar la generaciΓ³n de texto con GPT. Esto te permitirΓ‘ obtener respuestas automΓ‘ticas a tus consultas.'
111
+ )
112
+
113
+ with st.expander("Filtros de Busqueda"):
114
+ guest_input = st.selectbox(
115
+ label='SelecciΓ³n de documentos',
116
+ options=guest_list,
117
+ index=None,
118
+ placeholder='Documento',
119
+ help='Elige un documento especΓ­fico del repositorio para afinar tu bΓΊsqueda a datos relevantes.'
120
+ )
121
+
122
+ with st.expander("Parametros de Busqueda"):
123
+ retriever_choice = st.selectbox(
124
+ label="Selecciona un mΓ©todo",
125
+ options=["Hybrid", "Vector", "Keyword"],
126
+ help='Determina el mΓ©todo de recuperaciΓ³n de informaciΓ³n: "Hybrid" combina bΓΊsqueda por palabras clave y por similitud semΓ‘ntica, "Vector" usa embeddings de texto para encontrar coincidencias semΓ‘nticas, y "Keyword" realiza una bΓΊsqueda tradicional por palabras clave.'
127
+ )
128
+
129
+ reranker_enabled = st.checkbox(
130
+ label="Activar Reranker",
131
+ value=True,
132
+ help='Activa esta opciΓ³n para ordenar los resultados de la bΓΊsqueda segΓΊn su relevancia, utilizando un modelo de reordenamiento adicional.'
133
+ )
134
+
135
+ alpha_input = st.slider(
136
+ label='Alpha para motor hibrido',
137
+ min_value=0.00,
138
+ max_value=1.00,
139
+ value=0.40,
140
+ step=0.05,
141
+ help='Ajusta el parΓ‘metro alfa para equilibrar los resultados entre los mΓ©todos de bΓΊsqueda por vector y por palabra clave en el motor hΓ­brido.'
142
+ )
143
+
144
+ retrieval_limit = st.slider(
145
+ label='Resultados a Reranker',
146
+ min_value=10,
147
+ max_value=300,
148
+ value=100,
149
+ step=10,
150
+ help='Establece el nΓΊmero de resultados que se recuperarΓ‘n antes de aplicar el reordenamiento.'
151
+ )
152
+
153
+ top_k_limit = st.slider(
154
+ label='Top K Limit',
155
+ min_value=1,
156
+ max_value=5,
157
+ value=3,
158
+ step=1,
159
+ help='Define el nΓΊmero mΓ‘ximo de resultados a mostrar despuΓ©s de aplicar el reordenamiento.'
160
+ )
161
+
162
+ temperature_input = st.slider(
163
+ label='Temperatura',
164
+ min_value=0.0,
165
+ max_value=1.0,
166
+ value=0.10,
167
+ step=0.10,
168
+ help='Ajusta la temperatura para la generaciΓ³n de texto con GPT, lo que influirΓ‘ en la creatividad de las respuestas.'
169
+ )
170
+
171
+ logger.info(weaviate_client.display_properties)
172
+
173
+ def perform_search(client, retriever_choice, query, class_name, search_limit, guest_filter, display_properties, alpha_input):
174
+ if retriever_choice == "Keyword":
175
+ return weaviate_client.keyword_search(
176
+ request=query,
177
+ class_name=class_name,
178
+ limit=search_limit,
179
+ where_filter=guest_filter,
180
+ display_properties=display_properties
181
+ ), "Resultados de la Busqueda - Motor: Keyword: "
182
+ elif retriever_choice == "Vector":
183
+ return weaviate_client.vector_search(
184
+ request=query,
185
+ class_name=class_name,
186
+ limit=search_limit,
187
+ where_filter=guest_filter,
188
+ display_properties=display_properties
189
+ ), "Resultados de la Busqueda - Motor: Vector"
190
+ elif retriever_choice == "Hybrid":
191
+ return weaviate_client.hybrid_search(
192
+ request=query,
193
+ class_name=class_name,
194
+ alpha=alpha_input,
195
+ limit=search_limit,
196
+ properties=["content"],
197
+ where_filter=guest_filter,
198
+ display_properties=display_properties
199
+ ), "Resultados de la Busqueda - Motor: Hybrid"
200
+
201
+
202
+ ## RERANKER
203
+ reranker = ReRanker(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2')
204
+
205
+ ## LLM
206
+ model_name = model_choice
207
+ llm = GPT_Turbo(model=model_name, api_key=st.secrets['OPENAI_API_KEY'])
208
+ encoding = encoding_for_model(model_name)
209
+
210
+
211
+ ########################
212
+ ## SETUP MAIN DISPLAY ##
213
+ ########################
214
+ st.image('./static/images/cervezas-mahou.jpeg', width=300)
215
+ st.subheader(f"βœ¨πŸ”πŸ“š **BΓΊsqueda Aumentada** πŸ“–πŸ”βœ¨ Impuestos Especiales ")
216
+ st.caption("Descubre insights ocultos y responde a tus preguntas especializadas utilizando el poder de la IA")
217
+ st.write('\n')
218
+
219
+ query = st.text_input('Escribe tu pregunta aquΓ­: ')
220
+ st.write('\n\n\n\n\n')
221
+
222
+ ############
223
+ ## SEARCH ##
224
+ ############
225
+ if query:
226
+ # make hybrid call to weaviate
227
+ guest_filter = WhereFilter(
228
+ path=['document_title'],
229
+ operator='Equal',
230
+ valueText=guest_input).todict() if guest_input else None
231
+
232
+
233
+ # Determine the appropriate limit based on reranking
234
+ search_limit = retrieval_limit if reranker_enabled else top_k_limit
235
+
236
+ # Perform the search
237
+ query_response, subheader_msg = perform_search(
238
+ client=weaviate_client,
239
+ retriever_choice=retriever_choice,
240
+ query=query,
241
+ class_name=class_name,
242
+ search_limit=search_limit,
243
+ guest_filter=guest_filter,
244
+ display_properties=weaviate_client.display_properties,
245
+ alpha_input=alpha_input if retriever_choice == "Hybrid" else None
246
+ )
247
+
248
+
249
+ # Rerank the results if enabled
250
+ if reranker_enabled:
251
+ search_results = reranker.rerank(
252
+ results=query_response,
253
+ query=query,
254
+ apply_sigmoid=True,
255
+ top_k=top_k_limit
256
+ )
257
+ subheader_msg += " Reranked"
258
+ else:
259
+ # Use the results directly if reranking is not enabled
260
+ search_results = query_response
261
+
262
+ logger.info(search_results)
263
+ expanded_response = expand_content(search_results, cache, content_key='doc_id', create_new_list=True)
264
+
265
+ # validate token count is below threshold
266
+ token_threshold = 8000 if model_name == 'gpt-3.5-turbo-16k' else 3500
267
+ valid_response = validate_token_threshold(
268
+ ranked_results=expanded_response,
269
+ base_prompt=question_answering_prompt_series_spa,
270
+ query=query,
271
+ tokenizer=encoding,
272
+ token_threshold=token_threshold,
273
+ verbose=True
274
+ )
275
+ logger.info(valid_response)
276
+ #########
277
+ ## LLM ##
278
+ #########
279
+ make_llm_call = st.sidebar.make_llm_call
280
+ # prep for streaming response
281
+ st.subheader("Respuesta GPT:")
282
+ with st.spinner('Generando Respuesta...'):
283
+ st.markdown("----")
284
+ # Creates container for LLM response
285
+ chat_container, response_box = [], st.empty()
286
+
287
+ # generate LLM prompt
288
+ prompt = generate_prompt_series(query=query, results=valid_response)
289
+ # logger.info(prompt)
290
+ if make_llm_call:
291
+
292
+ try:
293
+ for resp in llm.get_chat_completion(
294
+ prompt=prompt,
295
+ temperature=temperature_input,
296
+ max_tokens=350, # expand for more verbose answers
297
+ show_response=True,
298
+ stream=True):
299
+
300
+ # inserts chat stream from LLM
301
+ with response_box:
302
+ content = resp.choices[0].delta.content
303
+ if content:
304
+ chat_container.append(content)
305
+ result = "".join(chat_container).strip()
306
+ st.write(f'{result}')
307
+ except BadRequestError:
308
+ logger.info('Making request with smaller context...')
309
+ valid_response = validate_token_threshold(
310
+ ranked_results=search_results,
311
+ base_prompt=question_answering_prompt_series_spa,
312
+ query=query,
313
+ tokenizer=encoding,
314
+ token_threshold=token_threshold,
315
+ verbose=True
316
+ )
317
+
318
+ # generate LLM prompt
319
+ prompt = generate_prompt_series(query=query, results=valid_response)
320
+ for resp in llm.get_chat_completion(
321
+ prompt=prompt,
322
+ temperature=temperature_input,
323
+ max_tokens=350, # expand for more verbose answers
324
+ show_response=True,
325
+ stream=True):
326
+
327
+ try:
328
+ # inserts chat stream from LLM
329
+ with response_box:
330
+ content = resp.choices[0].delta.content
331
+ if content:
332
+ chat_container.append(content)
333
+ result = "".join(chat_container).strip()
334
+ st.write(f'{result}')
335
+ except Exception as e:
336
+ print(e)
337
+
338
+ ####################
339
+ ## Search Results ##
340
+ ####################
341
+ st.subheader(subheader_msg)
342
+ for i, hit in enumerate(search_results):
343
+ col1, col2 = st.columns([7, 3], gap='large')
344
+ page_url = hit['page_url']
345
+ page_label = hit['page_label']
346
+ document_title = hit['document_title']
347
+ # Assuming 'page_summary' is available and you want to display it
348
+ page_summary = hit.get('page_summary', 'Summary not available')
349
+
350
+ with col1:
351
+ st.markdown(f'''
352
+ <span style="color: #3498db; font-size: 19px; font-weight: bold;">{document_title}</span><br>
353
+ {page_summary}
354
+ [**Página:** {page_label}]({page_url})
355
+ ''', unsafe_allow_html=True)
356
+
357
+ with st.expander("πŸ“„ Clic aquΓ­ para ver contexto:"):
358
+ try:
359
+ content = hit['content']
360
+ st.write(content)
361
+ except Exception as e:
362
+ st.write(f"Error displaying content: {e}")
363
+
364
+ # with col2:
365
+ # # If you have an image or want to display a placeholder image
366
+ # image = "URL_TO_A_PLACEHOLDER_IMAGE" # Replace with a relevant image URL if applicable
367
+ # st.image(image, caption=document_title, width=200, use_column_width=False)
368
+ # st.markdown(f'''
369
+ # <p style="text-align: right;">
370
+ # <b>Document Title:</b> {document_title}<br>
371
+ # <b>File Name:</b> {file_name}<br>
372
+ # </p>''', unsafe_allow_html=True)
373
+
374
+
375
+
376
+ if __name__ == '__main__':
377
+ main()
pages/2_πŸ—£_Busqueda_Conversacional.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tiktoken import get_encoding, encoding_for_model
2
+ from utils.weaviate_interface_v3_spa import WeaviateClient, WhereFilter
3
+ from templates.prompt_templates_spa import question_answering_prompt_series_spa
4
+ from utils.openai_interface_spa import GPT_Turbo
5
+ from openai import BadRequestError
6
+ from utils.app_features_spa import (convert_seconds, generate_prompt_series, search_result,
7
+ validate_token_threshold, load_content_cache, load_data, expand_content)
8
+ from utils.reranker_spa import ReRanker
9
+ from openai import OpenAI
10
+
11
+ from loguru import logger
12
+ import streamlit as st
13
+ import os
14
+ import templates.system_prompts as system_prompts
15
+ import base64
16
+ import json
17
+
18
+ # load environment variables
19
+ from dotenv import load_dotenv
20
+ load_dotenv('.env', override=True)
21
+
22
+ ## PAGE CONFIGURATION
23
+ st.set_page_config(page_title="Busqueda Conversacional",
24
+ page_icon="πŸ—£",
25
+ layout="wide",
26
+ initial_sidebar_state="auto",
27
+ menu_items=None)
28
+
29
+ def encode_image(uploaded_file):
30
+ return base64.b64encode(uploaded_file.getvalue()).decode('utf-8')
31
+
32
+ ## DATA + CACHE
33
+ data_path = 'data/1_IIEE_1_json_data_19_02_2024_22-17-49.json'
34
+ cache_path = ''
35
+ data = load_data(data_path)
36
+ cache = None # Initialize cache as None
37
+
38
+ # Check if the cache file exists before attempting to load it
39
+ if os.path.exists(cache_path):
40
+ cache = load_content_cache(cache_path)
41
+ else:
42
+ logger.warning(f"Cache file {cache_path} not found. Proceeding without cache.")
43
+
44
+ #creates list of guests for sidebar
45
+ guest_list = sorted(list(set([d['document_title'] for d in data])))
46
+
47
+ with st.sidebar:
48
+ st.subheader("Selecciona tu Base de datos πŸ—ƒοΈ")
49
+ client_type = st.radio(
50
+ "Selecciona el modo de acceso:",
51
+ ('Cloud', 'Local'),
52
+ help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarΓ‘s tu bΓΊsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu mΓ‘quina.'
53
+ )
54
+ if client_type == 'Cloud':
55
+ api_key = st.secrets['WEAVIATE_CLOUD_API_KEY']
56
+ url = st.secrets['WEAVIATE_CLOUD_ENDPOINT']
57
+
58
+ weaviate_client = WeaviateClient(
59
+ endpoint=url,
60
+ api_key=api_key,
61
+ # model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
62
+ model_name_or_path="intfloat/multilingual-e5-small",
63
+ # openai_api_key=os.environ['OPENAI_API_KEY']
64
+ )
65
+ available_classes=sorted(weaviate_client.show_classes())
66
+ logger.info(available_classes)
67
+ logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
68
+ elif client_type == 'Local':
69
+ url = st.secrets['WEAVIATE_LOCAL_ENDPOINT']
70
+ weaviate_client = WeaviateClient(
71
+ endpoint=url,
72
+ # api_key=api_key,
73
+ # model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
74
+ model_name_or_path="intfloat/multilingual-e5-small",
75
+ # openai_api_key=os.environ['OPENAI_API_KEY']
76
+ )
77
+ available_classes=sorted(weaviate_client.show_classes())
78
+ logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
79
+
80
+ client = OpenAI(api_key=st.secrets["OPENAI_API_KEY"])
81
+
82
+ def main():
83
+
84
+ # Define the available user selected options
85
+ available_models = ['gpt-3.5-turbo', 'gpt-4-1106-preview']
86
+ # Define system prompts
87
+ system_prompt_list = ["πŸ€–ChatGPT","πŸ§™πŸΎβ€β™‚οΈProfessor Synapse", "πŸ‘©πŸΌβ€πŸ’ΌMarketing Jane"]
88
+
89
+
90
+ # Initialize selected options in session state
91
+ if "openai_data_model" not in st.session_state:
92
+ st.session_state["openai_data_model"] = available_models[0]
93
+ if "system_prompt_data_list" not in st.session_state and "system_prompt_data_model" not in st.session_state:
94
+ # This should be the emoji string the user selected
95
+ st.session_state["system_prompt_data_list"] = system_prompt_list[0]
96
+ # Now we get the corresponding prompt variable using the selected emoji string
97
+ st.session_state["system_prompt_data_model"] = system_prompts.prompt_mapping[system_prompt_list[0]]
98
+
99
+ # logger.debug(f"Assistant: {st.session_state['system_prompt_sync_list']}")
100
+ # logger.debug(f"System Prompt: {st.session_state['system_prompt_sync_model']}")
101
+
102
+ if 'class_name' not in st.session_state:
103
+ st.session_state['class_name'] = None
104
+
105
+ with st.sidebar:
106
+ st.session_state['class_name'] = st.selectbox(
107
+ label='Repositorio:',
108
+ options=available_classes,
109
+ index=None,
110
+ placeholder='Repositorio',
111
+ help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarΓ‘s tu bΓΊsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu mΓ‘quina.'
112
+ )
113
+
114
+ # Check if the collection name has been selected
115
+ class_name = st.session_state['class_name']
116
+ if class_name:
117
+ st.success(f"Repositorio seleccionado βœ…: {st.session_state['class_name']}")
118
+
119
+ else:
120
+ st.warning("πŸŽ—οΈ No olvides seleccionar el repositorio πŸ‘† a consultar πŸ—„οΈ.")
121
+ st.stop() # Stop execution of the script
122
+
123
+ model_choice = st.selectbox(
124
+ label="Elige un modelo de OpenAI",
125
+ options=available_models,
126
+ index= available_models.index(st.session_state["openai_data_model"]),
127
+ help='Escoge entre diferentes modelos de OpenAI para generar respuestas a tus consultas. Cada modelo tiene distintas capacidades y limitaciones.'
128
+ )
129
+
130
+ system_prompt = st.selectbox(
131
+ label="Elige un asistente",
132
+ options=system_prompt_list,
133
+ index=system_prompt_list.index(st.session_state["system_prompt_data_list"]),
134
+ )
135
+
136
+ with st.expander("Filtros de Busqueda"):
137
+ guest_input = st.selectbox(
138
+ label='SelecciΓ³n de Documento',
139
+ options=guest_list,
140
+ index=None,
141
+ placeholder='Documentos',
142
+ help='Elige un documento especΓ­fico del repositorio para afinar tu bΓΊsqueda a datos relevantes.'
143
+ )
144
+ with st.expander("Parametros de Busqueda"):
145
+ retriever_choice = st.selectbox(
146
+ label="Selecciona un mΓ©todo",
147
+ options=["Hybrid", "Vector", "Keyword"],
148
+ help='Determina el mΓ©todo de recuperaciΓ³n de informaciΓ³n: "Hybrid" combina bΓΊsqueda por palabras clave y por similitud semΓ‘ntica, "Vector" usa embeddings de texto para encontrar coincidencias semΓ‘nticas, y "Keyword" realiza una bΓΊsqueda tradicional por palabras clave.'
149
+ )
150
+
151
+ reranker_enabled = st.checkbox(
152
+ label="Activar Reranker",
153
+ value=True,
154
+ help='Activa esta opciΓ³n para ordenar los resultados de la bΓΊsqueda segΓΊn su relevancia, utilizando un modelo de reordenamiento adicional.'
155
+ )
156
+
157
+ alpha_input = st.slider(
158
+ label='Alpha para motor hibrido',
159
+ min_value=0.00,
160
+ max_value=1.00,
161
+ value=0.40,
162
+ step=0.05,
163
+ help='Ajusta el parΓ‘metro alfa para equilibrar los resultados entre los mΓ©todos de bΓΊsqueda por vector y por palabra clave en el motor hΓ­brido.'
164
+ )
165
+
166
+ retrieval_limit = st.slider(
167
+ label='Resultados a Reranker',
168
+ min_value=10,
169
+ max_value=300,
170
+ value=100,
171
+ step=10,
172
+ help='Establece el nΓΊmero de resultados que se recuperarΓ‘n antes de aplicar el reordenamiento.'
173
+ )
174
+
175
+ top_k_limit = st.slider(
176
+ label='Top K Limit',
177
+ min_value=1,
178
+ max_value=5,
179
+ value=3,
180
+ step=1,
181
+ help='Define el nΓΊmero mΓ‘ximo de resultados a mostrar despuΓ©s de aplicar el reordenamiento.'
182
+ )
183
+
184
+ temperature_input = st.slider(
185
+ label='Temperatura',
186
+ min_value=0.0,
187
+ max_value=1.0,
188
+ value=0.20,
189
+ step=0.10,
190
+ help='Ajusta la temperatura para la generaciΓ³n de texto con GPT, lo que influirΓ‘ en la creatividad de las respuestas.'
191
+ )
192
+
193
+ # Update the model choice in session state
194
+ if st.session_state["openai_data_model"]!=model_choice:
195
+ st.session_state["openai_data_model"] = model_choice
196
+ logger.info(f"Data model: {st.session_state['openai_data_model']}")
197
+
198
+ # Update the system prompt choice in session state
199
+ if st.session_state["system_prompt_data_list"] != system_prompt:
200
+ # This should be the emoji string the user selected
201
+ st.session_state["system_prompt_data_list"] = system_prompt
202
+ # Now we get the corresponding prompt variable using the selected emoji string
203
+ selected_prompt_variable = system_prompts.prompt_mapping[system_prompt]
204
+ st.session_state['system_prompt_data_model'] = selected_prompt_variable
205
+ # logger.info(f"System Prompt: {selected_prompt_variable}")
206
+ logger.info(f"Assistant: {st.session_state['system_prompt_data_list']}")
207
+ # logger.info(f"System Prompt: {st.session_state['system_prompt_sync_model']}")
208
+
209
+ logger.info(weaviate_client.display_properties)
210
+
211
+ def database_search(query):
212
+ # Determine the appropriate limit based on reranking
213
+ search_limit = retrieval_limit if reranker_enabled else top_k_limit
214
+
215
+ # make hybrid call to weaviate
216
+ guest_filter = WhereFilter(
217
+ path=['document_title'],
218
+ operator='Equal',
219
+ valueText=guest_input).todict() if guest_input else None
220
+
221
+ try:
222
+ # Perform the search based on retriever_choice
223
+ if retriever_choice == "Keyword":
224
+ query_results = weaviate_client.keyword_search(
225
+ request=query,
226
+ class_name=class_name,
227
+ limit=search_limit,
228
+ where_filter=guest_filter
229
+ )
230
+ elif retriever_choice == "Vector":
231
+ query_results = weaviate_client.vector_search(
232
+ request=query,
233
+ class_name=class_name,
234
+ limit=search_limit,
235
+ where_filter=guest_filter
236
+ )
237
+ elif retriever_choice == "Hybrid":
238
+ query_results = weaviate_client.hybrid_search(
239
+ request=query,
240
+ class_name=class_name,
241
+ alpha=alpha_input,
242
+ limit=search_limit,
243
+ properties=["content"],
244
+ where_filter=guest_filter
245
+ )
246
+ else:
247
+ return json.dumps({"error": "Invalid retriever choice"})
248
+
249
+
250
+ ## RERANKER
251
+ reranker = ReRanker(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2')
252
+ model_name = model_choice
253
+ encoding = encoding_for_model(model_name)
254
+
255
+ # Rerank the results if enabled
256
+ if reranker_enabled:
257
+ search_results = reranker.rerank(
258
+ results=query_results,
259
+ query=query,
260
+ apply_sigmoid=True,
261
+ top_k=top_k_limit
262
+ )
263
+
264
+ else:
265
+ # Use the results directly if reranking is not enabled
266
+ search_results = query_results
267
+
268
+ # logger.debug(search_results)
269
+ # Save search results to session state for later use
270
+ # st.session_state['search_results'] = search_results
271
+ add_to_search_history(query=query, search_results=search_results)
272
+ expanded_response = expand_content(search_results, cache, content_key='doc_id', create_new_list=True)
273
+
274
+ # validate token count is below threshold
275
+ token_threshold = 8000
276
+ valid_response = validate_token_threshold(
277
+ ranked_results=expanded_response,
278
+ base_prompt=question_answering_prompt_series_spa,
279
+ query=query,
280
+ tokenizer=encoding,
281
+ token_threshold=token_threshold,
282
+ verbose=True
283
+ )
284
+
285
+ # generate LLM prompt
286
+ prompt = generate_prompt_series(query=query, results=valid_response)
287
+
288
+ # If the strings in 'prompt' are double-escaped, decode them before dumping to JSON
289
+ # prompt_decoded = prompt.encode().decode('unicode_escape')
290
+
291
+ # Then, when you dump to JSON, it should no longer double-escape the characters
292
+ return json.dumps({
293
+ "query": query,
294
+ "Search Results": prompt,
295
+ }, ensure_ascii=False)
296
+
297
+ except Exception as e:
298
+ # Handle any exceptions and return a JSON formatted error message
299
+ return json.dumps({
300
+ "error": "An error occurred during the search",
301
+ "details": str(e)
302
+ })
303
+
304
+ # When a new message is added, include the type and content
305
+ def add_to_search_history(query, search_results):
306
+ st.session_state["data_search_history"].append({
307
+ "query": query,
308
+ "search_results": search_results,
309
+ })
310
+
311
+ # Function to display search results
312
+ def display_search_results():
313
+ # Loop through each item in the search history
314
+ for search in st.session_state['data_search_history']:
315
+ query = search["query"]
316
+ search_results = search["search_results"]
317
+ # Create an expander for each search query
318
+ with st.expander(f"Pregunta: {query}", expanded=False):
319
+ for i, hit in enumerate(search_results):
320
+ # col1, col2 = st.columns([7, 3], gap='large')
321
+ page_url = hit['page_url']
322
+ page_label = hit['page_label']
323
+ document_title = hit['document_title']
324
+ # Assuming 'page_summary' is available and you want to display it
325
+ page_summary = hit.get('page_summary', 'Summary not available')
326
+
327
+ # with col1:
328
+ st.markdown(f'''
329
+ <span style="color: #3498db; font-size: 19px; font-weight: bold;">{document_title}</span><br>
330
+ {page_summary}
331
+ [**Página:** {page_label}]({page_url})
332
+ ''', unsafe_allow_html=True)
333
+
334
+ # with st.expander("πŸ“„ Clic aquΓ­ para ver contexto:"):
335
+ # try:
336
+ # content = hit['content']
337
+ # st.write(content)
338
+ # except Exception as e:
339
+ # st.write(f"Error displaying content: {e}")
340
+
341
+ # with col2:
342
+ # # If you have an image or want to display a placeholder image
343
+ # image = "URL_TO_A_PLACEHOLDER_IMAGE" # Replace with a relevant image URL if applicable
344
+ # st.image(image, caption=document_title, width=200, use_column_width=False)
345
+ # st.markdown(f'''
346
+ # <p style="text-align: right;">
347
+ # <b>Document Title:</b> {document_title}<br>
348
+ # <b>File Name:</b> {file_name}<br>
349
+ # </p>''', unsafe_allow_html=True)
350
+
351
+ ########################
352
+ ## SETUP MAIN DISPLAY ##
353
+ ########################
354
+
355
+ st.image('./static/images/cervezas-mahou.jpeg', width=400)
356
+ st.subheader(f"βœ¨πŸ—£οΈπŸ“˜ **BΓΊsqueda Conversacional** πŸ’‘πŸ—£οΈβœ¨ - Impuestos Especiales")
357
+ st.write('\n')
358
+ col1, col2 = st.columns([50,50])
359
+
360
+ # Initialize chat history
361
+ if "data_chat_history" not in st.session_state:
362
+ st.session_state["data_chat_history"] = []
363
+
364
+ if "data_search_history" not in st.session_state:
365
+ st.session_state["data_search_history"] = []
366
+
367
+ with col1:
368
+ st.write("Chat History:")
369
+ # Create a container for chat history
370
+ chat_history_container = st.container(height=500, border=True)
371
+ # Display chat messages from history on app rerun
372
+ with chat_history_container:
373
+ for message in st.session_state["data_chat_history"]:
374
+ with st.chat_message(message["role"]):
375
+ st.markdown(message["content"])
376
+ # Function to update chat display
377
+ def update_chat_display():
378
+ with chat_history_container:
379
+ for message in st.session_state["data_chat_history"]:
380
+ with st.chat_message(message["role"]):
381
+ st.markdown(message["content"])
382
+
383
+ if prompt := st.chat_input("What is up?"):
384
+ # Add user message to chat history
385
+ st.session_state["data_chat_history"].append({"role": "user", "content": prompt})
386
+ # Initially display the chat history
387
+ update_chat_display()
388
+ # # Display user message in chat message container
389
+ # with st.chat_message("user"):
390
+ # st.markdown(prompt)
391
+
392
+ with st.spinner('Generando Respuesta...'):
393
+ tools = [
394
+ {
395
+ "type": "function",
396
+ "function": {
397
+ "name": "database_search",
398
+ "description": "Takes the users query about the database and returns the results, extracting info to answer the user's question",
399
+ "parameters": {
400
+ "type": "object",
401
+ "properties": {
402
+ "query": {"type": "string", "description": "query"},
403
+
404
+ },
405
+ "required": ["query"],
406
+ },
407
+ }
408
+ }
409
+ ]
410
+
411
+ # Display live assistant response in chat message container
412
+ with st.chat_message(
413
+ name="assistant",
414
+ avatar="./static/images/openai_purple_logo_hres.jpeg"):
415
+ message_placeholder = st.empty()
416
+
417
+ # Building the messages payload with proper OPENAI API structure
418
+ messages=[
419
+ {"role": "system", "content": st.session_state["system_prompt_data_model"]}
420
+ ] + [
421
+ {"role": m["role"], "content": m["content"]} for m in st.session_state["data_chat_history"]
422
+ ]
423
+ logger.debug(f"Initial Messages: {messages}")
424
+ # call the OpenAI API to get the response
425
+
426
+ RESPONSE = client.chat.completions.create(
427
+ model=st.session_state["openai_data_model"],
428
+ temperature=0.5,
429
+ messages=messages,
430
+ tools=tools,
431
+ tool_choice="auto", # auto is default, but we'll be explicit
432
+ stream=True
433
+ )
434
+ logger.debug(f"First Response: {RESPONSE}")
435
+
436
+
437
+ FULL_RESPONSE = ""
438
+ tool_calls = []
439
+ # build up the response structs from the streamed response, simultaneously sending message chunks to the browser
440
+ for chunk in RESPONSE:
441
+ delta = chunk.choices[0].delta
442
+ # logger.debug(f"chunk: {delta}")
443
+
444
+
445
+
446
+ if delta and delta.content:
447
+ text_chunk = delta.content
448
+ FULL_RESPONSE += str(text_chunk)
449
+ message_placeholder.markdown(FULL_RESPONSE + "β–Œ")
450
+
451
+ elif delta and delta.tool_calls:
452
+ tcchunklist = delta.tool_calls
453
+ for tcchunk in tcchunklist:
454
+ if len(tool_calls) <= tcchunk.index:
455
+ tool_calls.append({"id": "", "type": "function", "function": { "name": "", "arguments": "" } })
456
+ tc = tool_calls[tcchunk.index]
457
+
458
+ if tcchunk.id:
459
+ tc["id"] += tcchunk.id
460
+ if tcchunk.function.name:
461
+ tc["function"]["name"] += tcchunk.function.name
462
+ if tcchunk.function.arguments:
463
+ tc["function"]["arguments"] += tcchunk.function.arguments
464
+ if tool_calls:
465
+ logger.debug(f"tool_calls: {tool_calls}")
466
+ # Define a dictionary mapping function names to actual functions
467
+ available_functions = {
468
+ "database_search": database_search,
469
+ # Add other functions as necessary
470
+ }
471
+ available_functions = {
472
+ "database_search": database_search,
473
+ } # only one function in this example, but you can have multiple
474
+ logger.debug(f"FuncCall Before messages: {messages}")
475
+ # Process each tool call
476
+ for tool_call in tool_calls:
477
+ # Get the function name and arguments from the tool call
478
+ function_name = tool_call['function']['name']
479
+ function_args = json.loads(tool_call['function']['arguments'])
480
+
481
+ # Get the actual function to call
482
+ function_to_call = available_functions[function_name]
483
+
484
+ # Call the function and get the response
485
+ function_response = function_to_call(**function_args)
486
+
487
+ # Append the function response to the messages list
488
+ messages.append({
489
+ "role": "assistant",
490
+ "tool_call_id": tool_call['id'],
491
+ "name": function_name,
492
+ "content": function_response,
493
+ })
494
+ logger.debug(f"FuncCall After messages: {messages}")
495
+
496
+ RESPONSE = client.chat.completions.create(
497
+ model=st.session_state["openai_data_model"],
498
+ temperature=0.1,
499
+ messages=messages,
500
+ stream=True
501
+ )
502
+ logger.debug(f"Second Response: {RESPONSE}")
503
+
504
+ # build up the response structs from the streamed response, simultaneously sending message chunks to the browser
505
+ for chunk in RESPONSE:
506
+ delta = chunk.choices[0].delta
507
+ # logger.debug(f"chunk: {delta}")
508
+
509
+ if delta and delta.content:
510
+ text_chunk = delta.content
511
+ FULL_RESPONSE += str(text_chunk)
512
+ message_placeholder.markdown(FULL_RESPONSE + "β–Œ")
513
+ # Add assistant response to chat history
514
+ st.session_state["data_chat_history"].append({"role": "assistant", "content": FULL_RESPONSE})
515
+ logger.debug(f"chat_history: {st.session_state['data_chat_history']}")
516
+
517
+ # Next block of code...
518
+
519
+
520
+ ####################
521
+ ## Search Results ##
522
+ ####################
523
+ # st.subheader(subheader_msg)
524
+ with col2:
525
+ st.write("Search Results:")
526
+ with st.container(height=500, border=True):
527
+ # Check if 'data_search_history' is in the session state and not empty
528
+ if 'data_search_history' in st.session_state and st.session_state['data_search_history']:
529
+ display_search_results()
530
+ # # Extract the latest message from the search history
531
+ # latest_search = st.session_state['data_search_history'][-1]
532
+ # query = latest_search["query"]
533
+ # with st.expander(query, expanded=False):
534
+ # # Extract the latest message from the search history
535
+ # latest_search = st.session_state['data_search_history'][-1]
536
+ # query = latest_search["query"]
537
+ # for i, hit in enumerate(latest_search["search_results"]):
538
+ # col1, col2 = st.columns([7, 3], gap='large')
539
+ # episode_url = hit['episode_url']
540
+ # title = hit['title']
541
+ # guest=hit['guest']
542
+ # show_length = hit['length']
543
+ # time_string = convert_seconds(show_length)
544
+ # # content = ranked_response[i]['content'] # Get 'content' from the same index in ranked_response
545
+ # content = hit['content']
546
+
547
+ # with col1:
548
+ # st.write( search_result(i=i,
549
+ # url=episode_url,
550
+ # guest=guest,
551
+ # title=title,
552
+ # content=content,
553
+ # length=time_string),
554
+ # unsafe_allow_html=True)
555
+ # st.write('\n\n')
556
+
557
+ # # with st.container("Episode Summary:"):
558
+ # # try:
559
+ # # ep_summary = hit['summary']
560
+ # # st.write(ep_summary)
561
+ # # except Exception as e:
562
+ # # st.error(f"Error displaying summary: {e}")
563
+
564
+ # with col2:
565
+ # image = hit['thumbnail_url']
566
+ # st.image(image, caption=title.split('|')[0], width=200, use_column_width=False)
567
+ # st.markdown(f'''
568
+ # <p style="text-align: right;">
569
+ # <b>Episode:</b> {title.split('|')[0]}<br>
570
+ # <b>Guest:</b> {hit['guest']}<br>
571
+ # <b>Length:</b> {time_string}
572
+ # </p>''', unsafe_allow_html=True)
573
+
574
+
575
+ if __name__ == '__main__':
576
+ main()
pages/__init__.py ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ loguru==0.7.0
2
+ numpy==1.24.4
3
+ openai==1.10.0
4
+ pandas==2.0.3
5
+ protobuf==4.23.4
6
+ pyarrow==12.0.1
7
+ python-dotenv==1.0.0
8
+ rank-bm25==0.2.2
9
+ requests==2.31.0
10
+ requests-oauthlib==1.3.1
11
+ rich==13.7.0
12
+ sentence-transformers==2.2.2
13
+ streamlit==1.31.1
14
+ tiktoken==0.5.1
15
+ tokenizers==0.13.3
16
+ torch==2.0.1
17
+ transformers==4.33.1
18
+ weaviate-client==3.25.3
static/.DS_Store ADDED
Binary file (6.15 kB). View file
 
static/images/cervezas-mahou.jpeg ADDED
static/images/fabrica-mahou-1200x675.jpeg ADDED
static/images/openai_logo.png ADDED
static/images/openai_logo_circle.png ADDED
static/images/openai_purple_logo_hres.jpeg ADDED
static/images/screen_recording_busqueda_final_2.gif ADDED

Git LFS Details

  • SHA256: aa4222b0ddf313d66a88eb14c8589d773f71fa1fd533321e223d7133293486df
  • Pointer size: 133 Bytes
  • Size of remote file: 43.8 MB
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/__init__.py ADDED
File without changes
utils/app_features_spa.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ from utils.preprocessing import FileIO
4
+ from typing import List, Optional
5
+ import tiktoken
6
+ from loguru import logger
7
+ from templates.prompt_templates_spa import context_block_spa, question_answering_prompt_series_spa
8
+ import streamlit as st
9
+
10
+ @st.cache_data
11
+ def load_content_cache(data_path: str):
12
+ data = FileIO().load_parquet(data_path)
13
+ content_data = {d['doc_id']: d['content'] for d in data}
14
+ return content_data
15
+
16
+ @st.cache_data
17
+ def load_data(data_path: str):
18
+ with open(data_path, 'r') as f:
19
+ data = json.load(f)
20
+ return data
21
+
22
+ def convert_seconds(seconds: int):
23
+ """
24
+ Converts seconds to a string of format Hours:Minutes:Seconds
25
+ """
26
+ return time.strftime("%H:%M:%S", time.gmtime(seconds))
27
+
28
+ def generate_prompt_series(query: str, results: List[dict]) -> str:
29
+ """
30
+ Generates a prompt for the OpenAI API by joining the context blocks of the top results.
31
+ Provides context to the LLM by supplying the summary, document name, and retrieved content of each result.
32
+
33
+ Args:
34
+ -----
35
+ query : str
36
+ User query
37
+ results : List[dict]
38
+ List of results from the Weaviate client
39
+ """
40
+ context_series = '\n'.join([context_block_spa.format(summary=res['page_summary'],
41
+ document=res['document_title'],
42
+ transcript=res['content']
43
+ )for res in results]).strip()
44
+ prompt = question_answering_prompt_series_spa.format(question=query, series=context_series)
45
+ return prompt
46
+
47
+ def expand_content(ranked_results: List[dict],
48
+ content_cache: Optional[dict] = None,
49
+ content_key: str = 'doc_id',
50
+ create_new_list: bool = False
51
+ ) -> List[dict]:
52
+ '''
53
+ Updates or creates a list of ranked results with content from a cache.
54
+
55
+ This function iterates over a list of dictionaries representing ranked results.
56
+ If a cache is provided, it adds or updates the 'content' key in each dictionary
57
+ with the corresponding content from the cache based on the content_key.
58
+
59
+ Args:
60
+ - ranked_results (List[dict]): A list of dictionaries, each representing a ranked result.
61
+ - content_cache (Optional[dict]): A dictionary that maps content_key to content.
62
+ If None, the content of ranked results will not be updated.
63
+ - content_key (str): The key used in both the ranked results and content cache to match
64
+ the ranked results with their corresponding content in the cache.
65
+ - create_new_list (bool): If True, a new list of dictionaries will be created and
66
+ returned with the content updated. If False, the ranked_results will be updated in place.
67
+
68
+ Returns:
69
+ - List[dict]: A new list with updated content if create_new_list is True; otherwise,
70
+ the original ranked_results list with updated content.
71
+
72
+ Note:
73
+ - If create_new_list is False, the function will mutate the original ranked_results list.
74
+ - The function only updates content if the content_key exists in both the ranked result
75
+ and the content cache.
76
+
77
+ Example:
78
+ ```
79
+ ranked_results = [{'doc_id': '123', 'title': 'Title 1'}, {'doc_id': '456', 'title': 'Title 2'}]
80
+ content_cache = {'123': 'Content for 123', '456': 'Content for 456'}
81
+ updated_results = expand_content(ranked_results, content_cache, create_new_list=True)
82
+ # updated_results is now [{'doc_id': '123', 'title': 'Title 1', 'content': 'Content for 123'},
83
+ # {'doc_id': '456', 'title': 'Title 2', 'content': 'Content for 456'}]
84
+ ```
85
+ '''
86
+ if create_new_list:
87
+ expanded_response = [{k:v for k, v in resp.items()} for resp in ranked_results]
88
+ if content_cache is not None:
89
+ for resp in expanded_response:
90
+ if resp[content_key] in content_cache:
91
+ resp['content'] = content_cache[resp[content_key]]
92
+ return expanded_response
93
+ else:
94
+ for resp in ranked_results:
95
+ if content_cache and resp[content_key] in content_cache:
96
+ resp['content'] = content_cache[resp[content_key]]
97
+ return ranked_results
98
+
99
+ def validate_token_threshold(ranked_results: List[dict],
100
+ base_prompt: str,
101
+ query: str,
102
+ tokenizer: tiktoken.Encoding,
103
+ token_threshold: int,
104
+ verbose: bool = False
105
+ ) -> List[dict]:
106
+ """
107
+ Validates that prompt is below the set token threshold by adding lengths of:
108
+ 1. Base prompt
109
+ 2. User query
110
+ 3. Context material
111
+ If threshold is exceeded, context results are reduced incrementally until the
112
+ combined prompt tokens are below the threshold. This function does not take into
113
+ account every token passed to the LLM, but it is a good approximation.
114
+ """
115
+ overhead_len = len(tokenizer.encode(base_prompt.format(question=query, series='')))
116
+ context_len = _get_batch_length(ranked_results, tokenizer)
117
+
118
+ token_count = overhead_len + context_len
119
+ if token_count > token_threshold:
120
+ print('Token count exceeds token count threshold, reducing size of returned results below token threshold')
121
+
122
+ while token_count > token_threshold and len(ranked_results) > 1:
123
+ num_results = len(ranked_results)
124
+
125
+ # remove the last ranked (most irrelevant) result
126
+ ranked_results = ranked_results[:num_results-1]
127
+ # recalculate new token_count
128
+ token_count = overhead_len + _get_batch_length(ranked_results, tokenizer)
129
+
130
+ if verbose:
131
+ logger.info(f'Total Final Token Count: {token_count}')
132
+ return ranked_results
133
+
134
+ def _get_batch_length(ranked_results: List[dict], tokenizer: tiktoken.Encoding) -> int:
135
+ '''
136
+ Convenience function to get the length in tokens of a batch of results
137
+ '''
138
+ contexts = tokenizer.encode_batch([r['content'] for r in ranked_results])
139
+ context_len = sum(list(map(len, contexts)))
140
+ return context_len
141
+
142
+ def search_result(i: int,
143
+ url: str,
144
+ title: str,
145
+ content: str,
146
+ guest: str,
147
+ length: str,
148
+ space: str='&nbsp; &nbsp;'
149
+ ) -> str:
150
+
151
+ '''
152
+ HTML to display search results.
153
+
154
+ Args:
155
+ -----
156
+ i: int
157
+ index of search result
158
+ url: str
159
+ url of YouTube video
160
+ title: str
161
+ title of episode
162
+ content: str
163
+ content chunk of episode
164
+ '''
165
+ return f"""
166
+ <div style="font-size:120%;">
167
+ {i + 1}.<a href="{url}">{title}</a>
168
+ </div>
169
+
170
+ <div style="font-size:95%;">
171
+ <p>Episode Length: {length} {space}{space} Guest: {guest}</p>
172
+ <div style="color:grey;float:left;">
173
+ ...
174
+ </div>
175
+ {content}
176
+ </div>
177
+ """
utils/openai_interface_spa.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from openai import OpenAI
3
+ from typing import List, Any, Tuple
4
+ from dotenv import load_dotenv
5
+ from tqdm import tqdm
6
+ from concurrent.futures import ThreadPoolExecutor, as_completed
7
+ _ = load_dotenv('./.env', override=True) # read local .env file
8
+
9
+
10
+ class GPT_Turbo:
11
+
12
+ def __init__(self, model: str="gpt-3.5-turbo-0613", api_key: str=os.environ['OPENAI_API_KEY']):
13
+ self.model = model
14
+ self.client = OpenAI(api_key=api_key)
15
+
16
+ def get_chat_completion(self,
17
+ prompt: str,
18
+ system_message: str='You are a helpful assistant.',
19
+ temperature: int=0,
20
+ max_tokens: int=500,
21
+ stream: bool=False,
22
+ show_response: bool=False
23
+ ) -> str:
24
+ messages = [
25
+ {'role': 'system', 'content': system_message},
26
+ {'role': 'assistant', 'content': prompt}
27
+ ]
28
+
29
+ response = self.client.chat.completions.create( model=self.model,
30
+ messages=messages,
31
+ temperature=temperature,
32
+ max_tokens=max_tokens,
33
+ stream=stream)
34
+ if show_response:
35
+ return response
36
+ return response.choices[0].message.content
37
+
38
+ def multi_thread_request(self,
39
+ filepath: str,
40
+ prompt: str,
41
+ content: List[str],
42
+ temperature: int=0
43
+ ) -> List[Any]:
44
+
45
+ data = []
46
+ with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
47
+ futures = [exec.submit(self.get_completion_from_messages, [{'role': 'user','content': f'{prompt} ```{c}```'}], temperature, 500, False) for c in content]
48
+ with open(filepath, 'a') as f:
49
+ for future in as_completed(futures):
50
+ result = future.result()
51
+ if len(data) % 10 == 0:
52
+ print(f'{len(data)} of {len(content)} completed.')
53
+ if result:
54
+ data.append(result)
55
+ self.write_to_file(file_handle=f, data=result)
56
+ return [res for res in data if res]
57
+
58
+ def generate_question_context_pairs(self,
59
+ context_tuple: Tuple[str, str],
60
+ num_questions_per_chunk: int=2,
61
+ max_words_per_question: int=10
62
+ ) -> List[str]:
63
+
64
+ doc_id, context = context_tuple
65
+ prompt = f'Context information is included below enclosed in triple backticks. Given the context information and not prior knowledge, generate questions based on the below query.\n\nYou are an end user querying for information about your favorite podcast. \
66
+ Your task is to setup {num_questions_per_chunk} questions that can be answered using only the given context. The questions should be diverse in nature across the document and be no longer than {max_words_per_question} words. \
67
+ Restrict the questions to the context information provided.\n\
68
+ ```{context}```\n\n'
69
+
70
+ response = self.get_completion_from_messages(prompt=prompt, temperature=0, max_tokens=500, show_response=True)
71
+ questions = response.choices[0].message["content"]
72
+ return (doc_id, questions)
73
+
74
+ def batch_generate_question_context_pairs(self,
75
+ context_tuple_list: List[Tuple[str, str]],
76
+ num_questions_per_chunk: int=2,
77
+ max_words_per_question: int=10
78
+ ) -> List[Tuple[str, str]]:
79
+ data = []
80
+ progress = tqdm(unit="Generated Questions", total=len(context_tuple_list))
81
+ with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
82
+ futures = [exec.submit(self.generate_question_context_pairs, context_tuple, num_questions_per_chunk, max_words_per_question) for context_tuple in context_tuple_list]
83
+ for future in as_completed(futures):
84
+ result = future.result()
85
+ if result:
86
+ data.append(result)
87
+ progress.update(1)
88
+ return data
89
+
90
+ def get_embedding(self):
91
+ pass
92
+
93
+ def write_to_file(self, file_handle, data: str) -> None:
94
+ file_handle.write(data)
95
+ file_handle.write('\n')
utils/preprocessing.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ from typing import List, Union, Dict
5
+ from loguru import logger
6
+ import pandas as pd
7
+ import pathlib
8
+
9
+
10
+ ## Set of helper functions that support data preprocessing
11
+ class FileIO:
12
+ '''
13
+ Convenience class for saving and loading data in parquet and
14
+ json formats to/from disk.
15
+ '''
16
+
17
+ def save_as_parquet(self,
18
+ file_path: str,
19
+ data: Union[List[dict], pd.DataFrame],
20
+ overwrite: bool=False) -> None:
21
+ '''
22
+ Saves DataFrame to disk as a parquet file. Removes the index.
23
+
24
+ Args:
25
+ -----
26
+ file_path : str
27
+ Output path to save file, if not included "parquet" will be appended
28
+ as file extension.
29
+ data : Union[List[dict], pd.DataFrame]
30
+ Data to save as parquet file. If data is a list of dicts, it will be
31
+ converted to a DataFrame before saving.
32
+ overwrite : bool
33
+ Overwrite existing file if True, otherwise raise FileExistsError.
34
+ '''
35
+ if isinstance(data, list):
36
+ data = self._convert_toDataFrame(data)
37
+ if not file_path.endswith('parquet'):
38
+ file_path = self._rename_file_extension(file_path, 'parquet')
39
+ self._check_file_path(file_path, overwrite=overwrite)
40
+ data.to_parquet(file_path, index=False)
41
+ logger.info(f'DataFrame saved as parquet file here: {file_path}')
42
+
43
+ def _convert_toDataFrame(self, data: List[dict]) -> pd.DataFrame:
44
+ return pd.DataFrame().from_dict(data)
45
+
46
+ def _rename_file_extension(self, file_path: str, extension: str):
47
+ '''
48
+ Renames file with appropriate extension if file_path
49
+ does not already have correct extension.
50
+ '''
51
+ prefix = os.path.splitext(file_path)[0]
52
+ file_path = prefix + '.' + extension
53
+ return file_path
54
+
55
+ def _check_file_path(self, file_path: str, overwrite: bool) -> None:
56
+ '''
57
+ Checks for existence of file and overwrite permissions.
58
+ '''
59
+ if os.path.exists(file_path) and overwrite == False:
60
+ raise FileExistsError(f'File by name {file_path} already exists, try using another file name or set overwrite to True.')
61
+ elif os.path.exists(file_path):
62
+ os.remove(file_path)
63
+ else:
64
+ file_name = os.path.basename(file_path)
65
+ dir_structure = file_path.replace(file_name, '')
66
+ pathlib.Path(dir_structure).mkdir(parents=True, exist_ok=True)
67
+
68
+ def load_parquet(self, file_path: str, verbose: bool=True) -> List[dict]:
69
+ '''
70
+ Loads parquet from disk, converts to pandas DataFrame as intermediate
71
+ step and outputs a list of dicts (docs).
72
+ '''
73
+ df = pd.read_parquet(file_path)
74
+ vector_labels = ['content_vector', 'image_vector', 'content_embedding']
75
+ for label in vector_labels:
76
+ if label in df.columns:
77
+ df[label] = df[label].apply(lambda x: x.tolist())
78
+ if verbose:
79
+ memory_usage = round(df.memory_usage().sum()/(1024*1024),2)
80
+ print(f'Shape of data: {df.values.shape}')
81
+ print(f'Memory Usage: {memory_usage}+ MB')
82
+ list_of_dicts = df.to_dict('records')
83
+ return list_of_dicts
84
+
85
+ def load_json(self, file_path: str):
86
+ '''
87
+ Loads json file from disk.
88
+ '''
89
+ with open(file_path) as f:
90
+ data = json.load(f)
91
+ return data
92
+
93
+ def save_as_json(self,
94
+ file_path: str,
95
+ data: Union[List[dict], dict],
96
+ indent: int=4,
97
+ overwrite: bool=False
98
+ ) -> None:
99
+ '''
100
+ Saves data to disk as a json file. Data can be a list of dicts or a single dict.
101
+ '''
102
+ if not file_path.endswith('json'):
103
+ file_path = self._rename_file_extension(file_path, 'json')
104
+ self._check_file_path(file_path, overwrite=overwrite)
105
+ with open(file_path, 'w') as f:
106
+ json.dump(data, f, indent=indent)
107
+ logger.info(f'Data saved as json file here: {file_path}')
108
+
109
+ class Utilities:
110
+
111
+ def create_video_url(self, video_id: str, playlist_id: str):
112
+ '''
113
+ Creates a hyperlink to a video episode given a video_id and playlist_id.
114
+
115
+ Args:
116
+ -----
117
+ video_id : str
118
+ Video id of the episode from YouTube
119
+ playlist_id : str
120
+ Playlist id of the episode from YouTube
121
+ '''
122
+ return f'https://www.youtube.com/watch?v={video_id}&list={playlist_id}'
123
+
utils/prompt_templates_spa.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ question_answering_prompt_series_spa = '''
2
+ Su tarea es sintetizar y razonar sobre una serie de contenidos proporcionados.
3
+ DespuΓ©s de su sΓ­ntesis, utilice estos contenidos para responder a la pregunta a continuaciΓ³n. La serie estarΓ‘ en el siguiente formato:\n
4
+ ```
5
+ RESUMEN: <summary>
6
+ DOCUMENTO: <document>
7
+ CONTENIDO: <transcript>
8
+ ```\n\n
9
+ Inicio de la Serie:
10
+ ```
11
+ {series}
12
+ ```
13
+ Pregunta:\n
14
+ {question}\n
15
+ Responda a la pregunta y proporcione razonamientos si es necesario para explicar la respuesta.
16
+ Si el contexto no proporciona suficiente informaciΓ³n para responder a la pregunta, entonces
17
+ indique que no puede responder a la pregunta con el contexto proporcionado.
18
+
19
+ Respuesta:
20
+ '''
21
+
22
+ context_block_spa = '''
23
+ RESUMEN: {summary}
24
+ DOCUMENTO: {document}
25
+ CONTENIDO: {transcript}
26
+ '''
utils/reranker_spa.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+ from torch.nn import Sigmoid
3
+ from typing import List, Union
4
+ import numpy as np
5
+ from loguru import logger
6
+
7
+ class ReRanker(CrossEncoder):
8
+ '''
9
+ Cross-Encoder models achieve higher performance than Bi-Encoders,
10
+ however, they do not scale well to large datasets. The lack of scalability
11
+ is due to the underlying cross-attention mechanism, which is computationally
12
+ expensive. Thus a Bi-Encoder is best used for 1st-stage document retrieval and
13
+ a Cross-Encoder is used to re-rank the retrieved documents.
14
+
15
+ https://www.sbert.net/examples/applications/cross-encoder/README.html
16
+ '''
17
+
18
+ def __init__(self,
19
+ model_name: str='cross-encoder/ms-marco-MiniLM-L-6-v2',
20
+ **kwargs
21
+ ):
22
+ super().__init__(model_name=model_name,
23
+ **kwargs)
24
+ self.model_name = model_name
25
+ self.score_field = 'cross_score'
26
+ self.activation_fct = Sigmoid()
27
+
28
+ def _cross_encoder_score(self,
29
+ results: List[dict],
30
+ query: str,
31
+ hit_field: str='content',
32
+ apply_sigmoid: bool=True,
33
+ return_scores: bool=False
34
+ ) -> Union[np.array, None]:
35
+ '''
36
+ Given a list of hits from a Retriever:
37
+ 1. Scores hits by passing query and results through CrossEncoder model.
38
+ 2. Adds cross-score key to results dictionary.
39
+ 3. If desired returns np.array of Cross Encoder scores.
40
+ '''
41
+ activation_fct = self.activation_fct if apply_sigmoid else None
42
+ #build query/content list
43
+ cross_inp = [[query, hit[hit_field]] for hit in results]
44
+ #get scores
45
+ cross_scores = self.predict(cross_inp, activation_fct=activation_fct)
46
+ for i, result in enumerate(results):
47
+ result[self.score_field]=cross_scores[i]
48
+
49
+ if return_scores:return cross_scores
50
+
51
+ def rerank(self,
52
+ results: List[dict],
53
+ query: str,
54
+ top_k: int=10,
55
+ apply_sigmoid: bool=True,
56
+ threshold: float=None
57
+ ) -> List[dict]:
58
+ '''
59
+ Given a list of hits from a Retriever:
60
+ 1. Scores hits by passing query and results through CrossEncoder model.
61
+ 2. Adds cross_score key to results dictionary.
62
+ 3. Returns reranked results limited by either a threshold value or top_k.
63
+
64
+ Args:
65
+ -----
66
+ results : List[dict]
67
+ List of results from the Weaviate client
68
+ query : str
69
+ User query
70
+ top_k : int=10
71
+ Number of results to return
72
+ apply_sigmoid : bool=True
73
+ Whether to apply sigmoid activation to cross-encoder scores. If False,
74
+ returns raw cross-encoder scores (logits).
75
+ threshold : float=None
76
+ Minimum cross-encoder score to return. If no hits are above threshold,
77
+ returns top_k hits.
78
+ '''
79
+ # Sort results by the cross-encoder scores
80
+ self._cross_encoder_score(results=results, query=query, apply_sigmoid=apply_sigmoid)
81
+
82
+ sorted_hits = sorted(results, key=lambda x: x[self.score_field], reverse=True)
83
+ if threshold or threshold == 0:
84
+ filtered_hits = [hit for hit in sorted_hits if hit[self.score_field] >= threshold]
85
+ if not any(filtered_hits):
86
+ logger.warning(f'No hits above threshold {threshold}. Returning top {top_k} hits.')
87
+ return sorted_hits[:top_k]
88
+ return filtered_hits
89
+ return sorted_hits[:top_k]
utils/retrieval_evaluation_spa.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #external files
2
+ from openai_interface_spa import GPT_Turbo
3
+ from weaviate_interface_v3_spa import WeaviateClient
4
+ from llama_index.finetuning import EmbeddingQAFinetuneDataset
5
+ from templates.prompt_templates_spa import qa_generation_prompt
6
+ from reranker_spa import ReRanker
7
+
8
+ #standard library imports
9
+ import json
10
+ import time
11
+ import uuid
12
+ import os
13
+ import re
14
+ import random
15
+ from datetime import datetime
16
+ from typing import List, Dict, Tuple, Union, Literal
17
+
18
+ #misc
19
+ from tqdm import tqdm
20
+
21
+
22
+ class QueryContextGenerator:
23
+ '''
24
+ Class designed for the generation of query/context pairs using a
25
+ Generative LLM. The LLM is used to generate questions from a given
26
+ corpus of text. The query/context pairs can be used to fine-tune
27
+ an embedding model using a MultipleNegativesRankingLoss loss function
28
+ or can be used to create evaluation datasets for retrieval models.
29
+ '''
30
+ def __init__(self, openai_key: str, model_id: str='gpt-3.5-turbo-0613'):
31
+ self.llm = GPT_Turbo(model=model_id, api_key=openai_key)
32
+
33
+ def clean_validate_data(self,
34
+ data: List[dict],
35
+ valid_fields: List[str]=['content', 'summary', 'guest', 'doc_id'],
36
+ total_chars: int=950
37
+ ) -> List[dict]:
38
+ '''
39
+ Strip original data chunks so they only contain valid_fields.
40
+ Remove any chunks less than total_chars in size. Prevents LLM
41
+ from asking questions from sparse content.
42
+ '''
43
+ clean_docs = [{k:v for k,v in d.items() if k in valid_fields} for d in data]
44
+ valid_docs = [d for d in clean_docs if len(d['content']) > total_chars]
45
+ return valid_docs
46
+
47
+ def train_val_split(self,
48
+ data: List[dict],
49
+ n_train_questions: int,
50
+ n_val_questions: int,
51
+ n_questions_per_chunk: int=2,
52
+ total_chars: int=950):
53
+ '''
54
+ Splits corpus into training and validation sets. Training and
55
+ validation samples are randomly selected from the corpus. total_chars
56
+ parameter is set based on pre-analysis of average doc length in the
57
+ training corpus.
58
+ '''
59
+ clean_data = self.clean_validate_data(data, total_chars=total_chars)
60
+ random.shuffle(clean_data)
61
+ train_index = n_train_questions//n_questions_per_chunk
62
+ valid_index = n_val_questions//n_questions_per_chunk
63
+ end_index = valid_index + train_index
64
+ if end_index > len(clean_data):
65
+ raise ValueError('Cannot create dataset with desired number of questions, try using a larger dataset')
66
+ train_data = clean_data[:train_index]
67
+ valid_data = clean_data[train_index:end_index]
68
+ print(f'Length Training Data: {len(train_data)}')
69
+ print(f'Length Validation Data: {len(valid_data)}')
70
+ return train_data, valid_data
71
+
72
+ def generate_qa_embedding_pairs(
73
+ self,
74
+ data: List[dict],
75
+ generate_prompt_tmpl: str=None,
76
+ num_questions_per_chunk: int = 2,
77
+ ) -> EmbeddingQAFinetuneDataset:
78
+ """
79
+ Generate query/context pairs from a list of documents. The query/context pairs
80
+ can be used for fine-tuning an embedding model using a MultipleNegativesRankingLoss
81
+ or can be used to create an evaluation dataset for retrieval models.
82
+
83
+ This function was adapted for this course from the llama_index.finetuning.common module:
84
+ https://github.com/run-llama/llama_index/blob/main/llama_index/finetuning/embeddings/common.py
85
+ """
86
+ generate_prompt_tmpl = qa_generation_prompt if not generate_prompt_tmpl else generate_prompt_tmpl
87
+ queries = {}
88
+ relevant_docs = {}
89
+ corpus = {chunk['doc_id'] : chunk['content'] for chunk in data}
90
+ for chunk in tqdm(data):
91
+ page_summary = chunk['page_summary']
92
+ # guest = chunk['guest']
93
+ context_str = chunk['content']
94
+ node_id = chunk['doc_id']
95
+ query = generate_prompt_tmpl.format(page_summary=page_summary,
96
+ # guest=guest,
97
+ context_str=context_str,
98
+ num_questions_per_chunk=num_questions_per_chunk)
99
+ try:
100
+ response = self.llm.get_chat_completion(prompt=query, temperature=0.1, max_tokens=100)
101
+ except Exception as e:
102
+ print(e)
103
+ continue
104
+ result = str(response).strip().split("\n")
105
+ questions = [
106
+ re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
107
+ ]
108
+ questions = [question for question in questions if len(question) > 0]
109
+
110
+ for question in questions:
111
+ question_id = str(uuid.uuid4())
112
+ queries[question_id] = question
113
+ relevant_docs[question_id] = [node_id]
114
+
115
+ # construct dataset
116
+ return EmbeddingQAFinetuneDataset(
117
+ queries=queries, corpus=corpus, relevant_docs=relevant_docs
118
+ )
119
+
120
+ def execute_evaluation(dataset: EmbeddingQAFinetuneDataset,
121
+ class_name: str,
122
+ retriever: WeaviateClient,
123
+ reranker: ReRanker=None,
124
+ alpha: float=0.5,
125
+ retrieve_limit: int=100,
126
+ top_k: int=5,
127
+ chunk_size: int=256,
128
+ hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
129
+ search_type: Literal['kw', 'vector', 'hybrid', 'all']='all',
130
+ display_properties: List[str]=['doc_id', 'content'],
131
+ dir_outpath: str='./eval_results',
132
+ include_miss_info: bool=False,
133
+ user_def_params: dict=None
134
+ ) -> Union[dict, Tuple[dict, List[dict]]]:
135
+ '''
136
+ Given a dataset, a retriever, and a reranker, evaluate the performance of the retriever and reranker.
137
+ Returns a dict of kw, vector, and hybrid hit rates and mrr scores. If include_miss_info is True, will
138
+ also return a list of kw and vector responses and their associated queries that did not return a hit.
139
+
140
+ Args:
141
+ -----
142
+ dataset: EmbeddingQAFinetuneDataset
143
+ Dataset to be used for evaluation
144
+ class_name: str
145
+ Name of Class on Weaviate host to be used for retrieval
146
+ retriever: WeaviateClient
147
+ WeaviateClient object to be used for retrieval
148
+ reranker: ReRanker
149
+ ReRanker model to be used for results reranking
150
+ alpha: float=0.5
151
+ Weighting factor for BM25 and Vector search.
152
+ alpha can be any number from 0 to 1, defaulting to 0.5:
153
+ alpha = 0 executes a pure keyword search method (BM25)
154
+ alpha = 0.5 weighs the BM25 and vector methods evenly
155
+ alpha = 1 executes a pure vector search method
156
+ retrieve_limit: int=5
157
+ Number of documents to retrieve from Weaviate host
158
+ top_k: int=5
159
+ Number of top results to evaluate
160
+ chunk_size: int=256
161
+ Number of tokens used to chunk text
162
+ hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef']
163
+ List of keys to be used for retrieving HNSW Index parameters from Weaviate host
164
+ search_type: Literal['kw', 'vector', 'hybrid', 'all']='all'
165
+ Type of search to be evaluated. Options are 'kw', 'vector', 'hybrid', or 'all'
166
+ display_properties: List[str]=['doc_id', 'content']
167
+ List of properties to be returned from Weaviate host for display in response
168
+ dir_outpath: str='./eval_results'
169
+ Directory path for saving results. Directory will be created if it does not
170
+ already exist.
171
+ include_miss_info: bool=False
172
+ Option to include queries and their associated search response values
173
+ for queries that are "total misses"
174
+ user_def_params : dict=None
175
+ Option for user to pass in a dictionary of user-defined parameters and their values.
176
+ Will be automatically added to the results_dict if correct type is passed.
177
+ '''
178
+
179
+ reranker_name = reranker.model_name if reranker else "None"
180
+
181
+ results_dict = {'n':retrieve_limit,
182
+ 'top_k': top_k,
183
+ 'alpha': alpha,
184
+ 'Retriever': retriever.model_name_or_path,
185
+ 'Ranker': reranker_name,
186
+ 'chunk_size': chunk_size,
187
+ 'kw_hit_rate': 0,
188
+ 'kw_mrr': 0,
189
+ 'vector_hit_rate': 0,
190
+ 'vector_mrr': 0,
191
+ 'hybrid_hit_rate':0,
192
+ 'hybrid_mrr': 0,
193
+ 'total_misses': 0,
194
+ 'total_questions':0
195
+ }
196
+ #add extra params to results_dict
197
+ results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
198
+
199
+ start = time.perf_counter()
200
+ miss_info = []
201
+ for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
202
+ results_dict['total_questions'] += 1
203
+ hit = False
204
+ #make Keyword, Vector, and Hybrid calls to Weaviate host
205
+ try:
206
+ kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
207
+ vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
208
+ hybrid_response = retriever.hybrid_search(request=q, class_name=class_name, alpha=alpha, limit=retrieve_limit, display_properties=display_properties)
209
+ #rerank returned responses if reranker is provided
210
+ if reranker:
211
+ kw_response = reranker.rerank(kw_response, q, top_k=top_k)
212
+ vector_response = reranker.rerank(vector_response, q, top_k=top_k)
213
+ hybrid_response = reranker.rerank(hybrid_response, q, top_k=top_k)
214
+
215
+ #collect doc_ids to check for document matches (include only results_top_k)
216
+ kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response[:top_k], 1)}
217
+ vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response[:top_k], 1)}
218
+ hybrid_doc_ids = {result['doc_id']:i for i, result in enumerate(hybrid_response[:top_k], 1)}
219
+
220
+ #extract doc_id for scoring purposes
221
+ doc_id = dataset.relevant_docs[query_id][0]
222
+
223
+ #increment hit_rate counters and mrr scores
224
+ if doc_id in kw_doc_ids:
225
+ results_dict['kw_hit_rate'] += 1
226
+ results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
227
+ hit = True
228
+ if doc_id in vector_doc_ids:
229
+ results_dict['vector_hit_rate'] += 1
230
+ results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
231
+ hit = True
232
+ if doc_id in hybrid_doc_ids:
233
+ results_dict['hybrid_hit_rate'] += 1
234
+ results_dict['hybrid_mrr'] += 1/hybrid_doc_ids[doc_id]
235
+ hit = True
236
+ # if no hits, let's capture that
237
+ if not hit:
238
+ results_dict['total_misses'] += 1
239
+ miss_info.append({'query': q,
240
+ 'answer': dataset.corpus[doc_id],
241
+ 'doc_id': doc_id,
242
+ 'kw_response': kw_response,
243
+ 'vector_response': vector_response,
244
+ 'hybrid_response': hybrid_response})
245
+ except Exception as e:
246
+ print(e)
247
+ continue
248
+
249
+ #use raw counts to calculate final scores
250
+ calc_hit_rate_scores(results_dict, search_type=search_type)
251
+ calc_mrr_scores(results_dict, search_type=search_type)
252
+
253
+ end = time.perf_counter() - start
254
+ print(f'Total Processing Time: {round(end/60, 2)} minutes')
255
+ record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
256
+
257
+ if include_miss_info:
258
+ return results_dict, miss_info
259
+ return results_dict
260
+
261
+ def calc_hit_rate_scores(results_dict: Dict[str, Union[str, int]],
262
+ search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
263
+ ) -> None:
264
+ if search_type == 'all':
265
+ search_type = ['kw', 'vector', 'hybrid']
266
+ for prefix in search_type:
267
+ results_dict[f'{prefix}_hit_rate'] = round(results_dict[f'{prefix}_hit_rate']/results_dict['total_questions'],2)
268
+
269
+ def calc_mrr_scores(results_dict: Dict[str, Union[str, int]],
270
+ search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
271
+ ) -> None:
272
+ if search_type == 'all':
273
+ search_type = ['kw', 'vector', 'hybrid']
274
+ for prefix in search_type:
275
+ results_dict[f'{prefix}_mrr'] = round(results_dict[f'{prefix}_mrr']/results_dict['total_questions'],2)
276
+
277
+ def create_dir(dir_path: str) -> None:
278
+ '''
279
+ Checks if directory exists, and creates new directory
280
+ if it does not exist
281
+ '''
282
+ if not os.path.exists(dir_path):
283
+ os.makedirs(dir_path)
284
+
285
+ def record_results(results_dict: Dict[str, Union[str, int]],
286
+ chunk_size: int,
287
+ dir_outpath: str='./eval_results',
288
+ as_text: bool=False
289
+ ) -> None:
290
+ '''
291
+ Write results to output file in either txt or json format
292
+
293
+ Args:
294
+ -----
295
+ results_dict: Dict[str, Union[str, int]]
296
+ Dictionary containing results of evaluation
297
+ chunk_size: int
298
+ Size of text chunks in tokens
299
+ dir_outpath: str
300
+ Path to output directory. Directory only, filename is hardcoded
301
+ as part of this function.
302
+ as_text: bool
303
+ If True, write results as text file. If False, write as json file.
304
+ '''
305
+ create_dir(dir_outpath)
306
+ time_marker = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
307
+ ext = 'txt' if as_text else 'json'
308
+ path = os.path.join(dir_outpath, f'retrieval_eval_{chunk_size}_{time_marker}.{ext}')
309
+ if as_text:
310
+ with open(path, 'a') as f:
311
+ f.write(f"{results_dict}\n")
312
+ else:
313
+ with open(path, 'w') as f:
314
+ json.dump(results_dict, f, indent=4)
315
+
316
+ def add_params(client: WeaviateClient,
317
+ class_name: str,
318
+ results_dict: dict,
319
+ param_options: dict,
320
+ hnsw_config_keys: List[str]
321
+ ) -> dict:
322
+ hnsw_params = {k:v for k,v in client.show_class_config(class_name)['vectorIndexConfig'].items() if k in hnsw_config_keys}
323
+ if hnsw_params:
324
+ results_dict = {**results_dict, **hnsw_params}
325
+ if param_options and isinstance(param_options, dict):
326
+ results_dict = {**results_dict, **param_options}
327
+ return results_dict
328
+
329
+
330
+
331
+
332
+
utils/system_prompts.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chatgpt = '''
2
+ You are a helpful assistant.
3
+ '''
4
+
5
+ professor_synapse = '''
6
+ Act as Professor SynapseπŸ§™πŸΎβ€β™‚οΈ, a conductor of expert agents. Your job is to support me in accomplishing my goals by finding alignment with me, then calling upon an expert agent perfectly suited to the task by initializing:
7
+
8
+ Synapse_CoR = "[emoji]: I am an expert in [role&domain]. I know [context]. I will reason step-by-step to determine the best course of action to achieve [goal]. I can use [tools] and [relevant frameworks] to help in this process.
9
+
10
+ I will help you accomplish your goal by following these steps:
11
+ [reasoned steps]
12
+
13
+ My task ends when [completion].
14
+
15
+ [first step, question]"
16
+
17
+ Instructions:
18
+
19
+ 1. πŸ§™πŸΎβ€β™‚οΈ gather context, relevant information and clarify my goals by asking questions
20
+ 2. Once confirmed, initialize Synapse_CoR
21
+ 3. πŸ§™πŸΎβ€β™‚οΈ and [emoji] support me until goal is complete
22
+
23
+ Commands:
24
+ /start=πŸ§™πŸΎβ€β™‚οΈ,introduce and begin with step one
25
+ /ts=πŸ§™πŸΎβ€β™‚οΈ,summon (Synapse_CoR*3) town square debate
26
+ /saveπŸ§™πŸΎβ€β™‚οΈ, restate goal, summarize progress, reason next step
27
+
28
+ Personality:
29
+ -curious, inquisitive, encouraging
30
+ -use emojis to express yourself
31
+
32
+ Rules:
33
+ -End every output with a question or reasoned next step
34
+ -Start every output with πŸ§™πŸΎβ€β™‚οΈ: or [emoji]: to indicate who is speaking.
35
+ -Organize every output β€œπŸ§™πŸΎβ€β™‚οΈ: [aligning on my goal], [emoji]: [actionable response]
36
+ -πŸ§™πŸΎβ€β™‚οΈ, recommend save after each task is completed
37
+ '''
38
+
39
+ marketing_jane = '''
40
+ Act as Marcus πŸ‘©πŸΌβ€πŸ’ΌMarketing jane, a strategist adept at melding analytics with creative zest. With mastery over data-driven marketing and an innate knack for storytelling, your mission is to carve out distinctive marketing strategies. From fledgling startups to seasoned giants.
41
+
42
+ Your strategy formulation entails:
43
+ - Understanding the business's narrative, competitive landscape, and audience psyche.
44
+ - Crafting a data-informed marketing roadmap, encompassing various channels, and innovative tactics.
45
+ - Leveraging storytelling to forge brand engagement and pioneering avant-garde campaigns.
46
+
47
+ Your endeavor culminates when the user possesses a dynamic, data-enriched marketing strategy, resonating with their business ethos.
48
+
49
+ Steps:
50
+ 1. πŸ‘©πŸΌβ€πŸ’Ό, Grasp the business's ethos, objectives, and challenges
51
+ 2. Design a data-backed marketing strategy, resonating with audience sentiments and business goals
52
+ 3. Engage in feedback loops, iteratively refining the strategy
53
+
54
+ Commands:
55
+ /explore - Modify the strategic focus or delve deeper into specific marketing nuances
56
+ /save - Chronicle progress, dissect strategy elements, and chart future endeavors
57
+ /critic - πŸ‘©πŸΌβ€πŸ’Ό seeks insights from fellow marketing aficionados
58
+ /reason - πŸ‘©πŸΌβ€πŸ’Ό and user collaboratively weave the marketing narrative
59
+ /new - Ignite a fresh strategic quest for a new venture or campaign
60
+
61
+ Rules:
62
+ - Culminate with an evocative campaign concept or the next strategic juncture
63
+ - Preface with πŸ‘©πŸΌβ€πŸ’Ό: for clarity
64
+ - Integrate data insights with creative innovation
65
+ '''
66
+
67
+ # Define a dictionary to map the emojis to the variables
68
+ prompt_mapping = {
69
+ "πŸ€–ChatGPT": chatgpt,
70
+ "πŸ§™πŸΎβ€β™‚οΈProfessor Synapse": professor_synapse,
71
+ "πŸ‘©πŸΌβ€πŸ’ΌMarketing Jane": marketing_jane,
72
+ }
utils/weaviate_interface_v3_spa.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from weaviate import Client, AuthApiKey
2
+ from dataclasses import dataclass
3
+ from openai import OpenAI
4
+ from sentence_transformers import SentenceTransformer
5
+ from typing import List, Union, Callable
6
+ from torch import cuda
7
+ from tqdm import tqdm
8
+ import time
9
+
10
+ class WeaviateClient(Client):
11
+ '''
12
+ A python native Weaviate Client class that encapsulates Weaviate functionalities
13
+ in one object. Several convenience methods are added for ease of use.
14
+
15
+ Args
16
+ ----
17
+ api_key: str
18
+ The API key for the Weaviate Cloud Service (WCS) instance.
19
+ https://console.weaviate.cloud/dashboard
20
+
21
+ endpoint: str
22
+ The url endpoint for the Weaviate Cloud Service instance.
23
+
24
+ model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2'
25
+ The name or path of the SentenceTransformer model to use for vector search.
26
+ Will also support OpenAI text-embedding-ada-002 model. This param enables
27
+ the use of most leading models on MTEB Leaderboard:
28
+ https://huggingface.co/spaces/mteb/leaderboard
29
+ openai_api_key: str=None
30
+ The API key for the OpenAI API. Only required if using OpenAI text-embedding-ada-002 model.
31
+ '''
32
+ def __init__(self,
33
+ endpoint: str,
34
+ api_key: str = None, # Make the api_key optional
35
+ model_name_or_path: str = 'sentence-transformers/all-MiniLM-L6-v2',
36
+ openai_api_key: str = None,
37
+ **kwargs
38
+ ):
39
+ if api_key: # Only use AuthApiKey if api_key is provided
40
+ auth_config = AuthApiKey(api_key=api_key)
41
+ super().__init__(auth_client_secret=auth_config, url=endpoint, **kwargs)
42
+ else:
43
+ super().__init__(url=endpoint, **kwargs)
44
+
45
+ self.model_name_or_path = model_name_or_path
46
+ self.openai_model = False
47
+ if self.model_name_or_path == 'text-embedding-ada-002':
48
+ if not openai_api_key:
49
+ raise ValueError(f'OpenAI API key must be provided to use this model: {self.model_name_or_path}')
50
+ self.model = OpenAI(api_key=openai_api_key)
51
+ self.openai_model = True
52
+ else:
53
+ self.model = SentenceTransformer(self.model_name_or_path) if self.model_name_or_path else None
54
+
55
+ self.display_properties = ['file_name', 'page_label', 'document_title', 'page_summary', 'page_url', 'doc_id', \
56
+ 'content']
57
+
58
+ def show_classes(self) -> Union[List[str], str]:
59
+ '''
60
+ Shows all available classes (indexes) on the Weaviate instance.
61
+ '''
62
+ schema = self.schema.get()
63
+ if 'classes' in schema:
64
+ return [cls['class'] for cls in schema['classes']]
65
+ else:
66
+ return "No classes found on cluster."
67
+
68
+ def show_class_info(self) -> Union[List[dict], str]:
69
+ '''
70
+ Shows all information related to the classes (indexes) on the Weaviate instance.
71
+ '''
72
+ schema = self.schema.get()
73
+ if 'classes' in schema:
74
+ return schema['classes']
75
+ else:
76
+ return "No classes found on cluster."
77
+
78
+ def show_class_properties(self, class_name: str) -> Union[dict, str]:
79
+ '''
80
+ Shows all properties of a class (index) on the Weaviate instance.
81
+ '''
82
+ classes = self.schema.get()
83
+ if classes:
84
+ all_classes = classes['classes']
85
+ for d in all_classes:
86
+ if d['class'] == class_name:
87
+ return d['properties']
88
+ return f'Class "{class_name}" not found on host'
89
+ return f'No Classes found on host'
90
+
91
+ def show_class_config(self, class_name: str) -> Union[dict, str]:
92
+ '''
93
+ Shows all configuration of a class (index) on the Weaviate instance.
94
+ '''
95
+ classes = self.schema.get()
96
+ if classes:
97
+ all_classes = classes['classes']
98
+ for d in all_classes:
99
+ if d['class'] == class_name:
100
+ return d
101
+ return f'Class "{class_name}" not found on host'
102
+ return f'No Classes found on host'
103
+
104
+ def delete_class(self, class_name: str) -> str:
105
+ '''
106
+ Deletes a class (index) on the Weaviate instance, if it exists.
107
+ '''
108
+ available = self._check_class_avialability(class_name)
109
+ if isinstance(available, bool):
110
+ if available:
111
+ self.schema.delete_class(class_name)
112
+ not_deleted = self._check_class_avialability(class_name)
113
+ if isinstance(not_deleted, bool):
114
+ if not_deleted:
115
+ return f'Class "{class_name}" was not deleted. Try again.'
116
+ else:
117
+ return f'Class "{class_name}" deleted'
118
+ return f'Class "{class_name}" deleted and there are no longer any classes on host'
119
+ return f'Class "{class_name}" not found on host'
120
+ return available
121
+
122
+ def _check_class_avialability(self, class_name: str) -> Union[bool, str]:
123
+ '''
124
+ Checks if a class (index) exists on the Weaviate instance.
125
+ '''
126
+ classes = self.schema.get()
127
+ if classes:
128
+ all_classes = classes['classes']
129
+ for d in all_classes:
130
+ if d['class'] == class_name:
131
+ return True
132
+ return False
133
+ else:
134
+ return f'No Classes found on host'
135
+
136
+ def format_response(self,
137
+ response: dict,
138
+ class_name: str
139
+ ) -> List[dict]:
140
+ '''
141
+ Formats json response from Weaviate into a list of dictionaries.
142
+ Expands _additional fields if present into top-level dictionary.
143
+ '''
144
+ if response.get('errors'):
145
+ return response['errors'][0]['message']
146
+ results = []
147
+ hits = response['data']['Get'][class_name]
148
+ for d in hits:
149
+ temp = {k:v for k,v in d.items() if k != '_additional'}
150
+ if d.get('_additional'):
151
+ for key in d['_additional']:
152
+ temp[key] = d['_additional'][key]
153
+ results.append(temp)
154
+ return results
155
+
156
+ def update_ef_value(self, class_name: str, ef_value: int) -> str:
157
+ '''
158
+ Updates ef_value for a class (index) on the Weaviate instance.
159
+ '''
160
+ self.schema.update_config(class_name=class_name, config={'vectorIndexConfig': {'ef': ef_value}})
161
+ print(f'ef_value updated to {ef_value} for class {class_name}')
162
+ return self.show_class_config(class_name)['vectorIndexConfig']
163
+
164
+ def keyword_search(self,
165
+ request: str,
166
+ class_name: str,
167
+ properties: List[str]=['content'],
168
+ limit: int=10,
169
+ where_filter: dict=None,
170
+ display_properties: List[str]=None,
171
+ return_raw: bool=False) -> Union[dict, List[dict]]:
172
+ '''
173
+ Executes Keyword (BM25) search.
174
+
175
+ Args
176
+ ----
177
+ query: str
178
+ User query.
179
+ class_name: str
180
+ Class (index) to search.
181
+ properties: List[str]
182
+ List of properties to search across.
183
+ limit: int=10
184
+ Number of results to return.
185
+ display_properties: List[str]=None
186
+ List of properties to return in response.
187
+ If None, returns all properties.
188
+ return_raw: bool=False
189
+ If True, returns raw response from Weaviate.
190
+ '''
191
+ display_properties = display_properties if display_properties else self.display_properties
192
+ response = (self.query
193
+ .get(class_name, display_properties)
194
+ .with_bm25(query=request, properties=properties)
195
+ .with_additional(['score', "id"])
196
+ .with_limit(limit)
197
+ )
198
+ response = response.with_where(where_filter).do() if where_filter else response.do()
199
+ if return_raw:
200
+ return response
201
+ else:
202
+ return self.format_response(response, class_name)
203
+
204
+ def vector_search(self,
205
+ request: str,
206
+ class_name: str,
207
+ limit: int=10,
208
+ where_filter: dict=None,
209
+ display_properties: List[str]=None,
210
+ return_raw: bool=False,
211
+ device: str='cuda:0' if cuda.is_available() else 'cpu'
212
+ ) -> Union[dict, List[dict]]:
213
+ '''
214
+ Executes vector search using embedding model defined on instantiation
215
+ of WeaviateClient instance.
216
+
217
+ Args
218
+ ----
219
+ query: str
220
+ User query.
221
+ class_name: str
222
+ Class (index) to search.
223
+ limit: int=10
224
+ Number of results to return.
225
+ display_properties: List[str]=None
226
+ List of properties to return in response.
227
+ If None, returns all properties.
228
+ return_raw: bool=False
229
+ If True, returns raw response from Weaviate.
230
+ '''
231
+ display_properties = display_properties if display_properties else self.display_properties
232
+ query_vector = self._create_query_vector(request, device=device)
233
+ response = (
234
+ self.query
235
+ .get(class_name, display_properties)
236
+ .with_near_vector({"vector": query_vector})
237
+ .with_limit(limit)
238
+ .with_additional(['distance'])
239
+ )
240
+ response = response.with_where(where_filter).do() if where_filter else response.do()
241
+ if return_raw:
242
+ return response
243
+ else:
244
+ return self.format_response(response, class_name)
245
+
246
+ def _create_query_vector(self, query: str, device: str) -> List[float]:
247
+ '''
248
+ Creates embedding vector from text query.
249
+ '''
250
+ return self.get_openai_embedding(query) if self.openai_model else self.model.encode(query, device=device).tolist()
251
+
252
+ def get_openai_embedding(self, query: str) -> List[float]:
253
+ '''
254
+ Gets embedding from OpenAI API for query.
255
+ '''
256
+ embedding = self.model.embeddings.create(input=query, model='text-embedding-ada-002').model_dump()
257
+ if embedding:
258
+ return embedding['data'][0]['embedding']
259
+ else:
260
+ raise ValueError(f'No embedding found for query: {query}')
261
+
262
+ def hybrid_search(self,
263
+ request: str,
264
+ class_name: str,
265
+ properties: List[str]=['content'],
266
+ alpha: float=0.5,
267
+ limit: int=10,
268
+ where_filter: dict=None,
269
+ display_properties: List[str]=None,
270
+ return_raw: bool=False,
271
+ device: str='cuda:0' if cuda.is_available() else 'cpu'
272
+ ) -> Union[dict, List[dict]]:
273
+ '''
274
+ Executes Hybrid (BM25 + Vector) search.
275
+
276
+ Args
277
+ ----
278
+ query: str
279
+ User query.
280
+ class_name: str
281
+ Class (index) to search.
282
+ properties: List[str]
283
+ List of properties to search across (using BM25)
284
+ alpha: float=0.5
285
+ Weighting factor for BM25 and Vector search.
286
+ alpha can be any number from 0 to 1, defaulting to 0.5:
287
+ alpha = 0 executes a pure keyword search method (BM25)
288
+ alpha = 0.5 weighs the BM25 and vector methods evenly
289
+ alpha = 1 executes a pure vector search method
290
+ limit: int=10
291
+ Number of results to return.
292
+ display_properties: List[str]=None
293
+ List of properties to return in response.
294
+ If None, returns all properties.
295
+ return_raw: bool=False
296
+ If True, returns raw response from Weaviate.
297
+ '''
298
+ display_properties = display_properties if display_properties else self.display_properties
299
+ query_vector = self._create_query_vector(request, device=device)
300
+ response = (
301
+ self.query
302
+ .get(class_name, display_properties)
303
+ .with_hybrid(query=request,
304
+ alpha=alpha,
305
+ vector=query_vector,
306
+ properties=properties,
307
+ fusion_type='relativeScoreFusion') #hard coded option for now
308
+ .with_additional(["score", "explainScore"])
309
+ .with_limit(limit)
310
+ )
311
+
312
+ response = response.with_where(where_filter).do() if where_filter else response.do()
313
+ if return_raw:
314
+ return response
315
+ else:
316
+ return self.format_response(response, class_name)
317
+
318
+
319
+ class WeaviateIndexer:
320
+
321
+ def __init__(self,
322
+ client: WeaviateClient,
323
+ batch_size: int=150,
324
+ num_workers: int=4,
325
+ dynamic: bool=True,
326
+ creation_time: int=5,
327
+ timeout_retries: int=3,
328
+ connection_error_retries: int=3,
329
+ callback: Callable=None,
330
+ ):
331
+ '''
332
+ Class designed to batch index documents into Weaviate. Instantiating
333
+ this class will automatically configure the Weaviate batch client.
334
+ '''
335
+ self._client = client
336
+ self._callback = callback if callback else self._default_callback
337
+
338
+ self._client.batch.configure(batch_size=batch_size,
339
+ num_workers=num_workers,
340
+ dynamic=dynamic,
341
+ creation_time=creation_time,
342
+ timeout_retries=timeout_retries,
343
+ connection_error_retries=connection_error_retries,
344
+ callback=self._callback
345
+ )
346
+
347
+ def _default_callback(self, results: dict):
348
+ """
349
+ Check batch results for errors.
350
+
351
+ Parameters
352
+ ----------
353
+ results : dict
354
+ The Weaviate batch creation return value.
355
+ """
356
+
357
+ if results is not None:
358
+ for result in results:
359
+ if "result" in result and "errors" in result["result"]:
360
+ if "error" in result["result"]["errors"]:
361
+ print(result["result"])
362
+
363
+ def batch_index_data(self,
364
+ data: List[dict],
365
+ class_name: str,
366
+ vector_property: str='content_embedding'
367
+ ) -> None:
368
+ '''
369
+ Batch function for fast indexing of data onto Weaviate cluster.
370
+ This method assumes that self._client.batch is already configured.
371
+ '''
372
+ start = time.perf_counter()
373
+ with self._client.batch as batch:
374
+ for d in tqdm(data):
375
+
376
+ #define single document
377
+ properties = {k:v for k,v in d.items() if k != vector_property}
378
+ try:
379
+ #add data object to batch
380
+ batch.add_data_object(
381
+ data_object=properties,
382
+ class_name=class_name,
383
+ vector=d[vector_property]
384
+ )
385
+ except Exception as e:
386
+ print(e)
387
+ continue
388
+
389
+ end = time.perf_counter() - start
390
+
391
+ print(f'Batch job completed in {round(end/60, 2)} minutes.')
392
+ # class_info = self._client.show_class_info()
393
+ # for i, c in enumerate(class_info):
394
+ # if c['class'] == class_name:
395
+ # print(class_info[i])
396
+ self._client.batch.shutdown()
397
+
398
+ @dataclass
399
+ class WhereFilter:
400
+
401
+ '''
402
+ Simplified interface for constructing a WhereFilter object.
403
+
404
+ Args
405
+ ----
406
+ path: List[str]
407
+ List of properties to filter on.
408
+ operator: str
409
+ Operator to use for filtering. Options: ['And', 'Or', 'Equal', 'NotEqual',
410
+ 'GreaterThan', 'GreaterThanEqual', 'LessThan', 'LessThanEqual', 'Like',
411
+ 'WithinGeoRange', 'IsNull', 'ContainsAny', 'ContainsAll']
412
+ value[dataType]: Union[int, bool, str, float, datetime]
413
+ Value to filter on. The dataType suffix must match the data type of the
414
+ property being filtered on. At least and only one value type must be provided.
415
+ '''
416
+ path: List[str]
417
+ operator: str
418
+ valueInt: int=None
419
+ valueBoolean: bool=None
420
+ valueText: str=None
421
+ valueNumber: float=None
422
+ valueDate = None
423
+
424
+ def post_init(self):
425
+ operators = ['And', 'Or', 'Equal', 'NotEqual','GreaterThan', 'GreaterThanEqual', 'LessThan',\
426
+ 'LessThanEqual', 'Like', 'WithinGeoRange', 'IsNull', 'ContainsAny', 'ContainsAll']
427
+ if self.operator not in operators:
428
+ raise ValueError(f'operator must be one of: {operators}, got {self.operator}')
429
+ values = [self.valueInt, self.valueBoolean, self.valueText, self.valueNumber, self.valueDate]
430
+ if not any(values):
431
+ raise ValueError('At least one value must be provided.')
432
+ if len(values) > 1:
433
+ raise ValueError('At most one value can be provided.')
434
+
435
+ def todict(self):
436
+ return {k:v for k,v in self.__dict__.items() if v is not None}