Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import streamlit as st | |
from dotenv import load_dotenv | |
from laas import ChatLaaS | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever | |
from langchain.retrievers.document_compressors import ( | |
CrossEncoderReranker, | |
FlashrankRerank, | |
) | |
from langchain.storage import LocalFileStore | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from langchain_community.document_loaders.generic import GenericLoader | |
from langchain_community.document_loaders.parsers.language.language_parser import ( | |
LanguageParser, | |
) | |
from langchain_community.retrievers import BM25Retriever | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
from langchain_core.vectorstores import VectorStore | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter | |
# Load environment variables | |
load_dotenv() | |
# Set up environment variables | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot" | |
def setup_embeddings_and_db(project_folder: str): # Note the underscore before 'docs' | |
CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache") | |
CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models") | |
CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings") | |
if not os.path.exists(CACHE_MODELS_PATH): | |
os.makedirs(CACHE_MODELS_PATH) | |
if not os.path.exists(CACHE_EMBEDDINGS_PATH): | |
os.makedirs(CACHE_EMBEDDINGS_PATH) | |
store = LocalFileStore(CACHE_EMBEDDINGS_PATH) | |
model_name = "BAAI/bge-m3" | |
model_kwargs = {"device": "cpu"} | |
encode_kwargs = {"normalize_embeddings": False} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
cache_folder=CACHE_MODELS_PATH, | |
multi_process=False, | |
show_progress=True, | |
) | |
cached_embeddings = CacheBackedEmbeddings.from_bytes_store( | |
embeddings, | |
store, | |
namespace=embeddings.model_name, | |
) | |
FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss") | |
db = FAISS.load_local( | |
FAISS_DB_INDEX, # λ‘λν FAISS μΈλ±μ€μ λλ ν 리 μ΄λ¦ | |
cached_embeddings, # μλ² λ© μ 보λ₯Ό μ 곡 | |
allow_dangerous_deserialization=True, # μμ§λ ¬νλ₯Ό νμ©νλ μ΅μ | |
) | |
return db | |
# Function to set up retrievers and chain | |
def setup_retrievers_and_chain( | |
_db: VectorStore, project_folder: str | |
): # Note the underscores | |
faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20}) | |
bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl") | |
with open(bm25_retriever_path, "rb") as f: | |
bm25_retriever = pickle.load(f) | |
bm25_retriever.k = 20 | |
ensemble_retriever = EnsembleRetriever( | |
retrievers=[bm25_retriever, faiss_retriever], | |
weights=[0.6, 0.4], | |
search_type="mmr", | |
) | |
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3") | |
compressor = CrossEncoderReranker(model=model, top_n=5) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, | |
base_retriever=ensemble_retriever, | |
) | |
laas = ChatLaaS( | |
project=st.secrets["LAAS_PROJECT"], | |
api_key=st.secrets["LAAS_API_KEY"], | |
hash=st.secrets["LAAS_HASH"], | |
) | |
rag_chain = ( | |
{ | |
"context": compression_retriever | RunnableLambda(lambda x: str(x)), | |
"question": RunnablePassthrough(), | |
} | |
| RunnableLambda( | |
lambda x: laas.invoke( | |
"", params={"context": x["context"], "question": x["question"]} | |
) | |
) | |
| StrOutputParser() | |
) | |
return rag_chain | |
def sidebar_content(): | |
st.sidebar.title("μ¬μ© κ°μ΄λ") | |
st.sidebar.info( | |
""" | |
1. μΌμͺ½ ν μ€νΈ μμμ μ§λ¬Έμ μ λ ₯νμΈμ. | |
2. 'λ΅λ³ μμ±' λ²νΌμ ν΄λ¦νμΈμ. | |
3. λ΅λ³μ΄ μλμ νμλ©λλ€. | |
4. μλ‘μ΄ μ§λ¬Έμ νλ €λ©΄ 'λ΅λ³ μ΄κΈ°ν' λ²νΌμ μ¬μ©νμΈμ. | |
""" | |
) | |
if st.sidebar.button("λ΅λ³ μ΄κΈ°ν", key="reset"): | |
st.session_state.answer = "" | |
st.experimental_rerun() | |
def main(): | |
st.set_page_config(page_title="WDS QA λ΄", page_icon="π€", layout="wide") | |
sidebar_content() | |
st.title("π€ WDS QA λ΄") | |
st.subheader("μ§λ¬ΈνκΈ°") | |
user_question = st.text_area("μ½λμ λν΄ κΆκΈν μ μ λ¬Όμ΄λ³΄μΈμ:", height=100) | |
if st.button("λ΅λ³ μμ±", key="generate"): | |
if user_question: | |
with st.spinner("λ΅λ³μ μμ± μ€μ λλ€..."): | |
project_folder = "wds" | |
db = setup_embeddings_and_db(project_folder) | |
rag_chain = setup_retrievers_and_chain(db, project_folder) | |
response = rag_chain.invoke(user_question) | |
st.session_state.answer = response | |
else: | |
st.warning("μ§λ¬Έμ μ λ ₯ν΄μ£ΌμΈμ.") | |
if "answer" in st.session_state and st.session_state.answer: | |
st.subheader("λ΅λ³") | |
st.markdown(st.session_state.answer) | |
st.markdown("---") | |
st.caption("Β© 2023 WDS QA λ΄. λͺ¨λ κΆλ¦¬ 보μ .") | |
if __name__ == "__main__": | |
main() | |