import os import torch import gradio as gr from dotenv import load_dotenv from langchain.callbacks.base import BaseCallbackHandler from langchain.embeddings import CacheBackedEmbeddings from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever from langchain.storage import LocalFileStore from langchain_anthropic import ChatAnthropic from langchain_community.chat_models import ChatOllama from langchain_community.document_loaders import NotebookLoader, TextLoader from langchain_community.document_loaders.generic import GenericLoader from langchain_community.document_loaders.parsers.language.language_parser import ( LanguageParser, ) from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores import FAISS, Chroma from langchain_core.callbacks.manager import CallbackManager from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import PromptTemplate from langchain_core.runnables import ConfigurableField, RunnablePassthrough from langchain_google_genai import GoogleGenerativeAI from langchain_groq import ChatGroq from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_text_splitters import Language, RecursiveCharacterTextSplitter # Load environment variables load_dotenv() # Repository directories repo_root_dir = "./docs/langchain" repo_dirs = [ "libs/core/langchain_core", "libs/community/langchain_community", "libs/experimental/langchain_experimental", "libs/partners", "libs/cookbook", ] repo_dirs = [os.path.join(repo_root_dir, repo) for repo in repo_dirs] # Load Python documents py_documents = [] for path in repo_dirs: py_loader = GenericLoader.from_filesystem( path, glob="**/*", suffixes=[".py"], parser=LanguageParser(language=Language.PYTHON, parser_threshold=30), ) py_documents.extend(py_loader.load()) print(f".py 파일의 개수: {len(py_documents)}") # Load Markdown documents mdx_documents = [] for dirpath, _, filenames in os.walk(repo_root_dir): for file in filenames: if file.endswith(".mdx") and "*venv/" not in dirpath: try: mdx_loader = TextLoader(os.path.join(dirpath, file), encoding="utf-8") mdx_documents.extend(mdx_loader.load()) except Exception: pass print(f".mdx 파일의 개수: {len(mdx_documents)}") # Load Jupyter Notebook documents ipynb_documents = [] for dirpath, _, filenames in os.walk(repo_root_dir): for file in filenames: if file.endswith(".ipynb") and "*venv/" not in dirpath: try: ipynb_loader = NotebookLoader( os.path.join(dirpath, file), include_outputs=True, max_output_length=20, remove_newline=True, ) ipynb_documents.extend(ipynb_loader.load()) except Exception: pass print(f".ipynb 파일의 개수: {len(ipynb_documents)}") # Split documents into chunks def split_documents(documents, language, chunk_size=2000, chunk_overlap=200): splitter = RecursiveCharacterTextSplitter.from_language( language=language, chunk_size=chunk_size, chunk_overlap=chunk_overlap ) return splitter.split_documents(documents) py_docs = split_documents(py_documents, Language.PYTHON) mdx_docs = split_documents(mdx_documents, Language.MARKDOWN) ipynb_docs = split_documents(ipynb_documents, Language.PYTHON) print(f"분할된 .py 파일의 개수: {len(py_docs)}") print(f"분할된 .mdx 파일의 개수: {len(mdx_docs)}") print(f"분할된 .ipynb 파일의 개수: {len(ipynb_docs)}") combined_documents = py_docs + mdx_docs + ipynb_docs print(f"총 도큐먼트 개수: {len(combined_documents)}") # Define the device setting function def get_device(): if torch.cuda.is_available(): return "cuda:0" elif torch.backends.mps.is_available(): return "mps" else: return "cpu" # Use the function to set the device in model_kwargs device = get_device() # Initialize embeddings and cache store = LocalFileStore("~/.cache/embedding") embeddings = HuggingFaceBgeEmbeddings( model_name="BAAI/bge-m3", model_kwargs={"device": device}, encode_kwargs={"normalize_embeddings": True}, ) cached_embeddings = CacheBackedEmbeddings.from_bytes_store( embeddings, store, namespace=embeddings.model_name ) # Create and save FAISS index FAISS_DB_INDEX = "./langchain_faiss" # faiss_db = FAISS.from_documents( # documents=combined_documents, # embedding=cached_embeddings, # ) # faiss_db.save_local(folder_path=FAISS_DB_INDEX) # Create and save Chroma index CHROMA_DB_INDEX = "./langchain_chroma" # chroma_db = Chroma.from_documents( # documents=combined_documents, # embedding=cached_embeddings, # persist_directory=CHROMA_DB_INDEX, # ) # load vectorstore faiss_db = FAISS.load_local( FAISS_DB_INDEX, cached_embeddings, allow_dangerous_deserialization=True ) chroma_db = Chroma( embedding_function=cached_embeddings, persist_directory=CHROMA_DB_INDEX, ) # Create retrievers faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) chroma_retriever = chroma_db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) bm25_retriever = BM25Retriever.from_documents(combined_documents) bm25_retriever.k = 10 ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever, chroma_retriever], weights=[0.4, 0.3, 0.3], ) # Create prompt template prompt = PromptTemplate.from_template( """당신은 20년차 AI 개발자입니다. 당신의 임무는 주어진 질문에 대하여 최대한 문서의 정보를 활용하여 답변하는 것입니다. 문서는 Python 코드에 대한 정보를 담고 있습니다. 따라서, 답변을 작성할 때에는 Python 코드에 대한 상세한 code snippet을 포함하여 작성해주세요. 최대한 자세하게 답변하고, 한글로 답변해 주세요. 주어진 문서에서 답변을 찾을 수 없는 경우, "문서에 답변이 없습니다."라고 답변해 주세요. 답변은 출처(source)를 반드시 표기해 주세요. #참고문서: {context} #질문: {question} #답변: 출처: - source1 - source2 - ... """ ) # Define callback handler for streaming class StreamCallback(BaseCallbackHandler): def on_llm_new_token(self, token: str, **kwargs): print(token, end="", flush=True) streaming = os.getenv("STREAMING", "true") == "true" print("STREAMING", streaming) # Initialize LLMs with configuration llm = ChatOpenAI( model="gpt-4o", temperature=0, streaming=streaming, callbacks=[StreamCallback()], ).configurable_alternatives( ConfigurableField(id="llm"), default_key="gpt4", claude=ChatAnthropic( model="claude-3-opus-20240229", temperature=0, streaming=True, callbacks=[StreamCallback()], ), gpt3=ChatOpenAI( model="gpt-3.5-turbo", temperature=0, streaming=True, callbacks=[StreamCallback()], ), gemini=GoogleGenerativeAI( model="gemini-1.5-flash", temperature=0, streaming=True, callbacks=[StreamCallback()], ), llama3=ChatGroq( model_name="llama3-70b-8192", temperature=0, streaming=True, callbacks=[StreamCallback()], ), ollama=ChatOllama( model="EEVE-Korean-10.8B:long", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), ), ) # Create retrieval-augmented generation chain rag_chain = ( {"context": ensemble_retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) model_key = os.getenv("MODEL_KEY", "gemini") print("MODEL_KEY", model_key) def respond_stream( message, history: list[tuple[str, str]], ): response = "" for chunk in rag_chain.with_config(configurable={"llm": model_key}).stream(message): response += chunk yield response def respond( message, history: list[tuple[str, str]], ): return rag_chain.with_config(configurable={"llm": model_key}).invoke(message) """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ demo = gr.ChatInterface( respond_stream if streaming else respond, title="랭체인에 대해서 물어보세요!", description="안녕하세요!\n저는 랭체인에 대한 인공지능 QA봇입니다. 랭체인에 대해 깊은 지식을 가지고 있어요. 랭체인 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!", ) if __name__ == "__main__": demo.launch()