Ritesh-hf commited on
Commit
043287c
1 Parent(s): ef8105a

added refereced functionality

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. Dockerfile +1 -1
  3. app.py +89 -103
  4. requirements.txt +4 -6
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .env
2
- *.ipynb
 
 
1
  .env
2
+ *.ipynb
3
+ __pycache__/*
Dockerfile CHANGED
@@ -13,4 +13,4 @@ COPY --chown=user ./requirements.txt requirements.txt
13
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
 
15
  COPY --chown=user . /app
16
- CMD ["gunicorn", "-b", "0.0.0.0:7860", "-k", "geventwebsocket.gunicorn.workers.GeventWebSocketWorker", "-w", "1", "app:app"]
 
13
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
 
15
  COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,18 +1,13 @@
1
- from gevent import monkey
2
- monkey.patch_all()
3
-
4
- import nltk
5
- nltk.download('punkt_tab')
6
-
7
  import nltk
8
  nltk.download('punkt_tab')
9
 
10
  import os
11
  from dotenv import load_dotenv
12
  import asyncio
13
- from flask import Flask, request, render_template
14
- from flask_cors import CORS
15
- from flask_socketio import SocketIO, emit, join_room, leave_room
 
16
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
17
  from langchain.chains.combine_documents import create_stuff_documents_chain
18
  from langchain_community.chat_message_histories import ChatMessageHistory
@@ -27,6 +22,7 @@ from langchain.retrievers import ContextualCompressionRetriever
27
  from langchain_community.chat_models import ChatPerplexity
28
  from langchain.retrievers.document_compressors import CrossEncoderReranker
29
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
 
30
 
31
  # Load environment variables
32
  load_dotenv(".env")
@@ -41,14 +37,19 @@ os.environ['USER_AGENT'] = USER_AGENT
41
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
42
  os.environ["TOKENIZERS_PARALLELISM"] = 'true'
43
 
44
- # Initialize Flask app and SocketIO with CORS
45
- app = Flask(__name__)
46
- CORS(app)
47
- socketio = SocketIO(app, async_mode='gevent', cors_allowed_origins="*")
48
- app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
49
- app.config['SESSION_COOKIE_HTTPONLY'] = True
50
- app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
51
- app.config['SECRET_KEY'] = SECRET_KEY
 
 
 
 
 
52
 
53
  # Function to initialize Pinecone connection
54
  def initialize_pinecone(index_name: str):
@@ -59,7 +60,6 @@ def initialize_pinecone(index_name: str):
59
  print(f"Error initializing Pinecone: {e}")
60
  raise
61
 
62
-
63
  ##################################################
64
  ## Change down here
65
  ##################################################
@@ -71,8 +71,6 @@ bm25 = BM25Encoder().load("./mbzuai-policies.json")
71
  ##################################################
72
  ##################################################
73
 
74
- # old_embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/gte-multilingual-base")
75
-
76
  # Initialize models and retriever
77
  embed_model = HuggingFaceEmbeddings(model_name="jinaai/jina-embeddings-v3", model_kwargs={"trust_remote_code":True})
78
  retriever = PineconeHybridSearchRetriever(
@@ -81,20 +79,15 @@ retriever = PineconeHybridSearchRetriever(
81
  index=pinecone_index,
82
  top_k=20,
83
  alpha=0.5,
84
-
85
  )
86
 
87
  # Initialize LLM
88
- # llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0, max_tokens=1024, max_retries=2)
89
- llm = ChatPerplexity(temperature=0, pplx_api_key=GROQ_API_KEY, model="llama-3.1-sonar-large-128k-chat", max_tokens=1024, max_retries=2)
90
-
91
 
92
  # Initialize Reranker
93
- # compressor = FlashrankRerank()
94
  model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
95
  compressor = CrossEncoderReranker(model=model, top_n=20)
96
 
97
-
98
  compression_retriever = ContextualCompressionRetriever(
99
  base_compressor=compressor, base_retriever=retriever
100
  )
@@ -115,33 +108,35 @@ contextualize_q_prompt = ChatPromptTemplate.from_messages(
115
  history_aware_retriever = create_history_aware_retriever(llm, compression_retriever, contextualize_q_prompt)
116
 
117
  # QA system prompt and chain
118
- qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following context to answer questions effectively. \
119
- If you don't know the answer, simply state that you don't know. \
120
- Your answer should be in {language} language. \
121
- Provide answers in proper HTML format and keep them concise. \
122
-
123
- When responding to queries, follow these guidelines: \
124
-
125
- 1. Provide Clear Answers: \
126
- - Based on the language of the question, you have to answer in that language. E.g. if the question is in English language then answer in the English language or if the question is in Arabic language then you should answer in Arabic language. /
127
- - Ensure the response directly addresses the query with accurate and relevant information.\
128
-
129
- 2. Include Detailed References: \
130
- - Links to Sources: Include URLs to credible sources where users can verify information or explore further. \
131
- - Reference Sites: Mention specific websites or platforms that offer additional information. \
132
- - Downloadable Materials: Provide links to any relevant downloadable resources if applicable. \
133
-
134
- 3. Formatting for Readability: \
135
- - The answer should be in a proper HTML format with appropriate tags. \
136
- - For arabic language response align the text to right and convert numbers also.
137
- - Double check if the language of answer is correct or not.
138
- - Use bullet points or numbered lists where applicable to present information clearly. \
139
- - Highlight key details using bold or italics. \
140
- - Provide proper and meaningful abbreviations for urls. Do not include naked urls. \
141
-
142
- 4. Organize Content Logically: \
143
- - Structure the content in a logical order, ensuring easy navigation and understanding for the user. \
144
-
 
 
145
  {context}
146
  """
147
  qa_prompt = ChatPromptTemplate.from_messages(
@@ -151,7 +146,9 @@ qa_prompt = ChatPromptTemplate.from_messages(
151
  ("human", "{input}")
152
  ]
153
  )
154
- question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
 
 
155
 
156
  # Retrieval and Generative (RAG) Chain
157
  rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
@@ -159,9 +156,6 @@ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chai
159
  # Chat message history storage
160
  store = {}
161
 
162
- def clean_temporary_data():
163
- store.clear()
164
-
165
  def get_session_history(session_id: str) -> BaseChatMessageHistory:
166
  if session_id not in store:
167
  store[session_id] = ChatMessageHistory()
@@ -177,53 +171,45 @@ conversational_rag_chain = RunnableWithMessageHistory(
177
  output_messages_key="answer",
178
  )
179
 
180
- # Function to handle WebSocket connection
181
- @socketio.on('connect')
182
- def handle_connect():
183
- print(f"Client connected: {request.sid}")
184
- emit('connection_response', {'message': 'Connected successfully.'})
185
-
186
- # Function to handle WebSocket disconnection
187
- @socketio.on('disconnect')
188
- def handle_disconnect():
189
- print(f"Client disconnected: {request.sid}")
190
- clean_temporary_data()
191
-
192
- # Function to handle WebSocket messages
193
- @socketio.on('message')
194
- def handle_message(data):
195
- question = data.get('question')
196
- language = data.get('language')
197
- if "en" in language:
198
- language = "English"
199
- else:
200
- language = "Arabic"
201
- session_id = data.get('session_id', SESSION_ID_DEFAULT)
202
- # chain = conversational_rag_chain.pick("answer")
203
-
204
- # try:
205
- # for chunk in conversational_rag_chain.stream(
206
- # {"input": question, 'language': language},
207
- # config={"configurable": {"session_id": session_id}},
208
- # ):
209
- # emit('response', chunk, room=request.sid)
210
- # except Exception as e:
211
- # print(f"Error during message handling: {e}")
212
- # emit('response', "An error occurred while processing your request." + str(e), room=request.sid)
213
 
 
 
 
 
 
 
214
  try:
215
- response = conversational_rag_chain.invoke({"input": question, 'language': language}, config={"configurable": {"session_id": session_id}})
216
- emit('response', response, room=request.sid)
217
- except Exception as e:
218
- print(f"Error during message handling: {e}")
219
- emit('response', "An error occurred while processing your request." + str(e), room=request.sid)
220
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  # Home route
223
- @app.route("/")
224
- def index_view():
225
- return render_template('chat.html')
226
-
227
- # Main function to run the app
228
- if __name__ == '__main__':
229
- socketio.run(app, debug=True)
 
 
 
 
 
 
 
1
  import nltk
2
  nltk.download('punkt_tab')
3
 
4
  import os
5
  from dotenv import load_dotenv
6
  import asyncio
7
+ from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
8
+ from fastapi.responses import HTMLResponse
9
+ from fastapi.templating import Jinja2Templates
10
+ from fastapi.middleware.cors import CORSMiddleware
11
  from langchain.chains import create_history_aware_retriever, create_retrieval_chain
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
13
  from langchain_community.chat_message_histories import ChatMessageHistory
 
22
  from langchain_community.chat_models import ChatPerplexity
23
  from langchain.retrievers.document_compressors import CrossEncoderReranker
24
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
25
+ from langchain_core.prompts import PromptTemplate
26
 
27
  # Load environment variables
28
  load_dotenv(".env")
 
37
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
38
  os.environ["TOKENIZERS_PARALLELISM"] = 'true'
39
 
40
+ # Initialize FastAPI app and CORS
41
+ app = FastAPI()
42
+ origins = ["*"] # Adjust as needed
43
+
44
+ app.add_middleware(
45
+ CORSMiddleware,
46
+ allow_origins=origins,
47
+ allow_credentials=True,
48
+ allow_methods=["*"],
49
+ allow_headers=["*"],
50
+ )
51
+
52
+ templates = Jinja2Templates(directory="templates")
53
 
54
  # Function to initialize Pinecone connection
55
  def initialize_pinecone(index_name: str):
 
60
  print(f"Error initializing Pinecone: {e}")
61
  raise
62
 
 
63
  ##################################################
64
  ## Change down here
65
  ##################################################
 
71
  ##################################################
72
  ##################################################
73
 
 
 
74
  # Initialize models and retriever
75
  embed_model = HuggingFaceEmbeddings(model_name="jinaai/jina-embeddings-v3", model_kwargs={"trust_remote_code":True})
76
  retriever = PineconeHybridSearchRetriever(
 
79
  index=pinecone_index,
80
  top_k=20,
81
  alpha=0.5,
 
82
  )
83
 
84
  # Initialize LLM
85
+ llm = ChatPerplexity(temperature=0, pplx_api_key=GROQ_API_KEY, model="llama-3.1-sonar-large-128k-chat", max_tokens=512, max_retries=2)
 
 
86
 
87
  # Initialize Reranker
 
88
  model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
89
  compressor = CrossEncoderReranker(model=model, top_n=20)
90
 
 
91
  compression_retriever = ContextualCompressionRetriever(
92
  base_compressor=compressor, base_retriever=retriever
93
  )
 
108
  history_aware_retriever = create_history_aware_retriever(llm, compression_retriever, contextualize_q_prompt)
109
 
110
  # QA system prompt and chain
111
+ qa_system_prompt = """ You are a highly skilled information retrieval assistant. Use the following context to answer questions effectively.
112
+ If you don't know the answer, simply state that you don't know.
113
+ Your answer should be in {language} language.
114
+
115
+ When responding to queries, follow these guidelines:
116
+
117
+ 1. Provide Clear Answers:
118
+ - Based on the language of the question, you have to answer in that language. E.g., if the question is in English, then answer in English; if the question is in Arabic, you should answer in Arabic.
119
+ - Ensure the response directly addresses the query with accurate and relevant information.
120
+ - Do not give long answers. Provide detailed but concise responses.
121
+
122
+ 2. Formatting for Readability:
123
+ - Provide the entire response in proper markdown format.
124
+ - Use structured Maekdown elements such as headings, subheading, lists, tables, and links.
125
+ - Use emaphsis on headings, important texts and phrases.
126
+
127
+ 3. Proper Citations and References:
128
+ - ALWAYS INCLUDE SOURCES URLs where users can verify information or explore further.
129
+ - Use inline citations with embed referenced source link in the format [1], [2], etc., in the response to reference sources.
130
+ - ALWAYS PROVIDE "References" SECTION AT THE END OF RESPONSE.
131
+ - In the "References" section, list the referenced sources with their urls in the following format
132
+ 'References
133
+ [1] Heading 1[Source 1 url] \
134
+ [2] Heading 2[Source 2 url] \
135
+ [3] Heading 3[Source 2 url] \
136
+ '
137
+
138
+ FOLLOW ALL THE GIVEN INSTRUCTIONS, FAILURE TO DO SO WILL RESULT IN TERMINATION OF THE CHAT.
139
+
140
  {context}
141
  """
142
  qa_prompt = ChatPromptTemplate.from_messages(
 
146
  ("human", "{input}")
147
  ]
148
  )
149
+
150
+ document_prompt = PromptTemplate(input_variables=["page_content", "source"], template="{page_content} \n\n Source: {source}")
151
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt, document_prompt=document_prompt)
152
 
153
  # Retrieval and Generative (RAG) Chain
154
  rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
 
156
  # Chat message history storage
157
  store = {}
158
 
 
 
 
159
  def get_session_history(session_id: str) -> BaseChatMessageHistory:
160
  if session_id not in store:
161
  store[session_id] = ChatMessageHistory()
 
171
  output_messages_key="answer",
172
  )
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ # WebSocket endpoint with streaming
176
+ @app.websocket("/ws")
177
+ async def websocket_endpoint(websocket: WebSocket):
178
+ await websocket.accept()
179
+ print(f"Client connected: {websocket.client}")
180
+ session_id = None
181
  try:
182
+ while True:
183
+ data = await websocket.receive_json()
184
+ question = data.get('question')
185
+ language = data.get('language')
186
+ if "en" in language:
187
+ language = "English"
188
+ else:
189
+ language = "Arabic"
190
+ session_id = data.get('session_id', SESSION_ID_DEFAULT)
191
+ # Process the question
192
+ try:
193
+ # Define an async generator for streaming
194
+ async def stream_response():
195
+ async for chunk in conversational_rag_chain.astream(
196
+ {"input": question, 'language': language},
197
+ config={"configurable": {"session_id": session_id}}
198
+ ):
199
+ # Send each chunk to the client
200
+ if "answer" in chunk:
201
+ await websocket.send_json({'response': chunk['answer']})
202
+
203
+ await stream_response()
204
+ except Exception as e:
205
+ print(f"Error during message handling: {e}")
206
+ await websocket.send_json({'response': "Something went wrong, Please try again.."})
207
+ except WebSocketDisconnect:
208
+ print(f"Client disconnected: {websocket.client}")
209
+ if session_id:
210
+ store.pop(session_id, None)
211
 
212
  # Home route
213
+ @app.get("/", response_class=HTMLResponse)
214
+ async def read_index(request: Request):
215
+ return templates.TemplateResponse("chat.html", {"request": request})
 
 
 
 
requirements.txt CHANGED
@@ -5,11 +5,9 @@ langchain-huggingface
5
  pinecone
6
  pinecone-text
7
  flashrank
8
- flask
9
- flask-cors
10
- flask-socketio
11
- gunicorn
12
- gevent
13
- gevent-websocket
14
  openai
15
  einops
 
5
  pinecone
6
  pinecone-text
7
  flashrank
8
+ fastapi>=0.68.0
9
+ uvicorn[standard]>=0.15.0
10
+ websockets>=10.0
11
+ python-multipart>=0.0.5
 
 
12
  openai
13
  einops