PDF-CHAT-BOT / app.py
abhisheksasidharanr's picture
Update app.py
ba3db0d verified
import streamlit as st
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
from pinecone import Pinecone, ServerlessSpec
from sentence_transformers import SentenceTransformer
from langchain_groq import ChatGroq
from langchain.chains import LLMChain
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain_core.messages import SystemMessage
import os
import string
import random
pc = Pinecone( api_key=st.secrets["PINE_CONE_KEY"])
index = pc.Index('example-index')
model = SentenceTransformer('all-mpnet-base-v2')
if 'body' not in st.session_state:
st.session_state.body = []
def randomIdGenerate():
ran = ''.join(random.choices(string.ascii_uppercase + string.digits, k = 5))
return ran
def readFiles(files):
st.session_state.processing = "Reading files..."
text = ""
for pdf in files:
pdf_reader= PdfReader(pdf)
for page in pdf_reader.pages:
text+= page.extract_text()
splits = get_text_chunks(text)
emb = embedThetext(splits)
saveInPinecone(emb)
return splits
def get_text_chunks(text):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=500)
chunks = text_splitter.split_text(text)
return chunks
def embedThetext(text):
st.session_state.processing = "Embedding text..."
embeddings = model.encode(text)
metadata_list = [{"text": s} for s in text]
ids = [f'id-{randomIdGenerate()}' for i in range(len(text))]
vectors = [
{'id': id_, 'values': embedding, 'metadata': metadata}
for id_, embedding, metadata in zip(ids, embeddings, metadata_list)
]
return vectors
def saveInPinecone(vector):
st.session_state.processing = "Inserting to prinecone vector..."
index.upsert(
vectors = vector, namespace=st.session_state.namespace
)
def getFinalResponse(user_question):
query_embedding = model.encode([user_question])[0].tolist()
result = index.query(top_k=5, namespace=st.session_state.namespace, vector=query_embedding, include_values=True, include_metadata=True)
response_text = result
matched_info = ' '.join(item['metadata']['text'] for item in result['matches'])
sources = [item['metadata']['text'] for item in result['matches']]
context = f"Information: {matched_info} and the sources: {matched_info}"
sys_prompt = f"""
Instructions:
- Never answer external questions
- Utilize the context provided for accurate and specific information.
- when an out of context question comes return it is out of context question. If so, strictly don't give any other information.
- Don't give external data please. why are you doing so?
- Dont add According to the provided information.
- Cite your sources
Context: {context}
"""
prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(
content=sys_prompt
), # This is the persistent system prompt that is always included at the start of the chat.
MessagesPlaceholder(
variable_name="chat_history"
), # This placeholder will be replaced by the actual chat history during the conversation. It helps in maintaining context.
HumanMessagePromptTemplate.from_template(
"{human_input}"
), # This template is where the user's current input will be injected into the prompt.
]
)
groq_chat = ChatGroq(
groq_api_key=st.secrets["GROQ_API_KEY"],
model_name="llama3-8b-8192"
)
conversation = LLMChain(
llm=groq_chat, # The Groq LangChain chat object initialized earlier.
prompt=prompt, # The constructed prompt template.
verbose=False, # TRUE Enables verbose output, which can be useful for debugging.
memory=st.session_state.memory,# The conversational memory object that stores and manages the conversation history.
)
response = conversation.predict(human_input=user_question)
st.write(response)
return {'question': user_question, 'answer': response}
conversational_memory_length = 5
if 'memory' not in st.session_state:
st.session_state.memory = ConversationBufferWindowMemory(k=5, memory_key="chat_history", return_messages=True)
if 'processing' not in st.session_state:
st.session_state.processing = 'Processing...'
if 'namespace' not in st.session_state:
st.session_state.namespace = randomIdGenerate()
def main():
with st.sidebar:
st.header("Upload Multiple PDF Files Here", divider='rainbow')
st.write("When you refresh, new namespace will be selected. So after reload the previous data is not accessable.")
files = st.file_uploader('', accept_multiple_files=True)
button = st.button("Process")
if button:
if files:
with st.spinner(st.session_state.processing):
textv = readFiles(files)
st.success('Files Processed Successfully')
else:
st.error('No files selected')
st.header("Chat with your PDF | RAG", divider='rainbow')
for chat in st.session_state.body:
with st.chat_message("user"):
st.write(chat["question"])
with st.chat_message("Assistant"):
st.write(chat["answer"])
user_question = st.chat_input('Ask Something')
if user_question:
st.chat_message("user").write(user_question)
with st.spinner("Processing..."):
result = getFinalResponse(user_question)
st.session_state.body.append(result)
# st.experimental_rerun()
if __name__ == "__main__":
main()