Spaces:
Running
Running
from dotenv import load_dotenv, find_dotenv | |
from langchain.chains import LLMChain | |
import streamlit as st | |
from decouple import config | |
from langchain.llms import OpenAI | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.retrievers.document_compressors import LLMChainExtractor | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.self_query.base import SelfQueryRetriever | |
from langchain.chains import RetrievalQA | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.evaluation.qa import QAGenerateChain | |
from langchain.chains import RetrievalQA | |
from langchain.chat_models import ChatOpenAI | |
from langchain.document_loaders import CSVLoader | |
from langchain.indexes import VectorstoreIndexCreator | |
from langchain.vectorstores import DocArrayInMemorySearch | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.document_loaders.generic import GenericLoader | |
from langchain.document_loaders.parsers import OpenAIWhisperParser | |
from langchain.document_loaders.blob_loaders.youtube_audio import YoutubeAudioLoader | |
from langchain.prompts import PromptTemplate | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
import time | |
from htmlTemplates import css, bot_template, user_template | |
from pathlib import Path | |
import pathlib | |
import platform | |
plt = platform.system() | |
if plt == 'Linux': | |
pathlib.WindowsPath = pathlib.PosixPath | |
_ = load_dotenv(find_dotenv()) # read local .env file | |
def timeit(func): | |
def wrapper(*args, **kwargs): | |
start_time = time.time() # Start time | |
result = func(*args, **kwargs) # Function execution | |
end_time = time.time() # End time | |
print( | |
f"Function {func.__name__} took {end_time - start_time} seconds to execute.") | |
return result | |
return wrapper | |
def get_llm(): | |
return OpenAI(temperature=0.1) | |
def get_memory(): | |
return ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True | |
) | |
def generate_response(question, vectordb, llm, memory, chat_history): | |
template = """Use the provided context to answer the user's question. | |
you are honest petroleum engineer specialist in hydraulic fracture stimulation and reservoir engineering. | |
If you don't know the answer, respond with "Sorry Sir, I do not know". | |
Context: {context} | |
Question: {question} | |
Answer: | |
""" | |
prompt = PromptTemplate( | |
template=template, | |
input_variables=[ 'question','context']) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=vectordb.as_retriever(search_type="mmr", k=5, fetch_k=10), | |
memory=memory, | |
combine_docs_chain_kwargs={"prompt": prompt} | |
) | |
handle_userinput( | |
(qa_chain({"question": question, "chat_history": chat_history}))) | |
def create_embeding_function(): | |
# embedding_func_all_mpnet_base_v2 = SentenceTransformerEmbeddings( | |
# model_name="all-mpnet-base-v2") | |
# # embedding_func_all_MiniLM_L6_v2 = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") | |
# embedding_func_jina_embeddings_v2_base_en = SentenceTransformerEmbeddings( | |
# model_name="jinaai/jina-embeddings-v2-base-en" | |
# ) | |
# embedding_func_jina_embeddings_v2_small_en = SentenceTransformerEmbeddings( | |
# model_name="jinaai/jina-embeddings-v2-small-en" | |
# ) | |
embedding_func_jgte_large = SentenceTransformerEmbeddings( | |
model_name="thenlper/gte-large" | |
) | |
return embedding_func_jgte_large | |
def get_vector_db(embedding_function): | |
vector_db = Chroma(persist_directory=str(Path('gte_large')), | |
embedding_function=embedding_function) | |
return vector_db | |
def handle_userinput(user_question): | |
response = user_question | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
st.session_state.chat_history = response['chat_history'] | |
for i, message in enumerate(st.session_state.chat_history): | |
if i % 2 == 0: | |
st.write(user_template.replace( | |
"{{MSG}}", message.content), unsafe_allow_html=True) | |
else: | |
st.write(bot_template.replace( | |
"{{MSG}}", message.content), unsafe_allow_html=True) | |
if __name__ == "__main__": | |
st.set_page_config( | |
page_title="Hydraulic Fracture Stimulation Chat", page_icon=":books:") | |
st.write(css, unsafe_allow_html=True) | |
st.title("Hydraulic Fracture Stimulation Chat") | |
st.write( | |
"This is a chatbot that can answer questions related to petroleum engineering specially in hydraulic fracture stimulation.") | |
# get embeding function | |
embeding_function = create_embeding_function() | |
# get vector db | |
vector_db = get_vector_db(embeding_function) | |
# get llm | |
llm = get_llm() | |
# get memory | |
if 'memory' not in st.session_state: | |
st.session_state['memory'] = get_memory() | |
memory = st.session_state['memory'] | |
# chat history | |
chat_history = [] | |
prompt_question = st.chat_input("Please ask a question:") | |
if prompt_question: | |
generate_response(question=prompt_question, vectordb=vector_db, | |
llm=llm, memory=memory, chat_history=chat_history) | |