Coverage for src/bartz/grove.py: 100%

83 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 14:46 +0000

1# bartz/src/bartz/grove.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to create and manipulate binary decision trees.""" 

26 

27import math 1ab

28from functools import partial 1ab

29from typing import Protocol 1ab

30 

31import jax 1ab

32from jax import jit, lax 1ab

33from jax import numpy as jnp 1ab

34from jaxtyping import Array, Bool, DTypeLike, Float32, Int32, Real, Shaped, UInt 1ab

35 

36from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc 1ab

37 

38 

39class TreeHeaps(Protocol): 1ab

40 """A protocol for dataclasses that represent trees. 

41 

42 A tree is represented with arrays as a heap. The root node is at index 1. 

43 The children nodes of a node at index :math:`i` are at indices :math:`2i` 

44 (left child) and :math:`2i + 1` (right child). The array element at index 0 

45 is unused. 

46 

47 Parameters 

48 ---------- 

49 leaf_tree 

50 The values in the leaves of the trees. This array can be dirty, i.e., 

51 unused nodes can have whatever value. 

52 var_tree 

53 The axes along which the decision nodes operate. This array can be 

54 dirty but for the always unused node at index 0 which must be set to 0. 

55 split_tree 

56 The decision boundaries of the trees. The boundaries are open on the 

57 right, i.e., a point belongs to the left child iff x < split. Whether a 

58 node is a leaf is indicated by the corresponding 'split' element being 

59 0. Unused nodes also have split set to 0. This array can't be dirty. 

60 

61 Notes 

62 ----- 

63 Since the nodes at the bottom can only be leaves and not decision nodes, 

64 `var_tree` and `split_tree` are half as long as `leaf_tree`. 

65 """ 

66 

67 leaf_tree: Float32[Array, '* 2**d'] 1ab

68 var_tree: UInt[Array, '* 2**(d-1)'] 1ab

69 split_tree: UInt[Array, '* 2**(d-1)'] 1ab

70 

71 

72def make_tree(depth: int, dtype: DTypeLike) -> Shaped[Array, ' 2**{depth}']: 1ab

73 """ 

74 Make an array to represent a binary tree. 

75 

76 Parameters 

77 ---------- 

78 depth 

79 The maximum depth of the tree. Depth 1 means that there is only a root 

80 node. 

81 dtype 

82 The dtype of the array. 

83 

84 Returns 

85 ------- 

86 An array of zeroes with the appropriate shape. 

87 """ 

88 return jnp.zeros(2**depth, dtype) 1ab

89 

90 

91def tree_depth(tree: Shaped[Array, '* 2**d']) -> int: 1ab

92 """ 

93 Return the maximum depth of a tree. 

94 

95 Parameters 

96 ---------- 

97 tree 

98 A tree created by `make_tree`. If the array is ND, the tree structure is 

99 assumed to be along the last axis. 

100 

101 Returns 

102 ------- 

103 The maximum depth of the tree. 

104 """ 

105 return round(math.log2(tree.shape[-1])) 1ab

106 

107 

108def traverse_tree( 1ab

109 x: Real[Array, ' p'], 

110 var_tree: UInt[Array, ' 2**(d-1)'], 

111 split_tree: UInt[Array, ' 2**(d-1)'], 

112) -> Int32[Array, '']: 

113 """ 

114 Find the leaf where a point falls into. 

115 

116 Parameters 

117 ---------- 

118 x 

119 The coordinates to evaluate the tree at. 

120 var_tree 

121 The decision axes of the tree. 

122 split_tree 

123 The decision boundaries of the tree. 

124 

125 Returns 

126 ------- 

127 The index of the leaf. 

128 """ 

129 carry = ( 1ab

130 jnp.zeros((), bool), 

131 jnp.ones((), minimal_unsigned_dtype(2 * var_tree.size - 1)), 

132 ) 

133 

134 def loop(carry, _): 1ab

135 leaf_found, index = carry 1ab

136 

137 split = split_tree[index] 1ab

138 var = var_tree[index] 1ab

139 

140 leaf_found |= split == 0 1ab

141 child_index = (index << 1) + (x[var] >= split) 1ab

142 index = jnp.where(leaf_found, index, child_index) 1ab

143 

144 return (leaf_found, index), None 1ab

145 

146 depth = tree_depth(var_tree) 1ab

147 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1ab

148 return index 1ab

149 

150 

151@partial(vmap_nodoc, in_axes=(None, 0, 0)) 1ab

152@partial(vmap_nodoc, in_axes=(1, None, None)) 1ab

153def traverse_forest( 1ab

154 X: Real[Array, 'p n'], 

155 var_trees: UInt[Array, 'm 2**(d-1)'], 

156 split_trees: UInt[Array, 'm 2**(d-1)'], 

157) -> Int32[Array, 'm n']: 

158 """ 

159 Find the leaves where points fall into. 

160 

161 Parameters 

162 ---------- 

163 X 

164 The coordinates to evaluate the trees at. 

165 var_trees 

166 The decision axes of the trees. 

167 split_trees 

168 The decision boundaries of the trees. 

169 

170 Returns 

171 ------- 

172 The indices of the leaves. 

173 """ 

174 return traverse_tree(X, var_trees, split_trees) 1ab

175 

176 

177def evaluate_forest( 1ab

178 X: UInt[Array, 'p n'], trees: TreeHeaps, *, sum_trees: bool = True 

179) -> Float32[Array, ' n'] | Float32[Array, 'm n']: 

180 """ 

181 Evaluate a ensemble of trees at an array of points. 

182 

183 Parameters 

184 ---------- 

185 X 

186 The coordinates to evaluate the trees at. 

187 trees 

188 The tree heaps, with batch shape (m,). 

189 sum_trees 

190 Whether to sum the values across trees. 

191 

192 Returns 

193 ------- 

194 The (sum of) the values of the trees at the points in `X`. 

195 """ 

196 indices = traverse_forest(X, trees.var_tree, trees.split_tree) 1ab

197 ntree, _ = trees.leaf_tree.shape 1ab

198 tree_index = jnp.arange(ntree, dtype=minimal_unsigned_dtype(ntree - 1)) 1ab

199 leaves = trees.leaf_tree[tree_index[:, None], indices] 1ab

200 if sum_trees: 1ab

201 return jnp.sum(leaves, axis=0, dtype=jnp.float32) 1ab

202 # this sum suggests to swap the vmaps, but I think it's better for X 

203 # copying to keep it that way 

204 else: 

205 return leaves 1ab

206 

207 

208def is_actual_leaf( 1ab

209 split_tree: UInt[Array, ' 2**(d-1)'], *, add_bottom_level: bool = False 

210) -> Bool[Array, ' 2**(d-1)'] | Bool[Array, ' 2**d']: 

211 """ 

212 Return a mask indicating the leaf nodes in a tree. 

213 

214 Parameters 

215 ---------- 

216 split_tree 

217 The splitting points of the tree. 

218 add_bottom_level 

219 If True, the bottom level of the tree is also considered. 

220 

221 Returns 

222 ------- 

223 The mask marking the leaf nodes. Length doubled if `add_bottom_level` is True. 

224 """ 

225 size = split_tree.size 1ab

226 is_leaf = split_tree == 0 1ab

227 if add_bottom_level: 1ab

228 size *= 2 1ab

229 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 1ab

230 index = jnp.arange(size, dtype=minimal_unsigned_dtype(size - 1)) 1ab

231 parent_index = index >> 1 1ab

232 parent_nonleaf = split_tree[parent_index].astype(bool) 1ab

233 parent_nonleaf = parent_nonleaf.at[1].set(True) 1ab

234 return is_leaf & parent_nonleaf 1ab

235 

236 

237def is_leaves_parent(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**(d-1)']: 1ab

238 """ 

239 Return a mask indicating the nodes with leaf (and only leaf) children. 

240 

241 Parameters 

242 ---------- 

243 split_tree 

244 The decision boundaries of the tree. 

245 

246 Returns 

247 ------- 

248 The mask indicating which nodes have leaf children. 

249 """ 

250 index = jnp.arange( 1ab

251 split_tree.size, dtype=minimal_unsigned_dtype(2 * split_tree.size - 1) 

252 ) 

253 left_index = index << 1 # left child 1ab

254 right_index = left_index + 1 # right child 1ab

255 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 1ab

256 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 1ab

257 is_not_leaf = split_tree.astype(bool) 1ab

258 return is_not_leaf & left_leaf & right_leaf 1ab

259 # the 0-th item has split == 0, so it's not counted 

260 

261 

262def tree_depths(tree_length: int) -> Int32[Array, ' {tree_length}']: 1ab

263 """ 

264 Return the depth of each node in a binary tree. 

265 

266 Parameters 

267 ---------- 

268 tree_length 

269 The length of the tree array, i.e., 2 ** d. 

270 

271 Returns 

272 ------- 

273 The depth of each node. 

274 

275 Notes 

276 ----- 

277 The root node (index 1) has depth 0. The depth is the position of the most 

278 significant non-zero bit in the index. The first element (the unused node) 

279 is marked as depth 0. 

280 """ 

281 depths = [] 1ab

282 depth = 0 1ab

283 for i in range(tree_length): 1ab

284 if i == 2**depth: 1ab

285 depth += 1 1ab

286 depths.append(depth - 1) 1ab

287 depths[0] = 0 1ab

288 return jnp.array(depths, minimal_unsigned_dtype(max(depths))) 1ab

289 

290 

291def is_used(split_tree: UInt[Array, ' 2**(d-1)']) -> Bool[Array, ' 2**d']: 1ab

292 """ 

293 Return a mask indicating the used nodes in a tree. 

294 

295 Parameters 

296 ---------- 

297 split_tree 

298 The decision boundaries of the tree. 

299 

300 Returns 

301 ------- 

302 A mask indicating which nodes are actually used. 

303 """ 

304 internal_node = split_tree.astype(bool) 1ab

305 internal_node = jnp.concatenate([internal_node, jnp.zeros_like(internal_node)]) 1ab

306 actual_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab

307 return internal_node | actual_leaf 1ab

308 

309 

310@jit 1ab

311def forest_fill(split_tree: UInt[Array, 'num_trees 2**(d-1)']) -> Float32[Array, '']: 1ab

312 """ 

313 Return the fraction of used nodes in a set of trees. 

314 

315 Parameters 

316 ---------- 

317 split_tree 

318 The decision boundaries of the trees. 

319 

320 Returns 

321 ------- 

322 Number of tree nodes over the maximum number that could be stored. 

323 """ 

324 num_trees, _ = split_tree.shape 1ab

325 used = jax.vmap(is_used)(split_tree) 1ab

326 count = jnp.count_nonzero(used) 1ab

327 return count / (used.size - num_trees) 1ab

328 

329 

330def var_histogram( 1ab

331 p: int, var_tree: UInt[Array, '* 2**(d-1)'], split_tree: UInt[Array, '* 2**(d-1)'] 

332) -> Int32[Array, ' {p}']: 

333 """ 

334 Count how many times each variable appears in a tree. 

335 

336 Parameters 

337 ---------- 

338 p 

339 The number of variables (the maximum value that can occur in 

340 `var_tree` is ``p - 1``). 

341 var_tree 

342 The decision axes of the tree. 

343 split_tree 

344 The decision boundaries of the tree. 

345 

346 Returns 

347 ------- 

348 The histogram of the variables used in the tree. 

349 

350 Notes 

351 ----- 

352 If there are leading axes in the tree arrays (i.e., multiple trees), the 

353 returned counts are cumulative over trees. 

354 """ 

355 is_internal = split_tree.astype(bool) 1ab

356 return jnp.zeros(p, int).at[var_tree].add(is_internal) 1ab