module mltree.tree_structure
#
Short summary#
module mlinsights.mltree.tree_structure
Helpers to investigate a tree structure.
Functions#
function |
truncated documentation |
---|---|
Returns the tree object. |
|
Returns the leave every observations of X falls into. |
|
Finds the common node to nodes i and j. |
|
Lists nodes involved into the path to find node i. |
|
Returns the indices of every leave in a tree. |
|
The function determines which leaves are neighbors. The method uses some memory as it creates creates a grid of … |
|
Returns a dictionary |
|
Determines the ranges for a node all dimensions. |
Documentation#
Helpers to investigate a tree structure.
- mlinsights.mltree.tree_structure._get_tree(obj)#
Returns the tree object.
- mlinsights.mltree.tree_structure.predict_leaves(model, X)#
Returns the leave every observations of X falls into.
- Parameters:
model – a decision tree
X – observations
- Returns:
array of leaves
- mlinsights.mltree.tree_structure.tree_find_common_node(tree, i, j, parents=None)#
Finds the common node to nodes i and j.
- Parameters:
tree – tree
i – node index (
tree.nodes[i]
)j – node index (
tree.nodes[j]
)parents – precomputed parents (None -> calls
tree_node_range
)
- Returns:
common root, remaining path to i, remaining path to j
- mlinsights.mltree.tree_structure.tree_find_path_to_root(tree, i, parents=None)#
Lists nodes involved into the path to find node i.
- Parameters:
tree – tree
i – node index (
tree.nodes[i]
)parents – precomputed parents (None -> calls
tree_node_range
)
- Returns:
one array of size (D, 2) where D is the number of dimensions
- mlinsights.mltree.tree_structure.tree_leave_index(model)#
Returns the indices of every leave in a tree.
- Parameters:
model – something which has a member
tree_
- Returns:
leave indices
- mlinsights.mltree.tree_structure.tree_leave_neighbors(model)#
The function determines which leaves are neighbors. The method uses some memory as it creates creates a grid of the feature spaces, each split multiplies the number of cells by two.
- Parameters:
model – a sklearn.tree.DecisionTreeRegressor, a sklearn.tree.DecisionTreeClassifier, a model which has a member
tree_
- Returns:
a dictionary
{(i, j): (dimension, x1, x2)}
, i, j are node indices, if, the observations goes to node i, j otherwise, i < j. The border is somewhere in the segment
[x1, x2]
.
The following example shows what the function returns in case of simple grid in two dimensions.
<<<
import numpy from sklearn.tree import DecisionTreeClassifier from mlinsights.mltree import tree_leave_neighbors X = numpy.array([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2]]) y = list(range(X.shape[0])) clr = DecisionTreeClassifier(max_depth=4) clr.fit(X, y) nei = tree_leave_neighbors(clr) import pprint pprint.pprint(nei)
>>>
somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead. warnings.warn(msg, category=FutureWarning) {(2, 4): [(0, (0.0, 0.0), (1.0, 0.0))], (2, 8): [(1, (0.0, 0.0), (0.0, 1.0))], (4, 5): [(0, (1.0, 0.0), (2.0, 0.0))], (4, 12): [(1, (1.0, 0.0), (1.0, 1.0))], (5, 13): [(1, (2.0, 0.0), (2.0, 1.0))], (8, 9): [(1, (0.0, 1.0), (0.0, 2.0))], (8, 12): [(0, (0.0, 1.0), (1.0, 1.0))], (9, 15): [(0, (0.0, 2.0), (1.0, 2.0))], (12, 13): [(0, (1.0, 1.0), (2.0, 1.0))], (12, 15): [(1, (1.0, 1.0), (1.0, 2.0))], (13, 16): [(1, (2.0, 1.0), (2.0, 2.0))], (15, 16): [(0, (1.0, 2.0), (2.0, 2.0))]}
- mlinsights.mltree.tree_structure.tree_node_parents(tree)#
Returns a dictionary
{node_id: parent_id}
.- Parameters:
tree – tree
- Returns:
parents
- mlinsights.mltree.tree_structure.tree_node_range(tree, i, parents=None)#
Determines the ranges for a node all dimensions.
nan
means infinity.- Parameters:
tree – tree
i – node index (
tree.nodes[i]
)parents – precomputed parents (None -> calls
tree_node_range
)
- Returns:
one array of size (D, 2) where D is the number of dimensions
The following example shows what the function returns in case of simple grid in two dimensions.
<<<
import numpy from sklearn.tree import DecisionTreeClassifier from mlinsights.mltree import tree_leave_index, tree_node_range X = numpy.array([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2]]) y = list(range(X.shape[0])) clr = DecisionTreeClassifier(max_depth=4) clr.fit(X, y) leaves = tree_leave_index(clr) ra = tree_node_range(clr, leaves[0]) print(ra)
>>>
[[nan 0.5] [nan 0.5]]