Upload 3 files
Browse files- app.py +54 -0
- helper.py +96 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_chat import message
|
3 |
+
from helper import get_qa_chain, create_vector_db
|
4 |
+
|
5 |
+
st.set_page_config(layout="wide",page_title="Chat with PDF")
|
6 |
+
|
7 |
+
|
8 |
+
def process_answer(instruction):
|
9 |
+
response = ''
|
10 |
+
instruction = instruction
|
11 |
+
qa = get_qa_chain()
|
12 |
+
generated_text = qa(instruction)
|
13 |
+
answer = generated_text['result']
|
14 |
+
return answer
|
15 |
+
|
16 |
+
# Display conversation history using Streamlit messages
|
17 |
+
def display_conversation(history):
|
18 |
+
for i in range(len(history["generated"])):
|
19 |
+
message(history["past"][i], is_user=True, key=str(i) + "_user")
|
20 |
+
message(history["generated"][i],key=str(i))
|
21 |
+
|
22 |
+
def main():
|
23 |
+
st.header("Chat with your PDF")
|
24 |
+
create_embeddings = st.button("Create Embeddings")
|
25 |
+
|
26 |
+
if create_embeddings:
|
27 |
+
with st.spinner('Embeddings are in process...'):
|
28 |
+
create_vector_db()
|
29 |
+
st.success('Embeddings are created successfully!')
|
30 |
+
|
31 |
+
st.subheader("Chat Here")
|
32 |
+
user_input = st.text_input("",key="input")
|
33 |
+
|
34 |
+
#initialize session state for generted response and past messages
|
35 |
+
if "generated" not in st.session_state:
|
36 |
+
st.session_state["generated"] = ["I am an AI assitance how can I help?"]
|
37 |
+
|
38 |
+
if "past" not in st.session_state:
|
39 |
+
st.session_state["past"] = ["Hey there!"]
|
40 |
+
# Search the database for a response based on user input and update session state
|
41 |
+
if user_input:
|
42 |
+
answer = process_answer({'query': user_input})
|
43 |
+
st.session_state["past"].append(user_input)
|
44 |
+
response = answer
|
45 |
+
st.session_state["generated"].append(response)
|
46 |
+
|
47 |
+
# Display conversation history using Streamlit messages
|
48 |
+
if st.session_state["generated"]:
|
49 |
+
display_conversation(st.session_state)
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
main()
|
helper.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
+
from transformers import pipeline
|
4 |
+
import torch
|
5 |
+
import textwrap
|
6 |
+
from PyPDF2 import PdfReader
|
7 |
+
from typing_extensions import Concatenate
|
8 |
+
from langchain.text_splitter import CharacterTextSplitter
|
9 |
+
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
10 |
+
from langchain.vectorstores import Chroma
|
11 |
+
from langchain.llms import HuggingFacePipeline
|
12 |
+
from langchain.chains import RetrievalQA
|
13 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
14 |
+
from langchain import PromptTemplate
|
15 |
+
from dotenv import load_dotenv
|
16 |
+
|
17 |
+
load_dotenv() # take environment variables from .env
|
18 |
+
os.environ["LANGCHAIN_API_KEY"] = str(os.getenv("LANGCHAIN_API_KEY"))
|
19 |
+
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
20 |
+
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
|
21 |
+
os.environ["LANGCHAIN_PROJECT"] = "2.pdf_chat_router_issue_assistant"
|
22 |
+
|
23 |
+
from dotenv import load_dotenv
|
24 |
+
load_dotenv() # take environment variables from .env (especially openai api key)
|
25 |
+
|
26 |
+
# Create LLM model
|
27 |
+
# model = "C:/Users/arasu/Workspace/Projects/GenAI/models/MBZUAILaMini-Flan-T5-248M/"
|
28 |
+
model = "MBZUAI/LaMini-Flan-T5-248M"
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(model,truncation=True)
|
30 |
+
base_model = AutoModelForSeq2SeqLM.from_pretrained(model)
|
31 |
+
pipe = pipeline(
|
32 |
+
'text2text-generation',
|
33 |
+
model = base_model,
|
34 |
+
tokenizer = tokenizer,
|
35 |
+
max_length = 256,
|
36 |
+
do_sample = True,
|
37 |
+
temperature = 0.3,
|
38 |
+
top_p= 0.95
|
39 |
+
)
|
40 |
+
llm = HuggingFacePipeline(pipeline=pipe)
|
41 |
+
|
42 |
+
# # Initialize instructor embeddings using the Hugging Face model
|
43 |
+
# instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="C:/Users/arasu/Workspace/Projects/GenAI/embeddings/hkunlp_instructor-large")
|
44 |
+
instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")
|
45 |
+
db_path = "vector_db"
|
46 |
+
|
47 |
+
def create_vector_db():
|
48 |
+
# Load data from pdf
|
49 |
+
raw_text = ""
|
50 |
+
text_splitter = CharacterTextSplitter(
|
51 |
+
separator = "\n",
|
52 |
+
chunk_size = 500,
|
53 |
+
chunk_overlap = 100,
|
54 |
+
length_function = len,
|
55 |
+
)
|
56 |
+
for root, dirs, files in os.walk("docs"):
|
57 |
+
for file in files:
|
58 |
+
if file.endswith(".pdf"):
|
59 |
+
pdf = PdfReader("./docs/"+file)
|
60 |
+
for i, page in enumerate(pdf.pages):
|
61 |
+
content = page.extract_text()
|
62 |
+
if content:
|
63 |
+
raw_text += content
|
64 |
+
texts = text_splitter.split_text(raw_text)
|
65 |
+
|
66 |
+
# Create a vector database from 'text'
|
67 |
+
vector_db = Chroma.from_texts(texts,instructor_embeddings,persist_directory=db_path)
|
68 |
+
vector_db.persist()
|
69 |
+
vector_db = None
|
70 |
+
|
71 |
+
def get_qa_chain():
|
72 |
+
# Load the vector database from the local folder
|
73 |
+
vector_db = Chroma(persist_directory=db_path, embedding_function = instructor_embeddings)
|
74 |
+
|
75 |
+
# Create a retriever for querying the vector database
|
76 |
+
retriever = vector_db.as_retriever(search_kwargs={"k":3})
|
77 |
+
|
78 |
+
template = """
|
79 |
+
You are friendly customer care assistant trying to help user on the context provided.\
|
80 |
+
if the question contains greetings then greet the user back. be friendly.\
|
81 |
+
if the answer is not found in the context then reply "No Evidence Found".\
|
82 |
+
context: {context}
|
83 |
+
question: {question}
|
84 |
+
"""
|
85 |
+
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
|
86 |
+
|
87 |
+
chain_type_kwargs = {"prompt": prompt}
|
88 |
+
qa = RetrievalQA.from_chain_type(
|
89 |
+
llm = llm,
|
90 |
+
chain_type = "stuff",
|
91 |
+
retriever = retriever,
|
92 |
+
input_key="query",
|
93 |
+
return_source_documents=True,
|
94 |
+
chain_type_kwargs=chain_type_kwargs
|
95 |
+
)
|
96 |
+
return qa
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
transformers
|
3 |
+
langchain
|
4 |
+
huggingface_hub
|
5 |
+
sentence-transformers==2.2.2
|
6 |
+
InstructorEmbedding
|
7 |
+
chromadb
|
8 |
+
PyPDF2
|
9 |
+
streamlit
|
10 |
+
streamlit_chat
|