Ritesh-hf commited on
Commit
4e322c2
1 Parent(s): 256e3d5

update app.py

Browse files
Files changed (2) hide show
  1. app.py +81 -69
  2. static/script.js +50 -37
app.py CHANGED
@@ -1,65 +1,70 @@
1
  import os
2
  from dotenv import load_dotenv
3
- load_dotenv(".env")
4
-
5
- os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
6
- os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
7
- os.environ["TOKENIZERS_PARALLELISM"]='true'
8
-
9
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
10
  from langchain.chains.combine_documents import create_stuff_documents_chain
11
  from langchain_community.chat_message_histories import ChatMessageHistory
12
- from langchain_community.document_loaders import WebBaseLoader
13
  from langchain_core.chat_history import BaseChatMessageHistory
14
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
  from langchain_core.runnables.history import RunnableWithMessageHistory
16
-
17
  from pinecone import Pinecone
18
  from pinecone_text.sparse import BM25Encoder
19
-
20
  from langchain_huggingface import HuggingFaceEmbeddings
21
  from langchain_community.retrievers import PineconeHybridSearchRetriever
22
-
23
  from langchain_groq import ChatGroq
24
 
25
- from flask import Flask, request, render_template
26
- from flask_cors import CORS
27
- from flask_socketio import SocketIO, emit
28
-
 
 
 
 
 
 
 
 
 
 
29
  app = Flask(__name__)
30
  CORS(app)
31
  socketio = SocketIO(app, cors_allowed_origins="*")
32
  app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
33
  app.config['SESSION_COOKIE_HTTPONLY'] = True
34
  app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
35
- app.config['SECRET_KEY'] = os.getenv('SECRET_KEY')
36
-
37
- try:
38
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
39
- index_name = "traveler-demo-website-vectorstore"
40
- # connect to index
41
- pinecone_index = pc.Index(index_name)
42
- except:
43
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
44
- index_name = "traveler-demo-website-vectorstore"
45
- # connect to index
46
- pinecone_index = pc.Index(index_name)
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  bm25 = BM25Encoder().load("./bm25_traveler_website.json")
49
 
 
50
  embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})
51
-
52
  retriever = PineconeHybridSearchRetriever(
53
  embeddings=embed_model,
54
  sparse_encoder=bm25,
55
  index=pinecone_index,
56
  top_k=20,
57
- alpha=0.5,
58
  )
59
 
60
- llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2)
 
61
 
62
- ### Contextualize question ###
63
  contextualize_q_system_prompt = """Given a chat history and the latest user question \
64
  which might reference context in the chat history, formulate a standalone question \
65
  which can be understood without the chat history. Do NOT answer the question, \
@@ -72,34 +77,32 @@ contextualize_q_prompt = ChatPromptTemplate.from_messages(
72
  ("human", "{input}")
73
  ]
74
  )
 
75
 
76
- history_aware_retriever = create_history_aware_retriever(
77
- llm, retriever, contextualize_q_prompt
78
- )
79
-
80
 
81
- qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \
82
- Provide links to sources provided in the answer. \
83
- If you don't know the answer, just say that you don't know. \
84
- Do not give extra long answers. \
85
- When responding to queries, your responses should be comprehensive and well-organized. For each response: \
86
 
87
- 1. Provide Clear Answers \
 
88
 
89
  2. Include Detailed References: \
90
- - Include links to sources and any links or sites where there is a mentioned in the answer.
91
- - Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \
92
- - Downloadable Materials: Include links to any relevant downloadable resources if applicable. \
93
  - Reference Sites: Mention specific websites or platforms that offer additional information. \
94
-
 
95
  3. Formatting for Readability: \
96
- - Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \
97
- - Emphasize Important Information: Use bold or italics to highlight key details. \
98
-
99
- 4. Organize Content Logically \
100
-
101
- Do not include anything about context in the answer. \
102
-
 
103
  {context}
104
  """
105
  qa_prompt = ChatPromptTemplate.from_messages(
@@ -111,20 +114,21 @@ qa_prompt = ChatPromptTemplate.from_messages(
111
  )
112
  question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
113
 
 
114
  rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
115
 
116
- ### Statefully manage chat history ###
117
  store = {}
118
 
119
  def clean_temporary_data():
120
- store = {}
121
 
122
  def get_session_history(session_id: str) -> BaseChatMessageHistory:
123
  if session_id not in store:
124
  store[session_id] = ChatMessageHistory()
125
  return store[session_id]
126
 
127
-
128
  conversational_rag_chain = RunnableWithMessageHistory(
129
  rag_chain,
130
  get_session_history,
@@ -133,33 +137,41 @@ conversational_rag_chain = RunnableWithMessageHistory(
133
  output_messages_key="answer",
134
  )
135
 
136
- # Stream response to client
 
 
 
 
 
 
 
 
 
 
 
 
137
  @socketio.on('message')
138
  def handle_message(data):
139
  question = data.get('question')
140
- session_id = data.get('session_id', 'abc123')
141
  chain = conversational_rag_chain.pick("answer")
142
-
143
  try:
144
  for chunk in chain.stream(
145
  {"input": question},
146
- config={
147
- "configurable": {"session_id": "abc123"}
148
- },
149
  ):
150
- emit('response', chunk, room=request.sid)
151
- except:
152
- for chunk in chain.stream(
153
- {"input": question},
154
- config={
155
- "configurable": {"session_id": "abc123"}
156
- },
157
- ):
158
- emit('response', chunk, room=request.sid)
159
 
 
160
  @app.route("/")
161
  def index_view():
162
  return render_template('chat.html')
163
 
 
164
  if __name__ == '__main__':
165
  socketio.run(app, debug=True)
 
1
  import os
2
  from dotenv import load_dotenv
3
+ import asyncio
4
+ from flask import Flask, request, render_template
5
+ from flask_cors import CORS
6
+ from flask_socketio import SocketIO, emit, join_room, leave_room
 
 
7
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
8
  from langchain.chains.combine_documents import create_stuff_documents_chain
9
  from langchain_community.chat_message_histories import ChatMessageHistory
 
10
  from langchain_core.chat_history import BaseChatMessageHistory
11
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
12
  from langchain_core.runnables.history import RunnableWithMessageHistory
 
13
  from pinecone import Pinecone
14
  from pinecone_text.sparse import BM25Encoder
 
15
  from langchain_huggingface import HuggingFaceEmbeddings
16
  from langchain_community.retrievers import PineconeHybridSearchRetriever
 
17
  from langchain_groq import ChatGroq
18
 
19
+ # Load environment variables
20
+ load_dotenv(".env")
21
+ USER_AGENT = os.getenv("USER_AGENT")
22
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
23
+ SECRET_KEY = os.getenv("SECRET_KEY")
24
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
25
+ SESSION_ID_DEFAULT = "abc123"
26
+
27
+ # Set environment variables
28
+ os.environ['USER_AGENT'] = USER_AGENT
29
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
30
+ os.environ["TOKENIZERS_PARALLELISM"] = 'true'
31
+
32
+ # Initialize Flask app and SocketIO with CORS
33
  app = Flask(__name__)
34
  CORS(app)
35
  socketio = SocketIO(app, cors_allowed_origins="*")
36
  app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
37
  app.config['SESSION_COOKIE_HTTPONLY'] = True
38
  app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
39
+ app.config['SECRET_KEY'] = SECRET_KEY
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Function to initialize Pinecone connection
42
+ def initialize_pinecone(index_name: str):
43
+ try:
44
+ pc = Pinecone(api_key=PINECONE_API_KEY)
45
+ return pc.Index(index_name)
46
+ except Exception as e:
47
+ print(f"Error initializing Pinecone: {e}")
48
+ raise
49
+
50
+ # Initialize Pinecone index and BM25 encoder
51
+ pinecone_index = initialize_pinecone("traveler-demo-website-vectorstore")
52
  bm25 = BM25Encoder().load("./bm25_traveler_website.json")
53
 
54
+ # Initialize models and retriever
55
  embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})
 
56
  retriever = PineconeHybridSearchRetriever(
57
  embeddings=embed_model,
58
  sparse_encoder=bm25,
59
  index=pinecone_index,
60
  top_k=20,
61
+ alpha=0.5
62
  )
63
 
64
+ # Initialize LLM
65
+ llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2)
66
 
67
+ # Contextualization prompt and retriever
68
  contextualize_q_system_prompt = """Given a chat history and the latest user question \
69
  which might reference context in the chat history, formulate a standalone question \
70
  which can be understood without the chat history. Do NOT answer the question, \
 
77
  ("human", "{input}")
78
  ]
79
  )
80
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
81
 
82
+ # QA system prompt and chain
83
+ qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following context to answer questions effectively. \
84
+ If you don't know the answer, simply state that you don't know. \
85
+ Provide answers in proper HTML format and keep them concise. \
86
 
87
+ When responding to queries, follow these guidelines: \
 
 
 
 
88
 
89
+ 1. Provide Clear Answers: \
90
+ - Ensure the response directly addresses the query with accurate and relevant information.\
91
 
92
  2. Include Detailed References: \
93
+ - Links to Sources: Include URLs to credible sources where users can verify information or explore further. \
 
 
94
  - Reference Sites: Mention specific websites or platforms that offer additional information. \
95
+ - Downloadable Materials: Provide links to any relevant downloadable resources if applicable. \
96
+
97
  3. Formatting for Readability: \
98
+ - The answer should be in a proper HTML format with appropriate tags. \
99
+ - Use bullet points or numbered lists where applicable to present information clearly. \
100
+ - Highlight key details using bold or italics. \
101
+ - Provide proper and meaningful abbreviations for urls. Do not include naked urls. \
102
+
103
+ 4. Organize Content Logically: \
104
+ - Structure the content in a logical order, ensuring easy navigation and understanding for the user. \
105
+
106
  {context}
107
  """
108
  qa_prompt = ChatPromptTemplate.from_messages(
 
114
  )
115
  question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
116
 
117
+ # Retrieval and Generative (RAG) Chain
118
  rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
119
 
120
+ # Chat message history storage
121
  store = {}
122
 
123
  def clean_temporary_data():
124
+ store.clear()
125
 
126
  def get_session_history(session_id: str) -> BaseChatMessageHistory:
127
  if session_id not in store:
128
  store[session_id] = ChatMessageHistory()
129
  return store[session_id]
130
 
131
+ # Conversational RAG chain with message history
132
  conversational_rag_chain = RunnableWithMessageHistory(
133
  rag_chain,
134
  get_session_history,
 
137
  output_messages_key="answer",
138
  )
139
 
140
+ # Function to handle WebSocket connection
141
+ @socketio.on('connect')
142
+ def handle_connect():
143
+ print(f"Client connected: {request.sid}")
144
+ emit('connection_response', {'message': 'Connected successfully.'})
145
+
146
+ # Function to handle WebSocket disconnection
147
+ @socketio.on('disconnect')
148
+ def handle_disconnect():
149
+ print(f"Client disconnected: {request.sid}")
150
+ clean_temporary_data()
151
+
152
+ # Function to handle WebSocket messages
153
  @socketio.on('message')
154
  def handle_message(data):
155
  question = data.get('question')
156
+ session_id = data.get('session_id', SESSION_ID_DEFAULT)
157
  chain = conversational_rag_chain.pick("answer")
158
+
159
  try:
160
  for chunk in chain.stream(
161
  {"input": question},
162
+ config={"configurable": {"session_id": session_id}},
 
 
163
  ):
164
+ emit('response', chunk, room=request.sid)
165
+ except Exception as e:
166
+ print(f"Error during message handling: {e}")
167
+ emit('response', {"error": "An error occurred while processing your request."}, room=request.sid)
168
+
 
 
 
 
169
 
170
+ # Home route
171
  @app.route("/")
172
  def index_view():
173
  return render_template('chat.html')
174
 
175
+ # Main function to run the app
176
  if __name__ == '__main__':
177
  socketio.run(app, debug=True)
static/script.js CHANGED
@@ -3,79 +3,92 @@ const socket = io.connect(document.baseURI);
3
  const chatBox = document.getElementById('chat-box');
4
  const chatInput = document.getElementById('chat-input');
5
  const sendButton = document.getElementById('send-button');
6
- var converter = new showdown.Converter();
7
- var response="";
8
 
 
9
 
10
- function addLoader(){
11
- // loader_ele = `
12
- // <div class="dot-loader">
13
- // <div></div>
14
- // <div></div>
15
- // <div></div>
16
- // </div>
17
- // `
18
- const loader_ele = document.createElement('div');
19
- loader_ele.classList.add('dot-loader');
20
- loader_ele.innerHTML = `
21
- <div></div>
22
- <div></div>
23
- <div></div>
24
  `;
25
- chatBox.appendChild(loader_ele);
26
  }
27
 
 
28
  function appendMessage(message, sender) {
29
- if(sender == "bot"){
30
  response += message;
31
- message = converter.makeHtml(response);
32
 
33
- let loader_ele = chatBox.lastElementChild;
34
-
35
- if(!loader_ele.classList.contains("hidden")){
36
- chatBox.removeChild(loader_ele);
37
  const messageElement = document.createElement('div');
38
  messageElement.classList.add('chat-message', sender);
39
- messageElement.innerHTML = `<span>${message}</span>`;
40
  chatBox.append(messageElement);
41
  chatBox.scrollTop = chatBox.scrollHeight;
42
- }else{
43
- last_message_ele = chatBox.lastElementChild.lastChild;
44
- last_message_ele.innerHTML = message;
 
 
45
  chatBox.scrollTop = chatBox.scrollHeight;
46
  }
47
- }else{
48
  const messageElement = document.createElement('div');
49
  messageElement.classList.add('chat-message', sender);
50
  messageElement.innerHTML = `<span>${message}</span>`;
51
  chatBox.append(messageElement);
52
  chatBox.scrollTop = chatBox.scrollHeight;
53
- setTimeout(() => {
54
- addLoader()
55
- }, 500);
56
  }
57
- chatBox.scrollTop = chatBox.scrollHeight;
58
  }
59
 
 
60
  sendButton.addEventListener('click', () => {
61
  const message = chatInput.value.trim();
62
  if (message) {
63
  appendMessage(message, 'user');
64
  socket.emit('message', { question: message, session_id: 'abc123' });
65
- // setInterval(()=>{
66
- // appendMessage("This is a test message", "bot");
67
- // }, 2000)
68
  chatInput.value = '';
69
  response = "";
 
 
70
  }
71
  });
72
 
 
73
  chatInput.addEventListener('keypress', (e) => {
74
  if (e.key === 'Enter') {
75
  sendButton.click();
76
  }
77
  });
78
 
79
- socket.on('response', (response) => {
80
- appendMessage(response, 'bot');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  });
 
3
  const chatBox = document.getElementById('chat-box');
4
  const chatInput = document.getElementById('chat-input');
5
  const sendButton = document.getElementById('send-button');
6
+ const converter = new showdown.Converter(); // If you're using showdown.js for markdown to HTML conversion
 
7
 
8
+ let response = "";
9
 
10
+ // Function to add a loader element
11
+ function addLoader() {
12
+ const loaderEle = document.createElement('div');
13
+ loaderEle.classList.add('dot-loader');
14
+ loaderEle.innerHTML = `
15
+ <div></div>
16
+ <div></div>
17
+ <div></div>
 
 
 
 
 
 
18
  `;
19
+ chatBox.appendChild(loaderEle);
20
  }
21
 
22
+ // Function to append a message to the chat box
23
  function appendMessage(message, sender) {
24
+ if (sender === 'bot') {
25
  response += message;
 
26
 
27
+ const loaderEle = chatBox.lastElementChild;
28
+
29
+ if (loaderEle && loaderEle.classList.contains('dot-loader')) {
30
+ chatBox.removeChild(loaderEle);
31
  const messageElement = document.createElement('div');
32
  messageElement.classList.add('chat-message', sender);
33
+ messageElement.innerHTML = `<span>${response}</span>`;
34
  chatBox.append(messageElement);
35
  chatBox.scrollTop = chatBox.scrollHeight;
36
+ } else {
37
+ const lastMessageEle = chatBox.lastElementChild;
38
+ if (lastMessageEle) {
39
+ lastMessageEle.innerHTML = response;
40
+ }
41
  chatBox.scrollTop = chatBox.scrollHeight;
42
  }
43
+ } else {
44
  const messageElement = document.createElement('div');
45
  messageElement.classList.add('chat-message', sender);
46
  messageElement.innerHTML = `<span>${message}</span>`;
47
  chatBox.append(messageElement);
48
  chatBox.scrollTop = chatBox.scrollHeight;
49
+
50
+ // Add a loader after a slight delay
51
+ setTimeout(addLoader, 500);
52
  }
 
53
  }
54
 
55
+ // Event listener for the send button
56
  sendButton.addEventListener('click', () => {
57
  const message = chatInput.value.trim();
58
  if (message) {
59
  appendMessage(message, 'user');
60
  socket.emit('message', { question: message, session_id: 'abc123' });
 
 
 
61
  chatInput.value = '';
62
  response = "";
63
+ } else {
64
+ console.error("Message cannot be empty.");
65
  }
66
  });
67
 
68
+ // Event listener for 'Enter' key press in the chat input
69
  chatInput.addEventListener('keypress', (e) => {
70
  if (e.key === 'Enter') {
71
  sendButton.click();
72
  }
73
  });
74
 
75
+ // Handle incoming responses from the server
76
+ socket.on('response', (data) => {
77
+ if (data && typeof data === 'string') {
78
+ appendMessage(data, 'bot');
79
+ } else {
80
+ console.error("Invalid response format received from the server.");
81
+ }
82
+ });
83
+
84
+ // Handle connection errors
85
+ socket.on('connect_error', (error) => {
86
+ console.error("Connection error:", error);
87
+ appendMessage("Sorry, there was a problem connecting to the server. Please try again later.", 'bot');
88
+ });
89
+
90
+ // Handle disconnection
91
+ socket.on('disconnect', (reason) => {
92
+ console.warn("Disconnected from server:", reason);
93
+ appendMessage("You have been disconnected from the server. Please refresh the page to reconnect.", 'bot');
94
  });