Spaces:
Sleeping
Sleeping
from appConfig import * | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint | |
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch | |
from langchain.vectorstores.faiss import FAISS | |
from huggingface_hub import login | |
login(token=ENV_VAR.HUGGINGFACEHUB_API_TOKEN,write_permission=True,add_to_git_credential=True) | |
class MongoChainGenerator: | |
LLM = None | |
def __init__(self, embedding_model, template_context, db_collection_name=None,tmp_vector_embedding=None): | |
if db_collection_name: | |
self._load_vectors(embedding_model, db_collection_name) | |
else: | |
self._create_tmp_retriever(tmp_vector_embedding) | |
self._initialize_prompt(template_context) | |
if MongoChainGenerator.LLM is None: | |
self._initialize_llm() | |
def _create_tmp_retriever(self, tmp_vector_embedding: FAISS): | |
self.qa_retriever = tmp_vector_embedding.as_retriever(search_type="similarity", search_kwargs={"k": 7}) | |
LOG.debug("Temporary retriever created") | |
def _load_vectors(self, embedding_model, db_collection_name): | |
self.qa_retriever = MongoDBAtlasVectorSearch.from_connection_string( | |
connection_string=ENV_VAR.MONGO_DB_URL, | |
namespace=ENV_VAR.MONGO_DB_NAME + "." + db_collection_name, | |
embedding=embedding_model, | |
).as_retriever(search_type="similarity", search_kwargs={"k": 7}) | |
LOG.debug("Retriever loaded from MongoDB Atlas") | |
def _initialize_prompt(self, template_context): | |
template = template_context + """ | |
{context} | |
Question: {question} all related details. | |
Answer:""" | |
self.prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
LOG.debug("Prompt template initialized") | |
def _initialize_llm(self): | |
MongoChainGenerator.LLM = HuggingFaceEndpoint(repo_id=CONST_VAR.TEXT_GENERATOR_MODEL_REPO_ID, temperature=0.8, max_new_tokens=4096) | |
# MongoChainGenerator.LLM = HuggingFaceHub(repo_id=CONST_VAR.TEXT_GENERATOR_MODEL_REPO_ID, model_kwargs={"temperature": 0.85, "return_full_text": False, "max_length": 4096, "max_new_tokens": 4096}) | |
LOG.info("LLM initialized") | |
def generate_retrieval_qa_chain(self): | |
chain = RetrievalQA.from_chain_type( | |
llm=MongoChainGenerator.LLM, | |
retriever=self.qa_retriever, | |
chain_type_kwargs={"prompt": self.prompt}, | |
) | |
LOG.debug("Retrieval QA chain generated") | |
return chain | |