Spaces:
Runtime error
Runtime error
import re | |
import os | |
import panel as pn | |
from mistralai.async_client import MistralAsyncClient | |
from mistralai.models.chat_completion import ChatMessage | |
from panel.io.mime_render import exec_with_return | |
pn.extension("codeeditor", sizing_mode="stretch_width") | |
LLM_MODEL = "mistral-small" | |
SYSTEM_MESSAGE = ChatMessage( | |
role="system", | |
content=( | |
"You are a renowned data visualization expert " | |
"with a strong background in matplotlib. " | |
"Your primary goal is to assist the user " | |
"in edit the code based on user request " | |
"using best practices. Simply provide code " | |
"in code fences (```python). You must have `fig` " | |
"as the last line of code" | |
), | |
) | |
USER_CONTENT_FORMAT = """ | |
Request: | |
{content} | |
Code: | |
```python | |
{code} | |
``` | |
""".strip() | |
DEFAULT_MATPLOTLIB = """ | |
import numpy as np | |
import matplotlib.pyplot as plt | |
fig = plt.figure() | |
ax = plt.axes(title="Plot Title", xlabel="X Label", ylabel="Y Label") | |
x = np.linspace(1, 10) | |
y = np.sin(x) | |
z = np.cos(x) | |
c = np.log(x) | |
ax.plot(x, y, c="blue", label="sin") | |
ax.plot(x, z, c="orange", label="cos") | |
img = ax.scatter(x, c, c=c, label="log") | |
plt.colorbar(img, label="Colorbar") | |
plt.legend() | |
# must have fig at the end! | |
fig | |
""".strip() | |
async def callback(content: str, user: str, instance: pn.chat.ChatInterface): | |
# system | |
messages = [SYSTEM_MESSAGE] | |
# history | |
messages.extend([ChatMessage(**message) for message in instance.serialize()[1:-1]]) | |
# new user contents | |
user_content = USER_CONTENT_FORMAT.format( | |
content=content, code=code_editor.value | |
) | |
messages.append(ChatMessage(role="user", content=user_content)) | |
# stream LLM tokens | |
message = "" | |
async for chunk in client.chat_stream(model=LLM_MODEL, messages=messages): | |
if chunk.choices[0].delta.content is not None: | |
message += chunk.choices[0].delta.content | |
yield message | |
# extract code | |
llm_code = re.findall(r"```python\n(.*)\n```", message, re.DOTALL)[0] | |
if llm_code.splitlines()[-1].strip() != "fig": | |
llm_code += "\nfig" | |
code_editor.value = llm_code | |
def update_plot(event): | |
matplotlib_pane.object = exec_with_return(event.new) | |
# instantiate widgets and panes | |
client = MistralAsyncClient(api_key=os.environ["MISTRAL_API_KEY"]) | |
chat_interface = pn.chat.ChatInterface( | |
callback=callback, | |
show_clear=False, | |
show_undo=False, | |
show_button_name=False, | |
message_params=dict( | |
show_reaction_icons=False, | |
show_copy_icon=False, | |
), | |
height=700, | |
callback_exception="verbose", | |
) | |
matplotlib_pane = pn.pane.Matplotlib( | |
exec_with_return(DEFAULT_MATPLOTLIB), | |
sizing_mode="stretch_both", | |
tight=True, | |
) | |
code_editor = pn.widgets.CodeEditor( | |
value=DEFAULT_MATPLOTLIB, | |
language="python", | |
sizing_mode="stretch_both", | |
) | |
# watch for code changes | |
code_editor.param.watch(update_plot, "value") | |
# lay them out | |
tabs = pn.Tabs( | |
("Plot", matplotlib_pane), | |
("Code", code_editor), | |
) | |
sidebar = [chat_interface] | |
main = [tabs] | |
template = pn.template.FastListTemplate( | |
sidebar=sidebar, | |
main=main, | |
sidebar_width=600, | |
main_layout=None, | |
accent_base_color="#fd7000", | |
header_background="#fd7000", | |
title="Chat with Plot" | |
) | |
template.servable() | |