|
import gradio as gr |
|
import os |
|
|
|
from langchain import OpenAI, ConversationChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.docstore.document import Document |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain.chains.conversation.memory import ConversationBufferMemory |
|
from langchain.chains import RetrievalQAWithSourcesChain |
|
|
|
from langchain.chains.conversation.memory import ConversationEntityMemory |
|
from langchain.chains.conversation.prompt import ENTITY_MEMORY_CONVERSATION_TEMPLATE |
|
|
|
from langchain import LLMChain |
|
|
|
persist_directory="db" |
|
llm=OpenAI(model_name = "text-davinci-003", temperature=0) |
|
model_name = "hkunlp/instructor-large" |
|
embed_instruction = "Represent the text from the BMW website for retrieval" |
|
query_instruction = "Query the most relevant text from the BMW website" |
|
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name, embed_instruction=embed_instruction, query_instruction=query_instruction) |
|
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings) |
|
chain = RetrievalQAWithSourcesChain.from_chain_type(llm, chain_type="stuff", retriever=vectordb.as_retriever()) |
|
|
|
def chat(message, history): |
|
history = history or [] |
|
response = "" |
|
markdown = "" |
|
try: |
|
response = chain({"question": f"{message}"}, return_only_outputs=True) |
|
print('got response') |
|
markdown = generate_markdown(response) |
|
except Exception as e: |
|
print(f"Erorr: {e}") |
|
history.append((message, markdown)) |
|
|
|
return history, history |
|
|
|
def generate_markdown(obj): |
|
print('generating markdown') |
|
md_string = "" |
|
|
|
if 'answer' in obj: |
|
md_string += f"**Answer:**\n\n{obj['answer']}\n" |
|
|
|
if 'sources' in obj: |
|
sources_list = obj['sources'].strip().split('\n') |
|
md_string += "**Sources:**\n\n" |
|
for i, source in enumerate(sources_list): |
|
md_string += f"{i + 1}. {source}\n" |
|
|
|
return md_string |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h3><center>BMW Chat Bot</center></h3>") |
|
gr.Markdown("<p><center>Ask questions about BMW</center></p>") |
|
chatbot = gr.Chatbot() |
|
with gr.Row(): |
|
inp = gr.Textbox(placeholder="Question",label =None) |
|
btn = gr.Button("Run").style(full_width=False) |
|
state = gr.State() |
|
agent_state = gr.State() |
|
btn.click(chat, [inp, state],[chatbot, state]) |
|
gr.Examples( |
|
examples=[ |
|
"What is BMW doing about sustainability?", |
|
"What is the future of BMW?" |
|
], |
|
inputs=inp, |
|
) |
|
if __name__ == '__main__': |
|
demo.launch() |
|
|