maartenbreddels commited on
Commit
1ec2294
0 Parent(s):

initial commit

Browse files
Files changed (1) hide show
  1. wanderlust.py +228 -0
wanderlust.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import ipyleaflet
5
+ import openai
6
+
7
+ import solara
8
+
9
+ center_default = (53.2305799, 6.5323552)
10
+ zoom_default = 2
11
+
12
+ messages_default = []
13
+
14
+ messages = solara.reactive(messages_default)
15
+ zoom_level = solara.reactive(zoom_default)
16
+ center = solara.reactive(center_default)
17
+ markers = solara.reactive([])
18
+
19
+ url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
20
+ openai.api_key = os.getenv("OPENAI_API_KEY")
21
+ model = "gpt-4-1106-preview"
22
+
23
+
24
+ function_descriptions = [
25
+ {
26
+ "type": "function",
27
+ "function": {
28
+ "name": "update_map",
29
+ "description": "Update map to center on a particular location",
30
+ "parameters": {
31
+ "type": "object",
32
+ "properties": {
33
+ "longitude": {"type": "number", "description": "Longitude of the location to center the map on"},
34
+ "latitude": {
35
+ "type": "number",
36
+ "description": "Latitude of the location to center the map on",
37
+ },
38
+ "zoom": {
39
+ "type": "integer",
40
+ "description": "Zoom level of the map",
41
+ },
42
+ },
43
+ "required": ["longitude", "latitude", "zoom"],
44
+ },
45
+ },
46
+ },
47
+ {
48
+ "type": "function",
49
+ "function": {
50
+ "name": "add_marker",
51
+ "description": "Add marker to the map",
52
+ "parameters": {
53
+ "type": "object",
54
+ "properties": {
55
+ "longitude": {"type": "number", "description": "Longitude of the location to the marker"},
56
+ "latitude": {
57
+ "type": "number",
58
+ "description": "Latitude of the location to the marker",
59
+ },
60
+ "label": {
61
+ "type": "string",
62
+ "description": "Text to display on the marker",
63
+ },
64
+ },
65
+ "required": ["longitude", "latitude", "label"],
66
+ },
67
+ },
68
+ },
69
+ ]
70
+
71
+
72
+ def update_map(longitude, latitude, zoom):
73
+ print("update_map", longitude, latitude, zoom)
74
+ center.set((latitude, longitude))
75
+ zoom_level.set(zoom)
76
+ return "Map updated"
77
+
78
+
79
+ def add_marker(longitude, latitude, label):
80
+ markers.set(markers.value + [{"location": (latitude, longitude), "label": label}])
81
+ return "Marker added"
82
+
83
+
84
+ functions = {
85
+ "update_map": update_map,
86
+ "add_marker": add_marker,
87
+ }
88
+
89
+
90
+ def ai_call(tool_call):
91
+ function = tool_call["function"]
92
+ name = function["name"]
93
+ arguments = json.loads(function["arguments"])
94
+ return_value = functions[name](**arguments)
95
+ message = {
96
+ "role": "tool",
97
+ "tool_call_id": tool_call["id"],
98
+ "name": tool_call["function"]["name"],
99
+ "content": return_value,
100
+ }
101
+ return message
102
+
103
+
104
+ @solara.component
105
+ def Map():
106
+ print("Map", zoom_level.value, center.value, markers.value)
107
+ ipyleaflet.Map.element( # type: ignore
108
+ zoom=zoom_level.value,
109
+ # on_zoom=zoom_level.set,
110
+ center=center.value,
111
+ # on_center=center.set,
112
+ scroll_wheel_zoom=True,
113
+ layers=[
114
+ ipyleaflet.TileLayer.element(url=url),
115
+ *[ipyleaflet.Marker.element(location=k["location"], draggable=False) for k in markers.value],
116
+ ],
117
+ )
118
+
119
+
120
+ @solara.component
121
+ def ChatInterface():
122
+ prompt = solara.use_reactive("")
123
+
124
+ def add_message(value: str):
125
+ if value == "":
126
+ return
127
+ messages.set(messages.value + [{"role": "user", "content": value}])
128
+ prompt.set("")
129
+
130
+ def ask():
131
+ if not messages.value:
132
+ return
133
+ last_message = messages.value[-1]
134
+ if last_message["role"] == "user" or last_message["role"] == "tool":
135
+ completion = openai.ChatCompletion.create(
136
+ model=model,
137
+ messages=messages.value,
138
+ # Add function calling
139
+ tools=function_descriptions,
140
+ tool_choice="auto",
141
+ )
142
+
143
+ output = completion.choices[0].message
144
+ print("received", output)
145
+ try:
146
+ handled_messages = handle_message(output)
147
+ messages.value = [*messages.value, output, *handled_messages]
148
+
149
+ except Exception as e:
150
+ print("errr", e)
151
+
152
+ def handle_message(message):
153
+ print("handle", message)
154
+ messages = []
155
+ if message["role"] == "assistant":
156
+ tools_calls = message.get("tool_calls", [])
157
+ for tool_call in tools_calls:
158
+ messages.append(ai_call(tool_call))
159
+ return messages
160
+
161
+ def handle_initial():
162
+ print("handle initial", messages.value)
163
+ for message in messages.value:
164
+ handle_message(message)
165
+
166
+ solara.use_effect(handle_initial, [])
167
+ result = solara.use_thread(ask, dependencies=[messages.value])
168
+ with solara.Column(style={"height": "100%"}):
169
+ with solara.Column(style={"height": "100%", "overflow-y": "auto"}, classes=["chat-interface"]):
170
+ for message in messages.value:
171
+ if message["role"] == "user":
172
+ solara.Text(message["content"], classes=["chat-message", "user-message"])
173
+ elif message["role"] == "assistant":
174
+ if message["content"]:
175
+ solara.Markdown(message["content"])
176
+ elif message["tool_calls"]:
177
+ solara.Markdown("*Calling map functions*")
178
+ else:
179
+ solara.Preformatted(repr(message), classes=["chat-message", "assistant-message"])
180
+ elif message["role"] == "tool":
181
+ pass # no need to display
182
+ else:
183
+ solara.Preformatted(repr(message), classes=["chat-message", "assistant-message"])
184
+ # solara.Text(message, classes=["chat-message"])
185
+ with solara.Column():
186
+ solara.InputText(
187
+ label="Ask your ", value=prompt, style={"flex-grow": "1"}, on_value=add_message, disabled=result.state == solara.ResultState.RUNNING
188
+ )
189
+ solara.ProgressLinear(result.state == solara.ResultState.RUNNING)
190
+ if result.state == solara.ResultState.ERROR:
191
+ solara.Error(repr(result.error))
192
+ # solara.Text("Thinking...")
193
+ # solara.Button("Send", on_click=lambda: messages.set(messages.value + [message_input.value]))
194
+
195
+
196
+ @solara.component
197
+ def Page():
198
+ reset_counter, set_reset_counter = solara.use_state(0)
199
+ print("reset", reset_counter, f"chat-{reset_counter}")
200
+
201
+ def reset_ui():
202
+ set_reset_counter(reset_counter + 1)
203
+
204
+ def save():
205
+ with open("log.json", "w") as f:
206
+ json.dump(messages.value, f)
207
+
208
+ def load():
209
+ with open("log.json", "r") as f:
210
+ messages.set(json.load(f))
211
+ reset_ui()
212
+
213
+ with solara.Column(style={"height": "100%"}):
214
+ with solara.AppBar():
215
+ solara.Button("Save", on_click=save)
216
+ solara.Button("Load", on_click=load)
217
+ solara.Button("Soft reset", on_click=reset_ui)
218
+ with solara.Columns(style={"height": "100%"}):
219
+ ChatInterface().key(f"chat-{reset_counter}")
220
+ Map() # .key(f"map-{reset_counter}")
221
+
222
+
223
+ # TODO: custom layout
224
+ # @solara.component
225
+ # def Layout(children):
226
+ # with solara.v.AppBar():
227
+ # with solara.Column(children=children):
228
+ # pass