Tree manipulation¶
Functions to create and manipulate binary decision trees.
- class bartz.grove.TreeHeaps(*args, **kwargs)[source]¶
A protocol for dataclasses that represent trees.
A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index \(i\) are at indices \(2i\) (left child) and \(2i + 1\) (right child). The array element at index 0 is unused.
- Parameters:
leaf_tree – The values in the leaves of the trees. This array can be dirty, i.e., unused nodes can have whatever value.
var_tree – The axes along which the decision nodes operate. This array can be dirty but for the always unused node at index 0 which must be set to 0.
split_tree – The decision boundaries of the trees. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding ‘split’ element being 0. Unused nodes also have split set to 0. This array can’t be dirty.
Notes
Since the nodes at the bottom can only be leaves and not decision nodes,
var_tree
andsplit_tree
are half as long asleaf_tree
.
- bartz.grove.make_tree(depth, dtype)[source]¶
Make an array to represent a binary tree.
- Parameters:
depth (
int
) – The maximum depth of the tree. Depth 1 means that there is only a root node.dtype (
str
|type
[Any
] |dtype
|SupportsDType
) – The dtype of the array.
- Returns:
Shaped[Array, '2**{depth}']
– An array of zeroes with the appropriate shape.
- bartz.grove.tree_depth(tree)[source]¶
Return the maximum depth of a tree.
- Parameters:
tree (
Shaped[Array, '* 2**d']
) – A tree created bymake_tree
. If the array is ND, the tree structure is assumed to be along the last axis.- Returns:
int
– The maximum depth of the tree.
- bartz.grove.traverse_tree(x, var_tree, split_tree)[source]¶
Find the leaf where a point falls into.
- Parameters:
x (
Real[Array, 'p']
) – The coordinates to evaluate the tree at.var_tree (
UInt[Array, '2**(d-1)']
) – The decision axes of the tree.split_tree (
UInt[Array, '2**(d-1)']
) – The decision boundaries of the tree.
- Returns:
Int32[Array, '']
– The index of the leaf.
- bartz.grove.traverse_forest(X, var_trees, split_trees)[source]¶
Find the leaves where points fall into.
- Parameters:
X (
Real[Array, 'p n']
) – The coordinates to evaluate the trees at.var_trees (
UInt[Array, 'm 2**(d-1)']
) – The decision axes of the trees.split_trees (
UInt[Array, 'm 2**(d-1)']
) – The decision boundaries of the trees.
- Returns:
Int32[Array, 'm n']
– The indices of the leaves.
- bartz.grove.evaluate_forest(X, trees, *, sum_trees=True)[source]¶
Evaluate a ensemble of trees at an array of points.
- Parameters:
X (
UInt[Array, 'p n']
) – The coordinates to evaluate the trees at.trees (
TreeHeaps
) – The tree heaps, with batch shape (m,).sum_trees (
bool
, default:True
) – Whether to sum the values across trees.
- Returns:
Float32[Array, 'n']
|Float32[Array, 'm n']
– The (sum of) the values of the trees at the points inX
.
- bartz.grove.is_actual_leaf(split_tree, *, add_bottom_level=False)[source]¶
Return a mask indicating the leaf nodes in a tree.
- Parameters:
split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.add_bottom_level (
bool
, default:False
) – If True, the bottom level of the tree is also considered.
- Returns:
Bool[Array, '2**(d-1)']
|Bool[Array, '2**d']
– The mask marking the leaf nodes. Length doubled ifadd_bottom_level
is True.
- bartz.grove.is_leaves_parent(split_tree)[source]¶
Return a mask indicating the nodes with leaf (and only leaf) children.
- Parameters:
split_tree (
UInt[Array, '2**(d-1)']
) – The decision boundaries of the tree.- Returns:
Bool[Array, '2**(d-1)']
– The mask indicating which nodes have leaf children.
- bartz.grove.tree_depths(tree_length)[source]¶
Return the depth of each node in a binary tree.
- Parameters:
tree_length (
int
) – The length of the tree array, i.e., 2 ** d.- Returns:
Int32[Array, '{tree_length}']
– The depth of each node.
Notes
The root node (index 1) has depth 0. The depth is the position of the most significant non-zero bit in the index. The first element (the unused node) is marked as depth 0.
- bartz.grove.is_used(split_tree)[source]¶
Return a mask indicating the used nodes in a tree.
- Parameters:
split_tree (
UInt[Array, '2**(d-1)']
) – The decision boundaries of the tree.- Returns:
Bool[Array, '2**d']
– A mask indicating which nodes are actually used.
- bartz.grove.forest_fill(split_tree)[source]¶
Return the fraction of used nodes in a set of trees.
- Parameters:
split_tree (
UInt[Array, 'num_trees 2**(d-1)']
) – The decision boundaries of the trees.- Returns:
Float32[Array, '']
– Number of tree nodes over the maximum number that could be stored.
- bartz.grove.var_histogram(p, var_tree, split_tree)[source]¶
Count how many times each variable appears in a tree.
- Parameters:
p (
int
) – The number of variables (the maximum value that can occur invar_tree
isp - 1
).var_tree (
UInt[Array, '* 2**(d-1)']
) – The decision axes of the tree.split_tree (
UInt[Array, '* 2**(d-1)']
) – The decision boundaries of the tree.
- Returns:
Int32[Array, '{p}']
– The histogram of the variables used in the tree.
Notes
If there are leading axes in the tree arrays (i.e., multiple trees), the returned counts are cumulative over trees.