Visualiser un arbre de décision

Links: notebook, html, PDF, python, slides, GitHub

Les arbres de décision sont des modèles intéressants car ils peuvent être interprétés. Encore faut-il pouvoir les voir.

from sklearn import datasets
iris = datasets.load_iris()
X = iris.data[:, :2]  # we only take the first two features.
y = iris.target
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
clf.fit(X, y)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

scikit-learn implémente une méthode qui permet d’exporter de graphe au format DOT : export_graphviz. Ce graphe peut être visualiser avec l’outil graphviz ou des modules comme pydot mais cela passe par l’installation graphviz.

from sklearn.tree import export_graphviz
dot = export_graphviz(clf, out_file=None)
print("\n".join(dot.split('\n')[:10]) + "\n...")
digraph Tree {
node [shape=box] ;
0 [label="X[0] <= 5.45ngini = 0.667nsamples = 150nvalue = [50, 50, 50]"] ;
1 [label="X[1] <= 2.8ngini = 0.237nsamples = 52nvalue = [45, 6, 1]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[0] <= 4.7ngini = 0.449nsamples = 7nvalue = [1, 5, 1]"] ;
1 -> 2 ;
3 [label="gini = 0.0nsamples = 1nvalue = [1, 0, 0]"] ;
2 -> 3 ;
4 [label="X[0] <= 4.95ngini = 0.278nsamples = 6nvalue = [0, 5, 1]"] ;
...

La libraire viz.js est une version javascript de graphviz. Avec un wrapper disponible RenderJsDot, cela devient :

from jyquickhelper import RenderJsDot
RenderJsDot(dot)

C’est encore lisible mais cela risque de ne plus le devenir pour de gros arbres. On utilise alors la librairie vis.js et le wrapper RenderJsVis.

from jyquickhelper import RenderJsVis
RenderJsVis(dot=dot, height="400px", layout='hierarchical')