Py.Cafe

maartenbreddels/

drawdata-sklearn-demo

Interactive Decision Boundary Visualization with sklearn

DocsPricing
  • app.py
  • requirements.txt
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from drawdata import ScatterWidget
widget = ScatterWidget()

import matplotlib.pyplot as plt
from IPython.core.display import HTML
from sklearn.linear_model import LogisticRegression
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.tree import DecisionTreeClassifier

import matplotlib.pylab as plt 
import numpy as np
import ipywidgets
widget = ScatterWidget()
output = ipywidgets.Output()


@output.capture(clear_output=True)
def on_change(change):
    df = widget.data_as_pandas
    if len(df) and (df['color'].nunique() > 1):
        X = df[['x', 'y']].values
        y = df['color']
        display(HTML("<br><br><br>"))
        fig = plt.figure(figsize=(12, 12));
        classifier = DecisionTreeClassifier().fit(X, y)
        disp = DecisionBoundaryDisplay.from_estimator(
            classifier, X, 
            response_method="predict_proba" if len(np.unique(df['color'])) == 2 else "predict",
            xlabel="x", ylabel="y",
            alpha=0.5,
        );
        disp.ax_.scatter(X[:, 0], X[:, 1], c=y, edgecolor="k");
        plt.title(f"{classifier.__class__.__name__}");
        plt.show();

widget.observe(on_change, names=["data"])
on_change(None)
# assign to page variable to let solara know what to render
page = ipywidgets.HBox([widget, output])
requirements.txt
1
2
3
4
5
6
7
numpy
drawdata
matplotlib
scikit-learn
pandas