AIE4-Class3-RAG / app.py
jet-taekyo's picture
modify load_env() for main branch
bc0a1d6
import os
# For type hints
from typing import List
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import ChatOpenAI
from chainlit.types import AskFileResponse
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_core.runnables import Runnable
from langchain_core.documents import Document
# Libraries to be used
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_wrappers.langchain_chat_models import MyChatOpenAI
from langchain_wrappers.langchain_embedding_models import MyOpenAIEmbeddings
from langchain_qdrant import QdrantVectorStore
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, Runnable
from rag_prompts import system_msg, user_msg
import chainlit as cl
from dotenv import load_dotenv
# Cache
from langchain.globals import set_llm_cache, get_llm_cache
from langchain_community.cache import InMemoryCache
set_llm_cache(InMemoryCache())
# Load the environment variables
load_dotenv()
# RAG chain
def Get_RAG_pipeline(retriever: VectorStoreRetriever, llm: ChatOpenAI)-> Runnable:
retriever = retriever.with_config({'run_name': 'RAG: Retriever'})
prompt = ChatPromptTemplate([system_msg, user_msg]).with_config({'run_name': 'RAG Step2: Prompt (Augmented)'})
llm = llm.with_config({'run_name': 'RAG Step3: LLM (Generation)'})
def get_context(relevant_docs: List):
context = ""
for doc in relevant_docs:
context += doc.page_content + "\n"
return context
RAG_chain = RunnableParallel(
relevant_docs = retriever,
question = lambda x: x
).with_config({'run_name':'RAG Step1-1: Get relevant docs (Retrieval)'}) | RunnablePassthrough.assign(
context = lambda x: get_context(x['relevant_docs'])
).with_config({'run_name':'RAG Step1-2: Get context (Retrieval)'}) | prompt | llm
RAG_chain = RAG_chain.with_config({'run_name':'RAG pipeline'})
return RAG_chain
# Split documents
def process_text_file(file: AskFileResponse)-> List[Document]:
import tempfile
if file.name.endswith('.txt'):
suffix = '.txt'
base_loader = TextLoader
elif file.name.endswith('.pdf'):
suffix = '.pdf'
base_loader = PyPDFLoader
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=suffix) as temp_file:
temp_file_path = temp_file.name
with open(temp_file_path, 'wb') as f:
f.write(file.content)
document_loader = base_loader(temp_file_path)
documents = document_loader.load()
text_splitter = RecursiveCharacterTextSplitter()
splitted_documents = [x.page_content for x in text_splitter.transform_documents(documents)]
return splitted_documents
@cl.on_chat_start
async def on_chat_start():
files = None
# Wait for the user to upload a file
while files == None:
files = await cl.AskFileMessage(
content="Please upload a Text File file to begin!",
accept=["text/plain", "application/pdf"],
max_size_mb=5,
timeout=180,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...", disable_human_feedback=True
)
await msg.send()
# load the file
texts = process_text_file(file)
print(f"Processing {len(texts)} text chunks")
# Create a dict vector store
vector_db = await QdrantVectorStore.afrom_texts(
texts, MyOpenAIEmbeddings.from_model('small'), location=":memory:", collection_name="texts"
)
# Create a chain
RAG_chain = Get_RAG_pipeline(
retriever=vector_db.as_retriever(search_kwargs = {'k':3}),
llm=MyChatOpenAI.from_model()
)
# Let the user know that the system is ready
msg.content = f"Processing `{file.name}` done ({len(texts)} chunks in total). You can now ask questions!"
await msg.update()
cl.user_session.set("chain", RAG_chain)
@cl.on_message
async def main(message):
os.environ['LANGSMITH_PROJECT'] = os.getenv('LANGCHAIN_PROJECT')
chain = cl.user_session.get("chain")
msg = cl.Message(content="")
async for stream_resp in chain.astream(message.content):
await msg.stream_token(stream_resp.content)
await msg.send()