Coverage for src/bartz/grove.py: 100%

73 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-05-29 23:01 +0000

1# bartz/src/bartz/grove.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26 

27Functions to create and manipulate binary trees. 

28 

29A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused. 

30 

31A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'. 

32 

33The 'leaf' array contains the values in the leaves. 

34 

35The 'var' array contains the axes along which the decision nodes operate. 

36 

37The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0. Unused nodes also have split set to 0. 

38 

39Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array. 

40 

41""" 

42 

43import functools 1ab

44import math 1ab

45 

46import jax 1ab

47from jax import lax 1ab

48from jax import numpy as jnp 1ab

49 

50from . import jaxext 1ab

51 

52 

53def make_tree(depth, dtype): 1ab

54 """ 

55 Make an array to represent a binary tree. 

56 

57 Parameters 

58 ---------- 

59 depth : int 

60 The maximum depth of the tree. Depth 1 means that there is only a root 

61 node. 

62 dtype : dtype 

63 The dtype of the array. 

64 

65 Returns 

66 ------- 

67 tree : array 

68 An array of zeroes with shape (2 ** depth,). 

69 """ 

70 return jnp.zeros(2**depth, dtype) 1ab

71 

72 

73def tree_depth(tree): 1ab

74 """ 

75 Return the maximum depth of a tree. 

76 

77 Parameters 

78 ---------- 

79 tree : array 

80 A tree created by `make_tree`. If the array is ND, the tree structure is 

81 assumed to be along the last axis. 

82 

83 Returns 

84 ------- 

85 depth : int 

86 The maximum depth of the tree. 

87 """ 

88 return int(round(math.log2(tree.shape[-1]))) 1ab

89 

90 

91def traverse_tree(x, var_tree, split_tree): 1ab

92 """ 

93 Find the leaf where a point falls into. 

94 

95 Parameters 

96 ---------- 

97 x : array (p,) 

98 The coordinates to evaluate the tree at. 

99 var_tree : array (2 ** (d - 1),) 

100 The decision axes of the tree. 

101 split_tree : array (2 ** (d - 1),) 

102 The decision boundaries of the tree. 

103 

104 Returns 

105 ------- 

106 index : int 

107 The index of the leaf. 

108 """ 

109 carry = ( 1ab

110 jnp.zeros((), bool), 

111 jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)), 

112 ) 

113 

114 def loop(carry, _): 1ab

115 leaf_found, index = carry 1ab

116 

117 split = split_tree[index] 1ab

118 var = var_tree[index] 1ab

119 

120 leaf_found |= split == 0 1ab

121 child_index = (index << 1) + (x[var] >= split) 1ab

122 index = jnp.where(leaf_found, index, child_index) 1ab

123 

124 return (leaf_found, index), None 1ab

125 

126 depth = tree_depth(var_tree) 1ab

127 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1ab

128 return index 1ab

129 

130 

131@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0)) 1ab

132@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None)) 1ab

133def traverse_forest(X, var_trees, split_trees): 1ab

134 """ 

135 Find the leaves where points fall into. 

136 

137 Parameters 

138 ---------- 

139 X : array (p, n) 

140 The coordinates to evaluate the trees at. 

141 var_trees : array (m, 2 ** (d - 1)) 

142 The decision axes of the trees. 

143 split_trees : array (m, 2 ** (d - 1)) 

144 The decision boundaries of the trees. 

145 

146 Returns 

147 ------- 

148 indices : array (m, n) 

149 The indices of the leaves. 

150 """ 

151 return traverse_tree(X, var_trees, split_trees) 1ab

152 

153 

154def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True): 1ab

155 """ 

156 Evaluate a ensemble of trees at an array of points. 

157 

158 Parameters 

159 ---------- 

160 X : array (p, n) 

161 The coordinates to evaluate the trees at. 

162 leaf_trees : array (m, 2 ** d) 

163 The leaf values of the tree or forest. If the input is a forest, the 

164 first axis is the tree index, and the values are summed. 

165 var_trees : array (m, 2 ** (d - 1)) 

166 The decision axes of the trees. 

167 split_trees : array (m, 2 ** (d - 1)) 

168 The decision boundaries of the trees. 

169 dtype : dtype, optional 

170 The dtype of the output. Ignored if `sum_trees` is `False`. 

171 sum_trees : bool, default True 

172 Whether to sum the values across trees. 

173 

174 Returns 

175 ------- 

176 out : array (n,) or (m, n) 

177 The (sum of) the values of the trees at the points in `X`. 

178 """ 

179 indices = traverse_forest(X, var_trees, split_trees) 1ab

180 ntree, _ = leaf_trees.shape 1ab

181 tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1)) 1ab

182 leaves = leaf_trees[tree_index[:, None], indices] 1ab

183 if sum_trees: 1ab

184 return jnp.sum(leaves, axis=0, dtype=dtype) 1ab

185 # this sum suggests to swap the vmaps, but I think it's better for X 

186 # copying to keep it that way 

187 else: 

188 return leaves 1ab

189 

190 

191def is_actual_leaf(split_tree, *, add_bottom_level=False): 1ab

192 """ 

193 Return a mask indicating the leaf nodes in a tree. 

194 

195 Parameters 

196 ---------- 

197 split_tree : int array (2 ** (d - 1),) 

198 The splitting points of the tree. 

199 add_bottom_level : bool, default False 

200 If True, the bottom level of the tree is also considered. 

201 

202 Returns 

203 ------- 

204 is_actual_leaf : bool array (2 ** (d - 1) or 2 ** d,) 

205 The mask indicating the leaf nodes. The length is doubled if 

206 `add_bottom_level` is True. 

207 """ 

208 size = split_tree.size 1ab

209 is_leaf = split_tree == 0 1ab

210 if add_bottom_level: 1ab

211 size *= 2 1ab

212 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 1ab

213 index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1)) 1ab

214 parent_index = index >> 1 1ab

215 parent_nonleaf = split_tree[parent_index].astype(bool) 1ab

216 parent_nonleaf = parent_nonleaf.at[1].set(True) 1ab

217 return is_leaf & parent_nonleaf 1ab

218 

219 

220def is_leaves_parent(split_tree): 1ab

221 """ 

222 Return a mask indicating the nodes with leaf (and only leaf) children. 

223 

224 Parameters 

225 ---------- 

226 split_tree : int array (2 ** (d - 1),) 

227 The decision boundaries of the tree. 

228 

229 Returns 

230 ------- 

231 is_leaves_parent : bool array (2 ** (d - 1),) 

232 The mask indicating which nodes have leaf children. 

233 """ 

234 index = jnp.arange( 1ab

235 split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1) 

236 ) 

237 left_index = index << 1 # left child 1ab

238 right_index = left_index + 1 # right child 1ab

239 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 1ab

240 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 1ab

241 is_not_leaf = split_tree.astype(bool) 1ab

242 return is_not_leaf & left_leaf & right_leaf 1ab

243 # the 0-th item has split == 0, so it's not counted 

244 

245 

246def tree_depths(tree_length): 1ab

247 """ 

248 Return the depth of each node in a binary tree. 

249 

250 Parameters 

251 ---------- 

252 tree_length : int 

253 The length of the tree array, i.e., 2 ** d. 

254 

255 Returns 

256 ------- 

257 depth : array (tree_length,) 

258 The depth of each node. The root node (index 1) has depth 0. The depth 

259 is the position of the most significant non-zero bit in the index. The 

260 first element (the unused node) is marked as depth 0. 

261 """ 

262 depths = [] 1ab

263 depth = 0 1ab

264 for i in range(tree_length): 1ab

265 if i == 2**depth: 1ab

266 depth += 1 1ab

267 depths.append(depth - 1) 1ab

268 depths[0] = 0 1ab

269 return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths))) 1ab

270 

271 

272def is_used(split_tree): 1ab

273 """ 

274 Return a mask indicating the used nodes in a tree. 

275 

276 Parameters 

277 ---------- 

278 split_tree : int array (2 ** (d - 1),) 

279 The decision boundaries of the tree. 

280 

281 Returns 

282 ------- 

283 is_used : bool array (2 ** d,) 

284 A mask indicating which nodes are actually used. 

285 """ 

286 internal_node = split_tree.astype(bool) 1ab

287 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 1ab

288 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab

289 return internal_node | actual_leaf 1ab

290 

291 

292def forest_fill(split_trees): 1ab

293 """ 

294 Return the fraction of used nodes in a set of trees. 

295 

296 Parameters 

297 ---------- 

298 split_trees : array (m, 2 ** (d - 1),) 

299 The decision boundaries of the trees. 

300 

301 Returns 

302 ------- 

303 fill : float 

304 The number of tree nodes in the forest over the maximum number that 

305 could be stored in the arrays. 

306 """ 

307 m, _ = split_trees.shape 1ab

308 used = jax.vmap(is_used)(split_trees) 1ab

309 count = jnp.count_nonzero(used) 1ab

310 return count / (used.size - m) 1ab