vinhnx90 commited on
Commit
48881b2
β€’
1 Parent(s): 7a185b4
Files changed (3) hide show
  1. app.py +74 -102
  2. stream_handler.py +37 -0
  3. token_stream_handler.py +0 -13
app.py CHANGED
@@ -2,52 +2,48 @@ import os
2
  import tempfile
3
 
4
  import streamlit as st
5
- from chat_profile import ChatProfileRoleEnum
6
 
7
  from langchain.callbacks.base import BaseCallbackHandler
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.chat_models import ChatOpenAI
10
- from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
11
  from langchain.embeddings import HuggingFaceEmbeddings
12
  from langchain.memory import ConversationBufferMemory
13
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
15
  from langchain_community.vectorstores import DocArrayInMemorySearch
16
- from streamlit_extras.add_vertical_space import add_vertical_space
17
 
18
- # TODO: refactor
19
- # TODO: extract class
20
- # TODO: modularize
21
- # TODO: hide side bar
22
- # TODO: make the page attactive
23
 
24
- # configs
25
  LLM_MODEL_NAME = "gpt-3.5-turbo"
26
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
27
 
28
- st.set_page_config(
29
- page_title=":books: InkChatGPT: Chat with Documents",
30
- page_icon="πŸ“š",
31
- initial_sidebar_state="collapsed",
32
- menu_items={
33
- "Get Help": "https://x.com/vinhnx",
34
- "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
35
- "About": "InkChatGPT is a Streamlit application that allows users to upload PDF documents and engage in a conversational Q&A with a language model (LLM) based on the content of those documents.",
36
- },
37
- )
38
-
39
- st.image("./assets/icon.jpg", width=100)
40
- st.header(
41
- ":gray[:books: InkChatGPT]",
42
- divider="blue",
43
- )
44
- st.write("**Chat** with Documents")
 
45
 
46
  # Setup memory for contextual conversation
47
  msgs = StreamlitChatMessageHistory()
48
 
49
 
50
- @st.cache_resource(ttl="1h")
51
  def configure_retriever(uploaded_files):
52
  # Read documents
53
  docs = []
@@ -70,7 +66,6 @@ def configure_retriever(uploaded_files):
70
  st.write("This document format is not supported!")
71
  return None
72
 
73
- # loader = PyPDFLoader(temp_filepath)
74
  docs.extend(loader.load())
75
 
76
  # Split documents
@@ -89,91 +84,68 @@ def configure_retriever(uploaded_files):
89
  return retriever
90
 
91
 
92
- class StreamHandler(BaseCallbackHandler):
93
- def __init__(
94
- self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""
95
- ):
96
- self.container = container
97
- self.text = initial_text
98
- self.run_id_ignore_token = None
99
-
100
- def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
101
- # Workaround to prevent showing the rephrased question as output
102
- if prompts[0].startswith("Human"):
103
- self.run_id_ignore_token = kwargs.get("run_id")
104
 
105
- def on_llm_new_token(self, token: str, **kwargs) -> None:
106
- if self.run_id_ignore_token == kwargs.get("run_id", False):
107
- return
108
- self.text += token
109
- self.container.markdown(self.text)
 
 
110
 
 
 
 
111
 
112
- class PrintRetrievalHandler(BaseCallbackHandler):
113
- def __init__(self, container):
114
- self.status = container.status("**Thinking...**")
115
- self.container = container
116
 
117
- def on_retriever_start(self, serialized: dict, query: str, **kwargs):
118
- self.status.write(f"**Checking document for query:** `{query}`. Please wait...")
 
119
 
120
- def on_retriever_end(self, documents, **kwargs):
121
- self.container.empty()
122
 
 
 
 
123
 
124
- with st.sidebar.expander("Documents"):
125
- st.subheader("Files")
126
- uploaded_files = st.file_uploader(
127
- label="Select files", type=["pdf", "txt", "docx"], accept_multiple_files=True
128
- )
129
-
130
-
131
- with st.sidebar.expander("Setup"):
132
- st.subheader("API Key")
133
- openai_api_key = st.text_input("OpenAI API Key", type="password")
134
 
135
- is_empty_chat_messages = len(msgs.messages) == 0
136
- if is_empty_chat_messages or st.button("Clear message history"):
137
- msgs.clear()
138
- msgs.add_ai_message("How can I help you?")
139
 
140
- if not openai_api_key:
141
- st.info("Please add your OpenAI API key in the sidebar to continue.")
142
- st.stop()
 
143
 
144
- if uploaded_files:
145
- retriever = configure_retriever(uploaded_files)
146
-
147
- memory = ConversationBufferMemory(
148
- memory_key="chat_history", chat_memory=msgs, return_messages=True
149
- )
150
-
151
- # Setup LLM and QA chain
152
- llm = ChatOpenAI(
153
- model_name=LLM_MODEL_NAME,
154
- openai_api_key=openai_api_key,
155
- temperature=0,
156
- streaming=True,
157
- )
158
-
159
- chain = ConversationalRetrievalChain.from_llm(
160
- llm, retriever=retriever, memory=memory, verbose=False
161
- )
162
 
163
- avatars = {
164
- ChatProfileRoleEnum.Human: "user",
165
- ChatProfileRoleEnum.AI: "assistant",
166
- }
167
 
168
- for msg in msgs.messages:
169
- st.chat_message(avatars[msg.type]).write(msg.content)
 
 
 
 
170
 
171
- if user_query := st.chat_input(placeholder="Ask me anything!"):
172
- st.chat_message("user").write(user_query)
173
 
174
- with st.chat_message("assistant"):
175
- retrieval_handler = PrintRetrievalHandler(st.empty())
176
- stream_handler = StreamHandler(st.empty())
177
- response = chain.run(
178
- user_query, callbacks=[retrieval_handler, stream_handler]
179
- )
 
2
  import tempfile
3
 
4
  import streamlit as st
 
5
 
6
  from langchain.callbacks.base import BaseCallbackHandler
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.chat_models import ChatOpenAI
 
9
  from langchain.embeddings import HuggingFaceEmbeddings
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
14
  from langchain_community.vectorstores import DocArrayInMemorySearch
 
15
 
16
+ from chat_profile import ChatProfileRoleEnum
17
+ from stream_handler import PrintRetrievalHandler, StreamHandler
 
 
 
18
 
19
+ # Configuration
20
  LLM_MODEL_NAME = "gpt-3.5-turbo"
21
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
22
 
23
+
24
+ # Set up Streamlit app
25
+ def setup_streamlit_app():
26
+ st.set_page_config(
27
+ page_title=":books: InkChatGPT: Chat with Documents",
28
+ page_icon="πŸ“š",
29
+ initial_sidebar_state="collapsed",
30
+ menu_items={
31
+ "Get Help": "https://x.com/vinhnx",
32
+ "Report a bug": "https://github.com/vinhnx/InkChatGPT/issues",
33
+ "About": "InkChatGPT is a Streamlit application that allows users to upload PDF documents and engage in a conversational Q&A with a language model (LLM) based on the content of those documents.",
34
+ },
35
+ )
36
+
37
+ st.image("./assets/icon.jpg", width=100)
38
+ st.header(":gray[:books: InkChatGPT]", divider="blue")
39
+ st.write("**Chat** with Documents")
40
+
41
 
42
  # Setup memory for contextual conversation
43
  msgs = StreamlitChatMessageHistory()
44
 
45
 
46
+ # Load and process documents
47
  def configure_retriever(uploaded_files):
48
  # Read documents
49
  docs = []
 
66
  st.write("This document format is not supported!")
67
  return None
68
 
 
69
  docs.extend(loader.load())
70
 
71
  # Split documents
 
84
  return retriever
85
 
86
 
87
+ # Main function
88
+ def main():
89
+ setup_streamlit_app()
 
 
 
 
 
 
 
 
 
90
 
91
+ with st.sidebar.expander("Documents"):
92
+ st.subheader("Files")
93
+ uploaded_files = st.file_uploader(
94
+ label="Select files",
95
+ type=["pdf", "txt", "docx"],
96
+ accept_multiple_files=True,
97
+ )
98
 
99
+ with st.sidebar.expander("Setup"):
100
+ st.subheader("API Key")
101
+ openai_api_key = st.text_input("OpenAI API Key", type="password")
102
 
103
+ is_empty_chat_messages = len(msgs.messages) == 0
104
+ if is_empty_chat_messages or st.button("Clear message history"):
105
+ msgs.clear()
106
+ msgs.add_ai_message("How can I help you?")
107
 
108
+ if not openai_api_key:
109
+ st.info("Please add your OpenAI API key in the sidebar to continue.")
110
+ st.stop()
111
 
112
+ if uploaded_files:
113
+ retriever = configure_retriever(uploaded_files)
114
 
115
+ memory = ConversationBufferMemory(
116
+ memory_key="chat_history", chat_memory=msgs, return_messages=True
117
+ )
118
 
119
+ # Setup LLM and QA chain
120
+ llm = ChatOpenAI(
121
+ model_name=LLM_MODEL_NAME,
122
+ openai_api_key=openai_api_key,
123
+ temperature=0,
124
+ streaming=True,
125
+ )
 
 
 
126
 
127
+ chain = ConversationalRetrievalChain.from_llm(
128
+ llm, retriever=retriever, memory=memory, verbose=False
129
+ )
 
130
 
131
+ avatars = {
132
+ ChatProfileRoleEnum.Human: "user",
133
+ ChatProfileRoleEnum.AI: "assistant",
134
+ }
135
 
136
+ for msg in msgs.messages:
137
+ st.chat_message(avatars[msg.type]).write(msg.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ if user_query := st.chat_input(placeholder="Ask me anything!"):
140
+ st.chat_message("user").write(user_query)
 
 
141
 
142
+ with st.chat_message("assistant"):
143
+ retrieval_handler = PrintRetrievalHandler(st.empty())
144
+ stream_handler = StreamHandler(st.empty())
145
+ response = chain.run(
146
+ user_query, callbacks=[retrieval_handler, stream_handler]
147
+ )
148
 
 
 
149
 
150
+ if __name__ == "__main__":
151
+ main()
 
 
 
 
stream_handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+
4
+ from langchain.callbacks.base import BaseCallbackHandler
5
+
6
+
7
+ # Callback handlers
8
+ class StreamHandler(BaseCallbackHandler):
9
+ def __init__(
10
+ self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""
11
+ ):
12
+ self.container = container
13
+ self.text = initial_text
14
+ self.run_id_ignore_token = None
15
+
16
+ def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
17
+ # Workaround to prevent showing the rephrased question as output
18
+ if prompts[0].startswith("Human"):
19
+ self.run_id_ignore_token = kwargs.get("run_id")
20
+
21
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
22
+ if self.run_id_ignore_token == kwargs.get("run_id", False):
23
+ return
24
+ self.text += token
25
+ self.container.markdown(self.text)
26
+
27
+
28
+ class PrintRetrievalHandler(BaseCallbackHandler):
29
+ def __init__(self, container):
30
+ self.status = container.status("**Thinking...**")
31
+ self.container = container
32
+
33
+ def on_retriever_start(self, serialized: dict, query: str, **kwargs):
34
+ self.status.write(f"**Checking document for query:** `{query}`. Please wait...")
35
+
36
+ def on_retriever_end(self, documents, **kwargs):
37
+ self.container.empty()
token_stream_handler.py DELETED
@@ -1,13 +0,0 @@
1
- import os
2
-
3
- from langchain.callbacks.base import BaseCallbackHandler
4
-
5
-
6
- class StreamHandler(BaseCallbackHandler):
7
- def __init__(self, container, initial_text=""):
8
- self.container = container
9
- self.text = initial_text
10
-
11
- def on_llm_new_token(self, token: str, **kwargs) -> None:
12
- self.text += token
13
- self.container.markdown(self.text)