wanderlust / wanderlust.py
Iisakki Rotko
feat: links on top bar
b243d3f
raw
history blame
11.6 kB
import json
import os
import time
from pathlib import Path
import ipyleaflet
from openai import NotFoundError, OpenAI
from openai.types.beta import Thread
import solara
HERE = Path(__file__).parent
center_default = (0, 0)
zoom_default = 2
messages = solara.reactive([])
zoom_level = solara.reactive(zoom_default)
center = solara.reactive(center_default)
markers = solara.reactive([])
url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model = "gpt-4-1106-preview"
app_style = (HERE / "style.css").read_text()
# Declare tools for openai assistant to use
tools = [
{
"type": "function",
"function": {
"name": "update_map",
"description": "Update map to center on a particular location",
"parameters": {
"type": "object",
"properties": {
"longitude": {
"type": "number",
"description": "Longitude of the location to center the map on",
},
"latitude": {
"type": "number",
"description": "Latitude of the location to center the map on",
},
"zoom": {
"type": "integer",
"description": "Zoom level of the map",
},
},
"required": ["longitude", "latitude", "zoom"],
},
},
},
{
"type": "function",
"function": {
"name": "add_marker",
"description": "Add marker to the map",
"parameters": {
"type": "object",
"properties": {
"longitude": {
"type": "number",
"description": "Longitude of the location to the marker",
},
"latitude": {
"type": "number",
"description": "Latitude of the location to the marker",
},
"label": {
"type": "string",
"description": "Text to display on the marker",
},
},
"required": ["longitude", "latitude", "label"],
},
},
},
]
def update_map(longitude, latitude, zoom):
center.set((latitude, longitude))
zoom_level.set(zoom)
return "Map updated"
def add_marker(longitude, latitude, label):
markers.set(markers.value + [{"location": (latitude, longitude), "label": label}])
return "Marker added"
functions = {
"update_map": update_map,
"add_marker": add_marker,
}
def assistant_tool_call(tool_call):
# actually executes the tool call the OpenAI assistant wants to perform
function = tool_call.function
name = function.name
arguments = json.loads(function.arguments)
return_value = functions[name](**arguments)
tool_outputs = {
"tool_call_id": tool_call.id,
"output": return_value,
}
return tool_outputs
@solara.component
def Map():
ipyleaflet.Map.element( # type: ignore
zoom=zoom_level.value,
center=center.value,
scroll_wheel_zoom=True,
layers=[
ipyleaflet.TileLayer.element(url=url),
*[
ipyleaflet.Marker.element(location=k["location"], draggable=False)
for k in markers.value
],
],
)
@solara.component
def ChatMessage(message):
with solara.Row(style={"align-items": "flex-start"}):
# Catch "messages" that are actually tool calls
if isinstance(message, dict):
icon = "mdi-map" if message["output"] == "Map updated" else "mdi-map-marker"
solara.v.Icon(children=[icon], style_="padding-top: 10px;")
solara.Markdown(message["output"])
elif message.role == "user":
solara.Text(message.content[0].text.value, style={"font-weight": "bold;"})
elif message.role == "assistant":
if message.content[0].text.value:
solara.v.Icon(
children=["mdi-compass-outline"], style_="padding-top: 10px;"
)
solara.Markdown(message.content[0].text.value)
elif message.content.tool_calls:
solara.v.Icon(children=["mdi-map"], style_="padding-top: 10px;")
solara.Markdown("*Calling map functions*")
else:
solara.v.Icon(
children=["mdi-compass-outline"], style_="padding-top: 10px;"
)
solara.Preformatted(repr(message))
else:
solara.v.Icon(children=["mdi-compass-outline"], style_="padding-top: 10px;")
solara.Preformatted(repr(message))
@solara.component
def ChatBox(children=[]):
# this uses a flexbox with column-reverse to reverse the order of the messages
# if we now also reverse the order of the messages, we get the correct order
# but the scroll position is at the bottom of the container automatically
with solara.Column(style={"flex-grow": "1"}):
solara.Style(
"""
.chat-box > :last-child{
padding-top: 7.5vh;
}
"""
)
# The height works effectively as `min-height`, since flex will grow the container to fill the available space
solara.Column(
style={
"flex-grow": "1",
"overflow-y": "auto",
"height": "100px",
"flex-direction": "column-reverse",
},
classes=["chat-box"],
children=list(reversed(children)),
)
@solara.component
def ChatInterface():
prompt = solara.use_reactive("")
run_id: solara.Reactive[str] = solara.use_reactive(None)
thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
def add_message(value: str):
if value == "":
return
prompt.set("")
new_message = openai.beta.threads.messages.create(
thread_id=thread.id, content=value, role="user"
)
messages.set([*messages.value, new_message])
run_id.value = openai.beta.threads.runs.create(
thread_id=thread.id,
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
tools=tools,
).id
def poll():
if not run_id.value:
return
completed = False
while not completed:
try:
run = openai.beta.threads.runs.retrieve(
run_id.value, thread_id=thread.id
)
# Above will raise NotFoundError when run creation is still in progress
except NotFoundError:
continue
if run.status == "requires_action":
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
tool_output = assistant_tool_call(tool_call)
tool_outputs.append(tool_output)
messages.set([*messages.value, tool_output])
openai.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run_id.value,
tool_outputs=tool_outputs,
)
if run.status == "completed":
messages.set(
[
*messages.value,
openai.beta.threads.messages.list(thread.id).data[0],
]
)
run_id.set(None)
completed = True
time.sleep(0.1)
result = solara.use_thread(poll, dependencies=[run_id.value])
# Create DOM for chat interface
with solara.Column(
classes=["chat-interface"],
):
if len(messages.value) > 0:
with ChatBox():
for message in messages.value:
ChatMessage(message)
with solara.Column():
solara.InputText(
label="Ask your question here",
value=prompt,
style={"flex-grow": "1"},
on_value=add_message,
disabled=result.state == solara.ResultState.RUNNING,
)
solara.ProgressLinear(result.state == solara.ResultState.RUNNING)
if result.state == solara.ResultState.ERROR:
solara.Error(repr(result.error))
@solara.component
def Page():
with solara.Column(
classes=["ui-container"],
gap="5vh",
):
with solara.Row(justify="space-between"):
with solara.Row(gap="10px", style={"align-items": "center"}):
solara.v.Icon(children=["mdi-compass-rose"], size="36px")
solara.HTML(
tag="h2",
unsafe_innerHTML="Wanderlust",
style={"display": "inline-block"},
)
with solara.Row(
gap="30px",
style={"align-items": "center"},
classes=["link-container"],
justify="end",
):
with solara.Row(gap="5px", style={"align-items": "center"}):
solara.Text("Source Code:", style="font-weight: bold;")
# target="_blank" links are still easiest to do via ipyvuetify
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://github.com/widgetti/wanderlust",
"title": "Wanderlust Source Code",
"target": "_blank",
},
):
solara.v.Icon(children=["mdi-github-circle"])
with solara.Row(gap="5px", style={"align-items": "center"}):
solara.Text("Powered by Solara:", style="font-weight: bold;")
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://solara.dev/",
"title": "Solara",
"target": "_blank",
},
):
solara.HTML(
tag="img",
attributes={
"src": "https://solara.dev/static/public/logo.svg",
"width": "24px",
},
)
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://github.com/widgetti/solara",
"title": "Solara Source Code",
"target": "_blank",
},
):
solara.v.Icon(children=["mdi-github-circle"])
with solara.Row(
justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
):
ChatInterface()
with solara.Column(classes=["map-container"]):
Map()
solara.Style(app_style)