Py.Cafe

jackparmer/

fish-school

DocsPricing
  • app.py
  • requirements.txt
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import dash
import dash_html_components as html
import dash_core_components as dcc
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import numpy as np

# Initialize the Dash app
app = dash.Dash(__name__)

# Parameters
initial_num_fish = 200
width, height = 300, 300  # 2D plane dimensions
max_force = 0.05
perception_radius = 50

# Fish class to represent each fish in the school
class Fish:
    def __init__(self, max_speed):
        self.position = np.random.rand(2) * [width, height]
        self.velocity = (np.random.rand(2) - 0.5) * max_speed
        self.acceleration = np.zeros(2)
        self.max_speed = max_speed

    def update(self):
        self.velocity += self.acceleration
        speed = np.linalg.norm(self.velocity)
        if speed > self.max_speed:
            self.velocity = (self.velocity / speed) * self.max_speed
        self.position += self.velocity
        self.acceleration *= 0

    def apply_force(self, force):
        self.acceleration += force

    def edges(self):
        for i in range(2):
            if self.position[i] > [width, height][i]:
                self.position[i] = 0
            elif self.position[i] < 0:
                self.position[i] = [width, height][i]

    def school(self, fishes):
        alignment = np.zeros(2)
        cohesion = np.zeros(2)
        separation = np.zeros(2)
        total = 0
        for other in fishes:
            distance = np.linalg.norm(self.position - other.position)
            if 0 < distance < perception_radius:
                alignment += other.velocity
                cohesion += other.position
                separation += (self.position - other.position) / (distance ** 2)
                total += 1
        if total > 0:
            alignment = (alignment / total) - self.velocity
            cohesion = ((cohesion / total) - self.position) - self.velocity
            separation = separation / total
            alignment = self.limit_force(alignment)
            cohesion = self.limit_force(cohesion)
            separation = self.limit_force(separation)
        self.apply_force(alignment)
        self.apply_force(cohesion)
        self.apply_force(separation)

    def limit_force(self, force):
        if np.linalg.norm(force) > max_force:
            force = (force / np.linalg.norm(force)) * max_force
        return force

# Initialize fish
fishes = [Fish(np.random.uniform(1, 5)) for _ in range(initial_num_fish)]

# Function to create fish traces
def create_fish_traces():
    traces = go.Scatter(
        x=[fish.position[0] for fish in fishes],
        y=[fish.position[1] for fish in fishes],
        mode='markers',
        marker=dict(size=5, color='blue'),
        hoverinfo='none'
    )
    return [traces]

# Define the layout of the Dash app
app.layout = html.Div([
    html.H1("Fish School Simulation", style={'color': 'white', 'textAlign': 'center'}),
    dcc.Slider(
        id='velocity-slider',
        min=1,
        max=10,
        step=0.5,
        value=5,
        marks={i: str(i) for i in range(1, 11)},
        tooltip={"placement": "bottom", "always_visible": True}
    ),
    html.Div(id='slider-output-velocity', style={'color': 'white', 'textAlign': 'center'}),
    dcc.Slider(
        id='num-fish-slider',
        min=50,
        max=500,
        step=50,
        value=initial_num_fish,
        marks={i: str(i) for i in range(50, 501, 50)},
        tooltip={"placement": "bottom", "always_visible": True}
    ),
    html.Div(id='slider-output-num-fish', style={'color': 'white', 'textAlign': 'center'}),
    dcc.Graph(id='school-graph', style={'height': '80vh'}),
    dcc.Interval(id='interval-component', interval=50, n_intervals=0)
], style={'backgroundColor': 'black', 'padding': '20px'})

# Define the callback to update the fish positions
@app.callback(
    [Output('school-graph', 'figure'), 
     Output('slider-output-velocity', 'children'), 
     Output('slider-output-num-fish', 'children')],
    [Input('interval-component', 'n_intervals'), 
     Input('velocity-slider', 'value'), 
     Input('num-fish-slider', 'value')]
)
def update_school(n, average_speed, num_fish):
    global fishes

    # Adjust the number of fish
    if len(fishes) < num_fish:
        fishes += [Fish(np.random.uniform(1, average_speed * 2)) for _ in range(num_fish - len(fishes))]
    elif len(fishes) > num_fish:
        fishes = fishes[:num_fish]

    # Update fish speeds to match the average speed
    for fish in fishes:
        fish.max_speed = np.random.uniform(1, average_speed * 2)

    for fish in fishes:
        fish.school(fishes)
        fish.update()
        fish.edges()

    traces = create_fish_traces()

    fig = go.Figure(data=traces)
    fig.update_layout(
        xaxis=dict(visible=False, range=[0, width]),
        yaxis=dict(visible=False, range=[0, height]),
        paper_bgcolor='black',
        plot_bgcolor='black',
        showlegend=False
    )
    return fig, f'Current average fish velocity: {average_speed}', f'Number of fish: {num_fish}'

# Run the app
if __name__ == '__main__':
    app.run_server(debug=True)