wanderlust / wanderlust.py
maartenbreddels's picture
feat: adaptive label for InputText
56ce50a
raw
history blame
12 kB
import json
import os
import time
from pathlib import Path
import ipyleaflet
from openai import NotFoundError, OpenAI
from openai.types.beta import Thread
import solara
HERE = Path(__file__).parent
center_default = (0, 0)
zoom_default = 2
messages = solara.reactive([])
zoom_level = solara.reactive(zoom_default)
center = solara.reactive(center_default)
markers = solara.reactive([])
url = ipyleaflet.basemaps.OpenStreetMap.Mapnik.build_url()
openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
model = "gpt-4-1106-preview"
app_style = (HERE / "style.css").read_text()
# Declare tools for openai assistant to use
tools = [
{
"type": "function",
"function": {
"name": "update_map",
"description": "Update map to center on a particular location",
"parameters": {
"type": "object",
"properties": {
"longitude": {
"type": "number",
"description": "Longitude of the location to center the map on",
},
"latitude": {
"type": "number",
"description": "Latitude of the location to center the map on",
},
"zoom": {
"type": "integer",
"description": "Zoom level of the map",
},
},
"required": ["longitude", "latitude", "zoom"],
},
},
},
{
"type": "function",
"function": {
"name": "add_marker",
"description": "Add marker to the map",
"parameters": {
"type": "object",
"properties": {
"longitude": {
"type": "number",
"description": "Longitude of the location to the marker",
},
"latitude": {
"type": "number",
"description": "Latitude of the location to the marker",
},
"label": {
"type": "string",
"description": "Text to display on the marker",
},
},
"required": ["longitude", "latitude", "label"],
},
},
},
]
def update_map(longitude, latitude, zoom):
center.set((latitude, longitude))
zoom_level.set(zoom)
return "Map updated"
def add_marker(longitude, latitude, label):
markers.set(markers.value + [{"location": (latitude, longitude), "label": label}])
return "Marker added"
functions = {
"update_map": update_map,
"add_marker": add_marker,
}
def assistant_tool_call(tool_call):
# actually executes the tool call the OpenAI assistant wants to perform
function = tool_call.function
name = function.name
arguments = json.loads(function.arguments)
return_value = functions[name](**arguments)
tool_outputs = {
"tool_call_id": tool_call.id,
"output": return_value,
}
return tool_outputs
@solara.component
def Map():
ipyleaflet.Map.element( # type: ignore
zoom=zoom_level.value,
center=center.value,
scroll_wheel_zoom=True,
layers=[
ipyleaflet.TileLayer.element(url=url),
*[
ipyleaflet.Marker.element(location=k["location"], draggable=False)
for k in markers.value
],
],
)
@solara.component
def ChatMessage(message):
with solara.Row(style={"align-items": "flex-start"}):
# Catch "messages" that are actually tool calls
if isinstance(message, dict):
icon = "mdi-map" if message["output"] == "Map updated" else "mdi-map-marker"
solara.v.Icon(children=[icon], style_="padding-top: 10px;")
solara.Markdown(message["output"])
elif message.role == "user":
solara.Text(message.content[0].text.value, style={"font-weight": "bold;"})
elif message.role == "assistant":
if message.content[0].text.value:
solara.v.Icon(
children=["mdi-compass-outline"], style_="padding-top: 10px;"
)
solara.Markdown(message.content[0].text.value)
elif message.content.tool_calls:
solara.v.Icon(children=["mdi-map"], style_="padding-top: 10px;")
solara.Markdown("*Calling map functions*")
else:
solara.v.Icon(
children=["mdi-compass-outline"], style_="padding-top: 10px;"
)
solara.Preformatted(repr(message))
else:
solara.v.Icon(children=["mdi-compass-outline"], style_="padding-top: 10px;")
solara.Preformatted(repr(message))
@solara.component
def ChatBox(children=[]):
# this uses a flexbox with column-reverse to reverse the order of the messages
# if we now also reverse the order of the messages, we get the correct order
# but the scroll position is at the bottom of the container automatically
with solara.Column(style={"flex-grow": "1"}):
solara.Style(
"""
.chat-box > :last-child{
padding-top: 7.5vh;
}
"""
)
# The height works effectively as `min-height`, since flex will grow the container to fill the available space
solara.Column(
style={
"flex-grow": "1",
"overflow-y": "auto",
"height": "100px",
"flex-direction": "column-reverse",
},
classes=["chat-box"],
children=list(reversed(children)),
)
@solara.component
def ChatInterface():
prompt = solara.use_reactive("")
run_id: solara.Reactive[str] = solara.use_reactive(None)
# Create a thread to hold the conversation only once when this component is created
thread: Thread = solara.use_memo(openai.beta.threads.create, dependencies=[])
def add_message(value: str):
if value == "":
return
prompt.set("")
new_message = openai.beta.threads.messages.create(
thread_id=thread.id, content=value, role="user"
)
messages.set([*messages.value, new_message])
# this creates a new run for the thread
# also also triggers a rerender (since run_id.value changes)
# which will trigger the poll function blow to start in a thread
run_id.value = openai.beta.threads.runs.create(
thread_id=thread.id,
assistant_id="asst_RqVKAzaybZ8un7chIwPCIQdH",
tools=tools,
).id
def poll():
if not run_id.value:
return
completed = False
while not completed:
try:
run = openai.beta.threads.runs.retrieve(
run_id.value, thread_id=thread.id
)
# Above will raise NotFoundError when run creation is still in progress
except NotFoundError:
continue
if run.status == "requires_action":
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
tool_output = assistant_tool_call(tool_call)
tool_outputs.append(tool_output)
messages.set([*messages.value, tool_output])
openai.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run_id.value,
tool_outputs=tool_outputs,
)
if run.status == "completed":
messages.set(
[
*messages.value,
openai.beta.threads.messages.list(thread.id).data[0],
]
)
run_id.set(None)
completed = True
time.sleep(0.1)
# run/restart a thread any time the run_id changes
result = solara.use_thread(poll, dependencies=[run_id.value])
# Create DOM for chat interface
with solara.Column(classes=["chat-interface"]):
if len(messages.value) > 0:
with ChatBox():
for message in messages.value:
ChatMessage(message)
with solara.Column():
solara.InputText(
label="Where do you want to go?"
if len(messages.value) == 0
else "Ask more question here",
value=prompt,
style={"flex-grow": "1"},
on_value=add_message,
disabled=result.state == solara.ResultState.RUNNING,
)
solara.ProgressLinear(result.state == solara.ResultState.RUNNING)
if result.state == solara.ResultState.ERROR:
solara.Error(repr(result.error))
@solara.component
def Page():
with solara.Column(
classes=["ui-container"],
gap="5vh",
):
with solara.Row(justify="space-between"):
with solara.Row(gap="10px", style={"align-items": "center"}):
solara.v.Icon(children=["mdi-compass-rose"], size="36px")
solara.HTML(
tag="h2",
unsafe_innerHTML="Wanderlust",
style={"display": "inline-block"},
)
with solara.Row(
gap="30px",
style={"align-items": "center"},
classes=["link-container"],
justify="end",
):
with solara.Row(gap="5px", style={"align-items": "center"}):
solara.Text("Source Code:", style="font-weight: bold;")
# target="_blank" links are still easiest to do via ipyvuetify
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://github.com/widgetti/wanderlust",
"title": "Wanderlust Source Code",
"target": "_blank",
},
):
solara.v.Icon(children=["mdi-github-circle"])
with solara.Row(gap="5px", style={"align-items": "center"}):
solara.Text("Powered by Solara:", style="font-weight: bold;")
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://solara.dev/",
"title": "Solara",
"target": "_blank",
},
):
solara.HTML(
tag="img",
attributes={
"src": "https://solara.dev/static/public/logo.svg",
"width": "24px",
},
)
with solara.v.Btn(
icon=True,
tag="a",
attributes={
"href": "https://github.com/widgetti/solara",
"title": "Solara Source Code",
"target": "_blank",
},
):
solara.v.Icon(children=["mdi-github-circle"])
with solara.Row(
justify="space-between", style={"flex-grow": "1"}, classes=["container-row"]
):
ChatInterface()
with solara.Column(classes=["map-container"]):
Map()
solara.Style(app_style)