File size: 10,204 Bytes
d1a829e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
from src import streamlit_utils
from src.prompts import AGENT_SYSTEM_PROMPT, AGENT_USER_PROMPT, RAG_USER_PROMPT, TRAVERSIALAI_USER_PROMPT
from src.retriever import Retriever

import streamlit as st

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ChatMessageHistory
import re
import requests
import os
from qdrant_client import QdrantClient

collection_name = 'hotels'

st.set_page_config(page_title="Hotels search chatbot", page_icon="⭐")
st.header('Hotels search chatbot')
st.write('[![view source code and description](https://img.shields.io/badge/view_source_code-gray?logo=github)](https://github.com/Maksimov-Dmitry/traversaal-ai-hackathon)')
st.write('Developed by [Dmitry Maksimov](https://www.linkedin.com/in/maksimov-dmitry/), [email protected] and [Ilya Dudnik](https://www.linkedin.com/in/ilia-dudnik-5b8018271/), [email protected]')

st.sidebar.header('Choose your preferences')
n_hotels = st.sidebar.number_input('Number of hotels', min_value=1, max_value=10, value=3)


@st.cache_resource
def get_db_client(path='data/db'):
    client = QdrantClient(path=path)
    return client


def add_new_info(chat_history, queries):
    """After the user has changed any parameters (city, price, rating), we notify the Agent about it.
        The information is added to the chat history.
    Args:
        chat_history: history of the chat
        queries (list): list of queries that the user has changed
    """
    for query in queries:
        chat_history.add_user_message(query)
        chat_history.add_ai_message('Ok, got it!')


def check_params(params):
    """Check if the user has changed the parameters (city, price, rating).
        If the user has changed the parameters, the corresponding queries are created.

    Args:
        params (dict): dictionary with the parameters

    Returns:
        list: list of queries that the user has changed
    """
    changed_params = []

    if 'prev_params' not in st.session_state:
        st.session_state.prev_params = {'city': '<BLANK>', 'price': '<BLANK>', 'rating': '<BLANK>'}

    if st.session_state.prev_params['city'] != params['city']:
        changed_params.append(f'I want to find hotels in {params["city"]}' if params['city'] else 'I want to find hotels in any city')

    if st.session_state.prev_params['price'] != params['price']:
        changed_params.append(f'I want to find hotels in price range {params["price"]}' if params['price'] else 'I want to find hotels in any price range')

    if st.session_state.prev_params['rating'] != params['rating']:
        changed_params.append(f'I want to find hotels with rating greater than {params["rating"]}')

    st.session_state.prev_params = params

    return changed_params


def get_parameters(db_client):
    """Get the parameters from the user (city, price, rating),
         The provided metadata (in case it was provided by the user) is used in the MixedRetrieval from Qdrant vector DB
    """
    points, _ = db_client.scroll(
        collection_name=collection_name,
        limit=1e9,
        with_payload=True,
        with_vectors=False,
    )
    cities = ['Doest not matter'] + list(set([point.payload['city'] for point in points]))
    city = st.sidebar.selectbox('City', list(cities), index=0)
    if city == 'Doest not matter':
        city = None

    prices = ['Doest not matter'] + list(set([point.payload['price'] for point in points]))
    price = st.sidebar.selectbox('Price', list(prices), index=0)
    if price == 'Doest not matter':
        price = None

    rating = st.sidebar.slider('Min hotel rating', min_value=.0, max_value=5.0, value=4.5, step=.5)
    return dict(city=city, price=price, rating=rating)


class HotelsSearchChatbot:
    """
        This is the Agent class. It is responsible for the decision-making during conversation with the user.
        Based on the user's query, the Agent decides which action to take and how to present result to the user.
    """
    def __init__(self, db_client):
        streamlit_utils.configure_api_keys()

        self.llm_model = "gpt-4-1106-preview"
        self.temperature = 0.6

        self.embeedings_model = "text-embedding-3-large"
        self.rerank_model = 'rerank-multilingual-v2.0'

        self.ares_api_key = os.environ.get("ARES_API_KEY")
        self.db_client = db_client

    def _traversialai(self, query):
        """Acquiring information from the internet using the Traversaal.ai.

        Args:
            query (str): search query

        Returns:
            str: information from the internet based on the query
        """
        url = "https://api-ares.traversaal.ai/live/predict"

        payload = {"query": [query]}
        headers = {
            "x-api-key": self.ares_api_key,
            "content-type": "application/json"
        }

        response = requests.post(url, json=payload, headers=headers)
        try:
            return response.json()['data']['response_text']
        except:
            return None

    def _get_action(self, text):
        """Parse (read) the action and the action input from the response of the Agent
        (after he made a decision what to do).
        'action' and 'action_input' indicate whether we need to query additional tools
        (vector DB, Traversaal AI) and how.

        Args:
            text (str): response of the Agent, which contains the action and the action input

        Returns:
            tuple: action, action input
        """
        action_pattern = r"Action:\s*(.*)\n"
        action_input_pattern = r"Action Input:\s*(.*)"

        action_match = re.search(action_pattern, text)
        action_input_match = re.search(action_input_pattern, text)

        action = action_match.group(1) if action_match else None
        action_input = action_input_match.group(1) if action_input_match else None
        return action, action_input

    def _make_action(self, action, action_input, retriever, chain, chat_history, config, retriever_params):
        """Take the action corresponding to 'action' and 'action input'. The 'action' can be one of the following:
            'nothing' - Agent is capable of dealing on its own without use of additional tools,
            'hotels_data_base' - Agent decides to get the information from the hotels vector DB,
            'ares_api' - Agent requires additional information from the internet using the Traversaal.ai.

        Args:
            action (str): action to make
            action_input (str): action input (formulated by Agent search query)
            retriever (Retriever): Retriever object
            chain (Chain): Chain object
            chat_history (ChatMessageHistory): history of the chat
            config (dict): handlers for a LangChain invoke method
            retriever_params (dict): parameters for the Retriever
        """
        if action == 'nothing':
            st.markdown(action_input)
            return action_input

        if action == 'hotels_data_base':
            context = retriever(action_input, top_k=n_hotels, **retriever_params)
            chat_history.add_user_message(RAG_USER_PROMPT.format(context=context, query=action_input))
            response = chain.invoke({"messages": chat_history.messages}, config)
            chat_history.messages.pop()
            return response.content

        if action == 'ares_api':
            context = self._traversialai(action_input)
            chat_history.add_user_message(TRAVERSIALAI_USER_PROMPT.format(context=context, query=action_input))
            response = chain.invoke({"messages": chat_history.messages}, config)
            chat_history.messages.pop()
            return response.content

        return None

    @st.cache_resource
    def setup_chain(_self):
        retriever = Retriever(embedding_model=_self.embeedings_model, llm_model=_self.llm_model,
                              rerank_model=_self.rerank_model, db_client=_self.db_client, db_collection=collection_name)

        chat_history = ChatMessageHistory()
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    AGENT_SYSTEM_PROMPT,
                ),
                MessagesPlaceholder(variable_name="messages"),
            ]
        )
        chat = ChatOpenAI(model=_self.llm_model, temperature=_self.temperature, streaming=True)
        chain = prompt | chat

        return chain, chat_history, retriever

    @streamlit_utils.enable_chat_history
    def main(self, params):
        chain, chat_history, retriever = self.setup_chain()
        user_query = st.chat_input(placeholder="Ask me anything!")
        if user_query:
            streamlit_utils.display_msg(user_query, 'user')

            # add new info to the chat history
            queries = check_params(params)
            add_new_info(chat_history, queries)

            # get the action and the action input based on the user's query
            chat_history.add_user_message(AGENT_USER_PROMPT.format(input=user_query))
            action_response = chain.invoke({"messages": chat_history.messages})
            chat_history.messages.pop()
            action, action_input = self._get_action(action_response.content)

            with st.chat_message("assistant"):
                st_cb = streamlit_utils.StreamHandler(st.empty())

                # create response on the user's query
                response = self._make_action(action, action_input,
                                             retriever, chain, chat_history, {"callbacks": [st_cb]}, params)
                chat_history.add_user_message(user_query)
                if response is None:
                    response = 'Sorry, I cannot help you with it. Could you rephrase your question?'
                    st.markdown(response)

                chat_history.add_ai_message(response)
                st.session_state.messages.append({"role": "assistant", "content": response})


if __name__ == "__main__":
    db_client = get_db_client()
    params = get_parameters(db_client)
    obj = HotelsSearchChatbot(db_client)
    obj.main(params)