Spaces:
Sleeping
Sleeping
import random | |
import time | |
import streamlit as st | |
import os | |
import pathlib | |
from typing import List | |
from models.llms import ( | |
llm_llama_2_7b_chat, | |
llm_mixtral_8x7b, | |
llm_bloomz_560m, | |
llm_gpt_3_5_turbo, | |
llm_gpt_3_5_turbo_0125, | |
llm_gpt_4_0125, | |
llm_llama_13b_v2_replicate | |
) | |
from models.embeddings import hf_embed_model, openai_embed_model | |
from models.llamaCustom import LlamaCustom | |
# from models.llamaCustom import LlamaCustom | |
from utils.chatbox import show_previous_messages, show_chat_input | |
from llama_index.core import ( | |
SimpleDirectoryReader, | |
Document, | |
VectorStoreIndex, | |
StorageContext, | |
Settings, | |
load_index_from_storage, | |
) | |
from llama_index.core.memory import ChatMemoryBuffer | |
from llama_index.core.base.llms.types import ChatMessage | |
SAVE_DIR = "uploaded_files" | |
VECTOR_STORE_DIR = "vectorStores" | |
# global | |
Settings.embed_model = hf_embed_model | |
llama_llms = { | |
"bigscience/bloomz-560m": llm_bloomz_560m, | |
"mistral/mixtral": llm_mixtral_8x7b, | |
"meta-llama/Llama-2-7b-chat-hf": llm_llama_2_7b_chat, | |
# "openai/gpt-3.5-turbo": llm_gpt_3_5_turbo, | |
"openai/gpt-3.5-turbo-0125": llm_gpt_3_5_turbo_0125, | |
# "openai/gpt-4-0125-preview": llm_gpt_4_0125, | |
# "meta/llama-2-13b-chat": llm_llama_13b_v2_replicate, | |
} | |
def init_session_state(): | |
if "llama_messages" not in st.session_state: | |
st.session_state.llama_messages = [ | |
{"role": "assistant", "content": "How can I help you today?"} | |
] | |
# TODO: create a chat history for each different document | |
if "llama_chat_history" not in st.session_state: | |
st.session_state.llama_chat_history = [ | |
ChatMessage.from_str(role="assistant", content="How can I help you today?") | |
] | |
if "llama_custom" not in st.session_state: | |
st.session_state.llama_custom = None | |
# @st.cache_resource | |
def index_docs( | |
filename: str, | |
) -> VectorStoreIndex: | |
try: | |
index_path = pathlib.Path(f"{VECTOR_STORE_DIR}/{filename.replace('.', '_')}") | |
if pathlib.Path.exists(index_path): | |
print("Loading index from storage ...") | |
storage_context = StorageContext.from_defaults(persist_dir=index_path) | |
index = load_index_from_storage(storage_context=storage_context) | |
# test the index | |
index.as_query_engine().query("What is the capital of France?") | |
else: | |
reader = SimpleDirectoryReader(input_files=[f"{SAVE_DIR}/{filename}"]) | |
docs = reader.load_data(show_progress=True) | |
index = VectorStoreIndex.from_documents( | |
documents=docs, | |
show_progress=True, | |
) | |
index.storage_context.persist(persist_dir=f"vectorStores/{filename.replace(".", '_')}") | |
except Exception as e: | |
print(f"Error: {e}") | |
index = None | |
return index | |
def load_llm(model_name: str): | |
return llama_llms[model_name] | |
init_session_state() | |
st.set_page_config(page_title="Llama", page_icon="🦙") | |
st.header("Llama Index with Custom LLM Demo") | |
tab1, tab2 = st.tabs(["Config", "Chat"]) | |
with tab1: | |
with st.form(key="llama_form"): | |
selected_llm_name = st.selectbox(label="Select a model:", options=llama_llms.keys()) | |
if selected_llm_name.startswith("openai"): | |
# ask for the api key | |
if st.secrets.get("OPENAI_API_KEY") is None: | |
# st.stop() | |
st.info("OpenAI API Key not found in secrets. Please enter it below.") | |
st.secrets["OPENAI_API_KEY"] = st.text_input( | |
"OpenAI API Key", | |
type="password", | |
help="Get your API key from https://platform.openai.com/account/api-keys", | |
) | |
selected_file = st.selectbox( | |
label="Choose a file to chat with: ", options=os.listdir(SAVE_DIR) | |
) | |
if st.form_submit_button(label="Submit"): | |
with st.status("Loading ...", expanded=True) as status: | |
st.write("Loading Model ...") | |
llama_llm = load_llm(selected_llm_name) | |
Settings.llm = llama_llm | |
st.write("Processing Data ...") | |
index = index_docs(selected_file) | |
if index is None: | |
st.error("Failed to index the documents.") | |
st.stop() | |
st.write("Finishing Up ...") | |
llama_custom = LlamaCustom(model_name=selected_llm_name, index=index) | |
st.session_state.llama_custom = llama_custom | |
status.update(label="Ready to query!", state="complete", expanded=False) | |
with tab2: | |
messages_container = st.container(height=300) | |
show_previous_messages(framework="llama", messages_container=messages_container) | |
show_chat_input(disabled=False, framework="llama", model=st.session_state.llama_custom, messages_container=messages_container) | |
def clear_history(): | |
messages_container.empty() | |
st.session_state.llama_messages = [ | |
{"role": "assistant", "content": "How can I help you today?"} | |
] | |
st.session_state.llama_chat_history = [ | |
ChatMessage.from_str(role="assistant", content="How can I help you today?") | |
] | |
if st.button("Clear Chat History"): | |
clear_history() | |
st.rerun() | |