Coverage for src/bartz/grove.py: 100%
73 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
1# bartz/src/bartz/grove.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
27Functions to create and manipulate binary trees.
29A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused.
31A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
33The 'leaf' array contains the values in the leaves.
35The 'var' array contains the axes along which the decision nodes operate.
37The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0. Unused nodes also have split set to 0.
39Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
41"""
43import functools 1ab
44import math 1ab
46import jax 1ab
47from jax import lax 1ab
48from jax import numpy as jnp 1ab
50from . import jaxext 1ab
53def make_tree(depth, dtype): 1ab
54 """
55 Make an array to represent a binary tree.
57 Parameters
58 ----------
59 depth : int
60 The maximum depth of the tree. Depth 1 means that there is only a root
61 node.
62 dtype : dtype
63 The dtype of the array.
65 Returns
66 -------
67 tree : array
68 An array of zeroes with shape (2 ** depth,).
69 """
70 return jnp.zeros(2**depth, dtype) 1ab
73def tree_depth(tree): 1ab
74 """
75 Return the maximum depth of a tree.
77 Parameters
78 ----------
79 tree : array
80 A tree created by `make_tree`. If the array is ND, the tree structure is
81 assumed to be along the last axis.
83 Returns
84 -------
85 depth : int
86 The maximum depth of the tree.
87 """
88 return int(round(math.log2(tree.shape[-1]))) 1ab
91def traverse_tree(x, var_tree, split_tree): 1ab
92 """
93 Find the leaf where a point falls into.
95 Parameters
96 ----------
97 x : array (p,)
98 The coordinates to evaluate the tree at.
99 var_tree : array (2 ** (d - 1),)
100 The decision axes of the tree.
101 split_tree : array (2 ** (d - 1),)
102 The decision boundaries of the tree.
104 Returns
105 -------
106 index : int
107 The index of the leaf.
108 """
109 carry = ( 1ab
110 jnp.zeros((), bool),
111 jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
112 )
114 def loop(carry, _): 1ab
115 leaf_found, index = carry 1ab
117 split = split_tree[index] 1ab
118 var = var_tree[index] 1ab
120 leaf_found |= split == 0 1ab
121 child_index = (index << 1) + (x[var] >= split) 1ab
122 index = jnp.where(leaf_found, index, child_index) 1ab
124 return (leaf_found, index), None 1ab
126 depth = tree_depth(var_tree) 1ab
127 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1ab
128 return index 1ab
131@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0)) 1ab
132@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None)) 1ab
133def traverse_forest(X, var_trees, split_trees): 1ab
134 """
135 Find the leaves where points fall into.
137 Parameters
138 ----------
139 X : array (p, n)
140 The coordinates to evaluate the trees at.
141 var_trees : array (m, 2 ** (d - 1))
142 The decision axes of the trees.
143 split_trees : array (m, 2 ** (d - 1))
144 The decision boundaries of the trees.
146 Returns
147 -------
148 indices : array (m, n)
149 The indices of the leaves.
150 """
151 return traverse_tree(X, var_trees, split_trees) 1ab
154def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True): 1ab
155 """
156 Evaluate a ensemble of trees at an array of points.
158 Parameters
159 ----------
160 X : array (p, n)
161 The coordinates to evaluate the trees at.
162 leaf_trees : array (m, 2 ** d)
163 The leaf values of the tree or forest. If the input is a forest, the
164 first axis is the tree index, and the values are summed.
165 var_trees : array (m, 2 ** (d - 1))
166 The decision axes of the trees.
167 split_trees : array (m, 2 ** (d - 1))
168 The decision boundaries of the trees.
169 dtype : dtype, optional
170 The dtype of the output. Ignored if `sum_trees` is `False`.
171 sum_trees : bool, default True
172 Whether to sum the values across trees.
174 Returns
175 -------
176 out : array (n,) or (m, n)
177 The (sum of) the values of the trees at the points in `X`.
178 """
179 indices = traverse_forest(X, var_trees, split_trees) 1ab
180 ntree, _ = leaf_trees.shape 1ab
181 tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1)) 1ab
182 leaves = leaf_trees[tree_index[:, None], indices] 1ab
183 if sum_trees: 1ab
184 return jnp.sum(leaves, axis=0, dtype=dtype) 1ab
185 # this sum suggests to swap the vmaps, but I think it's better for X
186 # copying to keep it that way
187 else:
188 return leaves 1ab
191def is_actual_leaf(split_tree, *, add_bottom_level=False): 1ab
192 """
193 Return a mask indicating the leaf nodes in a tree.
195 Parameters
196 ----------
197 split_tree : int array (2 ** (d - 1),)
198 The splitting points of the tree.
199 add_bottom_level : bool, default False
200 If True, the bottom level of the tree is also considered.
202 Returns
203 -------
204 is_actual_leaf : bool array (2 ** (d - 1) or 2 ** d,)
205 The mask indicating the leaf nodes. The length is doubled if
206 `add_bottom_level` is True.
207 """
208 size = split_tree.size 1ab
209 is_leaf = split_tree == 0 1ab
210 if add_bottom_level: 1ab
211 size *= 2 1ab
212 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 1ab
213 index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1)) 1ab
214 parent_index = index >> 1 1ab
215 parent_nonleaf = split_tree[parent_index].astype(bool) 1ab
216 parent_nonleaf = parent_nonleaf.at[1].set(True) 1ab
217 return is_leaf & parent_nonleaf 1ab
220def is_leaves_parent(split_tree): 1ab
221 """
222 Return a mask indicating the nodes with leaf (and only leaf) children.
224 Parameters
225 ----------
226 split_tree : int array (2 ** (d - 1),)
227 The decision boundaries of the tree.
229 Returns
230 -------
231 is_leaves_parent : bool array (2 ** (d - 1),)
232 The mask indicating which nodes have leaf children.
233 """
234 index = jnp.arange( 1ab
235 split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)
236 )
237 left_index = index << 1 # left child 1ab
238 right_index = left_index + 1 # right child 1ab
239 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 1ab
240 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 1ab
241 is_not_leaf = split_tree.astype(bool) 1ab
242 return is_not_leaf & left_leaf & right_leaf 1ab
243 # the 0-th item has split == 0, so it's not counted
246def tree_depths(tree_length): 1ab
247 """
248 Return the depth of each node in a binary tree.
250 Parameters
251 ----------
252 tree_length : int
253 The length of the tree array, i.e., 2 ** d.
255 Returns
256 -------
257 depth : array (tree_length,)
258 The depth of each node. The root node (index 1) has depth 0. The depth
259 is the position of the most significant non-zero bit in the index. The
260 first element (the unused node) is marked as depth 0.
261 """
262 depths = [] 1ab
263 depth = 0 1ab
264 for i in range(tree_length): 1ab
265 if i == 2**depth: 1ab
266 depth += 1 1ab
267 depths.append(depth - 1) 1ab
268 depths[0] = 0 1ab
269 return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths))) 1ab
272def is_used(split_tree): 1ab
273 """
274 Return a mask indicating the used nodes in a tree.
276 Parameters
277 ----------
278 split_tree : int array (2 ** (d - 1),)
279 The decision boundaries of the tree.
281 Returns
282 -------
283 is_used : bool array (2 ** d,)
284 A mask indicating which nodes are actually used.
285 """
286 internal_node = split_tree.astype(bool) 1ab
287 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 1ab
288 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab
289 return internal_node | actual_leaf 1ab
292def forest_fill(split_trees): 1ab
293 """
294 Return the fraction of used nodes in a set of trees.
296 Parameters
297 ----------
298 split_trees : array (m, 2 ** (d - 1),)
299 The decision boundaries of the trees.
301 Returns
302 -------
303 fill : float
304 The number of tree nodes in the forest over the maximum number that
305 could be stored in the arrays.
306 """
307 m, _ = split_trees.shape 1ab
308 used = jax.vmap(is_used)(split_trees) 1ab
309 count = jnp.count_nonzero(used) 1ab
310 return count / (used.size - m) 1ab