Coverage for src/bartz/grove.py: 100%
83 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
1# bartz/src/bartz/grove.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions to create and manipulate binary decision trees."""
27import math 1ab
28from functools import partial 1ab
29from typing import Protocol 1ab
31import jax 1ab
32from jax import jit, lax 1ab
33from jax import numpy as jnp 1ab
34from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Real, Shaped, UInt 1ab
36from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc 1ab
39class TreeHeaps(Protocol): 1ab
40 """A protocol for dataclasses that represent trees.
42 A tree is represented with arrays as a heap. The root node is at index 1.
43 The children nodes of a node at index :math:`i` are at indices :math:`2i`
44 (left child) and :math:`2i + 1` (right child). The array element at index 0
45 is unused.
47 Parameters
48 ----------
49 leaf_tree
50 The values in the leaves of the trees. This array can be dirty, i.e.,
51 unused nodes can have whatever value.
52 var_tree
53 The axes along which the decision nodes operate. This array can be
54 dirty but for the always unused node at index 0 which must be set to 0.
55 split_tree
56 The decision boundaries of the trees. The boundaries are open on the
57 right, i.e., a point belongs to the left child iff x < split. Whether a
58 node is a leaf is indicated by the corresponding 'split' element being
59 0. Unused nodes also have split set to 0. This array can't be dirty.
61 Notes
62 -----
63 Since the nodes at the bottom can only be leaves and not decision nodes,
64 `var_tree` and `split_tree` are half as long as `leaf_tree`.
65 """
67 leaf_tree: Float32[Array, '* 2**d'] 1ab
68 var_tree: UInt[Array, '* 2**(d-1)'] 1ab
69 split_tree: UInt[Array, '* 2**(d-1)'] 1ab
72def make_tree(depth: int, dtype: DTypeLike) -> Shaped[Array, ' 2**{depth}']: 1ab
73 """
74 Make an array to represent a binary tree.
76 Parameters
77 ----------
78 depth
79 The maximum depth of the tree. Depth 1 means that there is only a root
80 node.
81 dtype
82 The dtype of the array.
84 Returns
85 -------
86 An array of zeroes with the appropriate shape.
87 """
88 return jnp.zeros(2**depth, dtype) 1ab
91def tree_depth(tree: Shaped[Array, '* 2**d']) -> int: 1ab
92 """
93 Return the maximum depth of a tree.
95 Parameters
96 ----------
97 tree
98 A tree created by `make_tree`. If the array is ND, the tree structure is
99 assumed to be along the last axis.
101 Returns
102 -------
103 The maximum depth of the tree.
104 """
105 return round(math.log2(tree.shape[-1])) 1ab
108def traverse_tree( 1ab
109 x: Real[Array, ' p'],
110 var_tree: UInt[Array, ' 2**(d-1)'],
111 split_tree: UInt[Array, ' 2**(d-1)'],
112) -> Int32[Array, '']:
113 """
114 Find the leaf where a point falls into.
116 Parameters
117 ----------
118 x
119 The coordinates to evaluate the tree at.
120 var_tree
121 The decision axes of the tree.
122 split_tree
123 The decision boundaries of the tree.
125 Returns
126 -------
127 The index of the leaf.
128 """
129 carry = ( 1ab
130 jnp.zeros((), bool),
131 jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)),
132 )
134 def loop(carry, _): 1ab
135 leaf_found, index = carry 1ab
137 split = split_tree[index] 1ab
138 var = var_tree[index] 1ab
140 leaf_found |= split == 0 1ab
141 child_index = (index << 1) + (x[var] >= split) 1ab
142 index = jnp.where(leaf_found, index, child_index) 1ab
144 return (leaf_found, index), None 1ab
146 depth = tree_depth(var_tree) 1ab
147 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1ab
148 return index 1ab
151@partial(vmap_nodoc, in_axes=(None, 0, 0)) 1ab
152@partial(vmap_nodoc, in_axes=(1, None, None)) 1ab
153def traverse_forest( 1ab
154 X: Real[Array, 'p n'],
155 var_trees: UInt[Array, 'm 2**(d-1)'],
156 split_trees: UInt[Array, 'm 2**(d-1)'],
157) -> Int32[Array, 'm n']:
158 """
159 Find the leaves where points fall into.
161 Parameters
162 ----------
163 X
164 The coordinates to evaluate the trees at.
165 var_trees
166 The decision axes of the trees.
167 split_trees
168 The decision boundaries of the trees.
170 Returns
171 -------
172 The indices of the leaves.
173 """
174 return traverse_tree(X, var_trees, split_trees) 1ab
177def evaluate_forest( 1ab
178 X: UInt[Array, 'p n'], trees: TreeHeaps, *, sum_trees: bool = True
179) -> Float32[Array, ' n'] | Float32[Array, 'm n']:
180 """
181 Evaluate a ensemble of trees at an array of points.
183 Parameters
184 ----------
185 X
186 The coordinates to evaluate the trees at.
187 trees
188 The tree heaps, with batch shape (m,).
189 sum_trees
190 Whether to sum the values across trees.
192 Returns
193 -------
194 The (sum of) the values of the trees at the points in `X`.
195 """
196 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 1ab
197 ntree, _ = trees.leaf_tree.shape 1ab
198 tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1)) 1ab
199 leaves = trees.leaf_tree[tree_index[:, None], indices] 1ab
200 if sum_trees: 1ab
201 return jnp.sum(leaves, axis=0, dtype=jnp.float32) 1ab
202 # this sum suggests to swap the vmaps, but I think it's better for X
203 # copying to keep it that way
204 else:
205 return leaves 1ab
208def is_actual_leaf( 1ab
209 split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False
210) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']:
211 """
212 Return a mask indicating the leaf nodes in a tree.
214 Parameters
215 ----------
216 split_tree
217 The splitting points of the tree.
218 add_bottom_level
219 If True, the bottom level of the tree is also considered.
221 Returns
222 -------
223 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True.
224 """
225 size = split_tree.size 1ab
226 is_leaf = split_tree == 0 1ab
227 if add_bottom_level: 1ab
228 size *= 2 1ab
229 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 1ab
230 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 1ab
231 parent_index = index >> 1 1ab
232 parent_nonleaf = split_tree[parent_index].astype(bool) 1ab
233 parent_nonleaf = parent_nonleaf.at[1].set(True) 1ab
234 return is_leaf & parent_nonleaf 1ab
237def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']: 1ab
238 """
239 Return a mask indicating the nodes with leaf (and only leaf) children.
241 Parameters
242 ----------
243 split_tree
244 The decision boundaries of the tree.
246 Returns
247 -------
248 The mask indicating which nodes have leaf children.
249 """
250 index = jnp.arange( 1ab
251 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1)
252 )
253 left_index = index << 1 # left child 1ab
254 right_index = left_index + 1 # right child 1ab
255 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 1ab
256 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 1ab
257 is_not_leaf = split_tree.astype(bool) 1ab
258 return is_not_leaf & left_leaf & right_leaf 1ab
259 # the 0-th item has split == 0, so it's not counted
262def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']: 1ab
263 """
264 Return the depth of each node in a binary tree.
266 Parameters
267 ----------
268 tree_length
269 The length of the tree array, i.e., 2 ** d.
271 Returns
272 -------
273 The depth of each node.
275 Notes
276 -----
277 The root node (index 1) has depth 0. The depth is the position of the most
278 significant non-zero bit in the index. The first element (the unused node)
279 is marked as depth 0.
280 """
281 depths = [] 1ab
282 depth = 0 1ab
283 for i in range(tree_length): 1ab
284 if i == 2**depth: 1ab
285 depth += 1 1ab
286 depths.append(depth - 1) 1ab
287 depths[0] = 0 1ab
288 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 1ab
291def is_used(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**d']: 1ab
292 """
293 Return a mask indicating the used nodes in a tree.
295 Parameters
296 ----------
297 split_tree
298 The decision boundaries of the tree.
300 Returns
301 -------
302 A mask indicating which nodes are actually used.
303 """
304 internal_node = split_tree.astype(bool) 1ab
305 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 1ab
306 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab
307 return internal_node | actual_leaf 1ab
310@jit 1ab
311def forest_fill(split_tree: UInt[Array, 'num_trees 2**(d-1)']) -> Float32[Array, '']: 1ab
312 """
313 Return the fraction of used nodes in a set of trees.
315 Parameters
316 ----------
317 split_tree
318 The decision boundaries of the trees.
320 Returns
321 -------
322 Number of tree nodes over the maximum number that could be stored.
323 """
324 num_trees, _ = split_tree.shape 1ab
325 used = jax.vmap(is_used)(split_tree) 1ab
326 count = jnp.count_nonzero(used) 1ab
327 return count / (used.size - num_trees) 1ab
330def var_histogram( 1ab
331 p: int, var_tree: UInt[Array, '* 2**(d-1)'], split_tree: UInt[Array, '* 2**(d-1)']
332) -> Int32[Array, ' {p}']:
333 """
334 Count how many times each variable appears in a tree.
336 Parameters
337 ----------
338 p
339 The number of variables (the maximum value that can occur in
340 `var_tree` is ``p - 1``).
341 var_tree
342 The decision axes of the tree.
343 split_tree
344 The decision boundaries of the tree.
346 Returns
347 -------
348 The histogram of the variables used in the tree.
350 Notes
351 -----
352 If there are leading axes in the tree arrays (i.e., multiple trees), the
353 returned counts are cumulative over trees.
354 """
355 is_internal = split_tree.astype(bool) 1ab
356 return jnp.zeros(p, int).at[var_tree].add(is_internal) 1ab