# check out https://solara.dev/ for documentation
# or https://github.com/widgetti/solara/
# And check out https://py.cafe/maartenbreddels for more examples
import solara
from pydantic_ai import Agent
from pydantic_ai.models.mistral import MistralModel
from typing_extensions import TypedDict
from typing import List, cast
from dataclasses import dataclass
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
try:
import pycafe
reason = """You need a Mistral AI API key to run this example.
Get your [Mistral API key here](https://console.mistral.ai/api-keys/)
"""
mistral_api_key = pycafe.get_secret("MISTRAL_API_KEY", reason)
except ModuleNotFoundError:
# not on pycafe
import os
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
model = MistralModel('mistral-small-latest', api_key=mistral_api_key) if mistral_api_key else None
if mistral_api_key:
agent = Agent(
model,
system_prompt='Be concise, reply with one sentence.',
)
class DatabaseConn:
"""This is a fake database for example purposes.
In reality, you'd be connecting to an external database
(e.g. PostgreSQL) to get information about customers.
"""
@classmethod
async def customer_name(cls, *, id: int) -> str | None:
if id == 123:
return 'John'
@classmethod
async def customer_balance(cls, *, id: int, include_pending: bool) -> float:
if id == 123:
return 123.45
else:
raise ValueError('Customer not found')
@dataclass
class SupportDependencies:
customer_id: int
db: DatabaseConn
class SupportResult(BaseModel):
support_advice: str = Field(description='Advice returned to the customer')
block_card: bool = Field(description='Whether to block their')
risk: int = Field(description='Risk level of query', ge=0, le=10)
support_agent = Agent(
model,
deps_type=SupportDependencies,
result_type=SupportResult,
system_prompt=(
'You are a support agent in our bank, give the '
'customer support and judge the risk level of their query. '
"Reply using the customer's name."
),
)
@support_agent.system_prompt
async def add_customer_name(ctx: RunContext[SupportDependencies]) -> str:
customer_name = await ctx.deps.db.customer_name(id=ctx.deps.customer_id)
return f"The customer's name is {customer_name!r}"
@support_agent.tool
async def customer_balance(
ctx: RunContext[SupportDependencies], include_pending: bool
) -> str:
"""Returns the customer's current account balance."""
balance = await ctx.deps.db.customer_balance(
id=ctx.deps.customer_id,
include_pending=include_pending,
)
return f'${balance:.2f}'
deps = SupportDependencies(customer_id=123, db=DatabaseConn())
# below is only the UI part using Solara
class MessageDict(TypedDict):
role: str # "user" or "assistant"
content: str
messages: solara.Reactive[List[MessageDict]] = solara.reactive([])
def no_api_key_message():
messages.value = [
{
"role": "assistant",
"content": "No API key found. Please set your Mistral API key in the environment variable `MISTRAL_API_KEY`.",
},
]
@solara.lab.task
async def promt_ai(message: str):
if mistral_api_key is None:
no_api_key_message()
return
messages.value = [
*messages.value,
{"role": "user", "content": message},
]
result = await support_agent.run(message, deps=deps)
print("result", result)
messages.value = [*messages.value, {"role": "assistant", "content": result.data.support_advice}]
@solara.component
def Page():
with solara.Card(
"Bank assistant",
style={"width": "100%", "height": "100vh", "margin": "20px"},
):
with solara.lab.ChatBox():
for item in messages.value:
with solara.lab.ChatMessage(
user=item["role"] == "user",
avatar=False,
name="ChatGPT" if item["role"] == "assistant" else "User",
color="rgba(0,0,0, 0.06)" if item["role"] == "assistant" else "#ff991f",
avatar_background_color="primary" if item["role"] == "assistant" else None,
border_radius="20px",
):
solara.Markdown(item["content"])
if promt_ai.pending:
solara.Text("I'm thinking...", style={"font-size": "1rem", "padding-left": "20px"})
solara.ProgressLinear()
solara.lab.ChatInput(send_callback=promt_ai, disabled=promt_ai.pending)#, auto_focus=True)
with solara.Row():
solara.Button("Ask about your balance", on_click=lambda: promt_ai("What is my balance?"), disabled=promt_ai.pending, text=True, color="primary")
solara.Button("Report a lost card", on_click=lambda: promt_ai("I lost my card"), disabled=promt_ai.pending, text=True, color="primary")