wanderlust / wanderlust.py
Iisakki Rotko
fix: It's pretty now
2fb381b
raw
history blame
8.61 kB
import json
import os
import ipyleaflet
import openai
import solara
center_default = (0, 0)
zoom_default = 2
messages_default = []
messages = solara.reactive(messages_default)
zoom_level = solara.reactive(zoom_default)
center = solara.reactive(center_default)
markers = solara.reactive([])
url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
openai.api_key = os.getenv("OPENAI_API_KEY")
model = "gpt-4-1106-preview"
function_descriptions = [
{
"type": "function",
"function": {
"name": "update_map",
"description": "Update map to center on a particular location",
"parameters": {
"type": "object",
"properties": {
"longitude": {
"type": "number",
"description": "Longitude of the location to center the map on",
},
"latitude": {
"type": "number",
"description": "Latitude of the location to center the map on",
},
"zoom": {
"type": "integer",
"description": "Zoom level of the map",
},
},
"required": ["longitude", "latitude", "zoom"],
},
},
},
{
"type": "function",
"function": {
"name": "add_marker",
"description": "Add marker to the map",
"parameters": {
"type": "object",
"properties": {
"longitude": {
"type": "number",
"description": "Longitude of the location to the marker",
},
"latitude": {
"type": "number",
"description": "Latitude of the location to the marker",
},
"label": {
"type": "string",
"description": "Text to display on the marker",
},
},
"required": ["longitude", "latitude", "label"],
},
},
},
]
def update_map(longitude, latitude, zoom):
print("update_map", longitude, latitude, zoom)
center.set((latitude, longitude))
zoom_level.set(zoom)
return "Map updated"
def add_marker(longitude, latitude, label):
markers.set(markers.value + [{"location": (latitude, longitude), "label": label}])
return "Marker added"
functions = {
"update_map": update_map,
"add_marker": add_marker,
}
def ai_call(tool_call):
function = tool_call["function"]
name = function["name"]
arguments = json.loads(function["arguments"])
return_value = functions[name](**arguments)
message = {
"role": "tool",
"tool_call_id": tool_call["id"],
"name": tool_call["function"]["name"],
"content": return_value,
}
return message
@solara.component
def Map():
print("Map", zoom_level.value, center.value, markers.value)
ipyleaflet.Map.element( # type: ignore
zoom=zoom_level.value,
# on_zoom=zoom_level.set,
center=center.value,
# on_center=center.set,
scroll_wheel_zoom=True,
layers=[
ipyleaflet.TileLayer.element(url=url),
*[
ipyleaflet.Marker.element(location=k["location"], draggable=False)
for k in markers.value
],
],
)
@solara.component
def ChatInterface():
prompt = solara.use_reactive("")
def add_message(value: str):
if value == "":
return
messages.set(messages.value + [{"role": "user", "content": value}])
prompt.set("")
def ask():
if not messages.value:
return
last_message = messages.value[-1]
if last_message["role"] == "user" or last_message["role"] == "tool":
completion = openai.ChatCompletion.create(
model=model,
messages=messages.value,
# Add function calling
tools=function_descriptions,
tool_choice="auto",
)
output = completion.choices[0].message
print("received", output)
try:
handled_messages = handle_message(output)
messages.value = [*messages.value, output, *handled_messages]
except Exception as e:
print("errr", e)
def handle_message(message):
print("handle", message)
messages = []
if message["role"] == "assistant":
tools_calls = message.get("tool_calls", [])
for tool_call in tools_calls:
messages.append(ai_call(tool_call))
return messages
def handle_initial():
print("handle initial", messages.value)
for message in messages.value:
handle_message(message)
solara.use_effect(handle_initial, [])
result = solara.use_thread(ask, dependencies=[messages.value])
with solara.Column(
style={"height": "100%", "width": "38vw", "justify-content": "center"},
classes=["chat-interface"],
):
if len(messages.value) > 0:
with solara.Column(style={"flex-grow": "1", "overflow-y": "auto"}):
for message in messages.value:
if message["role"] == "user":
solara.Text(
message["content"], classes=["chat-message", "user-message"]
)
elif message["role"] == "assistant":
if message["content"]:
solara.Markdown(message["content"])
elif message["tool_calls"]:
solara.Markdown("*Calling map functions*")
else:
solara.Preformatted(
repr(message),
classes=["chat-message", "assistant-message"],
)
elif message["role"] == "tool":
pass # no need to display
else:
solara.Preformatted(
repr(message), classes=["chat-message", "assistant-message"]
)
# solara.Text(message, classes=["chat-message"])
with solara.Column():
solara.InputText(
label="Ask your ",
value=prompt,
style={"flex-grow": "1"},
on_value=add_message,
disabled=result.state == solara.ResultState.RUNNING,
)
solara.ProgressLinear(result.state == solara.ResultState.RUNNING)
if result.state == solara.ResultState.ERROR:
solara.Error(repr(result.error))
# solara.Text("Thinking...")
# solara.Button("Send", on_click=lambda: messages.set(messages.value + [message_input.value]))
@solara.component
def Page():
reset_counter, set_reset_counter = solara.use_state(0)
print("reset", reset_counter, f"chat-{reset_counter}")
def reset_ui():
set_reset_counter(reset_counter + 1)
def save():
with open("log.json", "w") as f:
json.dump(messages.value, f)
def load():
with open("log.json", "r") as f:
messages.set(json.load(f))
reset_ui()
with solara.Column(style={"flex-grow": "1"}, gap=0):
with solara.AppBar():
solara.Button("Save", on_click=save)
solara.Button("Load", on_click=load)
solara.Button("Soft reset", on_click=reset_ui)
with solara.Row(style={"height": "100%"}, justify="space-between"):
ChatInterface().key(f"chat-{reset_counter}")
with solara.Column(style={"width": "58vw", "justify-content": "center"}):
Map() # .key(f"map-{reset_counter}")
solara.Style(
"""
.jupyter-widgets.leaflet-widgets{
height: 100%;
}
.solara-autorouter-content{
display: flex;
flex-direction: column;
justify-content: stretch;
}
"""
)
# TODO: custom layout
# @solara.component
# def Layout(children):
# with solara.v.AppBar():
# with solara.Column(children=children):
# pass