Update BanglaRAG/bangla_rag_pipeline.py
Browse files- BanglaRAG/bangla_rag_pipeline.py +23 -30
BanglaRAG/bangla_rag_pipeline.py
CHANGED
@@ -26,7 +26,6 @@ warnings.filterwarnings("ignore")
|
|
26 |
class BanglaRAGChain:
|
27 |
"""
|
28 |
Bangla Retrieval-Augmented Generation (RAG) Chain for question answering.
|
29 |
-
|
30 |
This class uses a HuggingFace/local language model for text generation, a Chroma vector database for
|
31 |
document retrieval, and a custom prompt template to create a RAG chain that can generate
|
32 |
responses to user queries in Bengali.
|
@@ -74,7 +73,6 @@ class BanglaRAGChain:
|
|
74 |
):
|
75 |
"""
|
76 |
Loads the required models and data for the RAG chain.
|
77 |
-
|
78 |
Args:
|
79 |
chat_model_id (str): The Hugging Face model ID for the chat model.
|
80 |
embed_model_id (str): The Hugging Face model ID for the embedding model.
|
@@ -119,6 +117,7 @@ class BanglaRAGChain:
|
|
119 |
|
120 |
rprint(Panel("[bold green]Initializing LLM...", expand=False))
|
121 |
self._get_llm()
|
|
|
122 |
rprint(Panel("[bold green]Creating chain...", expand=False))
|
123 |
self._create_chain()
|
124 |
|
@@ -141,17 +140,14 @@ class BanglaRAGChain:
|
|
141 |
low_cpu_mem_usage=True,
|
142 |
quantization_config=bnb_config,
|
143 |
device_map="auto",
|
144 |
-
# cache_dir=CACHE_DIR, # Removed cache_dir to use default caching
|
145 |
)
|
146 |
rprint(Panel("[bold green]Applied 4bit quantization successfully", expand=False))
|
147 |
-
|
148 |
else:
|
149 |
self.chat_model = AutoModelForCausalLM.from_pretrained(
|
150 |
self.chat_model_id,
|
151 |
torch_dtype=torch.float16,
|
152 |
low_cpu_mem_usage=True,
|
153 |
device_map="auto",
|
154 |
-
# cache_dir=CACHE_DIR, # Removed cache_dir to use default caching
|
155 |
)
|
156 |
rprint(Panel("[bold green]Chat Model loaded successfully!", expand=False))
|
157 |
except Exception as e:
|
@@ -194,9 +190,8 @@ class BanglaRAGChain:
|
|
194 |
)
|
195 |
rprint(Panel(f"[bold green]Loaded embedding model successfully!", expand=False))
|
196 |
except Exception as e:
|
197 |
-
rprint(Panel("
|
198 |
|
199 |
-
|
200 |
self._db = Chroma.from_texts(texts=self._documents, embedding=embeddings)
|
201 |
rprint(
|
202 |
Panel("[bold green]Chroma database updated successfully!", expand=False)
|
@@ -207,13 +202,10 @@ class BanglaRAGChain:
|
|
207 |
def _create_chain(self):
|
208 |
"""Creates the retrieval-augmented generation (RAG) chain."""
|
209 |
template = """Below is an instruction in Bengali language that describes a task, paired with an input also in Bengali language that provides further context. Write a response in Bengali that appropriately completes the request.
|
210 |
-
|
211 |
### Instruction:
|
212 |
{question}
|
213 |
-
|
214 |
### Input:
|
215 |
{context}
|
216 |
-
|
217 |
### Response:
|
218 |
"""
|
219 |
prompt_template = ChatPromptTemplate(
|
@@ -256,7 +248,13 @@ class BanglaRAGChain:
|
|
256 |
|
257 |
def _get_retriever(self):
|
258 |
"""Creates a retriever for the vector database."""
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
def _get_llm(self):
|
262 |
"""Initializes the language model using the Hugging Face pipeline."""
|
@@ -280,24 +278,19 @@ class BanglaRAGChain:
|
|
280 |
rprint(Panel("[bold green]LLM initialized successfully!", expand=False))
|
281 |
except Exception as e:
|
282 |
rprint(Panel(f"[red]LLM initialization failed: {e}", expand=False))
|
|
|
283 |
|
284 |
-
def
|
285 |
-
"""
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
288 |
|
289 |
-
def
|
290 |
-
"""
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
Returns:
|
297 |
-
str: The generated response from the RAG chain.
|
298 |
-
"""
|
299 |
-
return self._chain.invoke({"question": prompt})
|
300 |
-
|
301 |
-
def __call__(self, prompt: str) -> str:
|
302 |
-
"""Alias for the query method."""
|
303 |
-
return self.query(prompt)
|
|
|
26 |
class BanglaRAGChain:
|
27 |
"""
|
28 |
Bangla Retrieval-Augmented Generation (RAG) Chain for question answering.
|
|
|
29 |
This class uses a HuggingFace/local language model for text generation, a Chroma vector database for
|
30 |
document retrieval, and a custom prompt template to create a RAG chain that can generate
|
31 |
responses to user queries in Bengali.
|
|
|
73 |
):
|
74 |
"""
|
75 |
Loads the required models and data for the RAG chain.
|
|
|
76 |
Args:
|
77 |
chat_model_id (str): The Hugging Face model ID for the chat model.
|
78 |
embed_model_id (str): The Hugging Face model ID for the embedding model.
|
|
|
117 |
|
118 |
rprint(Panel("[bold green]Initializing LLM...", expand=False))
|
119 |
self._get_llm()
|
120 |
+
|
121 |
rprint(Panel("[bold green]Creating chain...", expand=False))
|
122 |
self._create_chain()
|
123 |
|
|
|
140 |
low_cpu_mem_usage=True,
|
141 |
quantization_config=bnb_config,
|
142 |
device_map="auto",
|
|
|
143 |
)
|
144 |
rprint(Panel("[bold green]Applied 4bit quantization successfully", expand=False))
|
|
|
145 |
else:
|
146 |
self.chat_model = AutoModelForCausalLM.from_pretrained(
|
147 |
self.chat_model_id,
|
148 |
torch_dtype=torch.float16,
|
149 |
low_cpu_mem_usage=True,
|
150 |
device_map="auto",
|
|
|
151 |
)
|
152 |
rprint(Panel("[bold green]Chat Model loaded successfully!", expand=False))
|
153 |
except Exception as e:
|
|
|
190 |
)
|
191 |
rprint(Panel(f"[bold green]Loaded embedding model successfully!", expand=False))
|
192 |
except Exception as e:
|
193 |
+
rprint(Panel(f"[red]embedding model loading failed: {e}", expand=False))
|
194 |
|
|
|
195 |
self._db = Chroma.from_texts(texts=self._documents, embedding=embeddings)
|
196 |
rprint(
|
197 |
Panel("[bold green]Chroma database updated successfully!", expand=False)
|
|
|
202 |
def _create_chain(self):
|
203 |
"""Creates the retrieval-augmented generation (RAG) chain."""
|
204 |
template = """Below is an instruction in Bengali language that describes a task, paired with an input also in Bengali language that provides further context. Write a response in Bengali that appropriately completes the request.
|
|
|
205 |
### Instruction:
|
206 |
{question}
|
|
|
207 |
### Input:
|
208 |
{context}
|
|
|
209 |
### Response:
|
210 |
"""
|
211 |
prompt_template = ChatPromptTemplate(
|
|
|
248 |
|
249 |
def _get_retriever(self):
|
250 |
"""Creates a retriever for the vector database."""
|
251 |
+
try:
|
252 |
+
self._retriever = self._db.as_retriever(
|
253 |
+
search_type="similarity", search_kwargs={"k": self.k}
|
254 |
+
)
|
255 |
+
rprint(Panel("[bold green]Retriever created successfully!", expand=False))
|
256 |
+
except Exception as e:
|
257 |
+
rprint(Panel(f"[red]Retriever creation failed: {e}", expand=False))
|
258 |
|
259 |
def _get_llm(self):
|
260 |
"""Initializes the language model using the Hugging Face pipeline."""
|
|
|
278 |
rprint(Panel("[bold green]LLM initialized successfully!", expand=False))
|
279 |
except Exception as e:
|
280 |
rprint(Panel(f"[red]LLM initialization failed: {e}", expand=False))
|
281 |
+
self._llm = None # Ensure it’s set to None on failure
|
282 |
|
283 |
+
def __call__(self, query):
|
284 |
+
"""Runs the RAG chain on a user query and returns the generated answer."""
|
285 |
+
if not self._chain:
|
286 |
+
raise ValueError("The chain has not been initialized.")
|
287 |
+
if self._chain:
|
288 |
+
result = self._chain.invoke({"question": query})
|
289 |
+
return result["answer"], result["context"]
|
290 |
|
291 |
+
def _format_docs(self, docs):
|
292 |
+
"""Formats retrieved documents into a string format."""
|
293 |
+
context = ""
|
294 |
+
for i, doc in enumerate(docs):
|
295 |
+
context += f"\nDocument {i + 1}:\n{doc.page_content}\n\n"
|
296 |
+
return context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|