rahgadda commited on
Commit
eb9afcc
1 Parent(s): aad602c

Initial Draft

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tempfile
3
+ import os
4
+ import re
5
+ import torch
6
+ from threading import Thread
7
+
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextIteratorStreamer
9
+ from langchain_community.document_loaders import PyPDFLoader
10
+ from langchain.vectorstores.faiss import FAISS
11
+ from langchain.chains import ConversationalRetrievalChain
12
+ from langchain_community.embeddings import HuggingFaceEmbeddings
13
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
14
+
15
+ # Function return langchain document object of PDF pages
16
+ def fn_read_pdf(lv_temp_file_path, mv_processing_message):
17
+ """Returns langchain document object of PDF pages"""
18
+
19
+ lv_pdf_loader = PyPDFLoader(lv_temp_file_path)
20
+ lv_pdf_content = lv_pdf_loader.load()
21
+ print("Step2: PDF content extracted")
22
+ mv_processing_message.text("Step2: PDF content extracted")
23
+
24
+ return lv_pdf_content
25
+
26
+ # Function return FAISS Vector store
27
+ def fn_create_faiss_vector_store(lv_pdf_content, mv_processing_message):
28
+ """Returns FAISS vector store index of PDF Content"""
29
+
30
+ lv_embeddings = HuggingFaceEmbeddings(
31
+ model_name="sentence-transformers/msmarco-distilbert-base-v4",
32
+ model_kwargs={'device': 'cpu'},
33
+ encode_kwargs={'normalize_embeddings': False}
34
+ )
35
+ lv_vector_store = FAISS.from_documents(lv_pdf_content, lv_embeddings)
36
+ print("Step3: Vector store created")
37
+ mv_processing_message.text("Step3: Vector store created")
38
+
39
+ return lv_vector_store
40
+
41
+ # Function return QA Response using Vector Store
42
+ def fn_generate_QnA_response(mv_selected_model, mv_user_question, lv_vector_store, mv_processing_message):
43
+ """Returns QA Response using Vector Store"""
44
+
45
+ lv_chat_history = []
46
+
47
+ if "chat_history" not in st.session_state:
48
+ st.session_state.chat_history = []
49
+ else:
50
+ lv_chat_history = st.session_state.chat_history
51
+
52
+ print("Step4: Generating LLM response")
53
+ mv_processing_message.text("Step4: Generating LLM response")
54
+
55
+ lv_tokenizer = AutoTokenizer.from_pretrained(mv_selected_model, trust_remote_code=True)
56
+ lv_model = AutoModelForCausalLM.from_pretrained(
57
+ mv_selected_model,
58
+ torch_dtype="auto",
59
+ device_map="cpu",
60
+ trust_remote_code=True
61
+ )
62
+ # lv_streamer = TextIteratorStreamer(
63
+ # tokenizer=lv_tokenizer,
64
+ # skip_prompt=True,
65
+ # skip_special_tokens=True,
66
+ # timeout=300.0
67
+ # )
68
+ lv_ms_phi2_pipeline = pipeline(
69
+ "text-generation", tokenizer=lv_tokenizer, model=lv_model,
70
+ device_map="cpu", max_new_tokens=4000, return_full_text=True
71
+ )
72
+ lv_hf_phi2_pipeline = HuggingFacePipeline(pipeline=lv_ms_phi2_pipeline)
73
+ lv_chain = ConversationalRetrievalChain.from_llm(lv_hf_phi2_pipeline, lv_vector_store.as_retriever(), return_source_documents=True)
74
+ lv_response = lv_chain({"question": mv_user_question, 'chat_history': lv_chat_history})
75
+
76
+ lv_chat_history += [(mv_user_question, lv_response["answer"])]
77
+ st.session_state.chat_history = lv_chat_history
78
+
79
+ print("Step5: LLM response generated")
80
+ mv_processing_message.text("Step5: LLM response generated")
81
+
82
+ return lv_response['answer']
83
+
84
+
85
+ # Main Function
86
+ def main():
87
+
88
+ # -- Streamlit Settings
89
+ st.set_page_config(layout='wide')
90
+
91
+ # -- Initialize chat history
92
+ if "messages" not in st.session_state:
93
+ st.session_state.messages = []
94
+
95
+ col1, col2, col3 = st.columns(3)
96
+ col2.title("Chat with your PDF")
97
+ st.text("")
98
+
99
+ col1, col2, col3 = st.columns(3)
100
+ mv_selected_model=col3.selectbox('Select Model',['microsoft/phi-2'])
101
+ st.text("")
102
+ st.text("")
103
+ st.text("")
104
+ col1, col2, col3 = st.columns(3)
105
+
106
+ # -- Reading PDF File
107
+ mv_pdf_input_file = col2.file_uploader("Choose a PDF file:", type=["pdf"])
108
+
109
+ if 'mv_temp_file_storage_dir' not in st.session_state:
110
+ mv_temp_file_storage_dir = tempfile.mkdtemp()
111
+ st.session_state.mv_temp_file_storage_dir = mv_temp_file_storage_dir
112
+ else:
113
+ mv_temp_file_storage_dir = st.session_state.mv_temp_file_storage_dir
114
+
115
+ mv_processing_message = col2.empty()
116
+ st.text("")
117
+ st.text("")
118
+ st.text("")
119
+ st.text("")
120
+ st.text("")
121
+ st.text("")
122
+
123
+ mv_vector_storage_dir = "/workspace/knowledge-base/01-ML/01-dev/adhoc/Talk2PDF/vector_store"
124
+
125
+ if (mv_pdf_input_file is not None):
126
+ mv_file_name = mv_pdf_input_file.name
127
+ # mv_vectorstore_file_name = os.path.join(mv_vector_storage_dir, mv_file_name[:-4] + ".vectorstore")
128
+ # mv_metadata_file_name = os.path.join(mv_vector_storage_dir, mv_file_name[:-4] + ".metadata")
129
+
130
+ if 'lv_vector_store' not in st.session_state:
131
+ # -- Storing Uploaded PDF locally
132
+ lv_temp_file_path = os.path.join(mv_temp_file_storage_dir,mv_file_name)
133
+ with open(lv_temp_file_path,"wb") as lv_file:
134
+ lv_file.write(mv_pdf_input_file.getbuffer())
135
+ print("Step1: PDF uploaded successfully at -> " + lv_temp_file_path)
136
+ mv_processing_message.text("Step1: PDF uploaded successfully at -> " + lv_temp_file_path)
137
+
138
+ # -- Extracting PDF Text
139
+ lv_pdf_content = fn_read_pdf(lv_temp_file_path, mv_processing_message)
140
+
141
+ # -- Creating FAISS Vector Store
142
+ lv_vector_store = fn_create_faiss_vector_store(lv_pdf_content, mv_processing_message)
143
+ st.session_state.lv_vector_store = lv_vector_store
144
+ else:
145
+ lv_vector_store = st.session_state.lv_vector_store
146
+
147
+ # -- Taking input question and generate answer
148
+ col1, col2, col3 = st.columns(3)
149
+ lv_chat_history = col2.chat_message
150
+
151
+ if mv_user_question := col2.chat_input("Chat on PDF Data"):
152
+ # -- Add user message to chat history
153
+ st.session_state.messages.append({"role": "user", "content": mv_user_question})
154
+
155
+ # -- Generating LLM response
156
+ lv_response = fn_generate_QnA_response(mv_selected_model, mv_user_question, lv_vector_store, mv_processing_message)
157
+
158
+ # -- Adding assistant response to chat history
159
+ st.session_state.messages.append({"role": "assistant", "content": lv_response})
160
+
161
+ # -- Display chat messages from history on app rerun
162
+ for message in st.session_state.messages:
163
+ with lv_chat_history(message["role"]):
164
+ st.markdown(message["content"])
165
+
166
+
167
+
168
+ # Calling Main Function
169
+ if __name__ == '__main__':
170
+ main()