hf-iiee-msm / pages /2_🗣_Busqueda_Conversacional.py
luis-mi's picture
Upload 24 files
c5685a6 verified
raw
history blame contribute delete
No virus
26.8 kB
from tiktoken import get_encoding, encoding_for_model
from utils.weaviate_interface_v3_spa import WeaviateClient, WhereFilter
from templates.prompt_templates_spa import question_answering_prompt_series_spa
from utils.openai_interface_spa import GPT_Turbo
from openai import BadRequestError
from utils.app_features_spa import (convert_seconds, generate_prompt_series, search_result,
validate_token_threshold, load_content_cache, load_data, expand_content)
from utils.reranker_spa import ReRanker
from openai import OpenAI
from loguru import logger
import streamlit as st
import os
import templates.system_prompts as system_prompts
import base64
import json
# load environment variables
from dotenv import load_dotenv
load_dotenv('.env', override=True)
## PAGE CONFIGURATION
st.set_page_config(page_title="Busqueda Conversacional",
page_icon="🗣",
layout="wide",
initial_sidebar_state="auto",
menu_items=None)
def encode_image(uploaded_file):
return base64.b64encode(uploaded_file.getvalue()).decode('utf-8')
## DATA + CACHE
data_path = 'data/1_IIEE_1_json_data_19_02_2024_22-17-49.json'
cache_path = ''
data = load_data(data_path)
cache = None # Initialize cache as None
# Check if the cache file exists before attempting to load it
if os.path.exists(cache_path):
cache = load_content_cache(cache_path)
else:
logger.warning(f"Cache file {cache_path} not found. Proceeding without cache.")
#creates list of guests for sidebar
guest_list = sorted(list(set([d['document_title'] for d in data])))
with st.sidebar:
st.subheader("Selecciona tu Base de datos 🗃️")
client_type = st.radio(
"Selecciona el modo de acceso:",
('Cloud', 'Local'),
help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarás tu búsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu máquina.'
)
if client_type == 'Cloud':
api_key = st.secrets['WEAVIATE_CLOUD_API_KEY']
url = st.secrets['WEAVIATE_CLOUD_ENDPOINT']
weaviate_client = WeaviateClient(
endpoint=url,
api_key=api_key,
# model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
model_name_or_path="intfloat/multilingual-e5-small",
# openai_api_key=os.environ['OPENAI_API_KEY']
)
available_classes=sorted(weaviate_client.show_classes())
logger.info(available_classes)
logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
elif client_type == 'Local':
url = st.secrets['WEAVIATE_LOCAL_ENDPOINT']
weaviate_client = WeaviateClient(
endpoint=url,
# api_key=api_key,
# model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
model_name_or_path="intfloat/multilingual-e5-small",
# openai_api_key=os.environ['OPENAI_API_KEY']
)
available_classes=sorted(weaviate_client.show_classes())
logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
client = OpenAI(api_key=st.secrets["OPENAI_API_KEY"])
def main():
# Define the available user selected options
available_models = ['gpt-3.5-turbo', 'gpt-4-1106-preview']
# Define system prompts
system_prompt_list = ["🤖ChatGPT","🧙🏾‍♂️Professor Synapse", "👩🏼‍💼Marketing Jane"]
# Initialize selected options in session state
if "openai_data_model" not in st.session_state:
st.session_state["openai_data_model"] = available_models[0]
if "system_prompt_data_list" not in st.session_state and "system_prompt_data_model" not in st.session_state:
# This should be the emoji string the user selected
st.session_state["system_prompt_data_list"] = system_prompt_list[0]
# Now we get the corresponding prompt variable using the selected emoji string
st.session_state["system_prompt_data_model"] = system_prompts.prompt_mapping[system_prompt_list[0]]
# logger.debug(f"Assistant: {st.session_state['system_prompt_sync_list']}")
# logger.debug(f"System Prompt: {st.session_state['system_prompt_sync_model']}")
if 'class_name' not in st.session_state:
st.session_state['class_name'] = None
with st.sidebar:
st.session_state['class_name'] = st.selectbox(
label='Repositorio:',
options=available_classes,
index=None,
placeholder='Repositorio',
help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarás tu búsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu máquina.'
)
# Check if the collection name has been selected
class_name = st.session_state['class_name']
if class_name:
st.success(f"Repositorio seleccionado ✅: {st.session_state['class_name']}")
else:
st.warning("🎗️ No olvides seleccionar el repositorio 👆 a consultar 🗄️.")
st.stop() # Stop execution of the script
model_choice = st.selectbox(
label="Elige un modelo de OpenAI",
options=available_models,
index= available_models.index(st.session_state["openai_data_model"]),
help='Escoge entre diferentes modelos de OpenAI para generar respuestas a tus consultas. Cada modelo tiene distintas capacidades y limitaciones.'
)
system_prompt = st.selectbox(
label="Elige un asistente",
options=system_prompt_list,
index=system_prompt_list.index(st.session_state["system_prompt_data_list"]),
)
with st.expander("Filtros de Busqueda"):
guest_input = st.selectbox(
label='Selección de Documento',
options=guest_list,
index=None,
placeholder='Documentos',
help='Elige un documento específico del repositorio para afinar tu búsqueda a datos relevantes.'
)
with st.expander("Parametros de Busqueda"):
retriever_choice = st.selectbox(
label="Selecciona un método",
options=["Hybrid", "Vector", "Keyword"],
help='Determina el método de recuperación de información: "Hybrid" combina búsqueda por palabras clave y por similitud semántica, "Vector" usa embeddings de texto para encontrar coincidencias semánticas, y "Keyword" realiza una búsqueda tradicional por palabras clave.'
)
reranker_enabled = st.checkbox(
label="Activar Reranker",
value=True,
help='Activa esta opción para ordenar los resultados de la búsqueda según su relevancia, utilizando un modelo de reordenamiento adicional.'
)
alpha_input = st.slider(
label='Alpha para motor hibrido',
min_value=0.00,
max_value=1.00,
value=0.40,
step=0.05,
help='Ajusta el parámetro alfa para equilibrar los resultados entre los métodos de búsqueda por vector y por palabra clave en el motor híbrido.'
)
retrieval_limit = st.slider(
label='Resultados a Reranker',
min_value=10,
max_value=300,
value=100,
step=10,
help='Establece el número de resultados que se recuperarán antes de aplicar el reordenamiento.'
)
top_k_limit = st.slider(
label='Top K Limit',
min_value=1,
max_value=5,
value=3,
step=1,
help='Define el número máximo de resultados a mostrar después de aplicar el reordenamiento.'
)
temperature_input = st.slider(
label='Temperatura',
min_value=0.0,
max_value=1.0,
value=0.20,
step=0.10,
help='Ajusta la temperatura para la generación de texto con GPT, lo que influirá en la creatividad de las respuestas.'
)
# Update the model choice in session state
if st.session_state["openai_data_model"]!=model_choice:
st.session_state["openai_data_model"] = model_choice
logger.info(f"Data model: {st.session_state['openai_data_model']}")
# Update the system prompt choice in session state
if st.session_state["system_prompt_data_list"] != system_prompt:
# This should be the emoji string the user selected
st.session_state["system_prompt_data_list"] = system_prompt
# Now we get the corresponding prompt variable using the selected emoji string
selected_prompt_variable = system_prompts.prompt_mapping[system_prompt]
st.session_state['system_prompt_data_model'] = selected_prompt_variable
# logger.info(f"System Prompt: {selected_prompt_variable}")
logger.info(f"Assistant: {st.session_state['system_prompt_data_list']}")
# logger.info(f"System Prompt: {st.session_state['system_prompt_sync_model']}")
logger.info(weaviate_client.display_properties)
def database_search(query):
# Determine the appropriate limit based on reranking
search_limit = retrieval_limit if reranker_enabled else top_k_limit
# make hybrid call to weaviate
guest_filter = WhereFilter(
path=['document_title'],
operator='Equal',
valueText=guest_input).todict() if guest_input else None
try:
# Perform the search based on retriever_choice
if retriever_choice == "Keyword":
query_results = weaviate_client.keyword_search(
request=query,
class_name=class_name,
limit=search_limit,
where_filter=guest_filter
)
elif retriever_choice == "Vector":
query_results = weaviate_client.vector_search(
request=query,
class_name=class_name,
limit=search_limit,
where_filter=guest_filter
)
elif retriever_choice == "Hybrid":
query_results = weaviate_client.hybrid_search(
request=query,
class_name=class_name,
alpha=alpha_input,
limit=search_limit,
properties=["content"],
where_filter=guest_filter
)
else:
return json.dumps({"error": "Invalid retriever choice"})
## RERANKER
reranker = ReRanker(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2')
model_name = model_choice
encoding = encoding_for_model(model_name)
# Rerank the results if enabled
if reranker_enabled:
search_results = reranker.rerank(
results=query_results,
query=query,
apply_sigmoid=True,
top_k=top_k_limit
)
else:
# Use the results directly if reranking is not enabled
search_results = query_results
# logger.debug(search_results)
# Save search results to session state for later use
# st.session_state['search_results'] = search_results
add_to_search_history(query=query, search_results=search_results)
expanded_response = expand_content(search_results, cache, content_key='doc_id', create_new_list=True)
# validate token count is below threshold
token_threshold = 8000
valid_response = validate_token_threshold(
ranked_results=expanded_response,
base_prompt=question_answering_prompt_series_spa,
query=query,
tokenizer=encoding,
token_threshold=token_threshold,
verbose=True
)
# generate LLM prompt
prompt = generate_prompt_series(query=query, results=valid_response)
# If the strings in 'prompt' are double-escaped, decode them before dumping to JSON
# prompt_decoded = prompt.encode().decode('unicode_escape')
# Then, when you dump to JSON, it should no longer double-escape the characters
return json.dumps({
"query": query,
"Search Results": prompt,
}, ensure_ascii=False)
except Exception as e:
# Handle any exceptions and return a JSON formatted error message
return json.dumps({
"error": "An error occurred during the search",
"details": str(e)
})
# When a new message is added, include the type and content
def add_to_search_history(query, search_results):
st.session_state["data_search_history"].append({
"query": query,
"search_results": search_results,
})
# Function to display search results
def display_search_results():
# Loop through each item in the search history
for search in st.session_state['data_search_history']:
query = search["query"]
search_results = search["search_results"]
# Create an expander for each search query
with st.expander(f"Pregunta: {query}", expanded=False):
for i, hit in enumerate(search_results):
# col1, col2 = st.columns([7, 3], gap='large')
page_url = hit['page_url']
page_label = hit['page_label']
document_title = hit['document_title']
# Assuming 'page_summary' is available and you want to display it
page_summary = hit.get('page_summary', 'Summary not available')
# with col1:
st.markdown(f'''
<span style="color: #3498db; font-size: 19px; font-weight: bold;">{document_title}</span><br>
{page_summary}
[**Página:** {page_label}]({page_url})
''', unsafe_allow_html=True)
# with st.expander("📄 Clic aquí para ver contexto:"):
# try:
# content = hit['content']
# st.write(content)
# except Exception as e:
# st.write(f"Error displaying content: {e}")
# with col2:
# # If you have an image or want to display a placeholder image
# image = "URL_TO_A_PLACEHOLDER_IMAGE" # Replace with a relevant image URL if applicable
# st.image(image, caption=document_title, width=200, use_column_width=False)
# st.markdown(f'''
# <p style="text-align: right;">
# <b>Document Title:</b> {document_title}<br>
# <b>File Name:</b> {file_name}<br>
# </p>''', unsafe_allow_html=True)
########################
## SETUP MAIN DISPLAY ##
########################
st.image('./static/images/cervezas-mahou.jpeg', width=400)
st.subheader(f"✨🗣️📘 **Búsqueda Conversacional** 💡🗣️✨ - Impuestos Especiales")
st.write('\n')
col1, col2 = st.columns([50,50])
# Initialize chat history
if "data_chat_history" not in st.session_state:
st.session_state["data_chat_history"] = []
if "data_search_history" not in st.session_state:
st.session_state["data_search_history"] = []
with col1:
st.write("Chat History:")
# Create a container for chat history
chat_history_container = st.container(height=500, border=True)
# Display chat messages from history on app rerun
with chat_history_container:
for message in st.session_state["data_chat_history"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Function to update chat display
def update_chat_display():
with chat_history_container:
for message in st.session_state["data_chat_history"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("What is up?"):
# Add user message to chat history
st.session_state["data_chat_history"].append({"role": "user", "content": prompt})
# Initially display the chat history
update_chat_display()
# # Display user message in chat message container
# with st.chat_message("user"):
# st.markdown(prompt)
with st.spinner('Generando Respuesta...'):
tools = [
{
"type": "function",
"function": {
"name": "database_search",
"description": "Takes the users query about the database and returns the results, extracting info to answer the user's question",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "query"},
},
"required": ["query"],
},
}
}
]
# Display live assistant response in chat message container
with st.chat_message(
name="assistant",
avatar="./static/images/openai_purple_logo_hres.jpeg"):
message_placeholder = st.empty()
# Building the messages payload with proper OPENAI API structure
messages=[
{"role": "system", "content": st.session_state["system_prompt_data_model"]}
] + [
{"role": m["role"], "content": m["content"]} for m in st.session_state["data_chat_history"]
]
logger.debug(f"Initial Messages: {messages}")
# call the OpenAI API to get the response
RESPONSE = client.chat.completions.create(
model=st.session_state["openai_data_model"],
temperature=0.5,
messages=messages,
tools=tools,
tool_choice="auto", # auto is default, but we'll be explicit
stream=True
)
logger.debug(f"First Response: {RESPONSE}")
FULL_RESPONSE = ""
tool_calls = []
# build up the response structs from the streamed response, simultaneously sending message chunks to the browser
for chunk in RESPONSE:
delta = chunk.choices[0].delta
# logger.debug(f"chunk: {delta}")
if delta and delta.content:
text_chunk = delta.content
FULL_RESPONSE += str(text_chunk)
message_placeholder.markdown(FULL_RESPONSE + "▌")
elif delta and delta.tool_calls:
tcchunklist = delta.tool_calls
for tcchunk in tcchunklist:
if len(tool_calls) <= tcchunk.index:
tool_calls.append({"id": "", "type": "function", "function": { "name": "", "arguments": "" } })
tc = tool_calls[tcchunk.index]
if tcchunk.id:
tc["id"] += tcchunk.id
if tcchunk.function.name:
tc["function"]["name"] += tcchunk.function.name
if tcchunk.function.arguments:
tc["function"]["arguments"] += tcchunk.function.arguments
if tool_calls:
logger.debug(f"tool_calls: {tool_calls}")
# Define a dictionary mapping function names to actual functions
available_functions = {
"database_search": database_search,
# Add other functions as necessary
}
available_functions = {
"database_search": database_search,
} # only one function in this example, but you can have multiple
logger.debug(f"FuncCall Before messages: {messages}")
# Process each tool call
for tool_call in tool_calls:
# Get the function name and arguments from the tool call
function_name = tool_call['function']['name']
function_args = json.loads(tool_call['function']['arguments'])
# Get the actual function to call
function_to_call = available_functions[function_name]
# Call the function and get the response
function_response = function_to_call(**function_args)
# Append the function response to the messages list
messages.append({
"role": "assistant",
"tool_call_id": tool_call['id'],
"name": function_name,
"content": function_response,
})
logger.debug(f"FuncCall After messages: {messages}")
RESPONSE = client.chat.completions.create(
model=st.session_state["openai_data_model"],
temperature=0.1,
messages=messages,
stream=True
)
logger.debug(f"Second Response: {RESPONSE}")
# build up the response structs from the streamed response, simultaneously sending message chunks to the browser
for chunk in RESPONSE:
delta = chunk.choices[0].delta
# logger.debug(f"chunk: {delta}")
if delta and delta.content:
text_chunk = delta.content
FULL_RESPONSE += str(text_chunk)
message_placeholder.markdown(FULL_RESPONSE + "▌")
# Add assistant response to chat history
st.session_state["data_chat_history"].append({"role": "assistant", "content": FULL_RESPONSE})
logger.debug(f"chat_history: {st.session_state['data_chat_history']}")
# Next block of code...
####################
## Search Results ##
####################
# st.subheader(subheader_msg)
with col2:
st.write("Search Results:")
with st.container(height=500, border=True):
# Check if 'data_search_history' is in the session state and not empty
if 'data_search_history' in st.session_state and st.session_state['data_search_history']:
display_search_results()
# # Extract the latest message from the search history
# latest_search = st.session_state['data_search_history'][-1]
# query = latest_search["query"]
# with st.expander(query, expanded=False):
# # Extract the latest message from the search history
# latest_search = st.session_state['data_search_history'][-1]
# query = latest_search["query"]
# for i, hit in enumerate(latest_search["search_results"]):
# col1, col2 = st.columns([7, 3], gap='large')
# episode_url = hit['episode_url']
# title = hit['title']
# guest=hit['guest']
# show_length = hit['length']
# time_string = convert_seconds(show_length)
# # content = ranked_response[i]['content'] # Get 'content' from the same index in ranked_response
# content = hit['content']
# with col1:
# st.write( search_result(i=i,
# url=episode_url,
# guest=guest,
# title=title,
# content=content,
# length=time_string),
# unsafe_allow_html=True)
# st.write('\n\n')
# # with st.container("Episode Summary:"):
# # try:
# # ep_summary = hit['summary']
# # st.write(ep_summary)
# # except Exception as e:
# # st.error(f"Error displaying summary: {e}")
# with col2:
# image = hit['thumbnail_url']
# st.image(image, caption=title.split('|')[0], width=200, use_column_width=False)
# st.markdown(f'''
# <p style="text-align: right;">
# <b>Episode:</b> {title.split('|')[0]}<br>
# <b>Guest:</b> {hit['guest']}<br>
# <b>Length:</b> {time_string}
# </p>''', unsafe_allow_html=True)
if __name__ == '__main__':
main()