WDS-QA-Bot / app.py
jeongsk's picture
Update app.py
d8cad3f verified
import os
import pickle
import streamlit as st
from dotenv import load_dotenv
from laas import ChatLaaS
from langchain.embeddings import CacheBackedEmbeddings
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain.retrievers.document_compressors import (
CrossEncoderReranker,
FlashrankRerank,
)
from langchain.storage import LocalFileStore
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers.language.language_parser import (
LanguageParser,
)
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
# Load environment variables
load_dotenv()
# Set up environment variables
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Code QA Bot"
@st.cache_resource
def setup_embeddings_and_db(project_folder: str): # Note the underscore before 'docs'
CACHE_ROOT_PATH = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_MODELS_PATH = os.path.join(CACHE_ROOT_PATH, "models")
CACHE_EMBEDDINGS_PATH = os.path.join(CACHE_ROOT_PATH, "embeddings")
if not os.path.exists(CACHE_MODELS_PATH):
os.makedirs(CACHE_MODELS_PATH)
if not os.path.exists(CACHE_EMBEDDINGS_PATH):
os.makedirs(CACHE_EMBEDDINGS_PATH)
store = LocalFileStore(CACHE_EMBEDDINGS_PATH)
model_name = "BAAI/bge-m3"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
cache_folder=CACHE_MODELS_PATH,
multi_process=False,
show_progress=True,
)
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
embeddings,
store,
namespace=embeddings.model_name,
)
FAISS_DB_INDEX = os.path.join(project_folder, "langchain_faiss")
db = FAISS.load_local(
FAISS_DB_INDEX, # λ‘œλ“œν•  FAISS 인덱슀의 디렉토리 이름
cached_embeddings, # μž„λ² λ”© 정보λ₯Ό 제곡
allow_dangerous_deserialization=True, # 역직렬화λ₯Ό ν—ˆμš©ν•˜λŠ” μ˜΅μ…˜
)
return db
# Function to set up retrievers and chain
@st.cache_resource
def setup_retrievers_and_chain(
_db: VectorStore, project_folder: str
): # Note the underscores
faiss_retriever = _db.as_retriever(search_type="mmr", search_kwargs={"k": 20})
bm25_retriever_path = os.path.join(project_folder, "bm25_retriever.pkl")
with open(bm25_retriever_path, "rb") as f:
bm25_retriever = pickle.load(f)
bm25_retriever.k = 20
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever],
weights=[0.6, 0.4],
search_type="mmr",
)
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
compressor = CrossEncoderReranker(model=model, top_n=5)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=ensemble_retriever,
)
laas = ChatLaaS(
project=st.secrets["LAAS_PROJECT"],
api_key=st.secrets["LAAS_API_KEY"],
hash=st.secrets["LAAS_HASH"],
)
rag_chain = (
{
"context": compression_retriever | RunnableLambda(lambda x: str(x)),
"question": RunnablePassthrough(),
}
| RunnableLambda(
lambda x: laas.invoke(
"", params={"context": x["context"], "question": x["question"]}
)
)
| StrOutputParser()
)
return rag_chain
def sidebar_content():
st.sidebar.title("μ‚¬μš© κ°€μ΄λ“œ")
st.sidebar.info(
"""
1. μ™Όμͺ½ ν…μŠ€νŠΈ μ˜μ—­μ— μ§ˆλ¬Έμ„ μž…λ ₯ν•˜μ„Έμš”.
2. 'λ‹΅λ³€ 생성' λ²„νŠΌμ„ ν΄λ¦­ν•˜μ„Έμš”.
3. 닡변이 μ•„λž˜μ— ν‘œμ‹œλ©λ‹ˆλ‹€.
4. μƒˆλ‘œμš΄ μ§ˆλ¬Έμ„ ν•˜λ €λ©΄ 'λ‹΅λ³€ μ΄ˆκΈ°ν™”' λ²„νŠΌμ„ μ‚¬μš©ν•˜μ„Έμš”.
"""
)
if st.sidebar.button("λ‹΅λ³€ μ΄ˆκΈ°ν™”", key="reset"):
st.session_state.answer = ""
st.experimental_rerun()
def main():
st.set_page_config(page_title="WDS QA 봇", page_icon="πŸ€–", layout="wide")
sidebar_content()
st.title("πŸ€– WDS QA 봇")
st.subheader("μ§ˆλ¬Έν•˜κΈ°")
user_question = st.text_area("μ½”λ“œμ— λŒ€ν•΄ κΆκΈˆν•œ 점을 λ¬Όμ–΄λ³΄μ„Έμš”:", height=100)
if st.button("λ‹΅λ³€ 생성", key="generate"):
if user_question:
with st.spinner("닡변을 생성 μ€‘μž…λ‹ˆλ‹€..."):
project_folder = "wds"
db = setup_embeddings_and_db(project_folder)
rag_chain = setup_retrievers_and_chain(db, project_folder)
response = rag_chain.invoke(user_question)
st.session_state.answer = response
else:
st.warning("μ§ˆλ¬Έμ„ μž…λ ₯ν•΄μ£Όμ„Έμš”.")
if "answer" in st.session_state and st.session_state.answer:
st.subheader("λ‹΅λ³€")
st.markdown(st.session_state.answer)
st.markdown("---")
st.caption("Β© 2023 WDS QA 봇. λͺ¨λ“  ꢌ리 보유.")
if __name__ == "__main__":
main()