import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import numpy as np
from scipy.integrate import solve_ivp
# Initialize the Dash app
app = dash.Dash(__name__)
# Define the layout of the app
app.layout = html.Div([
dcc.Graph(id='wave-plot'),
dcc.Slider(
id='time-slider',
min=0.1,
max=100,
step=0.1,
value=0.1,
marks={i: str(i) for i in range(0, 101, 10)},
),
html.Div([
html.Label('Initial Amplitude'),
dcc.Input(id='initial-amplitude', type='number', value=1, step=0.1),
]),
html.Div([
html.Label('Initial Size'),
dcc.Input(id='initial-size', type='number', value=10, step=1),
]),
])
# Solve the wave equation
def wave_equation(t, U, c=1):
Nx, Ny = 50, 50
U = U.reshape(2, Nx, Ny)
dUdt = np.zeros_like(U)
dUdt[0] = U[1]
dUdt[1] = c**2 * (np.roll(U[0], 1, axis=0) + np.roll(U[0], -1, axis=0) + np.roll(U[0], 1, axis=1) + np.roll(U[0], -1, axis=1) - 4 * U[0])
return dUdt.flatten()
# Initial condition
def initial_condition(Nx, Ny, amplitude, size):
U = np.zeros((2, Nx, Ny))
center_x, center_y = Nx // 2, Ny // 2
half_size = size // 2
U[0, center_x - half_size:center_x + half_size, center_y - half_size:center_y + half_size] = amplitude
return U.flatten()
# Callback to update the plot
@app.callback(
Output('wave-plot', 'figure'),
Input('time-slider', 'value'),
Input('initial-amplitude', 'value'),
Input('initial-size', 'value')
)
def update_plot(t, amplitude, size):
Nx, Ny = 50, 50
initial_cond = initial_condition(Nx, Ny, amplitude, size)
sol = solve_ivp(wave_equation, [0, t], initial_cond, t_eval=[t], method='RK45')
if len(sol.y) == 0 or len(sol.y[0]) == 0:
# If the solution is empty, return an empty plot
fig = go.Figure()
fig.update_layout(
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Amplitude',
aspectratio=dict(x=1, y=1, z=0.5)
),
margin=dict(l=0, r=0, t=0, b=0)
)
return fig
sol_y = np.array(sol.y).flatten()
Z = sol_y[:Nx*Ny].reshape(Nx, Ny)
fig = go.Figure(data=[go.Surface(z=Z, colorscale='Blues', showscale=False)])
fig.update_layout(
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Amplitude',
aspectratio=dict(x=1, y=1, z=0.5)
),
margin=dict(l=0, r=0, t=0, b=0)
)
return fig
if __name__ == '__main__':
app.run_server(debug=True)