Spaces:
Runtime error
Runtime error
import streamlit as st | |
import os | |
import pathlib | |
from typing import List | |
# local imports | |
from models.llms import load_llm, integrated_llms | |
from models.embeddings import openai_embed_model | |
from models.llamaCustom import LlamaCustom | |
# from models.llamaCustomV2 import LlamaCustomV2 | |
from models.vector_database import get_pinecone_index | |
from utils.chatbox import show_previous_messages, show_chat_input | |
from utils.util import validate_openai_api_key | |
# llama_index | |
from llama_index.core import ( | |
SimpleDirectoryReader, | |
Document, | |
VectorStoreIndex, | |
StorageContext, | |
Settings, | |
load_index_from_storage, | |
) | |
from llama_index.vector_stores.pinecone import PineconeVectorStore | |
from llama_index.core.memory import ChatMemoryBuffer | |
from llama_index.core.base.llms.types import ChatMessage | |
# huggingface | |
from huggingface_hub import HfApi | |
SAVE_DIR = "uploaded_files" | |
VECTOR_STORE_DIR = "vectorStores" | |
HF_REPO_ID = "zhtet/RegBotBeta" | |
# global | |
# Settings.embed_model = hf_embed_model | |
Settings.embed_model = openai_embed_model | |
# huggingface api | |
hf_api = HfApi() | |
def init_session_state(): | |
if "llama_messages" not in st.session_state: | |
st.session_state.llama_messages = [ | |
{"role": "assistant", "content": "How can I help you today?"} | |
] | |
# TODO: create a chat history for each different document | |
if "llama_chat_history" not in st.session_state: | |
st.session_state.llama_chat_history = [ | |
ChatMessage.from_str(role="assistant", content="How can I help you today?") | |
] | |
if "llama_custom" not in st.session_state: | |
st.session_state.llama_custom = None | |
if "openai_api_key" not in st.session_state: | |
st.session_state.openai_api_key = "" | |
if "replicate_api_token" not in st.session_state: | |
st.session_state.replicate_api_token = "" | |
if "hf_token" not in st.session_state: | |
st.session_state.hf_token = "" | |
# @st.cache_resource | |
def get_index( | |
filename: str, | |
) -> VectorStoreIndex: | |
"""This function loads the index from storage if it exists, otherwise it creates a new index from the document.""" | |
try: | |
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}") | |
if pathlib.Path.exists(index_path): | |
print("Loading index from storage ...") | |
storage_context = StorageContext.from_defaults(persist_dir=index_path) | |
index = load_index_from_storage(storage_context=storage_context) | |
else: | |
reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"]) | |
docs = reader.load_data(show_progress=True) | |
index = VectorStoreIndex.from_documents( | |
documents=docs, | |
show_progress=True, | |
) | |
index.storage_context.persist( | |
persist_dir=f"vectorStores/{filename.replace('.', '_')}" | |
) | |
except Exception as e: | |
print(f"Error: {e}") | |
raise e | |
return index | |
def check_api_key(model_name: str, source: str): | |
if source.startswith("openai"): | |
if not st.session_state.openai_api_key: | |
with st.expander("OpenAI API Key", expanded=True): | |
openai_api_key = st.text_input( | |
label="Enter your OpenAI API Key:", | |
type="password", | |
help="Get your key from https://platform.openai.com/account/api-keys", | |
value=st.session_state.openai_api_key, | |
) | |
if openai_api_key and st.spinner("Validating OpenAI API Key ..."): | |
result = validate_openai_api_key(openai_api_key) | |
if result["status"] == "success": | |
st.session_state.openai_api_key = openai_api_key | |
st.success(result["message"]) | |
else: | |
st.error(result["message"]) | |
st.info("You can still select a different model to proceed.") | |
st.stop() | |
elif source.startswith("replicate"): | |
if not st.session_state.replicate_api_token: | |
with st.expander("Replicate API Token", expanded=True): | |
replicate_api_token = st.text_input( | |
label="Enter your Replicate API Token:", | |
type="password", | |
help="Get your key from https://replicate.ai/account", | |
value=st.session_state.replicate_api_token, | |
) | |
# TODO: need to validate the token | |
if replicate_api_token: | |
st.session_state.replicate_api_token = replicate_api_token | |
# set the environment variable | |
os.environ["REPLICATE_API_TOKEN"] = replicate_api_token | |
elif source.startswith("huggingface"): | |
if not st.session_state.hf_token: | |
with st.expander("Hugging Face Token", expanded=True): | |
hf_token = st.text_input( | |
label="Enter your Hugging Face Token:", | |
type="password", | |
help="Get your key from https://huggingface.co/settings/token", | |
value=st.session_state.hf_token, | |
) | |
if hf_token: | |
st.session_state.hf_token = hf_token | |
# set the environment variable | |
os.environ["HF_TOKEN"] = hf_token | |
init_session_state() | |
st.set_page_config(page_title="Llama", page_icon="🦙") | |
st.header("California Drinking Water Regulation Chatbot - LlamaIndex Demo") | |
tab1, tab2 = st.tabs(["Config", "Chat"]) | |
with tab1: | |
selected_llm_name = st.selectbox( | |
label="Select a model:", | |
options=[f"{key} | {value}" for key, value in integrated_llms.items()], | |
) | |
model_name, source = selected_llm_name.split("|") | |
check_api_key(model_name=model_name.strip(), source=source.strip()) | |
selected_file = st.selectbox( | |
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR) | |
) | |
if st.button("Clear all api keys"): | |
st.session_state.openai_api_key = "" | |
st.session_state.replicate_api_token = "" | |
st.session_state.hf_token = "" | |
st.success("All API keys cleared!") | |
st.rerun() | |
if st.button("Submit", key="submit", help="Submit the form"): | |
with st.status("Loading ...", expanded=True) as status: | |
try: | |
st.write("Loading Model ...") | |
llama_llm = load_llm( | |
model_name=model_name.strip(), source=source.strip() | |
) | |
if llama_llm is None: | |
raise ValueError("Model not found!") | |
Settings.llm = llama_llm | |
st.write("Processing Data ...") | |
# index = get_index(selected_file) | |
index = get_pinecone_index(selected_file) | |
st.write("Finishing Up ...") | |
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index) | |
# llama_custom = LlamaCustomV2(model_name=selected_llm_name, index=index) | |
st.session_state.llama_custom = llama_custom | |
status.update(label="Ready to query!", state="complete", expanded=False) | |
except Exception as e: | |
status.update(label="Error!", state="error", expanded=False) | |
st.error(f"Error: {e}") | |
st.stop() | |
with tab2: | |
messages_container = st.container(height=300) | |
show_previous_messages(framework="llama", messages_container=messages_container) | |
show_chat_input( | |
disabled=False, | |
framework="llama", | |
model=st.session_state.llama_custom, | |
messages_container=messages_container, | |
) | |
def clear_history(): | |
messages_container.empty() | |
st.session_state.llama_messages = [ | |
{"role": "assistant", "content": "How can I help you today?"} | |
] | |
st.session_state.llama_chat_history = [ | |
ChatMessage.from_str(role="assistant", content="How can I help you today?") | |
] | |
if st.button("Clear Chat History"): | |
clear_history() | |
st.rerun() | |