Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
import tempfile | |
import zipfile | |
import os | |
st.title('Testing and QA') | |
# Dynamically load the selected models from the session state | |
EMBEDDING_MODEL_NAME = st.session_state.get('selected_embedding_model', "thenlper/gte-small") | |
LLM_MODEL_NAME = st.session_state.get('selected_llm_model', "mistralai/Mistral-7B-Instruct-v0.2") | |
# Initialization block for embedding_model, with a debug message | |
if 'embedding_model' not in st.session_state: | |
EMBEDDING_MODEL_NAME = st.session_state.get('selected_embedding_model', "thenlper/gte-small") | |
st.session_state['embedding_model'] = HuggingFaceEmbeddings( | |
model_name=EMBEDDING_MODEL_NAME, | |
multi_process=True, | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
st.info("embedding_model has been initialized.") # Debug message for initialization | |
else: | |
st.info("embedding_model was already initialized.") # Debug message if already initialized | |
# Now that we've ensured embedding_model is initialized, we can safely access it | |
embedding_model = st.session_state['embedding_model'] | |
st.write("Accessing embedding_model...") # Debug message for accessing | |
# Form for LLM settings, allowing dynamic model selection | |
with st.form("llm_settings_form"): | |
st.subheader("LLM Settings") | |
repo_id = st.text_input("Repo ID", value=LLM_MODEL_NAME, key="repo_id") | |
max_new_tokens = st.number_input("Max New Tokens", value=250, key="max_new_tokens") | |
top_k = st.number_input("Top K", value=3, key="top_k") | |
top_p = st.number_input("Top P", value=0.95, key="top_p") | |
typical_p = st.number_input("Typical P", value=0.95, key="typical_p") | |
temperature = st.number_input("Temperature", value=0.01, key="temperature") | |
repetition_penalty = st.number_input("Repetition Penalty", value=1.035, key="repetition_penalty") | |
submitted = st.form_submit_button("Update LLM Settings") | |
if submitted: | |
st.session_state['llm'] = HuggingFaceEndpoint( | |
repo_id=repo_id, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
top_p=top_p, | |
typical_p=typical_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
) | |
st.success("LLM settings updated.") | |
# Vector store upload and setup | |
if 'collection_vectorstore' not in st.session_state: | |
uploaded_file = st.file_uploader("Upload Vector Store ZIP", type=["zip"]) | |
if uploaded_file is not None: | |
with tempfile.TemporaryDirectory() as temp_dir: | |
with zipfile.ZipFile(uploaded_file, 'r') as zip_ref: | |
zip_ref.extractall(temp_dir) | |
docs_vectors_path = os.path.join(temp_dir, "docs_vectors") | |
st.session_state['collection_vectorstore'] = FAISS.load_local(docs_vectors_path, embeddings=embedding_model, allow_dangerous_deserialization=True) | |
st.success("Vector store uploaded and loaded successfully.") | |
# Create the retriever as soon as the vector store is created | |
st.session_state['retriever'] = st.session_state['collection_vectorstore'].as_retriever() | |
st.info("Retriever has been created.") # Debug message to confirm the retriever's creation | |
# Check if LLM and vector store are ready | |
if 'llm' in st.session_state and 'collection_vectorstore' in st.session_state: | |
# Use a button to indicate when to update the prompt template | |
if st.button("Update Prompt Template"): | |
# Assuming you have a text area where users input the new template | |
new_template = st.text_area("Enter new prompt template", key="new_prompt_template") | |
# Update the session state only when the button is pressed | |
st.session_state['prompt_template'] = new_template | |
st.success("Prompt template updated.") | |
# Ensure there's a default prompt template | |
if 'prompt_template' not in st.session_state: | |
st.session_state['prompt_template'] = "You are a knowledgeable assistant answering the following question based on the provided documents: {context} Question: {question}" | |
# Display the current template for editing | |
current_template = st.text_area("Edit Prompt Template", value=st.session_state['prompt_template'], key="current_prompt_template") | |
# Question input and processing | |
question = st.text_input("Enter your question", key="question_input") | |
if question: | |
llm = st.session_state['llm'] | |
prompt = ChatPromptTemplate.from_template(current_template) | |
retriever = st.session_state['retriever'] | |
chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
if st.button("Ask"): | |
result = chain.invoke(question) | |
st.subheader("Answer:") | |
st.write(result) | |
else: | |
st.warning("Please configure and submit the LLM settings and ensure the vector store is loaded to ask questions.") |