import os import pickle import streamlit as st from dotenv import load_dotenv from laas import ChatLaaS from langchain.embeddings import CacheBackedEmbeddings from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain.retrievers.document_compressors import ( CrossEncoderReranker, FlashrankRerank, ) from langchain.storage import LocalFileStore from langchain_community.cross_encoders import HuggingFaceCrossEncoder from langchain_community.document_loaders.generic import GenericLoader from langchain_community.document_loaders.parsers.language.language_parser import ( LanguageParser, ) from langchain_community.retrievers import BM25Retriever from langchain_community.vectorstores import FAISS from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_core.vectorstores import VectorStore from langchain_huggingface import HuggingFaceEmbeddings from langchain_text_splitters import Language, RecursiveCharacterTextSplitter # Load environment variables load_dotenv() # Set up environment variables os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot" @st.cache_resource def setup_embeddings_and_db(project_folder: str): # Note the underscore before 'docs' CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache") CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models") CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings") if not os.path.exists(CACHE_MODELS_PATH): os.makedirs(CACHE_MODELS_PATH) if not os.path.exists(CACHE_EMBEDDINGS_PATH): os.makedirs(CACHE_EMBEDDINGS_PATH) store = LocalFileStore(CACHE_EMBEDDINGS_PATH) model_name = "BAAI/bge-m3" model_kwargs = {"device": "cpu"} encode_kwargs = {"normalize_embeddings": False} embeddings = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, cache_folder=CACHE_MODELS_PATH, multi_process=False, show_progress=True, ) cached_embeddings = CacheBackedEmbeddings.from_bytes_store( embeddings, store, namespace=embeddings.model_name, ) FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss") db = FAISS.load_local( FAISS_DB_INDEX, # 로드할 FAISS 인덱스의 디렉토리 이름 cached_embeddings, # 임베딩 정보를 제공 allow_dangerous_deserialization=True, # 역직렬화를 허용하는 옵션 ) return db # Function to set up retrievers and chain @st.cache_resource def setup_retrievers_and_chain( _db: VectorStore, project_folder: str ): # Note the underscores faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20}) bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl") with open(bm25_retriever_path, "rb") as f: bm25_retriever = pickle.load(f) bm25_retriever.k = 20 ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever], weights=[0.6, 0.4], search_type="mmr", ) model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3") compressor = CrossEncoderReranker(model=model, top_n=5) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=ensemble_retriever, ) laas = ChatLaaS( project=st.secrets["LAAS_PROJECT"], api_key=st.secrets["LAAS_API_KEY"], hash=st.secrets["LAAS_HASH"], ) rag_chain = ( { "context": compression_retriever | RunnableLambda(lambda x: str(x)), "question": RunnablePassthrough(), } | RunnableLambda( lambda x: laas.invoke( "", params={"context": x["context"], "question": x["question"]} ) ) | StrOutputParser() ) return rag_chain def sidebar_content(): st.sidebar.title("사용 가이드") st.sidebar.info( """ 1. 왼쪽 텍스트 영역에 질문을 입력하세요. 2. '답변 생성' 버튼을 클릭하세요. 3. 답변이 아래에 표시됩니다. 4. 새로운 질문을 하려면 '답변 초기화' 버튼을 사용하세요. """ ) if st.sidebar.button("답변 초기화", key="reset"): st.session_state.answer = "" st.experimental_rerun() def main(): st.set_page_config(page_title="WDS QA 봇", page_icon="🤖", layout="wide") sidebar_content() st.title("🤖 WDS QA 봇") st.subheader("질문하기") user_question = st.text_area("코드에 대해 궁금한 점을 물어보세요:", height=100) if st.button("답변 생성", key="generate"): if user_question: with st.spinner("답변을 생성 중입니다..."): project_folder = "wds" db = setup_embeddings_and_db(project_folder) rag_chain = setup_retrievers_and_chain(db, project_folder) response = rag_chain.invoke(user_question) st.session_state.answer = response else: st.warning("질문을 입력해주세요.") if "answer" in st.session_state and st.session_state.answer: st.subheader("답변") st.markdown(st.session_state.answer) st.markdown("---") st.caption("© 2023 WDS QA 봇. 모든 권리 보유.") if __name__ == "__main__": main()