updare index and reranker
Browse files- .gitignore +2 -0
- app.py +26 -11
- mbzuai-policies.json +0 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
*.ipynb
|
app.py
CHANGED
@@ -4,6 +4,9 @@ monkey.patch_all()
|
|
4 |
import nltk
|
5 |
nltk.download('punkt_tab')
|
6 |
|
|
|
|
|
|
|
7 |
import os
|
8 |
from dotenv import load_dotenv
|
9 |
import asyncio
|
@@ -20,10 +23,10 @@ from pinecone import Pinecone
|
|
20 |
from pinecone_text.sparse import BM25Encoder
|
21 |
from langchain_huggingface import HuggingFaceEmbeddings
|
22 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
23 |
-
from langchain_groq import ChatGroq
|
24 |
from langchain.retrievers import ContextualCompressionRetriever
|
25 |
-
from langchain.retrievers.document_compressors import FlashrankRerank
|
26 |
from langchain_community.chat_models import ChatPerplexity
|
|
|
|
|
27 |
|
28 |
# Load environment variables
|
29 |
load_dotenv(".env")
|
@@ -62,7 +65,7 @@ def initialize_pinecone(index_name: str):
|
|
62 |
##################################################
|
63 |
|
64 |
# Initialize Pinecone index and BM25 encoder
|
65 |
-
pinecone_index = initialize_pinecone("updated-mbzuai-policies")
|
66 |
bm25 = BM25Encoder().load("./new_mbzuai-policies.json")
|
67 |
|
68 |
##################################################
|
@@ -77,7 +80,8 @@ retriever = PineconeHybridSearchRetriever(
|
|
77 |
sparse_encoder=bm25,
|
78 |
index=pinecone_index,
|
79 |
top_k=20,
|
80 |
-
alpha=0.5
|
|
|
81 |
)
|
82 |
|
83 |
# Initialize LLM
|
@@ -86,7 +90,11 @@ llm = ChatPerplexity(temperature=0, pplx_api_key=GROQ_API_KEY, model="llama-3.1-
|
|
86 |
|
87 |
|
88 |
# Initialize Reranker
|
89 |
-
compressor = FlashrankRerank()
|
|
|
|
|
|
|
|
|
90 |
compression_retriever = ContextualCompressionRetriever(
|
91 |
base_compressor=compressor, base_retriever=retriever
|
92 |
)
|
@@ -191,14 +199,21 @@ def handle_message(data):
|
|
191 |
else:
|
192 |
language = "Arabic"
|
193 |
session_id = data.get('session_id', SESSION_ID_DEFAULT)
|
194 |
-
chain = conversational_rag_chain.pick("answer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|
196 |
try:
|
197 |
-
|
198 |
-
|
199 |
-
config={"configurable": {"session_id": session_id}},
|
200 |
-
):
|
201 |
-
emit('response', chunk, room=request.sid)
|
202 |
except Exception as e:
|
203 |
print(f"Error during message handling: {e}")
|
204 |
emit('response', "An error occurred while processing your request." + str(e), room=request.sid)
|
|
|
4 |
import nltk
|
5 |
nltk.download('punkt_tab')
|
6 |
|
7 |
+
import nltk
|
8 |
+
nltk.download('punkt_tab')
|
9 |
+
|
10 |
import os
|
11 |
from dotenv import load_dotenv
|
12 |
import asyncio
|
|
|
23 |
from pinecone_text.sparse import BM25Encoder
|
24 |
from langchain_huggingface import HuggingFaceEmbeddings
|
25 |
from langchain_community.retrievers import PineconeHybridSearchRetriever
|
|
|
26 |
from langchain.retrievers import ContextualCompressionRetriever
|
|
|
27 |
from langchain_community.chat_models import ChatPerplexity
|
28 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
29 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
30 |
|
31 |
# Load environment variables
|
32 |
load_dotenv(".env")
|
|
|
65 |
##################################################
|
66 |
|
67 |
# Initialize Pinecone index and BM25 encoder
|
68 |
+
pinecone_index = initialize_pinecone("updated-mbzuai-policies-17112024")
|
69 |
bm25 = BM25Encoder().load("./new_mbzuai-policies.json")
|
70 |
|
71 |
##################################################
|
|
|
80 |
sparse_encoder=bm25,
|
81 |
index=pinecone_index,
|
82 |
top_k=20,
|
83 |
+
alpha=0.5,
|
84 |
+
|
85 |
)
|
86 |
|
87 |
# Initialize LLM
|
|
|
90 |
|
91 |
|
92 |
# Initialize Reranker
|
93 |
+
# compressor = FlashrankRerank()
|
94 |
+
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
|
95 |
+
compressor = CrossEncoderReranker(model=model, top_n=20)
|
96 |
+
|
97 |
+
|
98 |
compression_retriever = ContextualCompressionRetriever(
|
99 |
base_compressor=compressor, base_retriever=retriever
|
100 |
)
|
|
|
199 |
else:
|
200 |
language = "Arabic"
|
201 |
session_id = data.get('session_id', SESSION_ID_DEFAULT)
|
202 |
+
# chain = conversational_rag_chain.pick("answer")
|
203 |
+
|
204 |
+
# try:
|
205 |
+
# for chunk in conversational_rag_chain.stream(
|
206 |
+
# {"input": question, 'language': language},
|
207 |
+
# config={"configurable": {"session_id": session_id}},
|
208 |
+
# ):
|
209 |
+
# emit('response', chunk, room=request.sid)
|
210 |
+
# except Exception as e:
|
211 |
+
# print(f"Error during message handling: {e}")
|
212 |
+
# emit('response', "An error occurred while processing your request." + str(e), room=request.sid)
|
213 |
|
214 |
try:
|
215 |
+
response = conversational_rag_chain.invoke({"input": question, 'language': language}, config={"configurable": {"session_id": session_id}})
|
216 |
+
emit('response', response, room=request.sid)
|
|
|
|
|
|
|
217 |
except Exception as e:
|
218 |
print(f"Error during message handling: {e}")
|
219 |
emit('response', "An error occurred while processing your request." + str(e), room=request.sid)
|
mbzuai-policies.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|