File size: 11,985 Bytes
1ec2294
 
74c8dfc
50e2fca
1ec2294
 
74c8dfc
f228ce0
 
1ec2294
 
50e2fca
 
2fb381b
1ec2294
 
1d70dc9
1ec2294
 
 
 
 
f228ce0
1ec2294
50e2fca
1ec2294
 
1d70dc9
f228ce0
1ec2294
 
 
 
 
 
 
 
2fb381b
 
 
 
1ec2294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fb381b
 
 
 
1ec2294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a271076
 
f228ce0
 
 
1ec2294
f228ce0
 
 
1ec2294
f228ce0
1ec2294
 
 
 
 
 
 
 
 
 
2fb381b
 
 
 
1ec2294
 
 
 
4153745
 
74c8dfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4153745
15534c0
 
4153745
 
5e18627
 
 
 
 
74c8dfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e18627
 
1ec2294
 
 
f228ce0
 
9531d87
f228ce0
1ec2294
 
 
 
 
f228ce0
 
 
 
9531d87
 
 
f228ce0
 
 
 
 
 
 
 
1ec2294
f228ce0
 
1ec2294
f228ce0
 
1d70dc9
 
f228ce0
 
 
edfa8ce
f228ce0
a271076
edfa8ce
1d70dc9
edfa8ce
 
 
 
 
f228ce0
 
 
 
 
 
 
 
 
 
 
9531d87
f228ce0
1ec2294
1d70dc9
9531d87
2fb381b
5e18627
74c8dfc
5e18627
4153745
1ec2294
 
56ce50a
 
 
2fb381b
 
 
 
1ec2294
 
 
 
 
 
 
 
f228ce0
06e15bc
f228ce0
 
 
 
 
 
 
 
 
 
b243d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e15bc
 
 
1d70dc9
06e15bc
1d70dc9
2fb381b
440ee82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import json
import os
import time
from pathlib import Path

import ipyleaflet
from openai import NotFoundError, OpenAI
from openai.types.beta import Thread

import solara

HERE = Path(__file__).parent

center_default = (0, 0)
zoom_default = 2

messages = solara.reactive([])
zoom_level = solara.reactive(zoom_default)
center = solara.reactive(center_default)
markers = solara.reactive([])

url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model = "gpt-4-1106-preview"
app_style = (HERE / "style.css").read_text()


# Declare tools for openai assistant to use
tools = [
    {
        "type": "function",
        "function": {
            "name": "update_map",
            "description": "Update map to center on a particular location",
            "parameters": {
                "type": "object",
                "properties": {
                    "longitude": {
                        "type": "number",
                        "description": "Longitude of the location to center the map on",
                    },
                    "latitude": {
                        "type": "number",
                        "description": "Latitude of the location to center the map on",
                    },
                    "zoom": {
                        "type": "integer",
                        "description": "Zoom level of the map",
                    },
                },
                "required": ["longitude", "latitude", "zoom"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "add_marker",
            "description": "Add marker to the map",
            "parameters": {
                "type": "object",
                "properties": {
                    "longitude": {
                        "type": "number",
                        "description": "Longitude of the location to the marker",
                    },
                    "latitude": {
                        "type": "number",
                        "description": "Latitude of the location to the marker",
                    },
                    "label": {
                        "type": "string",
                        "description": "Text to display on the marker",
                    },
                },
                "required": ["longitude", "latitude", "label"],
            },
        },
    },
]


def update_map(longitude, latitude, zoom):
    center.set((latitude, longitude))
    zoom_level.set(zoom)
    return "Map updated"


def add_marker(longitude, latitude, label):
    markers.set(markers.value + [{"location": (latitude, longitude), "label": label}])
    return "Marker added"


functions = {
    "update_map": update_map,
    "add_marker": add_marker,
}


def assistant_tool_call(tool_call):
    # actually executes the tool call the OpenAI assistant wants to perform
    function = tool_call.function
    name = function.name
    arguments = json.loads(function.arguments)
    return_value = functions[name](**arguments)
    tool_outputs = {
        "tool_call_id": tool_call.id,
        "output": return_value,
    }
    return tool_outputs


@solara.component
def Map():
    ipyleaflet.Map.element(  # type: ignore
        zoom=zoom_level.value,
        center=center.value,
        scroll_wheel_zoom=True,
        layers=[
            ipyleaflet.TileLayer.element(url=url),
            *[
                ipyleaflet.Marker.element(location=k["location"], draggable=False)
                for k in markers.value
            ],
        ],
    )


@solara.component
def ChatMessage(message):
    with solara.Row(style={"align-items": "flex-start"}):
        # Catch "messages" that are actually tool calls
        if isinstance(message, dict):
            icon = "mdi-map" if message["output"] == "Map updated" else "mdi-map-marker"
            solara.v.Icon(children=[icon], style_="padding-top: 10px;")
            solara.Markdown(message["output"])
        elif message.role == "user":
            solara.Text(message.content[0].text.value, style={"font-weight": "bold;"})
        elif message.role == "assistant":
            if message.content[0].text.value:
                solara.v.Icon(
                    children=["mdi-compass-outline"], style_="padding-top: 10px;"
                )
                solara.Markdown(message.content[0].text.value)
            elif message.content.tool_calls:
                solara.v.Icon(children=["mdi-map"], style_="padding-top: 10px;")
                solara.Markdown("*Calling map functions*")
            else:
                solara.v.Icon(
                    children=["mdi-compass-outline"], style_="padding-top: 10px;"
                )
                solara.Preformatted(repr(message))
        else:
            solara.v.Icon(children=["mdi-compass-outline"], style_="padding-top: 10px;")
            solara.Preformatted(repr(message))


@solara.component
def ChatBox(children=[]):
    # this uses a flexbox with column-reverse to reverse the order of the messages
    # if we now also reverse the order of the messages, we get the correct order
    # but the scroll position is at the bottom of the container automatically
    with solara.Column(style={"flex-grow": "1"}):
        solara.Style(
            """
            .chat-box > :last-child{
                padding-top: 7.5vh;
            }
            """
        )
        # The height works effectively as `min-height`, since flex will grow the container to fill the available space
        solara.Column(
            style={
                "flex-grow": "1",
                "overflow-y": "auto",
                "height": "100px",
                "flex-direction": "column-reverse",
            },
            classes=["chat-box"],
            children=list(reversed(children)),
        )


@solara.component
def ChatInterface():
    prompt = solara.use_reactive("")
    run_id: solara.Reactive[str] = solara.use_reactive(None)

    # Create a thread to hold the conversation only once when this component is created
    thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])

    def add_message(value: str):
        if value == "":
            return
        prompt.set("")
        new_message = openai.beta.threads.messages.create(
            thread_id=thread.id, content=value, role="user"
        )
        messages.set([*messages.value, new_message])
        # this creates a new run for the thread
        # also also triggers a rerender (since run_id.value changes)
        # which will trigger the poll function blow to start in a thread
        run_id.value = openai.beta.threads.runs.create(
            thread_id=thread.id,
            assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
            tools=tools,
        ).id

    def poll():
        if not run_id.value:
            return
        completed = False
        while not completed:
            try:
                run = openai.beta.threads.runs.retrieve(
                    run_id.value, thread_id=thread.id
                )
            # Above will raise NotFoundError when run creation is still in progress
            except NotFoundError:
                continue
            if run.status == "requires_action":
                tool_outputs = []
                for tool_call in run.required_action.submit_tool_outputs.tool_calls:
                    tool_output = assistant_tool_call(tool_call)
                    tool_outputs.append(tool_output)
                    messages.set([*messages.value, tool_output])
                openai.beta.threads.runs.submit_tool_outputs(
                    thread_id=thread.id,
                    run_id=run_id.value,
                    tool_outputs=tool_outputs,
                )
            if run.status == "completed":
                messages.set(
                    [
                        *messages.value,
                        openai.beta.threads.messages.list(thread.id).data[0],
                    ]
                )
                run_id.set(None)
                completed = True
            time.sleep(0.1)

    # run/restart a thread any time the run_id changes
    result = solara.use_thread(poll, dependencies=[run_id.value])

    # Create DOM for chat interface
    with solara.Column(classes=["chat-interface"]):
        if len(messages.value) > 0:
            with ChatBox():
                for message in messages.value:
                    ChatMessage(message)

        with solara.Column():
            solara.InputText(
                label="Where do you want to go?"
                if len(messages.value) == 0
                else "Ask more question here",
                value=prompt,
                style={"flex-grow": "1"},
                on_value=add_message,
                disabled=result.state == solara.ResultState.RUNNING,
            )
            solara.ProgressLinear(result.state == solara.ResultState.RUNNING)
            if result.state == solara.ResultState.ERROR:
                solara.Error(repr(result.error))


@solara.component
def Page():
    with solara.Column(
        classes=["ui-container"],
        gap="5vh",
    ):
        with solara.Row(justify="space-between"):
            with solara.Row(gap="10px", style={"align-items": "center"}):
                solara.v.Icon(children=["mdi-compass-rose"], size="36px")
                solara.HTML(
                    tag="h2",
                    unsafe_innerHTML="Wanderlust",
                    style={"display": "inline-block"},
                )
            with solara.Row(
                gap="30px",
                style={"align-items": "center"},
                classes=["link-container"],
                justify="end",
            ):
                with solara.Row(gap="5px", style={"align-items": "center"}):
                    solara.Text("Source Code:", style="font-weight: bold;")
                    # target="_blank" links are still easiest to do via ipyvuetify
                    with solara.v.Btn(
                        icon=True,
                        tag="a",
                        attributes={
                            "href": "https://github.com/widgetti/wanderlust",
                            "title": "Wanderlust Source Code",
                            "target": "_blank",
                        },
                    ):
                        solara.v.Icon(children=["mdi-github-circle"])
                with solara.Row(gap="5px", style={"align-items": "center"}):
                    solara.Text("Powered by Solara:", style="font-weight: bold;")
                    with solara.v.Btn(
                        icon=True,
                        tag="a",
                        attributes={
                            "href": "https://solara.dev/",
                            "title": "Solara",
                            "target": "_blank",
                        },
                    ):
                        solara.HTML(
                            tag="img",
                            attributes={
                                "src": "https://solara.dev/static/public/logo.svg",
                                "width": "24px",
                            },
                        )
                    with solara.v.Btn(
                        icon=True,
                        tag="a",
                        attributes={
                            "href": "https://github.com/widgetti/solara",
                            "title": "Solara Source Code",
                            "target": "_blank",
                        },
                    ):
                        solara.v.Icon(children=["mdi-github-circle"])

        with solara.Row(
            justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
        ):
            ChatInterface()
            with solara.Column(classes=["map-container"]):
                Map()

        solara.Style(app_style)