beer-sommelier / chain.py
sooolee's picture
Initial Model
ba34941
raw
history blame
2.18 kB
"""Generate QA Chain to answer the question given matches (db)"""
import os
import openai
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.memory import ConversationBufferMemory, ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
openai.api_key = os.environ['OPENAI_API_KEY']
def get_chain():
template = """
You are a ChatBot at a restaurant with expertise in beers and good at making recommendations based on the user's preferences.
You are a given a context delimited by "=========", which is extracted part of the restaurant's beer list.
Customer's question is given below delimited by "```````".
Based on the context, question and chat history, please respond in a friendly, conversational manner based on the context, describing features of beers.
When asked to make recommendations, make three to four. Do not mention ABV or IBU unless asked about it. Remember you are trying to promote your beers.
If asked about something that is not related to beers in any absolute way, say you are only good at recommending beers in a witty way and redirect the conversation
by asking about customer's preferences such as beer style, flavor, etc..
If you don't find the answer to the user's question with the context provided to you below,
answer that you don't have the requested beer and ask about customer's preferences such as beer style, flavor, etc..
Finish by proposing your help for anything else.
CONTEXT:
=========
{context}
=========
QUESTION:
````````
{query}
````````
CHAT HISTORY:
{chat_history}
ANSWER:
"""
prompt = PromptTemplate(input_variables=["chat_history", "query", "context"], template=template)
memory = ConversationBufferMemory(memory_key="chat_history", input_key="query")
chat = ChatOpenAI(temperature=0, streaming=True)
chain = load_qa_chain(llm=chat, chain_type="stuff", memory=memory, prompt=prompt)
return chain