Py.Cafe

jackparmer/

wave-equation

Wave Equation Simulation Dashboard using Dash

DocsPricing
  • app.py
  • requirements.txt
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import numpy as np
from scipy.integrate import solve_ivp

# Initialize the Dash app
app = dash.Dash(__name__)

# Define the layout of the app
app.layout = html.Div([
    dcc.Graph(id='wave-plot'),
    dcc.Slider(
        id='time-slider',
        min=0.1,
        max=100,
        step=0.1,
        value=0.1,
        marks={i: str(i) for i in range(0, 101, 10)},
    ),
    html.Div([
        html.Label('Initial Amplitude'),
        dcc.Input(id='initial-amplitude', type='number', value=1, step=0.1),
    ]),
    html.Div([
        html.Label('Initial Size'),
        dcc.Input(id='initial-size', type='number', value=10, step=1),
    ]),
])

# Solve the wave equation
def wave_equation(t, U, c=1):
    Nx, Ny = 50, 50
    U = U.reshape(2, Nx, Ny)
    dUdt = np.zeros_like(U)
    dUdt[0] = U[1]
    dUdt[1] = c**2 * (np.roll(U[0], 1, axis=0) + np.roll(U[0], -1, axis=0) + np.roll(U[0], 1, axis=1) + np.roll(U[0], -1, axis=1) - 4 * U[0])
    return dUdt.flatten()

# Initial condition
def initial_condition(Nx, Ny, amplitude, size):
    U = np.zeros((2, Nx, Ny))
    center_x, center_y = Nx // 2, Ny // 2
    half_size = size // 2
    U[0, center_x - half_size:center_x + half_size, center_y - half_size:center_y + half_size] = amplitude
    return U.flatten()

# Callback to update the plot
@app.callback(
    Output('wave-plot', 'figure'),
    Input('time-slider', 'value'),
    Input('initial-amplitude', 'value'),
    Input('initial-size', 'value')
)
def update_plot(t, amplitude, size):
    Nx, Ny = 50, 50
    initial_cond = initial_condition(Nx, Ny, amplitude, size)
    sol = solve_ivp(wave_equation, [0, t], initial_cond, t_eval=[t], method='RK45')

    if len(sol.y) == 0 or len(sol.y[0]) == 0:
        # If the solution is empty, return an empty plot
        fig = go.Figure()
        fig.update_layout(
            scene=dict(
                xaxis_title='X',
                yaxis_title='Y',
                zaxis_title='Amplitude',
                aspectratio=dict(x=1, y=1, z=0.5)
            ),
            margin=dict(l=0, r=0, t=0, b=0)
        )
        return fig

    sol_y = np.array(sol.y).flatten()
    Z = sol_y[:Nx*Ny].reshape(Nx, Ny)

    fig = go.Figure(data=[go.Surface(z=Z, colorscale='Blues', showscale=False)])
    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Amplitude',
            aspectratio=dict(x=1, y=1, z=0.5)
        ),
        margin=dict(l=0, r=0, t=0, b=0)
    )
    return fig

if __name__ == '__main__':
    app.run_server(debug=True)