scholarly360 commited on
Commit
9f104e9
1 Parent(s): f2bac35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Question Answering with Retrieval QA and LangChain Language Models featuring FAISS vector stores.
3
+ This script uses the LangChain Language Model API to answer questions using Retrieval QA
4
+ and FAISS vector stores. It also uses the Mistral huggingface inference endpoint to
5
+ generate responses.
6
+ """
7
+
8
+ import os
9
+ import streamlit as st
10
+ from dotenv import load_dotenv
11
+ from PyPDF2 import PdfReader
12
+ from langchain.text_splitter import CharacterTextSplitter
13
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.chat_models import ChatOpenAI
16
+ from langchain.memory import ConversationBufferMemory
17
+ from langchain.chains import ConversationalRetrievalChain
18
+ from htmlTemplates import css, bot_template, user_template
19
+ from langchain.llms import HuggingFaceHub
20
+
21
+
22
+ def get_pdf_text(pdf_docs):
23
+ """
24
+ Extract text from a list of PDF documents.
25
+ Parameters
26
+ ----------
27
+ pdf_docs : list
28
+ List of PDF documents to extract text from.
29
+ Returns
30
+ -------
31
+ str
32
+ Extracted text from all the PDF documents.
33
+ """
34
+ text = ""
35
+ for pdf in pdf_docs:
36
+ pdf_reader = PdfReader(pdf)
37
+ for page in pdf_reader.pages:
38
+ text += page.extract_text()
39
+ return text
40
+
41
+
42
+ def get_text_chunks(text):
43
+ """
44
+ Split the input text into chunks.
45
+ Parameters
46
+ ----------
47
+ text : str
48
+ The input text to be split.
49
+ Returns
50
+ -------
51
+ list
52
+ List of text chunks.
53
+ """
54
+ text_splitter = CharacterTextSplitter(
55
+ separator="\n", chunk_size=1500, chunk_overlap=300, length_function=len
56
+ )
57
+ chunks = text_splitter.split_text(text)
58
+ return chunks
59
+
60
+
61
+ def get_vectorstore(text_chunks):
62
+ """
63
+ Generate a vector store from a list of text chunks using HuggingFace BgeEmbeddings.
64
+ Parameters
65
+ ----------
66
+ text_chunks : list
67
+ List of text chunks to be embedded.
68
+ Returns
69
+ -------
70
+ FAISS
71
+ A FAISS vector store containing the embeddings of the text chunks.
72
+ """
73
+ model = "BAAI/bge-base-en-v1.5"
74
+ encode_kwargs = {
75
+ "normalize_embeddings": True
76
+ } # set True to compute cosine similarity
77
+ embeddings = HuggingFaceBgeEmbeddings(
78
+ model_name=model, encode_kwargs=encode_kwargs, model_kwargs={"device": "cpu"}
79
+ )
80
+ vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
81
+ return vectorstore
82
+
83
+
84
+ def get_conversation_chain(vectorstore):
85
+ """
86
+ Create a conversational retrieval chain using a vector store and a language model.
87
+ Parameters
88
+ ----------
89
+ vectorstore : FAISS
90
+ A FAISS vector store containing the embeddings of the text chunks.
91
+ Returns
92
+ -------
93
+ ConversationalRetrievalChain
94
+ A conversational retrieval chain for generating responses.
95
+ """
96
+ llm = HuggingFaceHub(
97
+ repo_id="mistralai/Mistral-7B-Instruct-v0.1",
98
+ model_kwargs={"temperature": 0.005, "max_length": 512},
99
+ )
100
+ # llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
101
+
102
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
103
+ conversation_chain = ConversationalRetrievalChain.from_llm(
104
+ llm=llm, retriever=vectorstore.as_retriever(), memory=memory
105
+ )
106
+ return conversation_chain
107
+
108
+
109
+ def handle_userinput(user_question):
110
+ """
111
+ Handle user input and generate a response using the conversational retrieval chain.
112
+ Parameters
113
+ ----------
114
+ user_question : str
115
+ The user's question.
116
+ """
117
+ response = st.session_state.conversation({"question": user_question})
118
+ st.session_state.chat_history = response["chat_history"]
119
+
120
+ for i, message in enumerate(st.session_state.chat_history):
121
+ if i % 2 == 0:
122
+ st.write(
123
+ user_template.replace("{{MSG}}", message.content),
124
+ unsafe_allow_html=True,
125
+ )
126
+ else:
127
+ st.write(
128
+ bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True
129
+ )
130
+
131
+
132
+ def main():
133
+ """
134
+ Putting it all together.
135
+ """
136
+ st.set_page_config(
137
+ page_title="Chat with a Bot that tries to answer questions about multiple PDFs",
138
+ page_icon=":books:",
139
+ )
140
+
141
+ st.markdown("# Chat with Contracts Bot")
142
+ st.markdown("This bot tries to answer questions about multiple PDFs using Open Source Mistral 7B")
143
+
144
+ st.write(css, unsafe_allow_html=True)
145
+
146
+ # set huggingface hub token in st.text_input widget
147
+ # then hide the input
148
+ huggingface_token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
149
+ #openai_api_key = st.text_input("Enter your OpenAI API key", type="password")
150
+
151
+
152
+ if "conversation" not in st.session_state:
153
+ st.session_state.conversation = None
154
+ if "chat_history" not in st.session_state:
155
+ st.session_state.chat_history = None
156
+
157
+ st.header("Chat with a Bot 🤖🦾 that tries to answer questions about multiple PDFs :books:")
158
+ user_question = st.text_input("Ask a question about your contracts:")
159
+ if user_question:
160
+ handle_userinput(user_question)
161
+
162
+ with st.sidebar:
163
+ st.subheader("Your Contracts")
164
+ pdf_docs = st.file_uploader(
165
+ "Upload your PDFs here and click on 'Index'", accept_multiple_files=True
166
+ )
167
+ if st.button("Index"):
168
+ with st.spinner("Processing"):
169
+ # get pdf text
170
+ raw_text = get_pdf_text(pdf_docs)
171
+
172
+ # get the text chunks
173
+ text_chunks = get_text_chunks(raw_text)
174
+
175
+ # create vector store
176
+ vectorstore = get_vectorstore(text_chunks)
177
+
178
+ # create conversation chain
179
+ st.session_state.conversation = get_conversation_chain(vectorstore)
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()