import random import time import streamlit as st import os import pathlib from typing import List from models.llms import ( llm_llama_2_7b_chat, llm_mixtral_8x7b, llm_bloomz_560m, llm_gpt_3_5_turbo, llm_gpt_3_5_turbo_0125, llm_gpt_4_0125, llm_llama_13b_v2_replicate ) from models.embeddings import hf_embed_model, openai_embed_model from models.llamaCustom import LlamaCustom # from models.llamaCustom import LlamaCustom from utils.chatbox import show_previous_messages, show_chat_input from llama_index.core import ( SimpleDirectoryReader, Document, VectorStoreIndex, StorageContext, Settings, load_index_from_storage, ) from llama_index.core.memory import ChatMemoryBuffer from llama_index.core.base.llms.types import ChatMessage SAVE_DIR = "uploaded_files" VECTOR_STORE_DIR = "vectorStores" # global Settings.embed_model = hf_embed_model llama_llms = { "bigscience/bloomz-560m": llm_bloomz_560m, "mistral/mixtral": llm_mixtral_8x7b, "meta-llama/Llama-2-7b-chat-hf": llm_llama_2_7b_chat, # "openai/gpt-3.5-turbo": llm_gpt_3_5_turbo, "openai/gpt-3.5-turbo-0125": llm_gpt_3_5_turbo_0125, # "openai/gpt-4-0125-preview": llm_gpt_4_0125, # "meta/llama-2-13b-chat": llm_llama_13b_v2_replicate, } def init_session_state(): if "llama_messages" not in st.session_state: st.session_state.llama_messages = [ {"role": "assistant", "content": "How can I help you today?"} ] # TODO: create a chat history for each different document if "llama_chat_history" not in st.session_state: st.session_state.llama_chat_history = [ ChatMessage.from_str(role="assistant", content="How can I help you today?") ] if "llama_custom" not in st.session_state: st.session_state.llama_custom = None # @st.cache_resource def index_docs( filename: str, ) -> VectorStoreIndex: try: index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}") if pathlib.Path.exists(index_path): print("Loading index from storage ...") storage_context = StorageContext.from_defaults(persist_dir=index_path) index = load_index_from_storage(storage_context=storage_context) # test the index index.as_query_engine().query("What is the capital of France?") else: reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"]) docs = reader.load_data(show_progress=True) index = VectorStoreIndex.from_documents( documents=docs, show_progress=True, ) index.storage_context.persist(persist_dir=f"vectorStores/{filename.replace(".", '_')}") except Exception as e: print(f"Error: {e}") index = None return index def load_llm(model_name: str): return llama_llms[model_name] init_session_state() st.set_page_config(page_title="Llama", page_icon="🦙") st.header("Llama Index with Custom LLM Demo") tab1, tab2 = st.tabs(["Config", "Chat"]) with tab1: with st.form(key="llama_form"): selected_llm_name = st.selectbox(label="Select a model:", options=llama_llms.keys()) if selected_llm_name.startswith("openai"): # ask for the api key if st.secrets.get("OPENAI_API_KEY") is None: # st.stop() st.info("OpenAI API Key not found in secrets. Please enter it below.") st.secrets["OPENAI_API_KEY"] = st.text_input( "OpenAI API Key", type="password", help="Get your API key from https://platform.openai.com/account/api-keys", ) selected_file = st.selectbox( label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR) ) if st.form_submit_button(label="Submit"): with st.status("Loading ...", expanded=True) as status: st.write("Loading Model ...") llama_llm = load_llm(selected_llm_name) Settings.llm = llama_llm st.write("Processing Data ...") index = index_docs(selected_file) if index is None: st.error("Failed to index the documents.") st.stop() st.write("Finishing Up ...") llama_custom = LlamaCustom(model_name=selected_llm_name, index=index) st.session_state.llama_custom = llama_custom status.update(label="Ready to query!", state="complete", expanded=False) with tab2: messages_container = st.container(height=300) show_previous_messages(framework="llama", messages_container=messages_container) show_chat_input(disabled=False, framework="llama", model=st.session_state.llama_custom, messages_container=messages_container) def clear_history(): messages_container.empty() st.session_state.llama_messages = [ {"role": "assistant", "content": "How can I help you today?"} ] st.session_state.llama_chat_history = [ ChatMessage.from_str(role="assistant", content="How can I help you today?") ] if st.button("Clear Chat History"): clear_history() st.rerun()