import solara
import shutil
from transformers_js_py import import_transformers_js
import numpy as np
import js
from pathlib import Path
clicks = solara.reactive(0)
input_text = solara.reactive("How are you doing today?")
progress_tokenizer = solara.reactive(0)
progress_model = solara.reactive(0)
messages = solara.reactive([])
js.pyodide.setDebug(True)
model_id = 'Xenova/Phi-3-mini-4k-instruct'
# hack to fetch the file from the webworker, will break in the future
pageId = js._pageId
url = 'https://py.cafe/_app/static/public/transformers3.js?_pageId=' + pageId
public = Path(__file__).parent.parent / 'public'
if not public.exists():
public.mkdir(parents=True, exist_ok=True)
name = "transformers3.js"
jsfile = Path(name)
shutil.copyfile(jsfile, public / name)
# bug in solara, after saving, refresh the browser preview
# otherwise the old reference to the function is kept
@solara.lab.task
async def run( message: str):
transformers = await import_transformers_js(url)
AutoTokenizer = transformers.AutoTokenizer
AutoModelForCausalLM = transformers.AutoModelForCausalLM
def progress_callback(*args):
info = args[0].to_py()
if info['status'] == 'progress':
progress_tokenizer.value = info['progress']
tokenizer = await AutoTokenizer.from_pretrained(model_id, {
"legacy": True,
"progress_callback": progress_callback,
})
def progress_callback(*args):
info = args[0].to_py()
if info['status'] == 'progress':
progress_model.value = info['progress']
model = await AutoModelForCausalLM.from_pretrained(model_id, {
"dtype": 'q4',
"device": 'webgpu',
"use_external_data_format": True,
"progress_callback": progress_callback,
});
messages.value = [*messages.value, { "role": "user", "content": message }]
print("messages.value", messages.value)
inputs = tokenizer.apply_chat_template(messages.value, {
"add_generation_prompt": True,
"return_dict": True,
});
from pyodide.ffi import create_proxy
# is this needed? Seems to have no effect
# inputs["input_ids"] = create_proxy(inputs["input_ids"])
# inputs["attention_mask"] = create_proxy(inputs["attention_mask"])
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
print("attention_mask", attention_mask, attention_mask.dims)
print("inputs", inputs, input_ids.dims)
print("inputs", inputs, input_ids.tolist())
# convert to uint32, otherwise pyodide complaints with
# Unknown typed array type 'BigInt64Array'. This is a problem with Pyodide, please open an issue about it here: https://github.com/pyodide/pyodide/issues/new
if 0:
if 1:
ar = np.array(input_ids.tolist(), dtype=np.uint32)
inputs["input_ids"] = transformers.Tensor(ar.flatten()).unsqueeze(1)
ar = np.array(attention_mask.tolist(), dtype=np.uint32)
inputs["attention_mask"] = transformers.Tensor(ar.flatten()).unsqueeze(1)
# but we do get after calling model.generate()
# TypeError: Cannot mix BigInt and other types, use explicit conversions
else:
# Not sure why we cannot do this:
# We get this error:
# pyodide.ffi.JsException: TypeError: Cannot convert a BigInt value to a number
inputs["input_ids"] = inputs["input_ids"].to('uint32')
inputs["attention_mask"] = inputs["attention_mask"].to('uint32')
TextStreamer = transformers.TextStreamer
# class MyStreamer(TextStreamer):
# def on_finalized_text(text):
# self.cb(text)
# streamer = TextStreamer.new(create_proxy(tokenizer), {
# "skip_prompt": True,
# "skip_special_tokens": True,
# })
# stopping_criteria = transformers.StoppingCriteria.new();
arg = {
**inputs,
"max_new_tokens": 512,
# "streamer": streamer,
# "stopping_criteria": stopping_criteria,
}
print("arg", arg)
outputs = await model.generate((arg)); # error happens here
# print(tokenizer, streamer, stopping_criteria, inputs)
print(tok, model)
return "dummpy"
@solara.lab.task
async def has_shader_f16():
if not js.navigator.gpu:
return False
adapter = await js.navigator.gpu.requestAdapter();
if not adapter:
return False
return adapter.features.has('shader-f16');
@solara.component
def Page():
solara.use_memo(lambda: has_shader_f16(), [])
if has_shader_f16.pending:
solara.ProgressLinear()
else:
if has_shader_f16.value:
with solara.Card("Test LLM"):
solara.ProgressLinear(run.pending)
with solara.Div():
solara.InputText(label="Input", value=input_text)
solara.Button(label=f"Respond", on_click=lambda: run(input_text.value), color="primary", filled=True)
if run.finished:
solara.Text(repr(run.value))
else:
solara.Error("no fp16 support")