AIE4-Class3-RAG / app.py
jet-taekyo's picture
change into langchain style
ece0f5f
raw
history blame
4.8 kB
import os
# For type hints
from typing import List
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import ChatOpenAI
from chainlit.types import AskFileResponse
from langchain_openai.embeddings import OpenAIEmbeddings
# Libraries to be used
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_wrappers.langchain_chat_models import MyChatOpenAI
from langchain_wrappers.langchain_embedding_models import MyOpenAIEmbeddings
from langchain_qdrant import QdrantVectorStore
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
import chainlit as cl
from dotenv import load_dotenv
# Cache
from langchain.globals import set_llm_cache, get_llm_cache
from langchain_community.cache import InMemoryCache
set_llm_cache(InMemoryCache())
system_template = """\
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer.\
Context:
{context}
"""
human_template = """\
Question:
{question}
"""
system_msg = ('system', system_template)
user_msg = ('human', human_template)
text_splitter = RecursiveCharacterTextSplitter()
load_dotenv()
### RAG chain
def Get_RAG_pipeline(retriever: VectorStoreRetriever, llm: ChatOpenAI):
retriever = retriever.with_config({'run_name': 'RAG: Retriever'})
prompt = ChatPromptTemplate([system_msg, user_msg]).with_config({'run_name': 'RAG Step2: Prompt (Augmented)'})
llm = llm.with_config({'run_name': 'RAG Step3: LLM (Generation)'})
def get_context(relevant_docs: List):
context = ""
for doc in relevant_docs:
context += doc.page_content + "\n"
return context
RAG_chain = RunnableParallel(
relevant_docs = retriever,
question = lambda x: x
).with_config({'run_name':'RAG Step1-1: Get relevant docs (Retrieval)'}) | RunnablePassthrough.assign(
context = lambda x: get_context(x['relevant_docs'])
).with_config({'run_name':'RAG Step1-2: Get context (Retrieval)'}) | prompt | llm
RAG_chain = RAG_chain.with_config({'run_name':'RAG pipeline'})
return RAG_chain
def process_text_file(file: AskFileResponse):
import tempfile
if file.name.endswith('.pdf'):
print("PDF file detected")
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".pdf") as temp_file:
temp_file_path = temp_file.name
with open(temp_file_path, "wb") as f:
f.write(file.content)
document_loader = PyPDFLoader(temp_file_path)
elif file.name.endswith('.txt'):
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as temp_file:
temp_file_path = temp_file.name
with open(temp_file_path, "wb") as f:
f.write(file.content)
document_loader = TextLoader(temp_file_path, autodetect_encoding=True)
documents = document_loader.load()
splitted_documents = [x.page_content for x in text_splitter.transform_documents(documents)]
return splitted_documents
@cl.on_chat_start
async def on_chat_start():
files = None
# Wait for the user to upload a file
while files == None:
files = await cl.AskFileMessage(
content="Please upload a Text File file to begin!",
accept=["text/plain", "application/pdf"],
max_size_mb=5,
timeout=180,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...", disable_human_feedback=True
)
await msg.send()
# load the file
texts = process_text_file(file)
print(f"Processing {len(texts)} text chunks")
# Create a dict vector store
vector_db = await QdrantVectorStore.afrom_texts(
texts, MyOpenAIEmbeddings.from_model('small'), location=":memory:", collection_name="texts"
)
# Create a chain
RAG_chain = Get_RAG_pipeline(
retriever=vector_db.as_retriever(search_kwargs = {'k':3}),
llm=MyChatOpenAI.from_model()
)
# Let the user know that the system is ready
msg.content = f"Processing `{file.name}` done ({len(texts)} chunks in total). You can now ask questions!"
await msg.update()
cl.user_session.set("chain", RAG_chain)
@cl.on_message
async def main(message):
os.environ['LANGSMITH_PROJECT'] = os.getenv('LANGCHAIN_PROJECT')
chain = cl.user_session.get("chain")
msg = cl.Message(content="")
async for stream_resp in chain.astream(message.content):
await msg.stream_token(stream_resp.content)
await msg.send()