import asyncio
import panel as pn
import param
from panel.custom import JSComponent, ESMEvent
pn.extension('mathjax', template='material')
MODELS = {
'SmolLM (130MB)': 'SmolLM-135M-Instruct-q4f16_1-MLC',
'TinyLlama-1.1B-Chat (675 MB)': 'TinyLlama-1.1B-Chat-v1.0-q4f16_1-MLC-1k',
"Qwen2.5-Coder-1.5B-Instruct-q4f16_1-MLC (0.9 GB)": "Qwen2.5-Coder-1.5B-Instruct-q4f16_1-MLC",
'Gemma-2b (1.9 GB)': 'gemma-2-2b-it-q4f16_1-MLC',
'Mistral-7b-Instruct (4.6 GB)': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',
'Llama-3.2-1B-Instruct (4.6 FB)': 'Llama-3.2-1B-Instruct-q4f16_1-MLC',
}
class WebLLM(JSComponent):
loaded = param.Boolean(default=False, doc="""
Whether the model is loaded.""")
history = param.Integer(default=3)
status = param.Dict(default={'text': '', 'progress': 0})
load_model = param.Event()
model = param.Selector(default='SmolLM-135M-Instruct-q4f16_1-MLC', objects=MODELS)
running = param.Boolean(default=False, doc="""
Whether the LLM is currently running.""")
temperature = param.Number(default=1, bounds=(0, 2), doc="""
The LLM temperature is a hyperparameter that regulates the randomness, or creativity, of the AI's responses.
A higher temperature value typically makes the output more diverse and creative but might also increase its
likelihood of straying from the context..""")
_esm = """
import * as webllm from "https://esm.run/@mlc-ai/web-llm";
const engines = new Map()
export async function render({ model }) {
model.on("msg:custom", async (event) => {
if (event.type === 'load') {
if (!engines.has(model.model)) {
const initProgressCallback = (status) => {
model.status = status
}
const mlc = await webllm.CreateMLCEngine(
model.model,
{initProgressCallback}
)
engines.set(model.model, mlc)
}
model.loaded = true
} else if (event.type === 'completion') {
const engine = engines.get(model.model)
if (engine == null) {
model.send_msg({'finish_reason': 'error'})
}
const chunks = await engine.chat.completions.create({
messages: event.messages,
temperature: model.temperature ,
stream: true,
})
model.running = true
for await (const chunk of chunks) {
if (!model.running) {
break
}
model.send_msg(chunk.choices[0])
}
}
})
}
"""
def __init__(self, **params):
super().__init__(**params)
if pn.state.location:
pn.state.location.sync(self, {'model': 'model'})
self._buffer = []
@param.depends('load_model', watch=True)
def _load_model(self):
self.loading = True
self._send_msg({'type': 'load'})
@param.depends('loaded', watch=True)
def _loaded(self):
self.loading = False
@param.depends('model', watch=True)
def _update_load_model(self):
self.loaded = False
def _handle_msg(self, msg):
if self.running:
self._buffer.insert(0, msg)
async def create_completion(self, msgs):
self._send_msg({'type': 'completion', 'messages': msgs})
latest = None
while True:
await asyncio.sleep(0.01)
if not self._buffer:
continue
choice = self._buffer.pop()
yield choice
reason = choice['finish_reason']
if reason == 'error':
raise RuntimeError('Model not loaded')
elif reason:
return
async def callback(self, contents: str, user: str):
if not self.loaded:
if self.loading:
yield pn.pane.Markdown(
f'## `{self.model}`\n\n' + self.param.status.rx()['text']
)
else:
yield 'Load the model'
return
self.running = False
self._buffer.clear()
message = ""
async for chunk in llm.create_completion([{'role': 'user', 'content': contents}]):
message += chunk['delta'].get('content', '')
yield message
def menu(self):
status = self.param.status.rx()
return pn.Column(
pn.widgets.Select.from_param(self.param.model, sizing_mode='stretch_width'),
pn.widgets.FloatSlider.from_param(self.param.temperature, sizing_mode='stretch_width'),
pn.widgets.Button.from_param(
self.param.load_model, sizing_mode='stretch_width',
disabled=self.param.loaded.rx().rx.or_(self.param.loading)
),
pn.indicators.Progress(
value=(status['progress']*100).rx.pipe(int), visible=self.param.loading,
sizing_mode='stretch_width'
),
pn.pane.Markdown(status['text'], visible=self.param.loading)
)
llm = WebLLM()
intro = pn.pane.Alert("""
`WebLLM` runs large-language models entirely in your browser.
When visiting the application the first time the model has
to be downloaded and loaded into memory, which may take
some time. Models are ordered by size (and capability),
e.g. SmolLLM is very quick to download but produces poor
quality output while Mistral-7b will take a while to
download but produces much higher quality output.
""".replace('\n', ' '))
pn.Column(
llm.menu(),
intro,
llm
).servable(area='sidebar')
chat_interface = pn.chat.ChatInterface(callback=llm.callback)
chat_interface.send(
"Load a model and start chatting.",
user="System",
respond=False,
)
llm.param.watch(lambda e: chat_interface.send(f'Loaded `{e.obj.model}`, start chatting!', user='System', respond=False), 'loaded')
pn.Row(chat_interface).servable(title='WebLLM')