mpi_data_store / pages /05_testing_qa.py
rianders's picture
Functioning pages added
9831243
raw
history blame
5.27 kB
import streamlit as st
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import tempfile
import zipfile
import os
st.title('Testing and QA')
# Dynamically load the selected models from the session state
EMBEDDING_MODEL_NAME = st.session_state.get('selected_embedding_model', "thenlper/gte-small")
LLM_MODEL_NAME = st.session_state.get('selected_llm_model', "mistralai/Mistral-7B-Instruct-v0.2")
# Initialization block for embedding_model, with a debug message
if 'embedding_model' not in st.session_state:
EMBEDDING_MODEL_NAME = st.session_state.get('selected_embedding_model', "thenlper/gte-small")
st.session_state['embedding_model'] = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
multi_process=True,
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
st.info("embedding_model has been initialized.") # Debug message for initialization
else:
st.info("embedding_model was already initialized.") # Debug message if already initialized
# Now that we've ensured embedding_model is initialized, we can safely access it
embedding_model = st.session_state['embedding_model']
st.write("Accessing embedding_model...") # Debug message for accessing
# Form for LLM settings, allowing dynamic model selection
with st.form("llm_settings_form"):
st.subheader("LLM Settings")
repo_id = st.text_input("Repo ID", value=LLM_MODEL_NAME, key="repo_id")
max_new_tokens = st.number_input("Max New Tokens", value=250, key="max_new_tokens")
top_k = st.number_input("Top K", value=3, key="top_k")
top_p = st.number_input("Top P", value=0.95, key="top_p")
typical_p = st.number_input("Typical P", value=0.95, key="typical_p")
temperature = st.number_input("Temperature", value=0.01, key="temperature")
repetition_penalty = st.number_input("Repetition Penalty", value=1.035, key="repetition_penalty")
submitted = st.form_submit_button("Update LLM Settings")
if submitted:
st.session_state['llm'] = HuggingFaceEndpoint(
repo_id=repo_id,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
st.success("LLM settings updated.")
# Vector store upload and setup
if 'collection_vectorstore' not in st.session_state:
uploaded_file = st.file_uploader("Upload Vector Store ZIP", type=["zip"])
if uploaded_file is not None:
with tempfile.TemporaryDirectory() as temp_dir:
with zipfile.ZipFile(uploaded_file, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
docs_vectors_path = os.path.join(temp_dir, "docs_vectors")
st.session_state['collection_vectorstore'] = FAISS.load_local(docs_vectors_path, embeddings=embedding_model, allow_dangerous_deserialization=True)
st.success("Vector store uploaded and loaded successfully.")
# Create the retriever as soon as the vector store is created
st.session_state['retriever'] = st.session_state['collection_vectorstore'].as_retriever()
st.info("Retriever has been created.") # Debug message to confirm the retriever's creation
# Check if LLM and vector store are ready
if 'llm' in st.session_state and 'collection_vectorstore' in st.session_state:
# Use a button to indicate when to update the prompt template
if st.button("Update Prompt Template"):
# Assuming you have a text area where users input the new template
new_template = st.text_area("Enter new prompt template", key="new_prompt_template")
# Update the session state only when the button is pressed
st.session_state['prompt_template'] = new_template
st.success("Prompt template updated.")
# Ensure there's a default prompt template
if 'prompt_template' not in st.session_state:
st.session_state['prompt_template'] = "You are a knowledgeable assistant answering the following question based on the provided documents: {context} Question: {question}"
# Display the current template for editing
current_template = st.text_area("Edit Prompt Template", value=st.session_state['prompt_template'], key="current_prompt_template")
# Question input and processing
question = st.text_input("Enter your question", key="question_input")
if question:
llm = st.session_state['llm']
prompt = ChatPromptTemplate.from_template(current_template)
retriever = st.session_state['retriever']
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
if st.button("Ask"):
result = chain.invoke(question)
st.subheader("Answer:")
st.write(result)
else:
st.warning("Please configure and submit the LLM settings and ensure the vector store is loaded to ask questions.")