RegBotBeta / pages /llama_custom_demo.py
Zwea Htet
fixed bugs in llama index custom demo and updated ui
4bb745d
raw
history blame
5.36 kB
import random
import time
import streamlit as st
import os
import pathlib
from typing import List
from models.llms import (
llm_llama_2_7b_chat,
llm_mixtral_8x7b,
llm_bloomz_560m,
llm_gpt_3_5_turbo,
llm_gpt_3_5_turbo_0125,
llm_gpt_4_0125,
llm_llama_13b_v2_replicate
)
from models.embeddings import hf_embed_model, openai_embed_model
from models.llamaCustom import LlamaCustom
# from models.llamaCustom import LlamaCustom
from utils.chatbox import show_previous_messages, show_chat_input
from llama_index.core import (
SimpleDirectoryReader,
Document,
VectorStoreIndex,
StorageContext,
Settings,
load_index_from_storage,
)
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.base.llms.types import ChatMessage
SAVE_DIR = "uploaded_files"
VECTOR_STORE_DIR = "vectorStores"
# global
Settings.embed_model = hf_embed_model
llama_llms = {
"bigscience/bloomz-560m": llm_bloomz_560m,
"mistral/mixtral": llm_mixtral_8x7b,
"meta-llama/Llama-2-7b-chat-hf": llm_llama_2_7b_chat,
# "openai/gpt-3.5-turbo": llm_gpt_3_5_turbo,
"openai/gpt-3.5-turbo-0125": llm_gpt_3_5_turbo_0125,
# "openai/gpt-4-0125-preview": llm_gpt_4_0125,
# "meta/llama-2-13b-chat": llm_llama_13b_v2_replicate,
}
def init_session_state():
if "llama_messages" not in st.session_state:
st.session_state.llama_messages = [
{"role": "assistant", "content": "How can I help you today?"}
]
# TODO: create a chat history for each different document
if "llama_chat_history" not in st.session_state:
st.session_state.llama_chat_history = [
ChatMessage.from_str(role="assistant", content="How can I help you today?")
]
if "llama_custom" not in st.session_state:
st.session_state.llama_custom = None
# @st.cache_resource
def index_docs(
filename: str,
) -> VectorStoreIndex:
try:
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}")
if pathlib.Path.exists(index_path):
print("Loading index from storage ...")
storage_context = StorageContext.from_defaults(persist_dir=index_path)
index = load_index_from_storage(storage_context=storage_context)
# test the index
index.as_query_engine().query("What is the capital of France?")
else:
reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"])
docs = reader.load_data(show_progress=True)
index = VectorStoreIndex.from_documents(
documents=docs,
show_progress=True,
)
index.storage_context.persist(persist_dir=f"vectorStores/{filename.replace(".", '_')}")
except Exception as e:
print(f"Error: {e}")
index = None
return index
def load_llm(model_name: str):
return llama_llms[model_name]
init_session_state()
st.set_page_config(page_title="Llama", page_icon="🦙")
st.header("Llama Index with Custom LLM Demo")
tab1, tab2 = st.tabs(["Config", "Chat"])
with tab1:
with st.form(key="llama_form"):
selected_llm_name = st.selectbox(label="Select a model:", options=llama_llms.keys())
if selected_llm_name.startswith("openai"):
# ask for the api key
if st.secrets.get("OPENAI_API_KEY") is None:
# st.stop()
st.info("OpenAI API Key not found in secrets. Please enter it below.")
st.secrets["OPENAI_API_KEY"] = st.text_input(
"OpenAI API Key",
type="password",
help="Get your API key from https://platform.openai.com/account/api-keys",
)
selected_file = st.selectbox(
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR)
)
if st.form_submit_button(label="Submit"):
with st.status("Loading ...", expanded=True) as status:
st.write("Loading Model ...")
llama_llm = load_llm(selected_llm_name)
Settings.llm = llama_llm
st.write("Processing Data ...")
index = index_docs(selected_file)
if index is None:
st.error("Failed to index the documents.")
st.stop()
st.write("Finishing Up ...")
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index)
st.session_state.llama_custom = llama_custom
status.update(label="Ready to query!", state="complete", expanded=False)
with tab2:
messages_container = st.container(height=300)
show_previous_messages(framework="llama", messages_container=messages_container)
show_chat_input(disabled=False, framework="llama", model=st.session_state.llama_custom, messages_container=messages_container)
def clear_history():
messages_container.empty()
st.session_state.llama_messages = [
{"role": "assistant", "content": "How can I help you today?"}
]
st.session_state.llama_chat_history = [
ChatMessage.from_str(role="assistant", content="How can I help you today?")
]
if st.button("Clear Chat History"):
clear_history()
st.rerun()