himel06 commited on
Commit
83ffd72
1 Parent(s): 0b0e325

Update BanglaRAG/bangla_rag_pipeline.py

Browse files
Files changed (1) hide show
  1. BanglaRAG/bangla_rag_pipeline.py +23 -30
BanglaRAG/bangla_rag_pipeline.py CHANGED
@@ -26,7 +26,6 @@ warnings.filterwarnings("ignore")
26
  class BanglaRAGChain:
27
  """
28
  Bangla Retrieval-Augmented Generation (RAG) Chain for question answering.
29
-
30
  This class uses a HuggingFace/local language model for text generation, a Chroma vector database for
31
  document retrieval, and a custom prompt template to create a RAG chain that can generate
32
  responses to user queries in Bengali.
@@ -74,7 +73,6 @@ class BanglaRAGChain:
74
  ):
75
  """
76
  Loads the required models and data for the RAG chain.
77
-
78
  Args:
79
  chat_model_id (str): The Hugging Face model ID for the chat model.
80
  embed_model_id (str): The Hugging Face model ID for the embedding model.
@@ -119,6 +117,7 @@ class BanglaRAGChain:
119
 
120
  rprint(Panel("[bold green]Initializing LLM...", expand=False))
121
  self._get_llm()
 
122
  rprint(Panel("[bold green]Creating chain...", expand=False))
123
  self._create_chain()
124
 
@@ -141,17 +140,14 @@ class BanglaRAGChain:
141
  low_cpu_mem_usage=True,
142
  quantization_config=bnb_config,
143
  device_map="auto",
144
- # cache_dir=CACHE_DIR, # Removed cache_dir to use default caching
145
  )
146
  rprint(Panel("[bold green]Applied 4bit quantization successfully", expand=False))
147
-
148
  else:
149
  self.chat_model = AutoModelForCausalLM.from_pretrained(
150
  self.chat_model_id,
151
  torch_dtype=torch.float16,
152
  low_cpu_mem_usage=True,
153
  device_map="auto",
154
- # cache_dir=CACHE_DIR, # Removed cache_dir to use default caching
155
  )
156
  rprint(Panel("[bold green]Chat Model loaded successfully!", expand=False))
157
  except Exception as e:
@@ -194,9 +190,8 @@ class BanglaRAGChain:
194
  )
195
  rprint(Panel(f"[bold green]Loaded embedding model successfully!", expand=False))
196
  except Exception as e:
197
- rprint(Panel("f[red]embedding model loading failed: {e}", expand=False))
198
 
199
-
200
  self._db = Chroma.from_texts(texts=self._documents, embedding=embeddings)
201
  rprint(
202
  Panel("[bold green]Chroma database updated successfully!", expand=False)
@@ -207,13 +202,10 @@ class BanglaRAGChain:
207
  def _create_chain(self):
208
  """Creates the retrieval-augmented generation (RAG) chain."""
209
  template = """Below is an instruction in Bengali language that describes a task, paired with an input also in Bengali language that provides further context. Write a response in Bengali that appropriately completes the request.
210
-
211
  ### Instruction:
212
  {question}
213
-
214
  ### Input:
215
  {context}
216
-
217
  ### Response:
218
  """
219
  prompt_template = ChatPromptTemplate(
@@ -256,7 +248,13 @@ class BanglaRAGChain:
256
 
257
  def _get_retriever(self):
258
  """Creates a retriever for the vector database."""
259
- self._retriever = self._db.as_retriever(search_kwargs={"k": self.k})
 
 
 
 
 
 
260
 
261
  def _get_llm(self):
262
  """Initializes the language model using the Hugging Face pipeline."""
@@ -280,24 +278,19 @@ class BanglaRAGChain:
280
  rprint(Panel("[bold green]LLM initialized successfully!", expand=False))
281
  except Exception as e:
282
  rprint(Panel(f"[red]LLM initialization failed: {e}", expand=False))
 
283
 
284
- def _format_docs(self, docs):
285
- """Formats the retrieved documents for the prompt."""
286
- formatted_docs = "\n".join([re.sub(r"\s+", " ", doc) for doc in docs])
287
- return formatted_docs
 
 
 
288
 
289
- def query(self, prompt: str) -> str:
290
- """
291
- Queries the RAG chain with a given prompt.
292
-
293
- Args:
294
- prompt (str): The input prompt to query the RAG chain.
295
-
296
- Returns:
297
- str: The generated response from the RAG chain.
298
- """
299
- return self._chain.invoke({"question": prompt})
300
-
301
- def __call__(self, prompt: str) -> str:
302
- """Alias for the query method."""
303
- return self.query(prompt)
 
26
  class BanglaRAGChain:
27
  """
28
  Bangla Retrieval-Augmented Generation (RAG) Chain for question answering.
 
29
  This class uses a HuggingFace/local language model for text generation, a Chroma vector database for
30
  document retrieval, and a custom prompt template to create a RAG chain that can generate
31
  responses to user queries in Bengali.
 
73
  ):
74
  """
75
  Loads the required models and data for the RAG chain.
 
76
  Args:
77
  chat_model_id (str): The Hugging Face model ID for the chat model.
78
  embed_model_id (str): The Hugging Face model ID for the embedding model.
 
117
 
118
  rprint(Panel("[bold green]Initializing LLM...", expand=False))
119
  self._get_llm()
120
+
121
  rprint(Panel("[bold green]Creating chain...", expand=False))
122
  self._create_chain()
123
 
 
140
  low_cpu_mem_usage=True,
141
  quantization_config=bnb_config,
142
  device_map="auto",
 
143
  )
144
  rprint(Panel("[bold green]Applied 4bit quantization successfully", expand=False))
 
145
  else:
146
  self.chat_model = AutoModelForCausalLM.from_pretrained(
147
  self.chat_model_id,
148
  torch_dtype=torch.float16,
149
  low_cpu_mem_usage=True,
150
  device_map="auto",
 
151
  )
152
  rprint(Panel("[bold green]Chat Model loaded successfully!", expand=False))
153
  except Exception as e:
 
190
  )
191
  rprint(Panel(f"[bold green]Loaded embedding model successfully!", expand=False))
192
  except Exception as e:
193
+ rprint(Panel(f"[red]embedding model loading failed: {e}", expand=False))
194
 
 
195
  self._db = Chroma.from_texts(texts=self._documents, embedding=embeddings)
196
  rprint(
197
  Panel("[bold green]Chroma database updated successfully!", expand=False)
 
202
  def _create_chain(self):
203
  """Creates the retrieval-augmented generation (RAG) chain."""
204
  template = """Below is an instruction in Bengali language that describes a task, paired with an input also in Bengali language that provides further context. Write a response in Bengali that appropriately completes the request.
 
205
  ### Instruction:
206
  {question}
 
207
  ### Input:
208
  {context}
 
209
  ### Response:
210
  """
211
  prompt_template = ChatPromptTemplate(
 
248
 
249
  def _get_retriever(self):
250
  """Creates a retriever for the vector database."""
251
+ try:
252
+ self._retriever = self._db.as_retriever(
253
+ search_type="similarity", search_kwargs={"k": self.k}
254
+ )
255
+ rprint(Panel("[bold green]Retriever created successfully!", expand=False))
256
+ except Exception as e:
257
+ rprint(Panel(f"[red]Retriever creation failed: {e}", expand=False))
258
 
259
  def _get_llm(self):
260
  """Initializes the language model using the Hugging Face pipeline."""
 
278
  rprint(Panel("[bold green]LLM initialized successfully!", expand=False))
279
  except Exception as e:
280
  rprint(Panel(f"[red]LLM initialization failed: {e}", expand=False))
281
+ self._llm = None # Ensure it’s set to None on failure
282
 
283
+ def __call__(self, query):
284
+ """Runs the RAG chain on a user query and returns the generated answer."""
285
+ if not self._chain:
286
+ raise ValueError("The chain has not been initialized.")
287
+ if self._chain:
288
+ result = self._chain.invoke({"question": query})
289
+ return result["answer"], result["context"]
290
 
291
+ def _format_docs(self, docs):
292
+ """Formats retrieved documents into a string format."""
293
+ context = ""
294
+ for i, doc in enumerate(docs):
295
+ context += f"\nDocument {i + 1}:\n{doc.page_content}\n\n"
296
+ return context