Coverage for src/bartz/mcmcstep.py: 95%

360 statements  

« prev     ^ index     » next       coverage.py v7.5.4, created at 2024-06-28 20:44 +0000

1# bartz/src/bartz/mcmcstep.py 

2# 

3# Copyright (c) 2024, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26Functions that implement the BART posterior MCMC initialization and update step. 

27 

28Functions that do MCMC steps operate by taking as input a bart state, and 

29outputting a new dictionary with the new state. The input dict/arrays are not 

30modified. 

31 

32In general, integer types are chosen to be the minimal types that contain the 

33range of possible values. 

34""" 

35 

36import functools 1a

37import math 1a

38 

39import jax 1a

40from jax import random 1a

41from jax import numpy as jnp 1a

42from jax import lax 1a

43 

44from . import jaxext 1a

45from . import grove 1a

46 

47def init(*, 1a

48 X, 

49 y, 

50 max_split, 

51 num_trees, 

52 p_nonterminal, 

53 sigma2_alpha, 

54 sigma2_beta, 

55 small_float=jnp.float32, 

56 large_float=jnp.float32, 

57 min_points_per_leaf=None, 

58 resid_batch_size='auto', 

59 count_batch_size='auto', 

60 save_ratios=False, 

61 ): 

62 """ 

63 Make a BART posterior sampling MCMC initial state. 

64 

65 Parameters 

66 ---------- 

67 X : int array (p, n) 

68 The predictors. Note this is trasposed compared to the usual convention. 

69 y : float array (n,) 

70 The response. 

71 max_split : int array (p,) 

72 The maximum split index for each variable. All split ranges start at 1. 

73 num_trees : int 

74 The number of trees in the forest. 

75 p_nonterminal : float array (d - 1,) 

76 The probability of a nonterminal node at each depth. The maximum depth 

77 of trees is fixed by the length of this array. 

78 sigma2_alpha : float 

79 The shape parameter of the inverse gamma prior on the noise variance. 

80 sigma2_beta : float 

81 The scale parameter of the inverse gamma prior on the noise variance. 

82 small_float : dtype, default float32 

83 The dtype for large arrays used in the algorithm. 

84 large_float : dtype, default float32 

85 The dtype for scalars, small arrays, and arrays which require accuracy. 

86 min_points_per_leaf : int, optional 

87 The minimum number of data points in a leaf node. 0 if not specified. 

88 resid_batch_size, count_batch_sizes : int, None, str, default 'auto' 

89 The batch sizes, along datapoints, for summing the residuals and 

90 counting the number of datapoints in each leaf. `None` for no batching. 

91 If 'auto', pick a value based on the device of `y`, or the default 

92 device. 

93 save_ratios : bool, default False 

94 Whether to save the Metropolis-Hastings ratios. 

95 

96 Returns 

97 ------- 

98 bart : dict 

99 A dictionary with array values, representing a BART mcmc state. The 

100 keys are: 

101 

102 'leaf_trees' : small_float array (num_trees, 2 ** d) 

103 The leaf values. 

104 'var_trees' : int array (num_trees, 2 ** (d - 1)) 

105 The decision axes. 

106 'split_trees' : int array (num_trees, 2 ** (d - 1)) 

107 The decision boundaries. 

108 'resid' : large_float array (n,) 

109 The residuals (data minus forest value). Large float to avoid 

110 roundoff. 

111 'sigma2' : large_float 

112 The noise variance. 

113 'grow_prop_count', 'prune_prop_count' : int 

114 The number of grow/prune proposals made during one full MCMC cycle. 

115 'grow_acc_count', 'prune_acc_count' : int 

116 The number of grow/prune moves accepted during one full MCMC cycle. 

117 'p_nonterminal' : large_float array (d,) 

118 The probability of a nonterminal node at each depth, padded with a 

119 zero. 

120 'p_propose_grow' : large_float array (2 ** (d - 1),) 

121 The unnormalized probability of picking a leaf for a grow proposal. 

122 'sigma2_alpha' : large_float 

123 The shape parameter of the inverse gamma prior on the noise variance. 

124 'sigma2_beta' : large_float 

125 The scale parameter of the inverse gamma prior on the noise variance. 

126 'max_split' : int array (p,) 

127 The maximum split index for each variable. 

128 'y' : small_float array (n,) 

129 The response. 

130 'X' : int array (p, n) 

131 The predictors. 

132 'leaf_indices' : int array (num_trees, n) 

133 The index of the leaf each datapoints falls into, for each tree. 

134 'min_points_per_leaf' : int or None 

135 The minimum number of data points in a leaf node. 

136 'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None 

137 Whether a non-bottom leaf nodes contains twice `min_points_per_leaf` 

138 datapoints. If `min_points_per_leaf` is not specified, this is None. 

139 'opt' : LeafDict 

140 A dictionary with config values: 

141 

142 'small_float' : dtype 

143 The dtype for large arrays used in the algorithm. 

144 'large_float' : dtype 

145 The dtype for scalars, small arrays, and arrays which require 

146 accuracy. 

147 'require_min_points' : bool 

148 Whether the `min_points_per_leaf` parameter is specified. 

149 'resid_batch_size', 'count_batch_size' : int or None 

150 The data batch sizes for computing the sufficient statistics. 

151 'ratios' : dict, optional 

152 If `save_ratios` is True, this field is present. It has the fields: 

153 

154 'log_trans_prior' : large_float array (num_trees,) 

155 The log transition and prior Metropolis-Hastings ratio for the 

156 proposed move on each tree. 

157 'log_likelihood' : large_float array (num_trees,) 

158 The log likelihood ratio. 

159 """ 

160 

161 p_nonterminal = jnp.asarray(p_nonterminal, large_float) 1a

162 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1a

163 max_depth = p_nonterminal.size 1a

164 

165 @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1a

166 def make_forest(max_depth, dtype): 1a

167 return grove.make_tree(max_depth, dtype) 1a

168 

169 small_float = jnp.dtype(small_float) 1a

170 large_float = jnp.dtype(large_float) 1a

171 y = jnp.asarray(y, small_float) 1a

172 resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, 2 ** max_depth * num_trees) 1a

173 sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float) 1a

174 sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1) 1a

175 

176 bart = dict( 1a

177 leaf_trees=make_forest(max_depth, small_float), 

178 var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)), 

179 split_trees=make_forest(max_depth - 1, max_split.dtype), 

180 resid=jnp.asarray(y, large_float), 

181 sigma2=sigma2, 

182 grow_prop_count=jnp.zeros((), int), 

183 grow_acc_count=jnp.zeros((), int), 

184 prune_prop_count=jnp.zeros((), int), 

185 prune_acc_count=jnp.zeros((), int), 

186 p_nonterminal=p_nonterminal, 

187 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))], 

188 sigma2_alpha=jnp.asarray(sigma2_alpha, large_float), 

189 sigma2_beta=jnp.asarray(sigma2_beta, large_float), 

190 max_split=jnp.asarray(max_split), 

191 y=y, 

192 X=jnp.asarray(X), 

193 leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)), 

194 min_points_per_leaf=( 

195 None if min_points_per_leaf is None else 

196 jnp.asarray(min_points_per_leaf) 

197 ), 

198 affluence_trees=( 

199 None if min_points_per_leaf is None else 

200 make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf) 

201 ), 

202 opt=jaxext.LeafDict( 

203 small_float=small_float, 

204 large_float=large_float, 

205 require_min_points=min_points_per_leaf is not None, 

206 resid_batch_size=resid_batch_size, 

207 count_batch_size=count_batch_size, 

208 ), 

209 ) 

210 

211 if save_ratios: 1a

212 bart['ratios'] = dict( 1a

213 log_trans_prior=jnp.full(num_trees, jnp.nan), 

214 log_likelihood=jnp.full(num_trees, jnp.nan), 

215 ) 

216 

217 return bart 1a

218 

219def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size): 1a

220 

221 @functools.cache 1a

222 def get_platform(): 1a

223 try: 1a

224 device = y.devices().pop() 1a

225 except jax.errors.ConcretizationTypeError: 1a

226 device = jax.devices()[0] 1a

227 platform = device.platform 1a

228 if platform not in ('cpu', 'gpu'): 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true1a

229 raise KeyError(f'Unknown platform: {platform}') 

230 return platform 1a

231 

232 if resid_batch_size == 'auto': 1a

233 platform = get_platform() 1a

234 n = max(1, y.size) 1a

235 if platform == 'cpu': 235 ↛ 237line 235 didn't jump to line 237 because the condition on line 235 was always true1a

236 resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6 1a

237 elif platform == 'gpu': 

238 resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3 

239 resid_batch_size = max(1, resid_batch_size) 1a

240 

241 if count_batch_size == 'auto': 1a

242 platform = get_platform() 1a

243 if platform == 'cpu': 243 ↛ 245line 243 didn't jump to line 245 because the condition on line 243 was always true1a

244 count_batch_size = None 1a

245 elif platform == 'gpu': 

246 n = max(1, y.size) 

247 count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2 

248 # /4 is good on V100, /2 on L4/T4, still haven't tried A100 

249 max_memory = 2 ** 29 

250 itemsize = 4 

251 min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory)) 

252 count_batch_size = max(count_batch_size, min_batch_size) 

253 count_batch_size = max(1, count_batch_size) 

254 

255 return resid_batch_size, count_batch_size 1a

256 

257def step(bart, key): 1a

258 """ 

259 Perform one full MCMC step on a BART state. 

260 

261 Parameters 

262 ---------- 

263 bart : dict 

264 A BART mcmc state, as created by `init`. 

265 key : jax.dtypes.prng_key array 

266 A jax random key. 

267 

268 Returns 

269 ------- 

270 bart : dict 

271 The new BART mcmc state. 

272 """ 

273 key, subkey = random.split(key) 1a

274 bart = sample_trees(bart, subkey) 1a

275 return sample_sigma(bart, key) 1a

276 

277def sample_trees(bart, key): 1a

278 """ 

279 Forest sampling step of BART MCMC. 

280 

281 Parameters 

282 ---------- 

283 bart : dict 

284 A BART mcmc state, as created by `init`. 

285 key : jax.dtypes.prng_key array 

286 A jax random key. 

287 

288 Returns 

289 ------- 

290 bart : dict 

291 The new BART mcmc state. 

292 

293 Notes 

294 ----- 

295 This function zeroes the proposal counters. 

296 """ 

297 key, subkey = random.split(key) 1a

298 moves = sample_moves(bart, subkey) 1a

299 return accept_moves_and_sample_leaves(bart, moves, key) 1a

300 

301def sample_moves(bart, key): 1a

302 """ 

303 Propose moves for all the trees. 

304 

305 Parameters 

306 ---------- 

307 bart : dict 

308 BART mcmc state. 

309 key : jax.dtypes.prng_key array 

310 A jax random key. 

311 

312 Returns 

313 ------- 

314 moves : dict 

315 A dictionary with fields: 

316 

317 'allowed' : bool array (num_trees,) 

318 Whether the move is possible. 

319 'grow' : bool array (num_trees,) 

320 Whether the move is a grow move or a prune move. 

321 'num_growable' : int array (num_trees,) 

322 The number of growable leaves in the original tree. 

323 'node' : int array (num_trees,) 

324 The index of the leaf to grow or node to prune. 

325 'left', 'right' : int array (num_trees,) 

326 The indices of the children of 'node'. 

327 'partial_ratio' : float array (num_trees,) 

328 A factor of the Metropolis-Hastings ratio of the move. It lacks 

329 the likelihood ratio and the probability of proposing the prune 

330 move. If the move is Prune, the ratio is inverted. 

331 'grow_var' : int array (num_trees,) 

332 The decision axes of the new rules. 

333 'grow_split' : int array (num_trees,) 

334 The decision boundaries of the new rules. 

335 'var_trees' : int array (num_trees, 2 ** (d - 1)) 

336 The updated decision axes of the trees, valid whatever move. 

337 'logu' : float array (num_trees,) 

338 The logarithm of a uniform (0, 1] random variable to be used to 

339 accept the move. It's in (-oo, 0]. 

340 """ 

341 ntree = bart['leaf_trees'].shape[0] 1a

342 key = random.split(key, 1 + ntree) 1a

343 key, subkey = key[0], key[1:] 1a

344 

345 # compute moves 

346 grow_moves, prune_moves = _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], subkey) 1a

347 

348 u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float']) 1a

349 

350 # choose between grow or prune 

351 grow_allowed = grow_moves['num_growable'].astype(bool) 1a

352 p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed) 1a

353 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1a

354 

355 # compute children indices 

356 node = jnp.where(grow, grow_moves['node'], prune_moves['node']) 1a

357 left = node << 1 1a

358 right = left + 1 1a

359 

360 return dict( 1a

361 allowed=grow | prune_moves['allowed'], 

362 grow=grow, 

363 num_growable=grow_moves['num_growable'], 

364 node=node, 

365 left=left, 

366 right=right, 

367 partial_ratio=jnp.where(grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']), 

368 grow_var=grow_moves['var'], 

369 grow_split=grow_moves['split'], 

370 var_trees=grow_moves['var_tree'], 

371 logu=jnp.log1p(-logu), 

372 ) 

373 

374@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0)) 1a

375def _sample_moves_vmap_trees(*args): 1a

376 args, key = args[:-1], args[-1] 1a

377 key, key1 = random.split(key) 1a

378 grow = grow_move(*args, key) 1a

379 prune = prune_move(*args, key1) 1a

380 return grow, prune 1a

381 

382def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key): 1a

383 """ 

384 Tree structure grow move proposal of BART MCMC. 

385 

386 This moves picks a leaf node and converts it to a non-terminal node with 

387 two leaf children. The move is not possible if all the leaves are already at 

388 maximum depth. 

389 

390 Parameters 

391 ---------- 

392 var_tree : array (2 ** (d - 1),) 

393 The variable indices of the tree. 

394 split_tree : array (2 ** (d - 1),) 

395 The splitting points of the tree. 

396 affluence_tree : bool array (2 ** (d - 1),) or None 

397 Whether a leaf has enough points to be grown. 

398 max_split : array (p,) 

399 The maximum split index for each variable. 

400 p_nonterminal : array (d,) 

401 The probability of a nonterminal node at each depth. 

402 p_propose_grow : array (2 ** (d - 1),) 

403 The unnormalized probability of choosing a leaf to grow. 

404 key : jax.dtypes.prng_key array 

405 A jax random key. 

406 

407 Returns 

408 ------- 

409 grow_move : dict 

410 A dictionary with fields: 

411 

412 'num_growable' : int 

413 The number of growable leaves. 

414 'node' : int 

415 The index of the leaf to grow. ``2 ** d`` if there are no growable 

416 leaves. 

417 'var', 'split' : int 

418 The decision axis and boundary of the new rule. 

419 'partial_ratio' : float 

420 A factor of the Metropolis-Hastings ratio of the move. It lacks 

421 the likelihood ratio and the probability of proposing the prune 

422 move. 

423 'var_tree' : array (2 ** (d - 1),) 

424 The updated decision axes of the tree. 

425 """ 

426 

427 key, key1, key2 = random.split(key, 3) 1a

428 

429 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key) 1a

430 

431 var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1) 1a

432 var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)) 1a

433 

434 split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2) 1a

435 

436 ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow) 1a

437 

438 return dict( 1a

439 num_growable=num_growable, 

440 node=leaf_to_grow, 

441 var=var, 

442 split=split, 

443 partial_ratio=ratio, 

444 var_tree=var_tree, 

445 ) 

446 

447def choose_leaf(split_tree, affluence_tree, p_propose_grow, key): 1a

448 """ 

449 Choose a leaf node to grow in a tree. 

450 

451 Parameters 

452 ---------- 

453 split_tree : array (2 ** (d - 1),) 

454 The splitting points of the tree. 

455 affluence_tree : bool array (2 ** (d - 1),) or None 

456 Whether a leaf has enough points to be grown. 

457 p_propose_grow : array (2 ** (d - 1),) 

458 The unnormalized probability of choosing a leaf to grow. 

459 key : jax.dtypes.prng_key array 

460 A jax random key. 

461 

462 Returns 

463 ------- 

464 leaf_to_grow : int 

465 The index of the leaf to grow. If ``num_growable == 0``, return 

466 ``2 ** d``. 

467 num_growable : int 

468 The number of leaf nodes that can be grown. 

469 prob_choose : float 

470 The normalized probability of choosing the selected leaf. 

471 num_prunable : int 

472 The number of leaf parents that could be pruned, after converting the 

473 selected leaf to a non-terminal node. 

474 """ 

475 is_growable = growable_leaves(split_tree, affluence_tree) 1a

476 num_growable = jnp.count_nonzero(is_growable) 1a

477 distr = jnp.where(is_growable, p_propose_grow, 0) 1a

478 leaf_to_grow, distr_norm = categorical(key, distr) 1a

479 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1a

480 prob_choose = distr[leaf_to_grow] / distr_norm 1a

481 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1a

482 num_prunable = jnp.count_nonzero(is_parent) 1a

483 return leaf_to_grow, num_growable, prob_choose, num_prunable 1a

484 

485def growable_leaves(split_tree, affluence_tree): 1a

486 """ 

487 Return a mask indicating the leaf nodes that can be proposed for growth. 

488 

489 Parameters 

490 ---------- 

491 split_tree : array (2 ** (d - 1),) 

492 The splitting points of the tree. 

493 affluence_tree : bool array (2 ** (d - 1),) or None 

494 Whether a leaf has enough points to be grown. 

495 

496 Returns 

497 ------- 

498 is_growable : bool array (2 ** (d - 1),) 

499 The mask indicating the leaf nodes that can be proposed to grow, i.e., 

500 that are not at the bottom level and have at least two times the number 

501 of minimum points per leaf. 

502 """ 

503 is_growable = grove.is_actual_leaf(split_tree) 1a

504 if affluence_tree is not None: 504 ↛ 506line 504 didn't jump to line 506 because the condition on line 504 was always true1a

505 is_growable &= affluence_tree 1a

506 return is_growable 1a

507 

508def categorical(key, distr): 1a

509 """ 

510 Return a random integer from an arbitrary distribution. 

511 

512 Parameters 

513 ---------- 

514 key : jax.dtypes.prng_key array 

515 A jax random key. 

516 distr : float array (n,) 

517 An unnormalized probability distribution. 

518 

519 Returns 

520 ------- 

521 u : int 

522 A random integer in the range ``[0, n)``. If all probabilities are zero, 

523 return ``n``. 

524 """ 

525 ecdf = jnp.cumsum(distr) 1a

526 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1a

527 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1a

528 

529def choose_variable(var_tree, split_tree, max_split, leaf_index, key): 1a

530 """ 

531 Choose a variable to split on for a new non-terminal node. 

532 

533 Parameters 

534 ---------- 

535 var_tree : int array (2 ** (d - 1),) 

536 The variable indices of the tree. 

537 split_tree : int array (2 ** (d - 1),) 

538 The splitting points of the tree. 

539 max_split : int array (p,) 

540 The maximum split index for each variable. 

541 leaf_index : int 

542 The index of the leaf to grow. 

543 key : jax.dtypes.prng_key array 

544 A jax random key. 

545 

546 Returns 

547 ------- 

548 var : int 

549 The index of the variable to split on. 

550 

551 Notes 

552 ----- 

553 The variable is chosen among the variables that have a non-empty range of 

554 allowed splits. If no variable has a non-empty range, return `p`. 

555 """ 

556 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1a

557 return randint_exclude(key, max_split.size, var_to_ignore) 1a

558 

559def fully_used_variables(var_tree, split_tree, max_split, leaf_index): 1a

560 """ 

561 Return a list of variables that have an empty split range at a given node. 

562 

563 Parameters 

564 ---------- 

565 var_tree : int array (2 ** (d - 1),) 

566 The variable indices of the tree. 

567 split_tree : int array (2 ** (d - 1),) 

568 The splitting points of the tree. 

569 max_split : int array (p,) 

570 The maximum split index for each variable. 

571 leaf_index : int 

572 The index of the node, assumed to be valid for `var_tree`. 

573 

574 Returns 

575 ------- 

576 var_to_ignore : int array (d - 2,) 

577 The indices of the variables that have an empty split range. Since the 

578 number of such variables is not fixed, unused values in the array are 

579 filled with `p`. The fill values are not guaranteed to be placed in any 

580 particular order. Variables may appear more than once. 

581 """ 

582 

583 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1a

584 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1a

585 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1a

586 num_split = r - l 1a

587 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1a

588 

589def ancestor_variables(var_tree, max_split, node_index): 1a

590 """ 

591 Return the list of variables in the ancestors of a node. 

592 

593 Parameters 

594 ---------- 

595 var_tree : int array (2 ** (d - 1),) 

596 The variable indices of the tree. 

597 max_split : int array (p,) 

598 The maximum split index for each variable. Used only to get `p`. 

599 node_index : int 

600 The index of the node, assumed to be valid for `var_tree`. 

601 

602 Returns 

603 ------- 

604 ancestor_vars : int array (d - 2,) 

605 The variable indices of the ancestors of the node, from the root to 

606 the parent. Unused spots are filled with `p`. 

607 """ 

608 max_num_ancestors = grove.tree_depth(var_tree) - 1 1a

609 ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size)) 1a

610 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1a

611 def loop(carry, _): 1a

612 i, index, ancestor_vars = carry 1a

613 index >>= 1 1a

614 var = var_tree[index] 1a

615 var = jnp.where(index, var, max_split.size) 1a

616 ancestor_vars = ancestor_vars.at[i].set(var) 1a

617 return (i - 1, index, ancestor_vars), None 1a

618 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1a

619 return ancestor_vars 1a

620 

621def split_range(var_tree, split_tree, max_split, node_index, ref_var): 1a

622 """ 

623 Return the range of allowed splits for a variable at a given node. 

624 

625 Parameters 

626 ---------- 

627 var_tree : int array (2 ** (d - 1),) 

628 The variable indices of the tree. 

629 split_tree : int array (2 ** (d - 1),) 

630 The splitting points of the tree. 

631 max_split : int array (p,) 

632 The maximum split index for each variable. 

633 node_index : int 

634 The index of the node, assumed to be valid for `var_tree`. 

635 ref_var : int 

636 The variable for which to measure the split range. 

637 

638 Returns 

639 ------- 

640 l, r : int 

641 The range of allowed splits is [l, r). 

642 """ 

643 max_num_ancestors = grove.tree_depth(var_tree) - 1 1a

644 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(jnp.int32) 1a

645 carry = 0, initial_r, node_index 1a

646 def loop(carry, _): 1a

647 l, r, index = carry 1a

648 right_child = (index & 1).astype(bool) 1a

649 index >>= 1 1a

650 split = split_tree[index] 1a

651 cond = (var_tree[index] == ref_var) & index.astype(bool) 1a

652 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1a

653 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1a

654 return (l, r, index), None 1a

655 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1a

656 return l + 1, r 1a

657 

658def randint_exclude(key, sup, exclude): 1a

659 """ 

660 Return a random integer in a range, excluding some values. 

661 

662 Parameters 

663 ---------- 

664 key : jax.dtypes.prng_key array 

665 A jax random key. 

666 sup : int 

667 The exclusive upper bound of the range. 

668 exclude : int array (n,) 

669 The values to exclude from the range. Values greater than or equal to 

670 `sup` are ignored. Values can appear more than once. 

671 

672 Returns 

673 ------- 

674 u : int 

675 A random integer in the range ``[0, sup)``, and which satisfies 

676 ``u not in exclude``. If all values in the range are excluded, return 

677 `sup`. 

678 """ 

679 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1a

680 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1a

681 u = random.randint(key, (), 0, num_allowed) 1a

682 def loop(u, i): 1a

683 return jnp.where(i <= u, u + 1, u), None 1a

684 u, _ = lax.scan(loop, u, exclude) 1a

685 return u 1a

686 

687def choose_split(var_tree, split_tree, max_split, leaf_index, key): 1a

688 """ 

689 Choose a split point for a new non-terminal node. 

690 

691 Parameters 

692 ---------- 

693 var_tree : int array (2 ** (d - 1),) 

694 The variable indices of the tree. 

695 split_tree : int array (2 ** (d - 1),) 

696 The splitting points of the tree. 

697 max_split : int array (p,) 

698 The maximum split index for each variable. 

699 leaf_index : int 

700 The index of the leaf to grow. It is assumed that `var_tree` already 

701 contains the target variable at this index. 

702 key : jax.dtypes.prng_key array 

703 A jax random key. 

704 

705 Returns 

706 ------- 

707 split : int 

708 The split point. 

709 """ 

710 var = var_tree[leaf_index] 1a

711 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1a

712 return random.randint(key, (), l, r) 1a

713 

714def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow): 1a

715 """ 

716 Compute the product of the transition and prior ratios of a grow move. 

717 

718 Parameters 

719 ---------- 

720 num_growable : int 

721 The number of leaf nodes that can be grown. 

722 num_prunable : int 

723 The number of leaf parents that could be pruned, after converting the 

724 leaf to be grown to a non-terminal node. 

725 p_nonterminal : array (d,) 

726 The probability of a nonterminal node at each depth. 

727 leaf_to_grow : int 

728 The index of the leaf to grow. 

729 

730 Returns 

731 ------- 

732 ratio : float 

733 The transition ratio P(new tree -> old tree) / P(old tree -> new tree) 

734 times the prior ratio P(new tree) / P(old tree), but the transition 

735 ratio is missing the factor P(propose prune) in the numerator. 

736 """ 

737 

738 # the two ratios also contain factors num_available_split * 

739 # num_available_var, but they cancel out 

740 

741 # p_prune can't be computed here because it needs the count trees, which are 

742 # computed in the acceptance phase 

743 

744 prune_allowed = leaf_to_grow != 1 1a

745 # prune allowed <---> the initial tree is not a root 

746 # leaf to grow is root --> the tree can only be a root 

747 # tree is a root --> the only leaf I can grow is root 

748 

749 p_grow = jnp.where(prune_allowed, 0.5, 1) 1a

750 

751 inv_trans_ratio = p_grow * prob_choose * num_prunable 1a

752 

753 depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow] 1a

754 p_parent = p_nonterminal[depth] 1a

755 cp_children = 1 - p_nonterminal[depth + 1] 1a

756 tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent) 1a

757 

758 return tree_ratio / inv_trans_ratio 1a

759 

760def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key): 1a

761 """ 

762 Tree structure prune move proposal of BART MCMC. 

763 

764 Parameters 

765 ---------- 

766 var_tree : int array (2 ** (d - 1),) 

767 The variable indices of the tree. 

768 split_tree : int array (2 ** (d - 1),) 

769 The splitting points of the tree. 

770 affluence_tree : bool array (2 ** (d - 1),) or None 

771 Whether a leaf has enough points to be grown. 

772 max_split : int array (p,) 

773 The maximum split index for each variable. 

774 p_nonterminal : float array (d,) 

775 The probability of a nonterminal node at each depth. 

776 p_propose_grow : float array (2 ** (d - 1),) 

777 The unnormalized probability of choosing a leaf to grow. 

778 key : jax.dtypes.prng_key array 

779 A jax random key. 

780 

781 Returns 

782 ------- 

783 prune_move : dict 

784 A dictionary with fields: 

785 

786 'allowed' : bool 

787 Whether the move is possible. 

788 'node' : int 

789 The index of the node to prune. ``2 ** d`` if no node can be pruned. 

790 'partial_ratio' : float 

791 A factor of the Metropolis-Hastings ratio of the move. It lacks 

792 the likelihood ratio and the probability of proposing the prune 

793 move. This ratio is inverted. 

794 """ 

795 node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key) 1a

796 allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root 1a

797 

798 ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune) 1a

799 

800 return dict( 1a

801 allowed=allowed, 

802 node=node_to_prune, 

803 partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves 

804 ) 

805 

806def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key): 1a

807 """ 

808 Pick a non-terminal node with leaf children to prune in a tree. 

809 

810 Parameters 

811 ---------- 

812 split_tree : array (2 ** (d - 1),) 

813 The splitting points of the tree. 

814 affluence_tree : bool array (2 ** (d - 1),) or None 

815 Whether a leaf has enough points to be grown. 

816 p_propose_grow : array (2 ** (d - 1),) 

817 The unnormalized probability of choosing a leaf to grow. 

818 key : jax.dtypes.prng_key array 

819 A jax random key. 

820 

821 Returns 

822 ------- 

823 node_to_prune : int 

824 The index of the node to prune. If ``num_prunable == 0``, return 

825 ``2 ** d``. 

826 num_prunable : int 

827 The number of leaf parents that could be pruned. 

828 prob_choose : float 

829 The normalized probability of choosing the node to prune for growth. 

830 """ 

831 is_prunable = grove.is_leaves_parent(split_tree) 1a

832 num_prunable = jnp.count_nonzero(is_prunable) 1a

833 node_to_prune = randint_masked(key, is_prunable) 1a

834 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1a

835 

836 split_tree = split_tree.at[node_to_prune].set(0) 1a

837 affluence_tree = ( 1a

838 None if affluence_tree is None else 

839 affluence_tree.at[node_to_prune].set(True) 

840 ) 

841 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1a

842 prob_choose = p_propose_grow[node_to_prune] 1a

843 prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf) 1a

844 

845 return node_to_prune, num_prunable, prob_choose 1a

846 

847def randint_masked(key, mask): 1a

848 """ 

849 Return a random integer in a range, including only some values. 

850 

851 Parameters 

852 ---------- 

853 key : jax.dtypes.prng_key array 

854 A jax random key. 

855 mask : bool array (n,) 

856 The mask indicating the allowed values. 

857 

858 Returns 

859 ------- 

860 u : int 

861 A random integer in the range ``[0, n)``, and which satisfies 

862 ``mask[u] == True``. If all values in the mask are `False`, return `n`. 

863 """ 

864 ecdf = jnp.cumsum(mask) 1a

865 u = random.randint(key, (), 0, ecdf[-1]) 1a

866 return jnp.searchsorted(ecdf, u, 'right') 1a

867 

868def accept_moves_and_sample_leaves(bart, moves, key): 1a

869 """ 

870 Accept or reject the proposed moves and sample the new leaf values. 

871 

872 Parameters 

873 ---------- 

874 bart : dict 

875 A BART mcmc state. 

876 moves : dict 

877 The proposed moves, see `sample_moves`. 

878 key : jax.dtypes.prng_key array 

879 A jax random key. 

880 

881 Returns 

882 ------- 

883 bart : dict 

884 The new BART mcmc state. 

885 """ 

886 bart, moves, count_trees, move_counts, prelkv, prelk, prelf = accept_moves_parallel_stage(bart, moves, key) 1a

887 bart, moves = accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf) 1a

888 return accept_moves_final_stage(bart, moves) 1a

889 

890def accept_moves_parallel_stage(bart, moves, key): 1a

891 """ 

892 Pre-computes quantities used to accept moves, in parallel across trees. 

893 

894 Parameters 

895 ---------- 

896 bart : dict 

897 A BART mcmc state. 

898 moves : dict 

899 The proposed moves, see `sample_moves`. 

900 key : jax.dtypes.prng_key array 

901 A jax random key. 

902 

903 Returns 

904 ------- 

905 bart : dict 

906 A partially updated BART mcmc state. 

907 moves : dict 

908 The proposed moves, with the field 'partial_ratio' replaced 

909 by 'log_trans_prior_ratio'. 

910 count_trees : array (num_trees, 2 ** d) 

911 The number of points in each potential or actual leaf node. 

912 move_counts : dict 

913 The counts of the number of points in the the nodes modified by the 

914 moves. 

915 prelkv, prelk, prelf : dict 

916 Dictionary with pre-computed terms of the likelihood ratios and leaf 

917 samples. 

918 """ 

919 bart = bart.copy() 1a

920 

921 # where the move is grow, modify the state like the move was accepted 

922 bart['var_trees'] = moves['var_trees'] 1a

923 bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X']) 1a

924 bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves) 1a

925 

926 # count number of datapoints per leaf 

927 count_trees, move_counts = compute_count_trees(bart['leaf_indices'], moves, bart['opt']['count_batch_size']) 1a

928 if bart['opt']['require_min_points']: 928 ↛ 933line 928 didn't jump to line 933 because the condition on line 928 was always true1a

929 count_half_trees = count_trees[:, :bart['var_trees'].shape[1]] 1a

930 bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf'] 1a

931 

932 # compute some missing information about moves 

933 moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf']) 1a

934 bart['grow_prop_count'] = jnp.sum(moves['grow']) 1a

935 bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow']) 1a

936 

937 prelkv, prelk = precompute_likelihood_terms(count_trees, bart['sigma2'], move_counts) 1a

938 prelf = precompute_leaf_terms(count_trees, bart['sigma2'], key) 1a

939 

940 return bart, moves, count_trees, move_counts, prelkv, prelk, prelf 1a

941 

942@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None)) 1a

943def apply_grow_to_indices(moves, leaf_indices, X): 1a

944 """ 

945 Update the leaf indices to apply a grow move. 

946 

947 Parameters 

948 ---------- 

949 moves : dict 

950 The proposed moves, see `sample_moves`. 

951 leaf_indices : array (num_trees, n) 

952 The index of the leaf each datapoint falls into. 

953 X : array (p, n) 

954 The predictors matrix. 

955 

956 Returns 

957 ------- 

958 grow_leaf_indices : array (num_trees, n) 

959 The updated leaf indices. 

960 """ 

961 left_child = moves['node'].astype(leaf_indices.dtype) << 1 1a

962 go_right = X[moves['grow_var'], :] >= moves['grow_split'] 1a

963 tree_size = jnp.array(2 * moves['var_trees'].size) 1a

964 node_to_update = jnp.where(moves['grow'], moves['node'], tree_size) 1a

965 return jnp.where( 1a

966 leaf_indices == node_to_update, 

967 left_child + go_right, 

968 leaf_indices, 

969 ) 

970 

971def compute_count_trees(leaf_indices, moves, batch_size): 1a

972 """ 

973 Count the number of datapoints in each leaf. 

974 

975 Parameters 

976 ---------- 

977 grow_leaf_indices : int array (num_trees, n) 

978 The index of the leaf each datapoint falls into, if the grow move is 

979 accepted. 

980 moves : dict 

981 The proposed moves, see `sample_moves`. 

982 batch_size : int or None 

983 The data batch size to use for the summation. 

984 

985 Returns 

986 ------- 

987 count_trees : int array (num_trees, 2 ** (d - 1)) 

988 The number of points in each potential or actual leaf node. 

989 counts : dict 

990 The counts of the number of points in the the nodes modified by the 

991 moves, organized as two dictionaries 'grow' and 'prune', with subfields 

992 'left', 'right', and 'total'. 

993 """ 

994 

995 ntree, tree_size = moves['var_trees'].shape 1a

996 tree_size *= 2 1a

997 tree_indices = jnp.arange(ntree) 1a

998 

999 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1a

1000 

1001 # count datapoints in nodes modified by move 

1002 counts = dict() 1a

1003 counts['left'] = count_trees[tree_indices, moves['left']] 1a

1004 counts['right'] = count_trees[tree_indices, moves['right']] 1a

1005 counts['total'] = counts['left'] + counts['right'] 1a

1006 

1007 # write count into non-leaf node 

1008 count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total']) 1a

1009 

1010 return count_trees, counts 1a

1011 

1012def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size): 1a

1013 """ 

1014 Count the number of datapoints in each leaf. 

1015 

1016 Parameters 

1017 ---------- 

1018 leaf_indices : int array (num_trees, n) 

1019 The index of the leaf each datapoint falls into. 

1020 tree_size : int 

1021 The size of the leaf tree array (2 ** d). 

1022 batch_size : int or None 

1023 The data batch size to use for the summation. 

1024 

1025 Returns 

1026 ------- 

1027 count_trees : int array (num_trees, 2 ** (d - 1)) 

1028 The number of points in each leaf node. 

1029 """ 

1030 if batch_size is None: 1a

1031 return _count_scan(leaf_indices, tree_size) 1a

1032 else: 

1033 return _count_vec(leaf_indices, tree_size, batch_size) 1a

1034 

1035def _count_scan(leaf_indices, tree_size): 1a

1036 def loop(_, leaf_indices): 1a

1037 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1a

1038 _, count_trees = lax.scan(loop, None, leaf_indices) 1a

1039 return count_trees 1a

1040 

1041def _aggregate_scatter(values, indices, size, dtype): 1a

1042 return (jnp 1a

1043 .zeros(size, dtype) 

1044 .at[indices] 

1045 .add(values) 

1046 ) 

1047 

1048def _count_vec(leaf_indices, tree_size, batch_size): 1a

1049 return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size) 1a

1050 # uint16 is super-slow on gpu, don't use it even if n < 2^16 

1051 

1052def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size): 1a

1053 ntree, n = indices.shape 1a

1054 tree_indices = jnp.arange(ntree) 1a

1055 nbatches = n // batch_size + bool(n % batch_size) 1a

1056 batch_indices = jnp.arange(n) % nbatches 1a

1057 return (jnp 1a

1058 .zeros((ntree, size, nbatches), dtype) 

1059 .at[tree_indices[:, None], indices, batch_indices] 

1060 .add(values) 

1061 .sum(axis=2) 

1062 ) 

1063 

1064def complete_ratio(moves, move_counts, min_points_per_leaf): 1a

1065 """ 

1066 Complete non-likelihood MH ratio calculation. 

1067 

1068 This functions adds the probability of choosing the prune move. 

1069 

1070 Parameters 

1071 ---------- 

1072 moves : dict 

1073 The proposed moves, see `sample_moves`. 

1074 move_counts : dict 

1075 The counts of the number of points in the the nodes modified by the 

1076 moves. 

1077 min_points_per_leaf : int or None 

1078 The minimum number of data points in a leaf node. 

1079 

1080 Returns 

1081 ------- 

1082 moves : dict 

1083 The updated moves, with the field 'partial_ratio' replaced by 

1084 'log_trans_prior_ratio'. 

1085 """ 

1086 moves = moves.copy() 1a

1087 p_prune = compute_p_prune(moves, move_counts['left'], move_counts['right'], min_points_per_leaf) 1a

1088 moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune) 1a

1089 return moves 1a

1090 

1091def compute_p_prune(moves, left_count, right_count, min_points_per_leaf): 1a

1092 """ 

1093 Compute the probability of proposing a prune move. 

1094 

1095 Parameters 

1096 ---------- 

1097 moves : dict 

1098 The proposed moves, see `sample_moves`. 

1099 left_count, right_count : int 

1100 The number of datapoints in the proposed children of the leaf to grow. 

1101 min_points_per_leaf : int or None 

1102 The minimum number of data points in a leaf node. 

1103 

1104 Returns 

1105 ------- 

1106 p_prune : float 

1107 The probability of proposing a prune move. If grow: after accepting the 

1108 grow move, if prune: right away. 

1109 """ 

1110 

1111 # calculation in case the move is grow 

1112 other_growable_leaves = moves['num_growable'] >= 2 1a

1113 new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2 1a

1114 if min_points_per_leaf is not None: 1114 ↛ 1118line 1114 didn't jump to line 1118 because the condition on line 1114 was always true1a

1115 any_above_threshold = left_count >= 2 * min_points_per_leaf 1a

1116 any_above_threshold |= right_count >= 2 * min_points_per_leaf 1a

1117 new_leaves_growable &= any_above_threshold 1a

1118 grow_again_allowed = other_growable_leaves | new_leaves_growable 1a

1119 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1a

1120 

1121 # calculation in case the move is prune 

1122 prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1) 1a

1123 

1124 return jnp.where(moves['grow'], grow_p_prune, prune_p_prune) 1a

1125 

1126@jaxext.vmap_nodoc 1a

1127def adapt_leaf_trees_to_grow_indices(leaf_trees, moves): 1a

1128 """ 

1129 Modify leaf values such that the indices of the grow moves work on the 

1130 original tree. 

1131 

1132 Parameters 

1133 ---------- 

1134 leaf_trees : float array (num_trees, 2 ** d) 

1135 The leaf values. 

1136 moves : dict 

1137 The proposed moves, see `sample_moves`. 

1138 

1139 Returns 

1140 ------- 

1141 leaf_trees : float array (num_trees, 2 ** d) 

1142 The modified leaf values. The value of the leaf to grow is copied to 

1143 what would be its children if the grow move was accepted. 

1144 """ 

1145 values_at_node = leaf_trees[moves['node']] 1a

1146 return (leaf_trees 1a

1147 .at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)] 

1148 .set(values_at_node) 

1149 .at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)] 

1150 .set(values_at_node) 

1151 ) 

1152 

1153def precompute_likelihood_terms(count_trees, sigma2, move_counts): 1a

1154 """ 

1155 Pre-compute terms used in the likelihood ratio of the acceptance step. 

1156 

1157 Parameters 

1158 ---------- 

1159 count_trees : array (num_trees, 2 ** d) 

1160 The number of points in each potential or actual leaf node. 

1161 sigma2 : float 

1162 The noise variance. 

1163 move_counts : dict 

1164 The counts of the number of points in the the nodes modified by the 

1165 moves. 

1166 

1167 Returns 

1168 ------- 

1169 prelkv : dict 

1170 Dictionary with pre-computed terms of the likelihood ratio, one per 

1171 tree. 

1172 prelk : dict 

1173 Dictionary with pre-computed terms of the likelihood ratio, shared by 

1174 all trees. 

1175 """ 

1176 ntree = len(count_trees) 1a

1177 sigma_mu2 = 1 / ntree 1a

1178 prelkv = dict() 1a

1179 prelkv['sigma2_left'] = sigma2 + move_counts['left'] * sigma_mu2 1a

1180 prelkv['sigma2_right'] = sigma2 + move_counts['right'] * sigma_mu2 1a

1181 prelkv['sigma2_total'] = sigma2 + move_counts['total'] * sigma_mu2 1a

1182 prelkv['sqrt_term'] = jnp.log( 1a

1183 sigma2 * prelkv['sigma2_total'] / 

1184 (prelkv['sigma2_left'] * prelkv['sigma2_right']) 

1185 ) / 2 

1186 return prelkv, dict( 1a

1187 exp_factor=sigma_mu2 / (2 * sigma2), 

1188 ) 

1189 

1190def precompute_leaf_terms(count_trees, sigma2, key): 1a

1191 """ 

1192 Pre-compute terms used to sample leaves from their posterior. 

1193 

1194 Parameters 

1195 ---------- 

1196 count_trees : array (num_trees, 2 ** d) 

1197 The number of points in each potential or actual leaf node. 

1198 sigma2 : float 

1199 The noise variance. 

1200 key : jax.dtypes.prng_key array 

1201 A jax random key. 

1202 

1203 Returns 

1204 ------- 

1205 prelf : dict 

1206 Dictionary with pre-computed terms of the leaf sampling, with fields: 

1207 

1208 'mean_factor' : float array (num_trees, 2 ** d) 

1209 The factor to be multiplied by the sum of residuals to obtain the 

1210 posterior mean. 

1211 'centered_leaves' : float array (num_trees, 2 ** d) 

1212 The mean-zero normal values to be added to the posterior mean to 

1213 obtain the posterior leaf samples. 

1214 """ 

1215 ntree = len(count_trees) 1a

1216 prec_lk = count_trees / sigma2 1a

1217 var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior) 1a

1218 z = random.normal(key, count_trees.shape, sigma2.dtype) 1a

1219 return dict( 1a

1220 mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree 

1221 centered_leaves=z * jnp.sqrt(var_post), 

1222 ) 

1223 

1224def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf): 1a

1225 """ 

1226 The part of accepting the moves that has to be done one tree at a time. 

1227 

1228 Parameters 

1229 ---------- 

1230 bart : dict 

1231 A partially updated BART mcmc state. 

1232 count_trees : array (num_trees, 2 ** d) 

1233 The number of points in each potential or actual leaf node. 

1234 moves : dict 

1235 The proposed moves, see `sample_moves`. 

1236 move_counts : dict 

1237 The counts of the number of points in the the nodes modified by the 

1238 moves. 

1239 prelkv, prelk, prelf : dict 

1240 Dictionaries with pre-computed terms of the likelihood ratios and leaf 

1241 samples. 

1242 

1243 Returns 

1244 ------- 

1245 bart : dict 

1246 A partially updated BART mcmc state. 

1247 moves : dict 

1248 The proposed moves, with these additional fields: 

1249 

1250 'acc' : bool array (num_trees,) 

1251 Whether the move was accepted. 

1252 'to_prune' : bool array (num_trees,) 

1253 Whether, to reflect the acceptance status of the move, the state 

1254 should be updated by pruning the leaves involved in the move. 

1255 """ 

1256 bart = bart.copy() 1a

1257 moves = moves.copy() 1a

1258 

1259 def loop(resid, item): 1a

1260 resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves( 1a

1261 bart['X'], 

1262 len(bart['leaf_trees']), 

1263 bart['opt']['resid_batch_size'], 

1264 resid, 

1265 bart['min_points_per_leaf'], 

1266 'ratios' in bart, 

1267 prelk, 

1268 *item, 

1269 ) 

1270 return resid, (leaf_tree, acc, to_prune, ratios) 1a

1271 

1272 items = ( 1a

1273 bart['leaf_trees'], count_trees, 

1274 moves, move_counts, 

1275 bart['leaf_indices'], 

1276 prelkv, prelf, 

1277 ) 

1278 resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items) 1a

1279 

1280 bart['resid'] = resid 1a

1281 bart['leaf_trees'] = leaf_trees 1a

1282 bart.get('ratios', {}).update(ratios) 1a

1283 moves['acc'] = acc 1a

1284 moves['to_prune'] = to_prune 1a

1285 

1286 return bart, moves 1a

1287 

1288def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf): 1a

1289 """ 

1290 Accept or reject a proposed move and sample the new leaf values. 

1291 

1292 Parameters 

1293 ---------- 

1294 X : int array (p, n) 

1295 The predictors. 

1296 ntree : int 

1297 The number of trees in the forest. 

1298 resid_batch_size : int, None 

1299 The batch size for computing the sum of residuals in each leaf. 

1300 resid : float array (n,) 

1301 The residuals (data minus forest value). 

1302 min_points_per_leaf : int or None 

1303 The minimum number of data points in a leaf node. 

1304 save_ratios : bool 

1305 Whether to save the acceptance ratios. 

1306 prelk : dict 

1307 The pre-computed terms of the likelihood ratio which are shared across 

1308 trees. 

1309 leaf_tree : float array (2 ** d,) 

1310 The leaf values of the tree. 

1311 count_tree : int array (2 ** d,) 

1312 The number of datapoints in each leaf. 

1313 move : dict 

1314 The proposed move, see `sample_moves`. 

1315 leaf_indices : int array (n,) 

1316 The leaf indices for the largest version of the tree compatible with 

1317 the move. 

1318 prelkv, prelf : dict 

1319 The pre-computed terms of the likelihood ratio and leaf sampling which 

1320 are specific to the tree. 

1321 

1322 Returns 

1323 ------- 

1324 resid : float array (n,) 

1325 The updated residuals (data minus forest value). 

1326 leaf_tree : float array (2 ** d,) 

1327 The new leaf values of the tree. 

1328 acc : bool 

1329 Whether the move was accepted. 

1330 to_prune : bool 

1331 Whether, to reflect the acceptance status of the move, the state should 

1332 be updated by pruning the leaves involved in the move. 

1333 ratios : dict 

1334 The acceptance ratios for the moves. Empty if not to be saved. 

1335 """ 

1336 

1337 # sum residuals and count units per leaf, in tree proposed by grow move 

1338 resid_tree = sum_resid(resid, leaf_indices, leaf_tree.size, resid_batch_size) 1a

1339 

1340 # subtract starting tree from function 

1341 resid_tree += count_tree * leaf_tree 1a

1342 

1343 # get indices of move 

1344 node = move['node'] 1a

1345 assert node.dtype == jnp.int32 1a

1346 left = move['left'] 1a

1347 right = move['right'] 1a

1348 

1349 # sum residuals in parent node modified by move 

1350 resid_left = resid_tree[left] 1a

1351 resid_right = resid_tree[right] 1a

1352 resid_total = resid_left + resid_right 1a

1353 resid_tree = resid_tree.at[node].set(resid_total) 1a

1354 

1355 # compute acceptance ratio 

1356 log_lk_ratio = compute_likelihood_ratio(resid_total, resid_left, resid_right, prelkv, prelk) 1a

1357 log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio 1a

1358 log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio) 1a

1359 ratios = {} 1a

1360 if save_ratios: 1a

1361 ratios.update( 1a

1362 log_trans_prior=move['log_trans_prior_ratio'], 

1363 log_likelihood=log_lk_ratio, 

1364 ) 

1365 

1366 # determine whether to accept the move 

1367 acc = move['allowed'] & (move['logu'] <= log_ratio) 1a

1368 if min_points_per_leaf is not None: 1368 ↛ 1373line 1368 didn't jump to line 1373 because the condition on line 1368 was always true1a

1369 acc &= move_counts['left'] >= min_points_per_leaf 1a

1370 acc &= move_counts['right'] >= min_points_per_leaf 1a

1371 

1372 # compute leaves posterior and sample leaves 

1373 initial_leaf_tree = leaf_tree 1a

1374 mean_post = resid_tree * prelf['mean_factor'] 1a

1375 leaf_tree = mean_post + prelf['centered_leaves'] 1a

1376 

1377 # copy leaves around such that the leaf indices select the right leaf 

1378 to_prune = acc ^ move['grow'] 1a

1379 leaf_tree = (leaf_tree 1a

1380 .at[jnp.where(to_prune, left, leaf_tree.size)] 

1381 .set(leaf_tree[node]) 

1382 .at[jnp.where(to_prune, right, leaf_tree.size)] 

1383 .set(leaf_tree[node]) 

1384 ) 

1385 

1386 # replace old tree with new tree in function values 

1387 resid += (initial_leaf_tree - leaf_tree)[leaf_indices] 1a

1388 

1389 return resid, leaf_tree, acc, to_prune, ratios 1a

1390 

1391def sum_resid(resid, leaf_indices, tree_size, batch_size): 1a

1392 """ 

1393 Sum the residuals in each leaf. 

1394 

1395 Parameters 

1396 ---------- 

1397 resid : float array (n,) 

1398 The residuals (data minus forest value). 

1399 leaf_indices : int array (n,) 

1400 The leaf indices of the tree (in which leaf each data point falls into). 

1401 tree_size : int 

1402 The size of the tree array (2 ** d). 

1403 batch_size : int, None 

1404 The data batch size for the aggregation. Batching increases numerical 

1405 accuracy and parallelism. 

1406 

1407 Returns 

1408 ------- 

1409 resid_tree : float array (2 ** d,) 

1410 The sum of the residuals at data points in each leaf. 

1411 """ 

1412 if batch_size is None: 1a

1413 aggr_func = _aggregate_scatter 1a

1414 else: 

1415 aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size) 1a

1416 return aggr_func(resid, leaf_indices, tree_size, jnp.float32) 1a

1417 

1418def _aggregate_batched_onetree(values, indices, size, dtype, batch_size): 1a

1419 n, = indices.shape 1a

1420 nbatches = n // batch_size + bool(n % batch_size) 1a

1421 batch_indices = jnp.arange(n) % nbatches 1a

1422 return (jnp 1a

1423 .zeros((size, nbatches), dtype) 

1424 .at[indices, batch_indices] 

1425 .add(values) 

1426 .sum(axis=1) 

1427 ) 

1428 

1429def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk): 1a

1430 """ 

1431 Compute the likelihood ratio of a grow move. 

1432 

1433 Parameters 

1434 ---------- 

1435 total_resid : float 

1436 The sum of the residuals in the leaf to grow. 

1437 left_resid, right_resid : float 

1438 The sum of the residuals in the left/right child of the leaf to grow. 

1439 prelkv, prelk : dict 

1440 The pre-computed terms of the likelihood ratio, see 

1441 `precompute_likelihood_terms`. 

1442 

1443 Returns 

1444 ------- 

1445 ratio : float 

1446 The likelihood ratio P(data | new tree) / P(data | old tree). 

1447 """ 

1448 exp_term = prelk['exp_factor'] * ( 1a

1449 left_resid * left_resid / prelkv['sigma2_left'] + 

1450 right_resid * right_resid / prelkv['sigma2_right'] - 

1451 total_resid * total_resid / prelkv['sigma2_total'] 

1452 ) 

1453 return prelkv['sqrt_term'] + exp_term 1a

1454 

1455def accept_moves_final_stage(bart, moves): 1a

1456 """ 

1457 The final part of accepting the moves, in parallel across trees. 

1458 

1459 Parameters 

1460 ---------- 

1461 bart : dict 

1462 A partially updated BART mcmc state. 

1463 counts : dict 

1464 The indicators of proposals and acceptances for grow and prune moves. 

1465 moves : dict 

1466 The proposed moves (see `sample_moves`) as updated by 

1467 `accept_moves_sequential_stage`. 

1468 

1469 Returns 

1470 ------- 

1471 bart : dict 

1472 The fully updated BART mcmc state. 

1473 """ 

1474 bart = bart.copy() 1a

1475 bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow']) 1a

1476 bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow']) 1a

1477 bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves) 1a

1478 bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves) 1a

1479 return bart 1a

1480 

1481@jax.vmap 1a

1482def apply_moves_to_leaf_indices(leaf_indices, moves): 1a

1483 """ 

1484 Update the leaf indices to match the accepted move. 

1485 

1486 Parameters 

1487 ---------- 

1488 leaf_indices : int array (num_trees, n) 

1489 The index of the leaf each datapoint falls into, if the grow move was 

1490 accepted. 

1491 moves : dict 

1492 The proposed moves (see `sample_moves`), as updated by 

1493 `accept_moves_sequential_stage`. 

1494 

1495 Returns 

1496 ------- 

1497 leaf_indices : int array (num_trees, n) 

1498 The updated leaf indices. 

1499 """ 

1500 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1a

1501 is_child = (leaf_indices & mask) == moves['left'] 1a

1502 return jnp.where( 1a

1503 is_child & moves['to_prune'], 

1504 moves['node'].astype(leaf_indices.dtype), 

1505 leaf_indices, 

1506 ) 

1507 

1508@jax.vmap 1a

1509def apply_moves_to_split_trees(split_trees, moves): 1a

1510 """ 

1511 Update the split trees to match the accepted move. 

1512 

1513 Parameters 

1514 ---------- 

1515 split_trees : int array (num_trees, 2 ** (d - 1)) 

1516 The cutpoints of the decision nodes in the initial trees. 

1517 moves : dict 

1518 The proposed moves (see `sample_moves`), as updated by 

1519 `accept_moves_sequential_stage`. 

1520 

1521 Returns 

1522 ------- 

1523 split_trees : int array (num_trees, 2 ** (d - 1)) 

1524 The updated split trees. 

1525 """ 

1526 return (split_trees 1a

1527 .at[jnp.where( 

1528 moves['grow'], 

1529 moves['node'], 

1530 split_trees.size, 

1531 )] 

1532 .set(moves['grow_split'].astype(split_trees.dtype)) 

1533 .at[jnp.where( 

1534 moves['to_prune'], 

1535 moves['node'], 

1536 split_trees.size, 

1537 )] 

1538 .set(0) 

1539 ) 

1540 

1541def sample_sigma(bart, key): 1a

1542 """ 

1543 Noise variance sampling step of BART MCMC. 

1544 

1545 Parameters 

1546 ---------- 

1547 bart : dict 

1548 A BART mcmc state, as created by `init`. 

1549 key : jax.dtypes.prng_key array 

1550 A jax random key. 

1551 

1552 Returns 

1553 ------- 

1554 bart : dict 

1555 The new BART mcmc state. 

1556 """ 

1557 bart = bart.copy() 1a

1558 

1559 resid = bart['resid'] 1a

1560 alpha = bart['sigma2_alpha'] + resid.size / 2 1a

1561 norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float']) 1a

1562 beta = bart['sigma2_beta'] + norm2 / 2 1a

1563 

1564 sample = random.gamma(key, alpha) 1a

1565 bart['sigma2'] = beta / sample 1a

1566 

1567 return bart 1a