Coverage for src/bartz/grove.py: 100%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 18:54 +0000

1# bartz/src/bartz/grove.py 

2# 

3# Copyright (c) 2024, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26 

27Functions to create and manipulate binary trees. 

28 

29A tree is represented with arrays as a heap. The root node is at index 1. The children nodes of a node at index :math:`i` are at indices :math:`2i` (left child) and :math:`2i + 1` (right child). The array element at index 0 is unused. 

30 

31A decision tree is represented by tree arrays: 'leaf', 'var', and 'split'. 

32 

33The 'leaf' array contains the values in the leaves. 

34 

35The 'var' array contains the axes along which the decision nodes operate. 

36 

37The 'split' array contains the decision boundaries. The boundaries are open on the right, i.e., a point belongs to the left child iff x < split. Whether a node is a leaf is indicated by the corresponding 'split' element being 0. 

38 

39Since the nodes at the bottom can only be leaves and not decision nodes, the 'var' and 'split' arrays have half the length of the 'leaf' array. 

40 

41""" 

42 

43import functools 1a

44import math 1a

45 

46import jax 1a

47from jax import numpy as jnp 1a

48from jax import lax 1a

49 

50from . import jaxext 1a

51 

52def make_tree(depth, dtype): 1a

53 """ 

54 Make an array to represent a binary tree. 

55 

56 Parameters 

57 ---------- 

58 depth : int 

59 The maximum depth of the tree. Depth 1 means that there is only a root 

60 node. 

61 dtype : dtype 

62 The dtype of the array. 

63 

64 Returns 

65 ------- 

66 tree : array 

67 An array of zeroes with shape (2 ** depth,). 

68 """ 

69 return jnp.zeros(2 ** depth, dtype) 1a

70 

71def tree_depth(tree): 1a

72 """ 

73 Return the maximum depth of a tree. 

74 

75 Parameters 

76 ---------- 

77 tree : array 

78 A tree created by `make_tree`. If the array is ND, the tree structure is 

79 assumed to be along the last axis. 

80 

81 Returns 

82 ------- 

83 depth : int 

84 The maximum depth of the tree. 

85 """ 

86 return int(round(math.log2(tree.shape[-1]))) 1a

87 

88def traverse_tree(x, var_tree, split_tree): 1a

89 """ 

90 Find the leaf where a point falls into. 

91 

92 Parameters 

93 ---------- 

94 x : array (p,) 

95 The coordinates to evaluate the tree at. 

96 var_tree : array (2 ** (d - 1),) 

97 The decision axes of the tree. 

98 split_tree : array (2 ** (d - 1),) 

99 The decision boundaries of the tree. 

100 

101 Returns 

102 ------- 

103 index : int 

104 The index of the leaf. 

105 """ 

106 

107 carry = ( 1a

108 jnp.zeros((), bool), 

109 jnp.ones((), jaxext.minimal_unsigned_dtype(2 * var_tree.size - 1)), 

110 ) 

111 

112 def loop(carry, _): 1a

113 leaf_found, index = carry 1a

114 

115 split = split_tree[index] 1a

116 var = var_tree[index] 1a

117 

118 leaf_found |= split == 0 1a

119 child_index = (index << 1) + (x[var] >= split) 1a

120 index = jnp.where(leaf_found, index, child_index) 1a

121 

122 return (leaf_found, index), None 1a

123 

124 depth = tree_depth(var_tree) 1a

125 (_, index), _ = lax.scan(loop, carry, None, depth, unroll=16) 1a

126 return index 1a

127 

128@functools.partial(jaxext.vmap_nodoc, in_axes=(None, 0, 0)) 1a

129@functools.partial(jaxext.vmap_nodoc, in_axes=(1, None, None)) 1a

130def traverse_forest(X, var_trees, split_trees): 1a

131 """ 

132 Find the leaves where points fall into. 

133 

134 Parameters 

135 ---------- 

136 X : array (p, n) 

137 The coordinates to evaluate the trees at. 

138 var_trees : array (m, 2 ** (d - 1)) 

139 The decision axes of the trees. 

140 split_trees : array (m, 2 ** (d - 1)) 

141 The decision boundaries of the trees. 

142 

143 Returns 

144 ------- 

145 indices : array (m, n) 

146 The indices of the leaves. 

147 """ 

148 return traverse_tree(X, var_trees, split_trees) 1a

149 

150def evaluate_forest(X, leaf_trees, var_trees, split_trees, dtype=None, sum_trees=True): 1a

151 """ 

152 Evaluate a ensemble of trees at an array of points. 

153 

154 Parameters 

155 ---------- 

156 X : array (p, n) 

157 The coordinates to evaluate the trees at. 

158 leaf_trees : array (m, 2 ** d) 

159 The leaf values of the tree or forest. If the input is a forest, the 

160 first axis is the tree index, and the values are summed. 

161 var_trees : array (m, 2 ** (d - 1)) 

162 The decision axes of the trees. 

163 split_trees : array (m, 2 ** (d - 1)) 

164 The decision boundaries of the trees. 

165 dtype : dtype, optional 

166 The dtype of the output. Ignored if `sum_trees` is `False`. 

167 sum_trees : bool, default True 

168 Whether to sum the values across trees. 

169 

170 Returns 

171 ------- 

172 out : array (n,) or (m, n) 

173 The (sum of) the values of the trees at the points in `X`. 

174 """ 

175 indices = traverse_forest(X, var_trees, split_trees) 1a

176 ntree, _ = leaf_trees.shape 1a

177 tree_index = jnp.arange(ntree, dtype=jaxext.minimal_unsigned_dtype(ntree - 1)) 1a

178 leaves = leaf_trees[tree_index[:, None], indices] 1a

179 if sum_trees: 1a

180 return jnp.sum(leaves, axis=0, dtype=dtype) 1a

181 # this sum suggests to swap the vmaps, but I think it's better for X 

182 # copying to keep it that way 

183 else: 

184 return leaves 1a

185 

186def is_actual_leaf(split_tree, *, add_bottom_level=False): 1a

187 """ 

188 Return a mask indicating the leaf nodes in a tree. 

189 

190 Parameters 

191 ---------- 

192 split_tree : int array (2 ** (d - 1),) 

193 The splitting points of the tree. 

194 add_bottom_level : bool, default False 

195 If True, the bottom level of the tree is also considered. 

196 

197 Returns 

198 ------- 

199 is_actual_leaf : bool array (2 ** (d - 1) or 2 ** d,) 

200 The mask indicating the leaf nodes. The length is doubled if 

201 `add_bottom_level` is True. 

202 """ 

203 size = split_tree.size 1a

204 is_leaf = split_tree == 0 1a

205 if add_bottom_level: 1a

206 size *= 2 1a

207 is_leaf = jnp.concatenate([is_leaf, jnp.ones_like(is_leaf)]) 1a

208 index = jnp.arange(size, dtype=jaxext.minimal_unsigned_dtype(size - 1)) 1a

209 parent_index = index >> 1 1a

210 parent_nonleaf = split_tree[parent_index].astype(bool) 1a

211 parent_nonleaf = parent_nonleaf.at[1].set(True) 1a

212 return is_leaf & parent_nonleaf 1a

213 

214def is_leaves_parent(split_tree): 1a

215 """ 

216 Return a mask indicating the nodes with leaf (and only leaf) children. 

217 

218 Parameters 

219 ---------- 

220 split_tree : int array (2 ** (d - 1),) 

221 The decision boundaries of the tree. 

222 

223 Returns 

224 ------- 

225 is_leaves_parent : bool array (2 ** (d - 1),) 

226 The mask indicating which nodes have leaf children. 

227 """ 

228 index = jnp.arange(split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)) 1a

229 left_index = index << 1 # left child 1a

230 right_index = left_index + 1 # right child 1a

231 left_leaf = split_tree.at[left_index].get(mode='fill', fill_value=0) == 0 1a

232 right_leaf = split_tree.at[right_index].get(mode='fill', fill_value=0) == 0 1a

233 is_not_leaf = split_tree.astype(bool) 1a

234 return is_not_leaf & left_leaf & right_leaf 1a

235 # the 0-th item has split == 0, so it's not counted 

236 

237def tree_depths(tree_length): 1a

238 """ 

239 Return the depth of each node in a binary tree. 

240 

241 Parameters 

242 ---------- 

243 tree_length : int 

244 The length of the tree array, i.e., 2 ** d. 

245 

246 Returns 

247 ------- 

248 depth : array (tree_length,) 

249 The depth of each node. The root node (index 1) has depth 0. The depth 

250 is the position of the most significant non-zero bit in the index. The 

251 first element (the unused node) is marked as depth 0. 

252 """ 

253 depths = [] 1a

254 depth = 0 1a

255 for i in range(tree_length): 1a

256 if i == 2 ** depth: 1a

257 depth += 1 1a

258 depths.append(depth - 1) 1a

259 depths[0] = 0 1a

260 return jnp.array(depths, jaxext.minimal_unsigned_dtype(max(depths))) 1a