Coverage for src/bartz/grove.py: 100%
63 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
1# bartz/src/bartz/grove.py
2#
3# Copyright (c) 2024, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
27Functions to create and manipulate binary trees.
29A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused.
31A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'.
33The 'leaf' array contains the values in the leaves.
35The 'var' array contains the axes along which the decision nodes operate.
37The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0.
39Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array.
41"""
43import functools 1a
44import math 1a
46import jax 1a
47from jax import numpy as jnp 1a
48from jax import lax 1a
50from . import jaxext 1a
52def make_tree(depth, dtype): 1a
53 """
54 Make an array to represent a binary tree.
56 Parameters
57 ----------
58 depth : int
59 The maximum depth of the tree. Depth 1 means that there is only a root
60 node.
61 dtype : dtype
62 The dtype of the array.
64 Returns
65 -------
66 tree : array
67 An array of zeroes with shape (2 ** depth,).
68 """
69 return jnp.zeros(2 ** depth, dtype) 1a
71def tree_depth(tree): 1a
72 """
73 Return the maximum depth of a tree.
75 Parameters
76 ----------
77 tree : array
78 A tree created by `make_tree`. If the array is ND, the tree structure is
79 assumed to be along the last axis.
81 Returns
82 -------
83 depth : int
84 The maximum depth of the tree.
85 """
86 return int(round(math.log2(tree.shape[-1]))) 1a
88def traverse_tree(x, var_tree, split_tree): 1a
89 """
90 Find the leaf where a point falls into.
92 Parameters
93 ----------
94 x : array (p,)
95 The coordinates to evaluate the tree at.
96 var_tree : array (2 ** (d - 1),)
97 The decision axes of the tree.
98 split_tree : array (2 ** (d - 1),)
99 The decision boundaries of the tree.
101 Returns
102 -------
103 index : int
104 The index of the leaf.
105 """
107 carry = ( 1a
108 jnp.zeros((), bool),
109 jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)),
110 )
112 def loop(carry, _): 1a
113 leaf_found, index = carry 1a
115 split = split_tree[index] 1a
116 var = var_tree[index] 1a
118 leaf_found |= split == 0 1a
119 child_index = (index << 1) + (x[var] >= split) 1a
120 index = jnp.where(leaf_found, index, child_index) 1a
122 return (leaf_found, index), None 1a
124 depth = tree_depth(var_tree) 1a
125 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1a
126 return index 1a
128@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0)) 1a
129@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None)) 1a
130def traverse_forest(X, var_trees, split_trees): 1a
131 """
132 Find the leaves where points fall into.
134 Parameters
135 ----------
136 X : array (p, n)
137 The coordinates to evaluate the trees at.
138 var_trees : array (m, 2 ** (d - 1))
139 The decision axes of the trees.
140 split_trees : array (m, 2 ** (d - 1))
141 The decision boundaries of the trees.
143 Returns
144 -------
145 indices : array (m, n)
146 The indices of the leaves.
147 """
148 return traverse_tree(X, var_trees, split_trees) 1a
150def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True): 1a
151 """
152 Evaluate a ensemble of trees at an array of points.
154 Parameters
155 ----------
156 X : array (p, n)
157 The coordinates to evaluate the trees at.
158 leaf_trees : array (m, 2 ** d)
159 The leaf values of the tree or forest. If the input is a forest, the
160 first axis is the tree index, and the values are summed.
161 var_trees : array (m, 2 ** (d - 1))
162 The decision axes of the trees.
163 split_trees : array (m, 2 ** (d - 1))
164 The decision boundaries of the trees.
165 dtype : dtype, optional
166 The dtype of the output. Ignored if `sum_trees` is `False`.
167 sum_trees : bool, default True
168 Whether to sum the values across trees.
170 Returns
171 -------
172 out : array (n,) or (m, n)
173 The (sum of) the values of the trees at the points in `X`.
174 """
175 indices = traverse_forest(X, var_trees, split_trees) 1a
176 ntree, _ = leaf_trees.shape 1a
177 tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1)) 1a
178 leaves = leaf_trees[tree_index[:, None], indices] 1a
179 if sum_trees: 1a
180 return jnp.sum(leaves, axis=0, dtype=dtype) 1a
181 # this sum suggests to swap the vmaps, but I think it's better for X
182 # copying to keep it that way
183 else:
184 return leaves 1a
186def is_actual_leaf(split_tree, *, add_bottom_level=False): 1a
187 """
188 Return a mask indicating the leaf nodes in a tree.
190 Parameters
191 ----------
192 split_tree : int array (2 ** (d - 1),)
193 The splitting points of the tree.
194 add_bottom_level : bool, default False
195 If True, the bottom level of the tree is also considered.
197 Returns
198 -------
199 is_actual_leaf : bool array (2 ** (d - 1) or 2 ** d,)
200 The mask indicating the leaf nodes. The length is doubled if
201 `add_bottom_level` is True.
202 """
203 size = split_tree.size 1a
204 is_leaf = split_tree == 0 1a
205 if add_bottom_level: 1a
206 size *= 2 1a
207 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 1a
208 index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1)) 1a
209 parent_index = index >> 1 1a
210 parent_nonleaf = split_tree[parent_index].astype(bool) 1a
211 parent_nonleaf = parent_nonleaf.at[1].set(True) 1a
212 return is_leaf & parent_nonleaf 1a
214def is_leaves_parent(split_tree): 1a
215 """
216 Return a mask indicating the nodes with leaf (and only leaf) children.
218 Parameters
219 ----------
220 split_tree : int array (2 ** (d - 1),)
221 The decision boundaries of the tree.
223 Returns
224 -------
225 is_leaves_parent : bool array (2 ** (d - 1),)
226 The mask indicating which nodes have leaf children.
227 """
228 index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)) 1a
229 left_index = index << 1 # left child 1a
230 right_index = left_index + 1 # right child 1a
231 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 1a
232 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 1a
233 is_not_leaf = split_tree.astype(bool) 1a
234 return is_not_leaf & left_leaf & right_leaf 1a
235 # the 0-th item has split == 0, so it's not counted
237def tree_depths(tree_length): 1a
238 """
239 Return the depth of each node in a binary tree.
241 Parameters
242 ----------
243 tree_length : int
244 The length of the tree array, i.e., 2 ** d.
246 Returns
247 -------
248 depth : array (tree_length,)
249 The depth of each node. The root node (index 1) has depth 0. The depth
250 is the position of the most significant non-zero bit in the index. The
251 first element (the unused node) is marked as depth 0.
252 """
253 depths = [] 1a
254 depth = 0 1a
255 for i in range(tree_length): 1a
256 if i == 2 ** depth: 1a
257 depth += 1 1a
258 depths.append(depth - 1) 1a
259 depths[0] = 0 1a
260 return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths))) 1a