File size: 2,704 Bytes
b5f78f9
 
 
 
 
 
a89ac38
b5f78f9
a89ac38
b5f78f9
3a583fb
b5f78f9
 
 
 
a89ac38
b5f78f9
a89ac38
 
 
 
 
 
7672bfa
ce8cedc
b5f78f9
dd1a630
b5f78f9
 
77cefa9
b5f78f9
28ca64c
6eed3cc
e5cc6c5
dd1a630
 
e5cc6c5
b5f78f9
 
 
e5cc6c5
6eed3cc
e5cc6c5
 
 
 
 
 
 
 
 
 
 
 
b5f78f9
 
a89ac38
 
b5f78f9
 
 
 
 
 
dd1a630
aeec0ed
 
 
 
 
ff66968
aeec0ed
b5f78f9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import os

from langchain import OpenAI, ConversationChain
from langchain.prompts import PromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.chains import RetrievalQAWithSourcesChain

from langchain.chains.conversation.memory import ConversationEntityMemory
from langchain.chains.conversation.prompt import ENTITY_MEMORY_CONVERSATION_TEMPLATE

from langchain import LLMChain

persist_directory="db"
llm=OpenAI(model_name = "text-davinci-003", temperature=0)
model_name = "hkunlp/instructor-large"
embed_instruction = "Represent the text from the BMW website for retrieval"
query_instruction = "Query the most relevant text from the BMW website"
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name, embed_instruction=embed_instruction, query_instruction=query_instruction)
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
chain = RetrievalQAWithSourcesChain.from_chain_type(llm, chain_type="stuff", retriever=vectordb.as_retriever())

def chat(message, history):
    history = history or []
    response = ""
    markdown = ""
    try:
        response = chain({"question": f"{message}"}, return_only_outputs=True)
        print('got response')
        markdown = generate_markdown(response)
    except Exception as e:
        print(f"Erorr: {e}")
    history.append((message, markdown))

    return history, history

def generate_markdown(obj):
    print('generating markdown')
    md_string = ""

    if 'answer' in obj:
        md_string += f"**Answer:**\n\n{obj['answer']}\n"

    if 'sources' in obj:
        sources_list = obj['sources'].strip().split('\n')
        md_string += "**Sources:**\n\n"
        for i, source in enumerate(sources_list):
            md_string += f"{i + 1}. {source}\n"
    
    return md_string

with gr.Blocks() as demo:
    gr.Markdown("<h3><center>BMW Chat Bot</center></h3>")
    gr.Markdown("<p><center>Ask questions about BMW</center></p>")
    chatbot = gr.Chatbot()
    with gr.Row():
        inp = gr.Textbox(placeholder="Question",label =None)
        btn = gr.Button("Run").style(full_width=False)
        state = gr.State()
        agent_state = gr.State()
        btn.click(chat, [inp, state],[chatbot, state])
    gr.Examples(
        examples=[
            "What is BMW doing about sustainability?",
            "What is the future of BMW?"
        ],
        inputs=inp,
    )
if __name__ == '__main__':
    demo.launch()