Spaces:
Sleeping
Sleeping
from MongoChainGenerator import * | |
from MongoEmbeddingGenerator import * | |
from DATABASE import * | |
from appConfig import LOG | |
class Main: | |
qa_chains = {} | |
embedding_generator = None | |
def __init__(self) -> None: | |
DATABASE() | |
self._initialize_embedding_generator() | |
self._load_existing_qa_chains() | |
def _initialize_embedding_generator(self): | |
if Main.embedding_generator is None: | |
Main.embedding_generator = MongoEmbeddingGenerator(repo_id=CONST_VAR.EMBEDDING_MODEL_REPO_ID) | |
LOG.debug("Embedding generator initialized") | |
def _load_existing_qa_chains(self): | |
chats = DATABASE.client["chatData"]["chats"].find() | |
for chat in chats: | |
if chat["collectionName"] not in Main.qa_chains: | |
self.create_exist_chains(chat) | |
def create_exist_chains(self, chat): | |
if chat["collectionName"] not in Main.qa_chains: | |
qa_generator = MongoChainGenerator( | |
embedding_model=Main.embedding_generator.embedding_model, | |
db_collection_name=chat["collectionName"], | |
template_context=chat["templateContext"] | |
) | |
Main.qa_chains[chat["collectionName"]] = qa_generator.generate_retrieval_qa_chain() | |
LOG.debug("Chain created for collection " + chat["collectionName"]) | |
else: | |
LOG.debug("Chain already exists for collection " + chat["collectionName"]) | |
def generate_embedding(self, content: str, file_name: str, collection_name: str): | |
return Main.embedding_generator.generate_embeddings(content, file_name, collection_name) | |
def generate_tmp_embedding_and_chain(self, contents: str, tmp_collection_name): | |
qa_generator = MongoChainGenerator( | |
embedding_model=Main.embedding_generator.embedding_model, | |
template_context=CONST_VAR.TEMPLATE_CONTEXT, | |
tmp_vector_embedding=Main.embedding_generator.generate_tmp_embeddings(pdf_bytes=contents) | |
) | |
Main.qa_chains[tmp_collection_name] = qa_generator.generate_retrieval_qa_chain() | |
LOG.debug(tmp_collection_name + ' chain created') | |
def ask_question(self, question: str, collection_name): | |
if collection_name in Main.qa_chains: | |
try: | |
LOG.debug(collection_name + " answering") | |
response = Main.qa_chains[collection_name]({"query": question, "early_stopping": True, "min_length": 2000, "max_tokens": 5000}) | |
return response["result"] | |
except Exception as e: | |
LOG.error("An error occurred while answering question: {}".format(str(e))) | |
return "Retry to ask question! An error occurred: {}".format(str(e)) | |
else: | |
LOG.warning("Chain for collection '{}' not found.".format(collection_name)) | |
return "Chain for collection '{}' not found.".format(collection_name) | |
def check_collection_name(self, collection_name): | |
return collection_name in self.qa_chains |