import io
import asyncio
import panel as pn
import param
from panel.custom import JSComponent
import re
import panel as pn
import matplotlib.pyplot as plt
from panel.io.mime_render import exec_with_return
plt.ioff()
plt.backend = "agg"
pn.extension("mathjax", "codeeditor", sizing_mode="stretch_width")
SYSTEM_MESSAGE = (
"You are an expert on matplotlib using the object "
"oriented API. To better understand, the user's request, "
"first concisely restate their goal in your own words and how you might approach it. "
"Then, modify the plotting code to fulfill it, only making targeted, accurate edits. "
"Please do not add/remove any unnecessary code that is not directly related. "
"Output the full, updated code in code fences (```python). "
)
SUGGESTION_PROMPTS = [
"Make the plot labels larger",
"Remove the axes spines",
"Change title to something more relevant",
"Separate into three subplots",
"Divide the colorbar into discrete levels"
]
USER_CONTENT_FORMAT = """
Request:
{content}
Code:
```python
{code}
```
""".strip()
DEFAULT_MATPLOTLIB = """
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = plt.axes()
ax.set_title("Plot Title")
ax.set_xlabel("X Label")
ax.set_ylabel("Y Label")
x = np.linspace(1, 10)
y = np.sin(x)
z = np.cos(x)
c = np.log(x)
ax.plot(x, y, c="blue", label="sin")
ax.plot(x, z, c="orange", label="cos")
img = ax.scatter(x, c, c=c, label="log")
ax.figure.colorbar(img, label="Colorbar")
ax.legend()
fig # keep this line
""".strip()
MODELS = {
"Mistral-7b-Instruct": "Mistral-7B-Instruct-v0.3-q4f16_1-MLC",
"Llama-3-8B-Instruct": "Llama-3-8B-Instruct-q4f16_1-MLC",
}
class WebLLM(JSComponent):
loaded = param.Boolean(
default=False,
doc="""
Whether the model is loaded.""",
)
history = param.Integer(default=3)
status = param.Dict(default={"text": "", "progress": 0})
load_model = param.Event()
model = param.Selector(
default="Mistral-7B-Instruct-v0.3-q4f16_1-MLC", objects=MODELS
)
running = param.Boolean(
default=False,
doc="""
Whether the LLM is currently running.""",
)
temperature = param.Number(
default=1,
bounds=(0, 2),
doc="""
Temperature of the model completions.""",
)
_esm = """
import * as webllm from "https://esm.run/@mlc-ai/web-llm";
const engines = new Map()
export async function render({ model }) {
model.on("msg:custom", async (event) => {
if (event.type === 'load') {
if (!engines.has(model.model)) {
const initProgressCallback = (status) => {
model.status = status
}
const mlc = await webllm.CreateMLCEngine(
model.model,
{initProgressCallback}
)
engines.set(model.model, mlc)
}
model.loaded = true
} else if (event.type === 'completion') {
const engine = engines.get(model.model)
if (engine == null) {
model.send_msg({'finish_reason': 'error'})
}
const chunks = await engine.chat.completions.create({
messages: event.messages,
temperature: model.temperature ,
stream: true,
})
model.running = true
for await (const chunk of chunks) {
if (!model.running) {
break
}
model.send_msg(chunk.choices[0])
}
}
})
}
"""
def __init__(self, **params):
super().__init__(**params)
if pn.state.location:
pn.state.location.sync(self, {"model": "model"})
self._buffer = []
status = self.param.status.rx()
self._menu = pn.Column(
pn.widgets.Select.from_param(self.param.model, sizing_mode="stretch_width"),
pn.widgets.FloatSlider.from_param(
self.param.temperature, sizing_mode="stretch_width"
),
pn.widgets.Button.from_param(
self.param.load_model,
sizing_mode="stretch_width",
disabled=self.param.loaded.rx().rx.or_(self.param.loading),
),
pn.indicators.Progress(
value=(status["progress"] * 100).rx.pipe(int),
visible=self.param.loading,
sizing_mode="stretch_width",
),
pn.pane.Markdown(status["text"], visible=self.param.loading),
)
@param.depends("load_model", watch=True)
def _load_model(self):
self.loading = True
chat_interface.loading = True
self._send_msg({"type": "load"})
@param.depends("loaded", watch=True)
def _loaded(self):
self.loading = False
chat_interface.loading = False
chat_interface.disabled = False
self._menu.visible = False
chat_interface.clear()
chat_interface.stream(
"Model is ready; ask me to edit the plot!",
user="Help",
footer_objects=[suggestion_buttons.clone()],
)
@param.depends("model", watch=True)
def _update_load_model(self):
self.loaded = False
def _handle_msg(self, msg):
if self.running:
self._buffer.insert(0, msg)
async def create_completion(self, msgs):
self._send_msg({"type": "completion", "messages": msgs})
latest = None
while True:
await asyncio.sleep(0.01)
if not self._buffer:
continue
choice = self._buffer.pop()
yield choice
reason = choice["finish_reason"]
if reason == "error":
raise RuntimeError("Model not loaded")
elif reason:
return
async def callback(self, contents: str, user: str, instance: pn.chat.ChatInterface):
if not self.loaded:
if self.loading:
yield pn.pane.Markdown(
f"## `{self.model}`\n\n" + self.param.status.rx()["text"]
)
else:
yield "Load the model"
return
self.running = False
self._buffer.clear()
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
*instance.serialize()[-3:],
]
messages[-1] = {
"role": "user",
"content": USER_CONTENT_FORMAT.format(
content=contents, code=code_editor.value
),
}
message = ""
async for chunk in llm.create_completion(messages):
message += chunk["delta"].get("content", "")
yield message
# extract code
llm_code = re.findall(r"```python\n(.*)\n```", message, re.DOTALL)
if not llm_code:
llm_code = re.findall(r"```\n(.*)", message, re.DOTALL)
if len(llm_code) > 1:
instance.send(f"Please provide the final version in code fences (```python).")
return
llm_code = llm_code[-1]
if llm_code.splitlines()[-1].strip() != "fig":
llm_code += "\nfig"
code_editor.value = llm_code.replace("```python", "").strip().strip("```")
@property
def menu(self):
return self._menu
async def use_suggestion(event):
button = event.obj
with button.param.update(loading=True):
contents = button.name
if event.new > 1: # prevent double clicks
return
chat_interface.send(contents)
def update_plot(code):
try:
stderr = io.StringIO()
out = exec_with_return(code, stderr=stderr)
if out is None:
raise RuntimeError
return out
except Exception:
# let the llm it failed
stderr.seek(0)
exc = stderr.read()
chat_interface.send(
f"Please take a moment to identify the error and fix it:"
f"\n```python\n{exc}\n```"
)
def load_model():
llm.param.trigger("load_model")
llm = WebLLM()
code_editor = pn.widgets.CodeEditor(
value=DEFAULT_MATPLOTLIB,
on_keyup=False,
language="python",
sizing_mode="stretch_both",
)
matplotlib_pane = pn.pane.Matplotlib(
object=pn.bind(update_plot, code_editor),
sizing_mode="stretch_both",
tight=True,
)
chat_interface = pn.chat.ChatInterface(
show_rerun=False,
show_clear=False,
show_button_name=False,
show_reaction_icons=False,
stylesheets=[
"""
:host(.chat-interface) {
height: calc(100vh - 100px);
}
"""
],
margin=0,
disabled=True,
callback=llm.callback,
help_text="First load the model, then ask me to edit the plot!",
)
suggestion_buttons = pn.FlexBox(
*[
pn.widgets.Button(
name=suggestion,
button_style="outline",
on_click=use_suggestion,
margin=5,
)
for suggestion in SUGGESTION_PROMPTS
],
margin=(5, 5),
)
# lay them out
tabs = pn.Tabs(
("Plot", matplotlib_pane),
("Code", code_editor),
)
sidebar = [llm.menu, chat_interface, llm]
main = [tabs]
template = pn.template.FastListTemplate(
sidebar=sidebar,
main=main,
sidebar_width=500,
main_layout=None,
accent_base_color="#fd7000",
header_background="#fd7000",
title="Chat with Plot",
)
template.servable()