Spaces:
Runtime error
Runtime error
import os | |
import shutil | |
import sys | |
from typing import Any, Dict, List, Optional | |
import torch | |
import yaml | |
from dotenv import load_dotenv | |
from langchain.chains.base import Chain | |
from langchain.docstore.document import Document | |
from langchain.prompts import BasePromptTemplate, load_prompt | |
from langchain_core.callbacks import CallbackManagerForChainRun | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.retrievers import BaseRetriever | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
current_dir = os.path.dirname(os.path.abspath(__file__)) # src/ directory | |
kit_dir = os.path.abspath(os.path.join(current_dir, '..')) # EKR/ directory | |
repo_dir = os.path.abspath(os.path.join(kit_dir, '..')) | |
sys.path.append(kit_dir) | |
sys.path.append(repo_dir) | |
#import streamlit as st | |
from utils.model_wrappers.api_gateway import APIGateway | |
from utils.vectordb.vector_db import VectorDb | |
from utils.visual.env_utils import get_wandb_key | |
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml') | |
PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db') | |
#load_dotenv(os.path.join(kit_dir, '.env')) | |
from utils.parsing.sambaparse import parse_doc_universal | |
# Handle the WANDB_API_KEY resolution before importing weave | |
#wandb_api_key = get_wandb_key() | |
# If WANDB_API_KEY is set, proceed with weave initialization | |
#if wandb_api_key: | |
# import weave | |
# Initialize Weave with your project name | |
# weave.init('sambanova_ekr') | |
#else: | |
# print('WANDB_API_KEY is not set. Weave initialization skipped.') | |
class RetrievalQAChain(Chain): | |
"""class for question-answering.""" | |
retriever: BaseRetriever | |
rerank: bool = True | |
llm: BaseLanguageModel | |
qa_prompt: BasePromptTemplate | |
final_k_retrieved_documents: int = 3 | |
def input_keys(self) -> List[str]: | |
"""Input keys. | |
:meta private: | |
""" | |
return ['question'] | |
def output_keys(self) -> List[str]: | |
"""Output keys. | |
:meta private: | |
""" | |
return ['answer', 'source_documents'] | |
def _format_docs(self, docs): | |
return '\n\n'.join(doc.page_content for doc in docs) | |
def rerank_docs(self, query, docs, final_k): | |
# Lazy hardcoding for now | |
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large') | |
reranker = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large') | |
pairs = [] | |
for d in docs: | |
pairs.append([query, d.page_content]) | |
with torch.no_grad(): | |
inputs = tokenizer( | |
pairs, | |
padding=True, | |
truncation=True, | |
return_tensors='pt', | |
max_length=512, | |
) | |
scores = ( | |
reranker(**inputs, return_dict=True) | |
.logits.view( | |
-1, | |
) | |
.float() | |
) | |
scores_list = scores.tolist() | |
scores_sorted_idx = sorted(range(len(scores_list)), key=lambda k: scores_list[k], reverse=True) | |
docs_sorted = [docs[k] for k in scores_sorted_idx] | |
# docs_sorted = [docs[k] for k in scores_sorted_idx if scores_list[k]>0] | |
docs_sorted = docs_sorted[:final_k] | |
return docs_sorted | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
qa_chain = self.qa_prompt | self.llm | StrOutputParser() | |
response = {} | |
documents = self.retriever.invoke(inputs['question']) | |
if self.rerank: | |
documents = self.rerank_docs(inputs['question'], documents, self.final_k_retrieved_documents) | |
docs = self._format_docs(documents) | |
response['answer'] = qa_chain.invoke({'question': inputs['question'], 'context': docs}) | |
response['source_documents'] = documents | |
return response | |
class DocumentRetrieval: | |
def __init__(self, sambanova_api_key): | |
self.vectordb = VectorDb() | |
config_info = self.get_config_info() | |
self.api_info = config_info[0] | |
self.llm_info = config_info[1] | |
self.embedding_model_info = config_info[2] | |
self.retrieval_info = config_info[3] | |
self.prompts = config_info[4] | |
self.prod_mode = config_info[5] | |
self.retriever = None | |
self.llm = self.set_llm(sambanova_api_key) | |
def get_config_info(self): | |
""" | |
Loads json config file | |
""" | |
# Read config file | |
with open(CONFIG_PATH, 'r') as yaml_file: | |
config = yaml.safe_load(yaml_file) | |
api_info = config['api'] | |
llm_info = config['llm'] | |
embedding_model_info = config['embedding_model'] | |
retrieval_info = config['retrieval'] | |
prompts = config['prompts'] | |
prod_mode = config['prod_mode'] | |
return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode | |
def set_llm(self, sambanova_api_key): | |
#if self.prod_mode: | |
# sambanova_api_key = st.session_state.SAMBANOVA_API_KEY | |
#else: | |
# if 'SAMBANOVA_API_KEY' in st.session_state: | |
# sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY | |
# else: | |
# sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') | |
#sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') | |
llm = APIGateway.load_llm( | |
type=self.api_info, | |
streaming=True, | |
coe=self.llm_info['coe'], | |
do_sample=self.llm_info['do_sample'], | |
max_tokens_to_generate=self.llm_info['max_tokens_to_generate'], | |
temperature=self.llm_info['temperature'], | |
select_expert=self.llm_info['select_expert'], | |
process_prompt=False, | |
sambanova_api_key=sambanova_api_key, | |
) | |
return llm | |
def parse_doc(self, docs: List, additional_metadata: Optional[Dict] = None) -> List[Document]: | |
""" | |
Parse the uploaded documents and return a list of LangChain documents. | |
Args: | |
docs (List[UploadFile]): A list of uploaded files. | |
additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents. | |
Defaults to an empty dictionary. | |
Returns: | |
List[Document]: A list of LangChain documents. | |
""" | |
if additional_metadata is None: | |
additional_metadata = {} | |
# Create the data/tmp folder if it doesn't exist | |
temp_folder = os.path.join(kit_dir, 'data/tmp') | |
if not os.path.exists(temp_folder): | |
os.makedirs(temp_folder) | |
else: | |
# If there are already files there, delete them | |
for filename in os.listdir(temp_folder): | |
file_path = os.path.join(temp_folder, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print(f'Failed to delete {file_path}. Reason: {e}') | |
# Save all selected files to the tmp dir with their file names | |
#for doc in docs: | |
# temp_file = os.path.join(temp_folder, doc.name) | |
# with open(temp_file, 'wb') as f: | |
# f.write(doc.getvalue()) | |
for doc_info in docs: | |
file_name, file_obj = doc_info | |
temp_file = os.path.join(temp_folder, file_name) | |
with open(temp_file, 'wb') as f: | |
f.write(file_obj.read()) | |
# Pass in the temp folder for processing into the parse_doc_universal function | |
_, _, langchain_docs = parse_doc_universal(doc=temp_folder, additional_metadata=additional_metadata) | |
return langchain_docs | |
def load_embedding_model(self): | |
embeddings = APIGateway.load_embedding_model( | |
type=self.embedding_model_info['type'], | |
batch_size=self.embedding_model_info['batch_size'], | |
coe=self.embedding_model_info['coe'], | |
select_expert=self.embedding_model_info['select_expert'], | |
) | |
return embeddings | |
def create_vector_store(self, text_chunks, embeddings, output_db=None, collection_name=None): | |
print(f'Collection name is {collection_name}') | |
vectorstore = self.vectordb.create_vector_store( | |
text_chunks, embeddings, output_db=output_db, collection_name=collection_name, db_type='chroma' | |
) | |
return vectorstore | |
def load_vdb(self, db_path, embeddings, collection_name=None): | |
print(f'Loading collection name is {collection_name}') | |
vectorstore = self.vectordb.load_vdb(db_path, embeddings, db_type='chroma', collection_name=collection_name) | |
return vectorstore | |
def init_retriever(self, vectorstore): | |
if self.retrieval_info['rerank']: | |
self.retriever = vectorstore.as_retriever( | |
search_type='similarity_score_threshold', | |
search_kwargs={ | |
'score_threshold': self.retrieval_info['score_threshold'], | |
'k': self.retrieval_info['k_retrieved_documents'], | |
}, | |
) | |
else: | |
self.retriever = vectorstore.as_retriever( | |
search_type='similarity_score_threshold', | |
search_kwargs={ | |
'score_threshold': self.retrieval_info['score_threshold'], | |
'k': self.retrieval_info['final_k_retrieved_documents'], | |
}, | |
) | |
def get_qa_retrieval_chain(self): | |
""" | |
Generate a qa_retrieval chain using a language model. | |
This function uses a language model, specifically a SambaNova LLM, to generate a qa_retrieval chain | |
based on the input vector store of text chunks. | |
Parameters: | |
vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context | |
for generating the conversation chain. | |
Returns: | |
RetrievalQA: A chain ready for QA without memory | |
""" | |
# customprompt = load_prompt(os.path.join(kit_dir, self.prompts["qa_prompt"])) | |
# qa_chain = customprompt | self.llm | StrOutputParser() | |
# response = {} | |
# documents = self.retriever.invoke(question) | |
# if self.retrieval_info["rerank"]: | |
# documents = self.rerank_docs(question, documents, self.retrieval_info["final_k_retrieved_documents"]) | |
# docs = self._format_docs(documents) | |
# response["answer"] = qa_chain.invoke({"question": question, "context": docs}) | |
# response["source_documents"] = documents | |
retrievalQAChain = RetrievalQAChain( | |
retriever=self.retriever, | |
llm=self.llm, | |
qa_prompt=load_prompt(os.path.join(kit_dir, self.prompts['qa_prompt'])), | |
rerank=self.retrieval_info['rerank'], | |
final_k_retrieved_documents=self.retrieval_info['final_k_retrieved_documents'], | |
) | |
return retrievalQAChain | |
def get_conversational_qa_retrieval_chain(self): | |
""" | |
Generate a conversational retrieval qa chain using a language model. | |
This function uses a language model, specifically a SambaNova LLM, to generate a conversational_qa_retrieval chain | |
based on the chat history and the relevant retrieved content from the input vector store of text chunks. | |
Parameters: | |
vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context | |
for generating the conversation chain. | |
Returns: | |
RetrievalQA: A chain ready for QA with memory | |
""" | |