Coverage for src/bartz/mcmcstep.py: 95%

491 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 14:46 +0000

1# bartz/src/bartz/mcmcstep.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26Functions that implement the BART posterior MCMC initialization and update step. 

27 

28Functions that do MCMC steps operate by taking as input a bart state, and 

29outputting a new state. The inputs are not modified. 

30 

31The main entry points are: 

32 

33 - `State`: The dataclass that represents a BART MCMC state. 

34 - `init`: Creates an initial `State` from data and configurations. 

35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`. 

36""" 

37 

38import math 1ab

39from dataclasses import replace 1ab

40from functools import cache, partial 1ab

41from typing import Any, Literal 1ab

42 

43import jax 1ab

44from equinox import Module, field, tree_at 1ab

45from jax import lax, random 1ab

46from jax import numpy as jnp 1ab

47from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab

48 

49from bartz import grove 1ab

50from bartz.jaxext import ( 1ab

51 minimal_unsigned_dtype, 

52 split, 

53 truncated_normal_onesided, 

54 vmap_nodoc, 

55) 

56 

57 

58class Forest(Module): 1ab

59 """ 

60 Represents the MCMC state of a sum of trees. 

61 

62 Parameters 

63 ---------- 

64 leaf_tree 

65 The leaf values. 

66 var_tree 

67 The decision axes. 

68 split_tree 

69 The decision boundaries. 

70 affluence_tree 

71 Marks leaves that can be grown. 

72 p_nonterminal 

73 The prior probability of each node being nonterminal, conditional on 

74 its ancestors. Includes the nodes at maximum depth which should be set 

75 to 0. 

76 p_propose_grow 

77 The unnormalized probability of picking a leaf for a grow proposal. 

78 leaf_indices 

79 The index of the leaf each datapoints falls into, for each tree. 

80 min_points_per_decision_node 

81 The minimum number of data points in a decision node. 

82 min_points_per_leaf 

83 The minimum number of data points in a leaf node. 

84 resid_batch_size 

85 count_batch_size 

86 The data batch sizes for computing the sufficient statistics. If `None`, 

87 they are computed with no batching. 

88 log_trans_prior 

89 The log transition and prior Metropolis-Hastings ratio for the 

90 proposed move on each tree. 

91 log_likelihood 

92 The log likelihood ratio. 

93 grow_prop_count 

94 prune_prop_count 

95 The number of grow/prune proposals made during one full MCMC cycle. 

96 grow_acc_count 

97 prune_acc_count 

98 The number of grow/prune moves accepted during one full MCMC cycle. 

99 sigma_mu2 

100 The prior variance of a leaf, conditional on the tree structure. 

101 """ 

102 

103 leaf_tree: Float32[Array, 'num_trees 2**d'] 1ab

104 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

105 split_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

106 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

107 max_split: UInt[Array, ' p'] 1ab

108 blocked_vars: UInt[Array, ' k'] | None 1ab

109 p_nonterminal: Float32[Array, ' 2**d'] 1ab

110 p_propose_grow: Float32[Array, ' 2**(d-1)'] 1ab

111 leaf_indices: UInt[Array, 'num_trees n'] 1ab

112 min_points_per_decision_node: Int32[Array, ''] | None 1ab

113 min_points_per_leaf: Int32[Array, ''] | None 1ab

114 resid_batch_size: int | None = field(static=True) 1ab

115 count_batch_size: int | None = field(static=True) 1ab

116 log_trans_prior: Float32[Array, ' num_trees'] | None 1ab

117 log_likelihood: Float32[Array, ' num_trees'] | None 1ab

118 grow_prop_count: Int32[Array, ''] 1ab

119 prune_prop_count: Int32[Array, ''] 1ab

120 grow_acc_count: Int32[Array, ''] 1ab

121 prune_acc_count: Int32[Array, ''] 1ab

122 sigma_mu2: Float32[Array, ''] 1ab

123 

124 

125class State(Module): 1ab

126 """ 

127 Represents the MCMC state of BART. 

128 

129 Parameters 

130 ---------- 

131 X 

132 The predictors. 

133 max_split 

134 The maximum split index for each predictor. 

135 y 

136 The response. If the data type is `bool`, the model is binary regression. 

137 resid 

138 The residuals (`y` or `z` minus sum of trees). 

139 z 

140 The latent variable for binary regression. `None` in continuous 

141 regression. 

142 offset 

143 Constant shift added to the sum of trees. 

144 sigma2 

145 The error variance. `None` in binary regression. 

146 prec_scale 

147 The scale on the error precision, i.e., ``1 / error_scale ** 2``. 

148 `None` in binary regression. 

149 sigma2_alpha 

150 sigma2_beta 

151 The shape and scale parameters of the inverse gamma prior on the noise 

152 variance. `None` in binary regression. 

153 forest 

154 The sum of trees model. 

155 """ 

156 

157 X: UInt[Array, 'p n'] 1ab

158 y: Float32[Array, ' n'] | Bool[Array, ' n'] 1ab

159 z: None | Float32[Array, ' n'] 1ab

160 offset: Float32[Array, ''] 1ab

161 resid: Float32[Array, ' n'] 1ab

162 sigma2: Float32[Array, ''] | None 1ab

163 prec_scale: Float32[Array, ' n'] | None 1ab

164 sigma2_alpha: Float32[Array, ''] | None 1ab

165 sigma2_beta: Float32[Array, ''] | None 1ab

166 forest: Forest 1ab

167 

168 

169def init( 1ab

170 *, 

171 X: UInt[Any, 'p n'], 

172 y: Float32[Any, ' n'] | Bool[Any, ' n'], 

173 offset: float | Float32[Any, ''] = 0.0, 

174 max_split: UInt[Any, ' p'], 

175 num_trees: int, 

176 p_nonterminal: Float32[Any, ' d-1'], 

177 sigma_mu2: float | Float32[Any, ''], 

178 sigma2_alpha: float | Float32[Any, ''] | None = None, 

179 sigma2_beta: float | Float32[Any, ''] | None = None, 

180 error_scale: Float32[Any, ' n'] | None = None, 

181 min_points_per_decision_node: int | Integer[Any, ''] | None = None, 

182 resid_batch_size: int | None | Literal['auto'] = 'auto', 

183 count_batch_size: int | None | Literal['auto'] = 'auto', 

184 save_ratios: bool = False, 

185 filter_splitless_vars: bool = True, 

186 min_points_per_leaf: int | Integer[Any, ''] | None = None, 

187) -> State: 

188 """ 

189 Make a BART posterior sampling MCMC initial state. 

190 

191 Parameters 

192 ---------- 

193 X 

194 The predictors. Note this is trasposed compared to the usual convention. 

195 y 

196 The response. If the data type is `bool`, the regression model is binary 

197 regression with probit. 

198 offset 

199 Constant shift added to the sum of trees. 0 if not specified. 

200 max_split 

201 The maximum split index for each variable. All split ranges start at 1. 

202 num_trees 

203 The number of trees in the forest. 

204 p_nonterminal 

205 The probability of a nonterminal node at each depth. The maximum depth 

206 of trees is fixed by the length of this array. 

207 sigma_mu2 

208 The prior variance of a leaf, conditional on the tree structure. The 

209 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The 

210 prior mean of leaves is always zero. 

211 sigma2_alpha 

212 sigma2_beta 

213 The shape and scale parameters of the inverse gamma prior on the error 

214 variance. Leave unspecified for binary regression. 

215 error_scale 

216 Each error is scaled by the corresponding factor in `error_scale`, so 

217 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``. 

218 Not supported for binary regression. If not specified, defaults to 1 for 

219 all points, but potentially skipping calculations. 

220 min_points_per_decision_node 

221 The minimum number of data points in a decision node. 0 if not 

222 specified. 

223 resid_batch_size 

224 count_batch_size 

225 The batch sizes, along datapoints, for summing the residuals and 

226 counting the number of datapoints in each leaf. `None` for no batching. 

227 If 'auto', pick a value based on the device of `y`, or the default 

228 device. 

229 save_ratios 

230 Whether to save the Metropolis-Hastings ratios. 

231 filter_splitless_vars 

232 Whether to check `max_split` for variables without available cutpoints. 

233 If any are found, they are put into a list of variables to exclude from 

234 the MCMC. If `False`, no check is performed, but the results may be 

235 wrong if any variable is blocked. The function is jax-traceable only 

236 if this is set to `False`. 

237 min_points_per_leaf 

238 The minimum number of datapoints in a leaf node. 0 if not specified. 

239 Unlike `min_points_per_decision_node`, this constraint is not taken into 

240 account in the Metropolis-Hastings ratio because it would be expensive 

241 to compute. Grow moves that would violate this constraint are vetoed. 

242 This parameter is independent of `min_points_per_decision_node` and 

243 there is no check that they are coherent. It makes sense to set 

244 ``min_points_per_decision_node >= 2 * min_points_per_leaf``. 

245 

246 Returns 

247 ------- 

248 An initialized BART MCMC state. 

249 

250 Raises 

251 ------ 

252 ValueError 

253 If `y` is boolean and arguments unused in binary regression are set. 

254 

255 Notes 

256 ----- 

257 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out 

258 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left 

259 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be 

260 integers in the range ``[0, 1, ..., max_split[i]]``. 

261 """ 

262 p_nonterminal = jnp.asarray(p_nonterminal) 1ab

263 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab

264 max_depth = p_nonterminal.size 1ab

265 

266 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab

267 def make_forest(max_depth, dtype): 1ab

268 return grove.make_tree(max_depth, dtype) 1ab

269 

270 y = jnp.asarray(y) 1ab

271 offset = jnp.asarray(offset) 1ab

272 

273 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab

274 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees 

275 ) 

276 

277 is_binary = y.dtype == bool 1ab

278 if is_binary: 1ab

279 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true1ab

280 msg = ( 

281 'error_scale, sigma2_alpha, and sigma2_beta must be set ' 

282 ' to `None` for binary regression.' 

283 ) 

284 raise ValueError(msg) 

285 sigma2 = None 1ab

286 else: 

287 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab

288 sigma2_beta = jnp.asarray(sigma2_beta) 1ab

289 sigma2 = sigma2_beta / sigma2_alpha 1ab

290 

291 max_split = jnp.asarray(max_split) 1ab

292 

293 if filter_splitless_vars: 1ab

294 (blocked_vars,) = jnp.nonzero(max_split == 0) 1ab

295 blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size)) 1ab

296 # see `fully_used_variables` for the type cast 

297 else: 

298 blocked_vars = None 1ab

299 

300 return State( 1ab

301 X=jnp.asarray(X), 

302 y=y, 

303 z=jnp.full(y.shape, offset) if is_binary else None, 

304 offset=offset, 

305 resid=jnp.zeros(y.shape) if is_binary else y - offset, 

306 sigma2=sigma2, 

307 prec_scale=( 

308 None if error_scale is None else lax.reciprocal(jnp.square(error_scale)) 

309 ), 

310 sigma2_alpha=sigma2_alpha, 

311 sigma2_beta=sigma2_beta, 

312 forest=Forest( 

313 leaf_tree=make_forest(max_depth, jnp.float32), 

314 var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)), 

315 split_tree=make_forest(max_depth - 1, max_split.dtype), 

316 affluence_tree=( 

317 make_forest(max_depth - 1, bool) 

318 .at[:, 1] 

319 .set( 

320 True 

321 if min_points_per_decision_node is None 

322 else y.size >= min_points_per_decision_node 

323 ) 

324 ), 

325 blocked_vars=blocked_vars, 

326 max_split=max_split, 

327 grow_prop_count=jnp.zeros((), int), 

328 grow_acc_count=jnp.zeros((), int), 

329 prune_prop_count=jnp.zeros((), int), 

330 prune_acc_count=jnp.zeros((), int), 

331 p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)], 

332 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))], 

333 leaf_indices=jnp.ones( 

334 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1) 

335 ), 

336 min_points_per_decision_node=( 

337 None 

338 if min_points_per_decision_node is None 

339 else jnp.asarray(min_points_per_decision_node) 

340 ), 

341 min_points_per_leaf=( 

342 None 

343 if min_points_per_leaf is None 

344 else jnp.asarray(min_points_per_leaf) 

345 ), 

346 resid_batch_size=resid_batch_size, 

347 count_batch_size=count_batch_size, 

348 log_trans_prior=jnp.zeros(num_trees) if save_ratios else None, 

349 log_likelihood=jnp.zeros(num_trees) if save_ratios else None, 

350 sigma_mu2=jnp.asarray(sigma_mu2), 

351 ), 

352 ) 

353 

354 

355def _choose_suffstat_batch_size( 1ab

356 resid_batch_size, count_batch_size, y, forest_size 

357) -> tuple[int | None, ...]: 

358 @cache 1ab

359 def get_platform(): 1ab

360 try: 1ab

361 device = y.devices().pop() 1ab

362 except jax.errors.ConcretizationTypeError: 1ab

363 device = jax.devices()[0] 1ab

364 platform = device.platform 1ab

365 if platform not in ('cpu', 'gpu'): 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true1ab

366 msg = f'Unknown platform: {platform}' 

367 raise KeyError(msg) 

368 return platform 1ab

369 

370 if resid_batch_size == 'auto': 1ab

371 platform = get_platform() 1ab

372 n = max(1, y.size) 1ab

373 if platform == 'cpu': 373 ↛ 375line 373 didn't jump to line 375 because the condition on line 373 was always true1ab

374 resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6 1ab

375 elif platform == 'gpu': 

376 resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3 

377 resid_batch_size = max(1, resid_batch_size) 1ab

378 

379 if count_batch_size == 'auto': 1ab

380 platform = get_platform() 1ab

381 if platform == 'cpu': 381 ↛ 383line 381 didn't jump to line 383 because the condition on line 381 was always true1ab

382 count_batch_size = None 1ab

383 elif platform == 'gpu': 

384 n = max(1, y.size) 

385 count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2 

386 # /4 is good on V100, /2 on L4/T4, still haven't tried A100 

387 max_memory = 2**29 

388 itemsize = 4 

389 min_batch_size = math.ceil(forest_size * itemsize * n / max_memory) 

390 count_batch_size = max(count_batch_size, min_batch_size) 

391 count_batch_size = max(1, count_batch_size) 

392 

393 return resid_batch_size, count_batch_size 1ab

394 

395 

396@jax.jit 1ab

397def step(key: Key[Array, ''], bart: State) -> State: 1ab

398 """ 

399 Do one MCMC step. 

400 

401 Parameters 

402 ---------- 

403 key 

404 A jax random key. 

405 bart 

406 A BART mcmc state, as created by `init`. 

407 

408 Returns 

409 ------- 

410 The new BART mcmc state. 

411 """ 

412 keys = split(key) 1ab

413 

414 if bart.y.dtype == bool: # binary regression 1ab

415 bart = replace(bart, sigma2=jnp.float32(1)) 1ab

416 bart = step_trees(keys.pop(), bart) 1ab

417 bart = replace(bart, sigma2=None) 1ab

418 return step_z(keys.pop(), bart) 1ab

419 

420 else: # continuous regression 

421 bart = step_trees(keys.pop(), bart) 1ab

422 return step_sigma(keys.pop(), bart) 1ab

423 

424 

425def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab

426 """ 

427 Forest sampling step of BART MCMC. 

428 

429 Parameters 

430 ---------- 

431 key 

432 A jax random key. 

433 bart 

434 A BART mcmc state, as created by `init`. 

435 

436 Returns 

437 ------- 

438 The new BART mcmc state. 

439 

440 Notes 

441 ----- 

442 This function zeroes the proposal counters. 

443 """ 

444 keys = split(key) 1ab

445 moves = propose_moves(keys.pop(), bart.forest) 1ab

446 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab

447 

448 

449class Moves(Module): 1ab

450 """ 

451 Moves proposed to modify each tree. 

452 

453 Parameters 

454 ---------- 

455 allowed 

456 Whether there is a possible move. If `False`, the other values may not 

457 make sense. The only case in which a move is marked as allowed but is 

458 then vetoed is if it does not satisfy `min_points_per_leaf`, which for 

459 efficiency is implemented post-hoc without changing the rest of the 

460 MCMC logic. 

461 grow 

462 Whether the move is a grow move or a prune move. 

463 num_growable 

464 The number of growable leaves in the original tree. 

465 node 

466 The index of the leaf to grow or node to prune. 

467 left 

468 right 

469 The indices of the children of 'node'. 

470 partial_ratio 

471 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

472 likelihood ratio, the probability of proposing the prune move, and the 

473 probability that the children of the modified node are terminal. If the 

474 move is PRUNE, the ratio is inverted. `None` once 

475 `log_trans_prior_ratio` has been computed. 

476 log_trans_prior_ratio 

477 The logarithm of the product of the transition and prior terms of the 

478 Metropolis-Hastings ratio for the acceptance of the proposed move. 

479 `None` if not yet computed. If PRUNE, the log-ratio is negated. 

480 grow_var 

481 The decision axes of the new rules. 

482 grow_split 

483 The decision boundaries of the new rules. 

484 var_tree 

485 The updated decision axes of the trees, valid whatever move. 

486 affluence_tree 

487 A partially updated `affluence_tree`, marking non-leaf nodes that would 

488 become leaves if the move was accepted. This mark initially (out of 

489 `propose_moves`) takes into account if there would be available decision 

490 rules to grow the leaf, and whether there are enough datapoints in the 

491 node is marked in `accept_moves_parallel_stage`. 

492 logu 

493 The logarithm of a uniform (0, 1] random variable to be used to 

494 accept the move. It's in (-oo, 0]. 

495 acc 

496 Whether the move was accepted. `None` if not yet computed. 

497 to_prune 

498 Whether the final operation to apply the move is pruning. This indicates 

499 an accepted prune move or a rejected grow move. `None` if not yet 

500 computed. 

501 """ 

502 

503 allowed: Bool[Array, ' num_trees'] 1ab

504 grow: Bool[Array, ' num_trees'] 1ab

505 num_growable: UInt[Array, ' num_trees'] 1ab

506 node: UInt[Array, ' num_trees'] 1ab

507 left: UInt[Array, ' num_trees'] 1ab

508 right: UInt[Array, ' num_trees'] 1ab

509 partial_ratio: Float32[Array, ' num_trees'] | None 1ab

510 log_trans_prior_ratio: None | Float32[Array, ' num_trees'] 1ab

511 grow_var: UInt[Array, ' num_trees'] 1ab

512 grow_split: UInt[Array, ' num_trees'] 1ab

513 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

514 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

515 logu: Float32[Array, ' num_trees'] 1ab

516 acc: None | Bool[Array, ' num_trees'] 1ab

517 to_prune: None | Bool[Array, ' num_trees'] 1ab

518 

519 

520def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 1ab

521 """ 

522 Propose moves for all the trees. 

523 

524 There are two types of moves: GROW (convert a leaf to a decision node and 

525 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a 

526 leaf, deleting its children). 

527 

528 Parameters 

529 ---------- 

530 key 

531 A jax random key. 

532 forest 

533 The `forest` field of a BART MCMC state. 

534 

535 Returns 

536 ------- 

537 The proposed move for each tree. 

538 """ 

539 num_trees, _ = forest.leaf_tree.shape 1ab

540 keys = split(key, 1 + 2 * num_trees) 1ab

541 

542 # compute moves 

543 grow_moves = propose_grow_moves( 1ab

544 keys.pop(num_trees), 

545 forest.var_tree, 

546 forest.split_tree, 

547 forest.affluence_tree, 

548 forest.max_split, 

549 forest.blocked_vars, 

550 forest.p_nonterminal, 

551 forest.p_propose_grow, 

552 ) 

553 prune_moves = propose_prune_moves( 1ab

554 keys.pop(num_trees), 

555 forest.split_tree, 

556 grow_moves.affluence_tree, 

557 forest.p_nonterminal, 

558 forest.p_propose_grow, 

559 ) 

560 

561 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1ab

562 

563 # choose between grow or prune 

564 p_grow = jnp.where( 1ab

565 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed 

566 ) 

567 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab

568 

569 # compute children indices 

570 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab

571 left = node << 1 1ab

572 right = left + 1 1ab

573 

574 return Moves( 1ab

575 allowed=grow_moves.allowed | prune_moves.allowed, 

576 grow=grow, 

577 num_growable=grow_moves.num_growable, 

578 node=node, 

579 left=left, 

580 right=right, 

581 partial_ratio=jnp.where( 

582 grow, grow_moves.partial_ratio, prune_moves.partial_ratio 

583 ), 

584 log_trans_prior_ratio=None, # will be set in complete_ratio 

585 grow_var=grow_moves.var, 

586 grow_split=grow_moves.split, 

587 # var_tree does not need to be updated if prune 

588 var_tree=grow_moves.var_tree, 

589 # affluence_tree is updated for both moves unconditionally, prune last 

590 affluence_tree=prune_moves.affluence_tree, 

591 logu=jnp.log1p(-exp1mlogu), 

592 acc=None, # will be set in accept_moves_sequential_stage 

593 to_prune=None, # will be set in accept_moves_sequential_stage 

594 ) 

595 

596 

597class GrowMoves(Module): 1ab

598 """ 

599 Represent a proposed grow move for each tree. 

600 

601 Parameters 

602 ---------- 

603 allowed 

604 Whether the move is allowed for proposal. 

605 num_growable 

606 The number of leaves that can be proposed for grow. 

607 node 

608 The index of the leaf to grow. ``2 ** d`` if there are no growable 

609 leaves. 

610 var 

611 split 

612 The decision axis and boundary of the new rule. 

613 partial_ratio 

614 A factor of the Metropolis-Hastings ratio of the move. It lacks 

615 the likelihood ratio and the probability of proposing the prune 

616 move. 

617 var_tree 

618 The updated decision axes of the tree. 

619 affluence_tree 

620 A partially updated `affluence_tree` that marks each new leaf that 

621 would be produced as `True` if it would have available decision rules. 

622 """ 

623 

624 allowed: Bool[Array, ' num_trees'] 1ab

625 num_growable: UInt[Array, ' num_trees'] 1ab

626 node: UInt[Array, ' num_trees'] 1ab

627 var: UInt[Array, ' num_trees'] 1ab

628 split: UInt[Array, ' num_trees'] 1ab

629 partial_ratio: Float32[Array, ' num_trees'] 1ab

630 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

631 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

632 

633 

634@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None)) 1ab

635def propose_grow_moves( 1ab

636 key: Key[Array, ' num_trees'], 

637 var_tree: UInt[Array, 'num_trees 2**(d-1)'], 

638 split_tree: UInt[Array, 'num_trees 2**(d-1)'], 

639 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'], 

640 max_split: UInt[Array, ' p'], 

641 blocked_vars: Int32[Array, ' k'] | None, 

642 p_nonterminal: Float32[Array, ' 2**d'], 

643 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

644) -> GrowMoves: 

645 """ 

646 Propose a GROW move for each tree. 

647 

648 A GROW move picks a leaf node and converts it to a non-terminal node with 

649 two leaf children. 

650 

651 Parameters 

652 ---------- 

653 key 

654 A jax random key. 

655 var_tree 

656 The splitting axes of the tree. 

657 split_tree 

658 The splitting points of the tree. 

659 affluence_tree 

660 Whether each leaf has enough points to be grown. 

661 max_split 

662 The maximum split index for each variable. 

663 blocked_vars 

664 The indices of the variables that have no available cutpoints. 

665 p_nonterminal 

666 The a priori probability of a node to be nonterminal conditional on the 

667 ancestors, including at the maximum depth where it should be zero. 

668 p_propose_grow 

669 The unnormalized probability of choosing a leaf to grow. 

670 

671 Returns 

672 ------- 

673 An object representing the proposed move. 

674 

675 Notes 

676 ----- 

677 The move is not proposed if each leaf is already at maximum depth, or has 

678 less datapoints than the requested threshold `min_points_per_decision_node`, 

679 or it does not have any available decision rules given its ancestors. This 

680 is marked by setting `allowed` to `False` and `num_growable` to 0. 

681 """ 

682 keys = split(key, 3) 1ab

683 

684 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab

685 keys.pop(), split_tree, affluence_tree, p_propose_grow 

686 ) 

687 

688 # sample a decision rule 

689 var, num_available_var = choose_variable( 1ab

690 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars 

691 ) 

692 split_idx, l, r = choose_split( 1ab

693 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow 

694 ) 

695 

696 # determine if the new leaves would have available decision rules; if the 

697 # move is blocked, these values may not make sense 

698 left_growable = right_growable = num_available_var > 1 1ab

699 left_growable |= l < split_idx 1ab

700 right_growable |= split_idx + 1 < r 1ab

701 left = leaf_to_grow << 1 1ab

702 right = left + 1 1ab

703 affluence_tree = affluence_tree.at[left].set(left_growable) 1ab

704 affluence_tree = affluence_tree.at[right].set(right_growable) 1ab

705 

706 ratio = compute_partial_ratio( 1ab

707 prob_choose, num_prunable, p_nonterminal, leaf_to_grow 

708 ) 

709 

710 return GrowMoves( 1ab

711 allowed=num_growable > 0, 

712 num_growable=num_growable, 

713 node=leaf_to_grow, 

714 var=var, 

715 split=split_idx, 

716 partial_ratio=ratio, 

717 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)), 

718 affluence_tree=affluence_tree, 

719 ) 

720 

721 

722def choose_leaf( 1ab

723 key: Key[Array, ''], 

724 split_tree: UInt[Array, ' 2**(d-1)'], 

725 affluence_tree: Bool[Array, ' 2**(d-1)'], 

726 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

727) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]: 

728 """ 

729 Choose a leaf node to grow in a tree. 

730 

731 Parameters 

732 ---------- 

733 key 

734 A jax random key. 

735 split_tree 

736 The splitting points of the tree. 

737 affluence_tree 

738 Whether a leaf has enough points that it could be split into two leaves 

739 satisfying the `min_points_per_leaf` requirement. 

740 p_propose_grow 

741 The unnormalized probability of choosing a leaf to grow. 

742 

743 Returns 

744 ------- 

745 leaf_to_grow : Int32[Array, ''] 

746 The index of the leaf to grow. If ``num_growable == 0``, return 

747 ``2 ** d``. 

748 num_growable : Int32[Array, ''] 

749 The number of leaf nodes that can be grown, i.e., are nonterminal 

750 and have at least twice `min_points_per_leaf`. 

751 prob_choose : Float32[Array, ''] 

752 The (normalized) probability that this function had to choose that 

753 specific leaf, given the arguments. 

754 num_prunable : Int32[Array, ''] 

755 The number of leaf parents that could be pruned, after converting the 

756 selected leaf to a non-terminal node. 

757 """ 

758 is_growable = growable_leaves(split_tree, affluence_tree) 1ab

759 num_growable = jnp.count_nonzero(is_growable) 1ab

760 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab

761 leaf_to_grow, distr_norm = categorical(key, distr) 1ab

762 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab

763 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1ab

764 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab

765 num_prunable = jnp.count_nonzero(is_parent) 1ab

766 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab

767 

768 

769def growable_leaves( 1ab

770 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)'] 

771) -> Bool[Array, ' 2**(d-1)']: 

772 """ 

773 Return a mask indicating the leaf nodes that can be proposed for growth. 

774 

775 The condition is that a leaf is not at the bottom level, has available 

776 decision rules given its ancestors, and has at least 

777 `min_points_per_decision_node` points. 

778 

779 Parameters 

780 ---------- 

781 split_tree 

782 The splitting points of the tree. 

783 affluence_tree 

784 Marks leaves that can be grown. 

785 

786 Returns 

787 ------- 

788 The mask indicating the leaf nodes that can be proposed to grow. 

789 

790 Notes 

791 ----- 

792 This function needs `split_tree` and not just `affluence_tree` because 

793 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`. 

794 """ 

795 return grove.is_actual_leaf(split_tree) & affluence_tree 1ab

796 

797 

798def categorical( 1ab

799 key: Key[Array, ''], distr: Float32[Array, ' n'] 

800) -> tuple[Int32[Array, ''], Float32[Array, '']]: 

801 """ 

802 Return a random integer from an arbitrary distribution. 

803 

804 Parameters 

805 ---------- 

806 key 

807 A jax random key. 

808 distr 

809 An unnormalized probability distribution. 

810 

811 Returns 

812 ------- 

813 u : Int32[Array, ''] 

814 A random integer in the range ``[0, n)``. If all probabilities are zero, 

815 return ``n``. 

816 norm : Float32[Array, ''] 

817 The sum of `distr`. 

818 """ 

819 ecdf = jnp.cumsum(distr) 1ab

820 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab

821 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab

822 

823 

824def choose_variable( 1ab

825 key: Key[Array, ''], 

826 var_tree: UInt[Array, ' 2**(d-1)'], 

827 split_tree: UInt[Array, ' 2**(d-1)'], 

828 max_split: UInt[Array, ' p'], 

829 leaf_index: Int32[Array, ''], 

830 blocked_vars: Int32[Array, ' k'] | None, 

831) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

832 """ 

833 Choose a variable to split on for a new non-terminal node. 

834 

835 Parameters 

836 ---------- 

837 key 

838 A jax random key. 

839 var_tree 

840 The variable indices of the tree. 

841 split_tree 

842 The splitting points of the tree. 

843 max_split 

844 The maximum split index for each variable. 

845 leaf_index 

846 The index of the leaf to grow. 

847 blocked_vars 

848 The indices of the variables that have no available cutpoints. If 

849 `None`, all variables are assumed unblocked. 

850 

851 Returns 

852 ------- 

853 var : Int32[Array, ''] 

854 The index of the variable to split on. 

855 num_available_var : Int32[Array, ''] 

856 The number of variables with available decision rules `var` was chosen 

857 from. 

858 """ 

859 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab

860 if blocked_vars is not None: 1ab

861 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab

862 return randint_exclude(key, max_split.size, var_to_ignore) 1ab

863 

864 

865def fully_used_variables( 1ab

866 var_tree: UInt[Array, ' 2**(d-1)'], 

867 split_tree: UInt[Array, ' 2**(d-1)'], 

868 max_split: UInt[Array, ' p'], 

869 leaf_index: Int32[Array, ''], 

870) -> UInt[Array, ' d-2']: 

871 """ 

872 Find variables in the ancestors of a node that have an empty split range. 

873 

874 Parameters 

875 ---------- 

876 var_tree 

877 The variable indices of the tree. 

878 split_tree 

879 The splitting points of the tree. 

880 max_split 

881 The maximum split index for each variable. 

882 leaf_index 

883 The index of the node, assumed to be valid for `var_tree`. 

884 

885 Returns 

886 ------- 

887 The indices of the variables that have an empty split range. 

888 

889 Notes 

890 ----- 

891 The number of unused variables is not known in advance. Unused values in the 

892 array are filled with `p`. The fill values are not guaranteed to be placed 

893 in any particular order, and variables may appear more than once. 

894 """ 

895 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab

896 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab

897 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab

898 num_split = r - l 1ab

899 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab

900 # the type of var_to_ignore is already sufficient to hold max_split.size, 

901 # see ancestor_variables() 

902 

903 

904def ancestor_variables( 1ab

905 var_tree: UInt[Array, ' 2**(d-1)'], 

906 max_split: UInt[Array, ' p'], 

907 node_index: Int32[Array, ''], 

908) -> UInt[Array, ' d-2']: 

909 """ 

910 Return the list of variables in the ancestors of a node. 

911 

912 Parameters 

913 ---------- 

914 var_tree 

915 The variable indices of the tree. 

916 max_split 

917 The maximum split index for each variable. Used only to get `p`. 

918 node_index 

919 The index of the node, assumed to be valid for `var_tree`. 

920 

921 Returns 

922 ------- 

923 The variable indices of the ancestors of the node. 

924 

925 Notes 

926 ----- 

927 The ancestors are the nodes going from the root to the parent of the node. 

928 The number of ancestors is not known at tracing time; unused spots in the 

929 output array are filled with `p`. 

930 """ 

931 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

932 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab

933 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab

934 

935 def loop(carry, _): 1ab

936 i, index, ancestor_vars = carry 1ab

937 index >>= 1 1ab

938 var = var_tree[index] 1ab

939 var = jnp.where(index, var, max_split.size) 1ab

940 ancestor_vars = ancestor_vars.at[i].set(var) 1ab

941 return (i - 1, index, ancestor_vars), None 1ab

942 

943 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab

944 return ancestor_vars 1ab

945 

946 

947def split_range( 1ab

948 var_tree: UInt[Array, ' 2**(d-1)'], 

949 split_tree: UInt[Array, ' 2**(d-1)'], 

950 max_split: UInt[Array, ' p'], 

951 node_index: Int32[Array, ''], 

952 ref_var: Int32[Array, ''], 

953) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

954 """ 

955 Return the range of allowed splits for a variable at a given node. 

956 

957 Parameters 

958 ---------- 

959 var_tree 

960 The variable indices of the tree. 

961 split_tree 

962 The splitting points of the tree. 

963 max_split 

964 The maximum split index for each variable. 

965 node_index 

966 The index of the node, assumed to be valid for `var_tree`. 

967 ref_var 

968 The variable for which to measure the split range. 

969 

970 Returns 

971 ------- 

972 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1. 

973 """ 

974 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

975 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab

976 jnp.int32 

977 ) 

978 carry = jnp.int32(0), initial_r, node_index 1ab

979 

980 def loop(carry, _): 1ab

981 l, r, index = carry 1ab

982 right_child = (index & 1).astype(bool) 1ab

983 index >>= 1 1ab

984 split = split_tree[index] 1ab

985 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab

986 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab

987 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab

988 return (l, r, index), None 1ab

989 

990 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab

991 return l + 1, r 1ab

992 

993 

994def randint_exclude( 1ab

995 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n'] 

996) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

997 """ 

998 Return a random integer in a range, excluding some values. 

999 

1000 Parameters 

1001 ---------- 

1002 key 

1003 A jax random key. 

1004 sup 

1005 The exclusive upper bound of the range. 

1006 exclude 

1007 The values to exclude from the range. Values greater than or equal to 

1008 `sup` are ignored. Values can appear more than once. 

1009 

1010 Returns 

1011 ------- 

1012 u : Int32[Array, ''] 

1013 A random integer `u` in the range ``[0, sup)`` such that ``u not in 

1014 exclude``. 

1015 num_allowed : Int32[Array, ''] 

1016 The number of integers in the range that were not excluded. 

1017 

1018 Notes 

1019 ----- 

1020 If all values in the range are excluded, return `sup`. 

1021 """ 

1022 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab

1023 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab

1024 u = random.randint(key, (), 0, num_allowed) 1ab

1025 

1026 def loop(u, i_excluded): 1ab

1027 return jnp.where(i_excluded <= u, u + 1, u), None 1ab

1028 

1029 u, _ = lax.scan(loop, u, exclude) 1ab

1030 return u, num_allowed 1ab

1031 

1032 

1033def choose_split( 1ab

1034 key: Key[Array, ''], 

1035 var: Int32[Array, ''], 

1036 var_tree: UInt[Array, ' 2**(d-1)'], 

1037 split_tree: UInt[Array, ' 2**(d-1)'], 

1038 max_split: UInt[Array, ' p'], 

1039 leaf_index: Int32[Array, ''], 

1040) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]: 

1041 """ 

1042 Choose a split point for a new non-terminal node. 

1043 

1044 Parameters 

1045 ---------- 

1046 key 

1047 A jax random key. 

1048 var 

1049 The variable to split on. 

1050 var_tree 

1051 The splitting axes of the tree. Does not need to already contain `var` 

1052 at `leaf_index`. 

1053 split_tree 

1054 The splitting points of the tree. 

1055 max_split 

1056 The maximum split index for each variable. 

1057 leaf_index 

1058 The index of the leaf to grow. 

1059 

1060 Returns 

1061 ------- 

1062 split : Int32[Array, ''] 

1063 The cutpoint. 

1064 l : Int32[Array, ''] 

1065 r : Int32[Array, ''] 

1066 The integer range `split` was drawn from is [l, r). 

1067 

1068 Notes 

1069 ----- 

1070 If `var` is out of bounds, or if the available split range on that variable 

1071 is empty, return 0. 

1072 """ 

1073 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab

1074 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1ab

1075 

1076 

1077def compute_partial_ratio( 1ab

1078 prob_choose: Float32[Array, ''], 

1079 num_prunable: Int32[Array, ''], 

1080 p_nonterminal: Float32[Array, ' 2**d'], 

1081 leaf_to_grow: Int32[Array, ''], 

1082) -> Float32[Array, '']: 

1083 """ 

1084 Compute the product of the transition and prior ratios of a grow move. 

1085 

1086 Parameters 

1087 ---------- 

1088 prob_choose 

1089 The probability that the leaf had to be chosen amongst the growable 

1090 leaves. 

1091 num_prunable 

1092 The number of leaf parents that could be pruned, after converting the 

1093 leaf to be grown to a non-terminal node. 

1094 p_nonterminal 

1095 The a priori probability of each node being nonterminal conditional on 

1096 its ancestors. 

1097 leaf_to_grow 

1098 The index of the leaf to grow. 

1099 

1100 Returns 

1101 ------- 

1102 The partial transition ratio times the prior ratio. 

1103 

1104 Notes 

1105 ----- 

1106 The transition ratio is P(new tree => old tree) / P(old tree => new tree). 

1107 The "partial" transition ratio returned is missing the factor P(propose 

1108 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The 

1109 "partial" prior ratio is missing the factor P(children are leaves). 

1110 """ 

1111 # the two ratios also contain factors num_available_split * 

1112 # num_available_var, but they cancel out 

1113 

1114 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be 

1115 # computed here because they need the count trees, which are computed in the 

1116 # acceptance phase 

1117 

1118 prune_allowed = leaf_to_grow != 1 1ab

1119 # prune allowed <---> the initial tree is not a root 

1120 # leaf to grow is root --> the tree can only be a root 

1121 # tree is a root --> the only leaf I can grow is root 

1122 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab

1123 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab

1124 

1125 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this 

1126 # would produce a 0 and then an inf when `complete_ratio` takes the log 

1127 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1ab

1128 tree_ratio = pnt / (1 - pnt) 1ab

1129 

1130 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1ab

1131 

1132 

1133class PruneMoves(Module): 1ab

1134 """ 

1135 Represent a proposed prune move for each tree. 

1136 

1137 Parameters 

1138 ---------- 

1139 allowed 

1140 Whether the move is possible. 

1141 node 

1142 The index of the node to prune. ``2 ** d`` if no node can be pruned. 

1143 partial_ratio 

1144 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

1145 likelihood ratio, the probability of proposing the prune move, and the 

1146 prior probability that the children of the node to prune are leaves. 

1147 This ratio is inverted, and is meant to be inverted back in 

1148 `accept_move_and_sample_leaves`. 

1149 """ 

1150 

1151 allowed: Bool[Array, ' num_trees'] 1ab

1152 node: UInt[Array, ' num_trees'] 1ab

1153 partial_ratio: Float32[Array, ' num_trees'] 1ab

1154 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

1155 

1156 

1157@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab

1158def propose_prune_moves( 1ab

1159 key: Key[Array, ''], 

1160 split_tree: UInt[Array, ' 2**(d-1)'], 

1161 affluence_tree: Bool[Array, ' 2**(d-1)'], 

1162 p_nonterminal: Float32[Array, ' 2**d'], 

1163 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

1164) -> PruneMoves: 

1165 """ 

1166 Tree structure prune move proposal of BART MCMC. 

1167 

1168 Parameters 

1169 ---------- 

1170 key 

1171 A jax random key. 

1172 split_tree 

1173 The splitting points of the tree. 

1174 affluence_tree 

1175 Whether each leaf can be grown. 

1176 p_nonterminal 

1177 The a priori probability of a node to be nonterminal conditional on 

1178 the ancestors, including at the maximum depth where it should be zero. 

1179 p_propose_grow 

1180 The unnormalized probability of choosing a leaf to grow. 

1181 

1182 Returns 

1183 ------- 

1184 An object representing the proposed moves. 

1185 """ 

1186 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1ab

1187 key, split_tree, affluence_tree, p_propose_grow 

1188 ) 

1189 

1190 ratio = compute_partial_ratio( 1ab

1191 prob_choose, num_prunable, p_nonterminal, node_to_prune 

1192 ) 

1193 

1194 return PruneMoves( 1ab

1195 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root 

1196 node=node_to_prune, 

1197 partial_ratio=ratio, 

1198 affluence_tree=affluence_tree, 

1199 ) 

1200 

1201 

1202def choose_leaf_parent( 1ab

1203 key: Key[Array, ''], 

1204 split_tree: UInt[Array, ' 2**(d-1)'], 

1205 affluence_tree: Bool[Array, ' 2**(d-1)'], 

1206 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

1207) -> tuple[ 

1208 Int32[Array, ''], 

1209 Int32[Array, ''], 

1210 Float32[Array, ''], 

1211 Bool[Array, 'num_trees 2**(d-1)'], 

1212]: 

1213 """ 

1214 Pick a non-terminal node with leaf children to prune in a tree. 

1215 

1216 Parameters 

1217 ---------- 

1218 key 

1219 A jax random key. 

1220 split_tree 

1221 The splitting points of the tree. 

1222 affluence_tree 

1223 Whether a leaf has enough points to be grown. 

1224 p_propose_grow 

1225 The unnormalized probability of choosing a leaf to grow. 

1226 

1227 Returns 

1228 ------- 

1229 node_to_prune : Int32[Array, ''] 

1230 The index of the node to prune. If ``num_prunable == 0``, return 

1231 ``2 ** d``. 

1232 num_prunable : Int32[Array, ''] 

1233 The number of leaf parents that could be pruned. 

1234 prob_choose : Float32[Array, ''] 

1235 The (normalized) probability that `choose_leaf` would chose 

1236 `node_to_prune` as leaf to grow, if passed the tree where 

1237 `node_to_prune` had been pruned. 

1238 affluence_tree : Bool[Array, 'num_trees 2**(d-1)'] 

1239 A partially updated `affluence_tree`, marking the node to prune as 

1240 growable. 

1241 """ 

1242 # sample a node to prune 

1243 is_prunable = grove.is_leaves_parent(split_tree) 1ab

1244 num_prunable = jnp.count_nonzero(is_prunable) 1ab

1245 node_to_prune = randint_masked(key, is_prunable) 1ab

1246 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab

1247 

1248 # compute stuff for reverse move 

1249 split_tree = split_tree.at[node_to_prune].set(0) 1ab

1250 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab

1251 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab

1252 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab

1253 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1ab

1254 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1ab

1255 

1256 return node_to_prune, num_prunable, prob_choose, affluence_tree 1ab

1257 

1258 

1259def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 1ab

1260 """ 

1261 Return a random integer in a range, including only some values. 

1262 

1263 Parameters 

1264 ---------- 

1265 key 

1266 A jax random key. 

1267 mask 

1268 The mask indicating the allowed values. 

1269 

1270 Returns 

1271 ------- 

1272 A random integer in the range ``[0, n)`` such that ``mask[u] == True``. 

1273 

1274 Notes 

1275 ----- 

1276 If all values in the mask are `False`, return `n`. 

1277 """ 

1278 ecdf = jnp.cumsum(mask) 1ab

1279 u = random.randint(key, (), 0, ecdf[-1]) 1ab

1280 return jnp.searchsorted(ecdf, u, 'right') 1ab

1281 

1282 

1283def accept_moves_and_sample_leaves( 1ab

1284 key: Key[Array, ''], bart: State, moves: Moves 

1285) -> State: 

1286 """ 

1287 Accept or reject the proposed moves and sample the new leaf values. 

1288 

1289 Parameters 

1290 ---------- 

1291 key 

1292 A jax random key. 

1293 bart 

1294 A valid BART mcmc state. 

1295 moves 

1296 The proposed moves, see `propose_moves`. 

1297 

1298 Returns 

1299 ------- 

1300 A new (valid) BART mcmc state. 

1301 """ 

1302 pso = accept_moves_parallel_stage(key, bart, moves) 1ab

1303 bart, moves = accept_moves_sequential_stage(pso) 1ab

1304 return accept_moves_final_stage(bart, moves) 1ab

1305 

1306 

1307class Counts(Module): 1ab

1308 """ 

1309 Number of datapoints in the nodes involved in proposed moves for each tree. 

1310 

1311 Parameters 

1312 ---------- 

1313 left 

1314 Number of datapoints in the left child. 

1315 right 

1316 Number of datapoints in the right child. 

1317 total 

1318 Number of datapoints in the parent (``= left + right``). 

1319 """ 

1320 

1321 left: UInt[Array, ' num_trees'] 1ab

1322 right: UInt[Array, ' num_trees'] 1ab

1323 total: UInt[Array, ' num_trees'] 1ab

1324 

1325 

1326class Precs(Module): 1ab

1327 """ 

1328 Likelihood precision scale in the nodes involved in proposed moves for each tree. 

1329 

1330 The "likelihood precision scale" of a tree node is the sum of the inverse 

1331 squared error scales of the datapoints selected by the node. 

1332 

1333 Parameters 

1334 ---------- 

1335 left 

1336 Likelihood precision scale in the left child. 

1337 right 

1338 Likelihood precision scale in the right child. 

1339 total 

1340 Likelihood precision scale in the parent (``= left + right``). 

1341 """ 

1342 

1343 left: Float32[Array, ' num_trees'] 1ab

1344 right: Float32[Array, ' num_trees'] 1ab

1345 total: Float32[Array, ' num_trees'] 1ab

1346 

1347 

1348class PreLkV(Module): 1ab

1349 """ 

1350 Non-sequential terms of the likelihood ratio for each tree. 

1351 

1352 These terms can be computed in parallel across trees. 

1353 

1354 Parameters 

1355 ---------- 

1356 sigma2_left 

1357 The noise variance in the left child of the leaves grown or pruned by 

1358 the moves. 

1359 sigma2_right 

1360 The noise variance in the right child of the leaves grown or pruned by 

1361 the moves. 

1362 sigma2_total 

1363 The noise variance in the total of the leaves grown or pruned by the 

1364 moves. 

1365 sqrt_term 

1366 The **logarithm** of the square root term of the likelihood ratio. 

1367 """ 

1368 

1369 sigma2_left: Float32[Array, ' num_trees'] 1ab

1370 sigma2_right: Float32[Array, ' num_trees'] 1ab

1371 sigma2_total: Float32[Array, ' num_trees'] 1ab

1372 sqrt_term: Float32[Array, ' num_trees'] 1ab

1373 

1374 

1375class PreLk(Module): 1ab

1376 """ 

1377 Non-sequential terms of the likelihood ratio shared by all trees. 

1378 

1379 Parameters 

1380 ---------- 

1381 exp_factor 

1382 The factor to multiply the likelihood ratio by, shared by all trees. 

1383 """ 

1384 

1385 exp_factor: Float32[Array, ''] 1ab

1386 

1387 

1388class PreLf(Module): 1ab

1389 """ 

1390 Pre-computed terms used to sample leaves from their posterior. 

1391 

1392 These terms can be computed in parallel across trees. 

1393 

1394 Parameters 

1395 ---------- 

1396 mean_factor 

1397 The factor to be multiplied by the sum of the scaled residuals to 

1398 obtain the posterior mean. 

1399 centered_leaves 

1400 The mean-zero normal values to be added to the posterior mean to 

1401 obtain the posterior leaf samples. 

1402 """ 

1403 

1404 mean_factor: Float32[Array, 'num_trees 2**d'] 1ab

1405 centered_leaves: Float32[Array, 'num_trees 2**d'] 1ab

1406 

1407 

1408class ParallelStageOut(Module): 1ab

1409 """ 

1410 The output of `accept_moves_parallel_stage`. 

1411 

1412 Parameters 

1413 ---------- 

1414 bart 

1415 A partially updated BART mcmc state. 

1416 moves 

1417 The proposed moves, with `partial_ratio` set to `None` and 

1418 `log_trans_prior_ratio` set to its final value. 

1419 prec_trees 

1420 The likelihood precision scale in each potential or actual leaf node. If 

1421 there is no precision scale, this is the number of points in each leaf. 

1422 move_counts 

1423 The counts of the number of points in the the nodes modified by the 

1424 moves. If `bart.min_points_per_leaf` is not set and 

1425 `bart.prec_scale` is set, they are not computed. 

1426 move_precs 

1427 The likelihood precision scale in each node modified by the moves. If 

1428 `bart.prec_scale` is not set, this is set to `move_counts`. 

1429 prelkv 

1430 prelk 

1431 prelf 

1432 Objects with pre-computed terms of the likelihood ratios and leaf 

1433 samples. 

1434 """ 

1435 

1436 bart: State 1ab

1437 moves: Moves 1ab

1438 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab

1439 move_precs: Precs | Counts 1ab

1440 prelkv: PreLkV 1ab

1441 prelk: PreLk 1ab

1442 prelf: PreLf 1ab

1443 

1444 

1445def accept_moves_parallel_stage( 1ab

1446 key: Key[Array, ''], bart: State, moves: Moves 

1447) -> ParallelStageOut: 

1448 """ 

1449 Pre-compute quantities used to accept moves, in parallel across trees. 

1450 

1451 Parameters 

1452 ---------- 

1453 key : jax.dtypes.prng_key array 

1454 A jax random key. 

1455 bart : dict 

1456 A BART mcmc state. 

1457 moves : dict 

1458 The proposed moves, see `propose_moves`. 

1459 

1460 Returns 

1461 ------- 

1462 An object with all that could be done in parallel. 

1463 """ 

1464 # where the move is grow, modify the state like the move was accepted 

1465 bart = replace( 1ab

1466 bart, 

1467 forest=replace( 

1468 bart.forest, 

1469 var_tree=moves.var_tree, 

1470 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X), 

1471 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves), 

1472 ), 

1473 ) 

1474 

1475 # count number of datapoints per leaf 

1476 if ( 

1477 bart.forest.min_points_per_decision_node is not None 

1478 or bart.forest.min_points_per_leaf is not None 

1479 or bart.prec_scale is None 

1480 ): 

1481 count_trees, move_counts = compute_count_trees( 1ab

1482 bart.forest.leaf_indices, moves, bart.forest.count_batch_size 

1483 ) 

1484 

1485 # mark which leaves & potential leaves have enough points to be grown 

1486 if bart.forest.min_points_per_decision_node is not None: 1ab

1487 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab

1488 moves = replace( 1ab

1489 moves, 

1490 affluence_tree=moves.affluence_tree 

1491 & (count_half_trees >= bart.forest.min_points_per_decision_node), 

1492 ) 

1493 

1494 # copy updated affluence_tree to state 

1495 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1ab

1496 

1497 # veto grove move if new leaves don't have enough datapoints 

1498 if bart.forest.min_points_per_leaf is not None: 1ab

1499 moves = replace( 1ab

1500 moves, 

1501 allowed=moves.allowed 

1502 & (move_counts.left >= bart.forest.min_points_per_leaf) 

1503 & (move_counts.right >= bart.forest.min_points_per_leaf), 

1504 ) 

1505 

1506 # count number of datapoints per leaf, weighted by error precision scale 

1507 if bart.prec_scale is None: 1ab

1508 prec_trees = count_trees 1ab

1509 move_precs = move_counts 1ab

1510 else: 

1511 prec_trees, move_precs = compute_prec_trees( 1ab

1512 bart.prec_scale, 

1513 bart.forest.leaf_indices, 

1514 moves, 

1515 bart.forest.count_batch_size, 

1516 ) 

1517 assert move_precs is not None 1ab

1518 

1519 # compute some missing information about moves 

1520 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab

1521 save_ratios = bart.forest.log_likelihood is not None 1ab

1522 bart = replace( 1ab

1523 bart, 

1524 forest=replace( 

1525 bart.forest, 

1526 grow_prop_count=jnp.sum(moves.grow), 

1527 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

1528 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None, 

1529 ), 

1530 ) 

1531 

1532 # pre-compute some likelihood ratio & posterior terms 

1533 assert bart.sigma2 is not None # `step` shall temporarily set it to 1 1ab

1534 prelkv, prelk = precompute_likelihood_terms( 1ab

1535 bart.sigma2, bart.forest.sigma_mu2, move_precs 

1536 ) 

1537 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab

1538 

1539 return ParallelStageOut( 1ab

1540 bart=bart, 

1541 moves=moves, 

1542 prec_trees=prec_trees, 

1543 move_precs=move_precs, 

1544 prelkv=prelkv, 

1545 prelk=prelk, 

1546 prelf=prelf, 

1547 ) 

1548 

1549 

1550@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab

1551def apply_grow_to_indices( 1ab

1552 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

1553) -> UInt[Array, 'num_trees n']: 

1554 """ 

1555 Update the leaf indices to apply a grow move. 

1556 

1557 Parameters 

1558 ---------- 

1559 moves 

1560 The proposed moves, see `propose_moves`. 

1561 leaf_indices 

1562 The index of the leaf each datapoint falls into. 

1563 X 

1564 The predictors matrix. 

1565 

1566 Returns 

1567 ------- 

1568 The updated leaf indices. 

1569 """ 

1570 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab

1571 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab

1572 tree_size = jnp.array(2 * moves.var_tree.size) 1ab

1573 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab

1574 return jnp.where( 1ab

1575 leaf_indices == node_to_update, left_child + go_right, leaf_indices 

1576 ) 

1577 

1578 

1579def compute_count_trees( 1ab

1580 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None 

1581) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]: 

1582 """ 

1583 Count the number of datapoints in each leaf. 

1584 

1585 Parameters 

1586 ---------- 

1587 leaf_indices 

1588 The index of the leaf each datapoint falls into, with the deeper version 

1589 of the tree (post-GROW, pre-PRUNE). 

1590 moves 

1591 The proposed moves, see `propose_moves`. 

1592 batch_size 

1593 The data batch size to use for the summation. 

1594 

1595 Returns 

1596 ------- 

1597 count_trees : Int32[Array, 'num_trees 2**d'] 

1598 The number of points in each potential or actual leaf node. 

1599 counts : Counts 

1600 The counts of the number of points in the leaves grown or pruned by the 

1601 moves. 

1602 """ 

1603 num_trees, tree_size = moves.var_tree.shape 1ab

1604 tree_size *= 2 1ab

1605 tree_indices = jnp.arange(num_trees) 1ab

1606 

1607 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab

1608 

1609 # count datapoints in nodes modified by move 

1610 left = count_trees[tree_indices, moves.left] 1ab

1611 right = count_trees[tree_indices, moves.right] 1ab

1612 counts = Counts(left=left, right=right, total=left + right) 1ab

1613 

1614 # write count into non-leaf node 

1615 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab

1616 

1617 return count_trees, counts 1ab

1618 

1619 

1620def count_datapoints_per_leaf( 1ab

1621 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None 

1622) -> Int32[Array, 'num_trees 2**(d-1)']: 

1623 """ 

1624 Count the number of datapoints in each leaf. 

1625 

1626 Parameters 

1627 ---------- 

1628 leaf_indices 

1629 The index of the leaf each datapoint falls into. 

1630 tree_size 

1631 The size of the leaf tree array (2 ** d). 

1632 batch_size 

1633 The data batch size to use for the summation. 

1634 

1635 Returns 

1636 ------- 

1637 The number of points in each leaf node. 

1638 """ 

1639 if batch_size is None: 1ab

1640 return _count_scan(leaf_indices, tree_size) 1ab

1641 else: 

1642 return _count_vec(leaf_indices, tree_size, batch_size) 1ab

1643 

1644 

1645def _count_scan( 1ab

1646 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int 

1647) -> Int32[Array, 'num_trees {tree_size}']: 

1648 def loop(_, leaf_indices): 1ab

1649 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab

1650 

1651 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab

1652 return count_trees 1ab

1653 

1654 

1655def _aggregate_scatter( 1ab

1656 values: Shaped[Array, '*'], 

1657 indices: Integer[Array, '*'], 

1658 size: int, 

1659 dtype: jnp.dtype, 

1660) -> Shaped[Array, ' {size}']: 

1661 return jnp.zeros(size, dtype).at[indices].add(values) 1ab

1662 

1663 

1664def _count_vec( 1ab

1665 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int 

1666) -> Int32[Array, 'num_trees 2**(d-1)']: 

1667 return _aggregate_batched_alltrees( 1ab

1668 1, leaf_indices, tree_size, jnp.uint32, batch_size 

1669 ) 

1670 # uint16 is super-slow on gpu, don't use it even if n < 2^16 

1671 

1672 

1673def _aggregate_batched_alltrees( 1ab

1674 values: Shaped[Array, '*'], 

1675 indices: UInt[Array, 'num_trees n'], 

1676 size: int, 

1677 dtype: jnp.dtype, 

1678 batch_size: int, 

1679) -> Shaped[Array, 'num_trees {size}']: 

1680 num_trees, n = indices.shape 1ab

1681 tree_indices = jnp.arange(num_trees) 1ab

1682 nbatches = n // batch_size + bool(n % batch_size) 1ab

1683 batch_indices = jnp.arange(n) % nbatches 1ab

1684 return ( 1ab

1685 jnp.zeros((num_trees, size, nbatches), dtype) 

1686 .at[tree_indices[:, None], indices, batch_indices] 

1687 .add(values) 

1688 .sum(axis=2) 

1689 ) 

1690 

1691 

1692def compute_prec_trees( 1ab

1693 prec_scale: Float32[Array, ' n'], 

1694 leaf_indices: UInt[Array, 'num_trees n'], 

1695 moves: Moves, 

1696 batch_size: int | None, 

1697) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]: 

1698 """ 

1699 Compute the likelihood precision scale in each leaf. 

1700 

1701 Parameters 

1702 ---------- 

1703 prec_scale 

1704 The scale of the precision of the error on each datapoint. 

1705 leaf_indices 

1706 The index of the leaf each datapoint falls into, with the deeper version 

1707 of the tree (post-GROW, pre-PRUNE). 

1708 moves 

1709 The proposed moves, see `propose_moves`. 

1710 batch_size 

1711 The data batch size to use for the summation. 

1712 

1713 Returns 

1714 ------- 

1715 prec_trees : Float32[Array, 'num_trees 2**d'] 

1716 The likelihood precision scale in each potential or actual leaf node. 

1717 precs : Precs 

1718 The likelihood precision scale in the nodes involved in the moves. 

1719 """ 

1720 num_trees, tree_size = moves.var_tree.shape 1ab

1721 tree_size *= 2 1ab

1722 tree_indices = jnp.arange(num_trees) 1ab

1723 

1724 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab

1725 

1726 # prec datapoints in nodes modified by move 

1727 left = prec_trees[tree_indices, moves.left] 1ab

1728 right = prec_trees[tree_indices, moves.right] 1ab

1729 precs = Precs(left=left, right=right, total=left + right) 1ab

1730 

1731 # write prec into non-leaf node 

1732 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab

1733 

1734 return prec_trees, precs 1ab

1735 

1736 

1737def prec_per_leaf( 1ab

1738 prec_scale: Float32[Array, ' n'], 

1739 leaf_indices: UInt[Array, 'num_trees n'], 

1740 tree_size: int, 

1741 batch_size: int | None, 

1742) -> Float32[Array, 'num_trees {tree_size}']: 

1743 """ 

1744 Compute the likelihood precision scale in each leaf. 

1745 

1746 Parameters 

1747 ---------- 

1748 prec_scale 

1749 The scale of the precision of the error on each datapoint. 

1750 leaf_indices 

1751 The index of the leaf each datapoint falls into. 

1752 tree_size 

1753 The size of the leaf tree array (2 ** d). 

1754 batch_size 

1755 The data batch size to use for the summation. 

1756 

1757 Returns 

1758 ------- 

1759 The likelihood precision scale in each leaf node. 

1760 """ 

1761 if batch_size is None: 1761 ↛ 1762line 1761 didn't jump to line 1762 because the condition on line 1761 was never true1ab

1762 return _prec_scan(prec_scale, leaf_indices, tree_size) 

1763 else: 

1764 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size) 1ab

1765 

1766 

1767def _prec_scan( 1ab

1768 prec_scale: Float32[Array, ' n'], 

1769 leaf_indices: UInt[Array, 'num_trees n'], 

1770 tree_size: int, 

1771) -> Float32[Array, 'num_trees {tree_size}']: 

1772 def loop(_, leaf_indices): 

1773 return None, _aggregate_scatter( 

1774 prec_scale, leaf_indices, tree_size, jnp.float32 

1775 ) 

1776 

1777 _, prec_trees = lax.scan(loop, None, leaf_indices) 

1778 return prec_trees 

1779 

1780 

1781def _prec_vec( 1ab

1782 prec_scale: Float32[Array, ' n'], 

1783 leaf_indices: UInt[Array, 'num_trees n'], 

1784 tree_size: int, 

1785 batch_size: int, 

1786) -> Float32[Array, 'num_trees {tree_size}']: 

1787 return _aggregate_batched_alltrees( 1ab

1788 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size 

1789 ) 

1790 

1791 

1792def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 1ab

1793 """ 

1794 Complete non-likelihood MH ratio calculation. 

1795 

1796 This function adds the probability of choosing a prune move over the grow 

1797 move in the inverse transition, and the a priori probability that the 

1798 children nodes are leaves. 

1799 

1800 Parameters 

1801 ---------- 

1802 moves 

1803 The proposed moves. Must have already been updated to keep into account 

1804 the thresholds on the number of datapoints per node, this happens in 

1805 `accept_moves_parallel_stage`. 

1806 p_nonterminal 

1807 The a priori probability of each node being nonterminal conditional on 

1808 its ancestors, including at the maximum depth where it should be zero. 

1809 

1810 Returns 

1811 ------- 

1812 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

1813 """ 

1814 # can the leaves can be grown? 

1815 num_trees, _ = moves.affluence_tree.shape 1ab

1816 tree_indices = jnp.arange(num_trees) 1ab

1817 left_growable = moves.affluence_tree.at[tree_indices, moves.left].get( 1ab

1818 mode='fill', fill_value=False 

1819 ) 

1820 right_growable = moves.affluence_tree.at[tree_indices, moves.right].get( 1ab

1821 mode='fill', fill_value=False 

1822 ) 

1823 

1824 # p_prune if grow 

1825 other_growable_leaves = moves.num_growable >= 2 1ab

1826 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab

1827 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab

1828 

1829 # p_prune if prune 

1830 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab

1831 

1832 # select p_prune 

1833 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab

1834 

1835 # prior probability of both children being terminal 

1836 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab

1837 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab

1838 pt_children = pt_left * pt_right 1ab

1839 

1840 return replace( 1ab

1841 moves, 

1842 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune), 

1843 partial_ratio=None, 

1844 ) 

1845 

1846 

1847@vmap_nodoc 1ab

1848def adapt_leaf_trees_to_grow_indices( 1ab

1849 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves 

1850) -> Float32[Array, 'num_trees 2**d']: 

1851 """ 

1852 Modify leaves such that post-grow indices work on the original tree. 

1853 

1854 The value of the leaf to grow is copied to what would be its children if the 

1855 grow move was accepted. 

1856 

1857 Parameters 

1858 ---------- 

1859 leaf_trees 

1860 The leaf values. 

1861 moves 

1862 The proposed moves, see `propose_moves`. 

1863 

1864 Returns 

1865 ------- 

1866 The modified leaf values. 

1867 """ 

1868 values_at_node = leaf_trees[moves.node] 1ab

1869 return ( 1ab

1870 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)] 

1871 .set(values_at_node) 

1872 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)] 

1873 .set(values_at_node) 

1874 ) 

1875 

1876 

1877def precompute_likelihood_terms( 1ab

1878 sigma2: Float32[Array, ''], 

1879 sigma_mu2: Float32[Array, ''], 

1880 move_precs: Precs | Counts, 

1881) -> tuple[PreLkV, PreLk]: 

1882 """ 

1883 Pre-compute terms used in the likelihood ratio of the acceptance step. 

1884 

1885 Parameters 

1886 ---------- 

1887 sigma2 

1888 The error variance, or the global error variance factor is `prec_scale` 

1889 is set. 

1890 sigma_mu2 

1891 The prior variance of each leaf. 

1892 move_precs 

1893 The likelihood precision scale in the leaves grown or pruned by the 

1894 moves, under keys 'left', 'right', and 'total' (left + right). 

1895 

1896 Returns 

1897 ------- 

1898 prelkv : PreLkV 

1899 Dictionary with pre-computed terms of the likelihood ratio, one per 

1900 tree. 

1901 prelk : PreLk 

1902 Dictionary with pre-computed terms of the likelihood ratio, shared by 

1903 all trees. 

1904 """ 

1905 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab

1906 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab

1907 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab

1908 prelkv = PreLkV( 1ab

1909 sigma2_left=sigma2_left, 

1910 sigma2_right=sigma2_right, 

1911 sigma2_total=sigma2_total, 

1912 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2, 

1913 ) 

1914 return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2)) 1ab

1915 

1916 

1917def precompute_leaf_terms( 1ab

1918 key: Key[Array, ''], 

1919 prec_trees: Float32[Array, 'num_trees 2**d'], 

1920 sigma2: Float32[Array, ''], 

1921 sigma_mu2: Float32[Array, ''], 

1922) -> PreLf: 

1923 """ 

1924 Pre-compute terms used to sample leaves from their posterior. 

1925 

1926 Parameters 

1927 ---------- 

1928 key 

1929 A jax random key. 

1930 prec_trees 

1931 The likelihood precision scale in each potential or actual leaf node. 

1932 sigma2 

1933 The error variance, or the global error variance factor if `prec_scale` 

1934 is set. 

1935 sigma_mu2 

1936 The prior variance of each leaf. 

1937 

1938 Returns 

1939 ------- 

1940 Pre-computed terms for leaf sampling. 

1941 """ 

1942 prec_lk = prec_trees / sigma2 1ab

1943 prec_prior = lax.reciprocal(sigma_mu2) 1ab

1944 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab

1945 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab

1946 return PreLf( 1ab

1947 mean_factor=var_post / sigma2, 

1948 # | mean = mean_lk * prec_lk * var_post 

1949 # | resid_tree = mean_lk * prec_tree --> 

1950 # | --> mean_lk = resid_tree / prec_tree (kind of) 

1951 # | mean_factor = 

1952 # | = mean / resid_tree = 

1953 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

1954 # | = 1 / prec_tree * prec_tree / sigma2 * var_post = 

1955 # | = var_post / sigma2 

1956 centered_leaves=z * jnp.sqrt(var_post), 

1957 ) 

1958 

1959 

1960def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab

1961 """ 

1962 Accept/reject the moves one tree at a time. 

1963 

1964 This is the most performance-sensitive function because it contains all and 

1965 only the parts of the algorithm that can not be parallelized across trees. 

1966 

1967 Parameters 

1968 ---------- 

1969 pso 

1970 The output of `accept_moves_parallel_stage`. 

1971 

1972 Returns 

1973 ------- 

1974 bart : State 

1975 A partially updated BART mcmc state. 

1976 moves : Moves 

1977 The accepted/rejected moves, with `acc` and `to_prune` set. 

1978 """ 

1979 

1980 def loop(resid, pt): 1ab

1981 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab

1982 resid, 

1983 SeqStageInAllTrees( 

1984 pso.bart.X, 

1985 pso.bart.forest.resid_batch_size, 

1986 pso.bart.prec_scale, 

1987 pso.bart.forest.log_likelihood is not None, 

1988 pso.prelk, 

1989 ), 

1990 pt, 

1991 ) 

1992 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab

1993 

1994 pts = SeqStageInPerTree( 1ab

1995 pso.bart.forest.leaf_tree, 

1996 pso.prec_trees, 

1997 pso.moves, 

1998 pso.move_precs, 

1999 pso.bart.forest.leaf_indices, 

2000 pso.prelkv, 

2001 pso.prelf, 

2002 ) 

2003 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab

2004 

2005 bart = replace( 1ab

2006 pso.bart, 

2007 resid=resid, 

2008 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio), 

2009 ) 

2010 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab

2011 

2012 return bart, moves 1ab

2013 

2014 

2015class SeqStageInAllTrees(Module): 1ab

2016 """ 

2017 The inputs to `accept_move_and_sample_leaves` that are shared by all trees. 

2018 

2019 Parameters 

2020 ---------- 

2021 X 

2022 The predictors. 

2023 resid_batch_size 

2024 The batch size for computing the sum of residuals in each leaf. 

2025 prec_scale 

2026 The scale of the precision of the error on each datapoint. If None, it 

2027 is assumed to be 1. 

2028 save_ratios 

2029 Whether to save the acceptance ratios. 

2030 prelk 

2031 The pre-computed terms of the likelihood ratio which are shared across 

2032 trees. 

2033 """ 

2034 

2035 X: UInt[Array, 'p n'] 1ab

2036 resid_batch_size: int | None = field(static=True) 1ab

2037 prec_scale: Float32[Array, ' n'] | None 1ab

2038 save_ratios: bool = field(static=True) 1ab

2039 prelk: PreLk 1ab

2040 

2041 

2042class SeqStageInPerTree(Module): 1ab

2043 """ 

2044 The inputs to `accept_move_and_sample_leaves` that are separate for each tree. 

2045 

2046 Parameters 

2047 ---------- 

2048 leaf_tree 

2049 The leaf values of the tree. 

2050 prec_tree 

2051 The likelihood precision scale in each potential or actual leaf node. 

2052 move 

2053 The proposed move, see `propose_moves`. 

2054 move_precs 

2055 The likelihood precision scale in each node modified by the moves. 

2056 leaf_indices 

2057 The leaf indices for the largest version of the tree compatible with 

2058 the move. 

2059 prelkv 

2060 prelf 

2061 The pre-computed terms of the likelihood ratio and leaf sampling which 

2062 are specific to the tree. 

2063 """ 

2064 

2065 leaf_tree: Float32[Array, ' 2**d'] 1ab

2066 prec_tree: Float32[Array, ' 2**d'] 1ab

2067 move: Moves 1ab

2068 move_precs: Precs | Counts 1ab

2069 leaf_indices: UInt[Array, ' n'] 1ab

2070 prelkv: PreLkV 1ab

2071 prelf: PreLf 1ab

2072 

2073 

2074def accept_move_and_sample_leaves( 1ab

2075 resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree 

2076) -> tuple[ 

2077 Float32[Array, ' n'], 

2078 Float32[Array, ' 2**d'], 

2079 Bool[Array, ''], 

2080 Bool[Array, ''], 

2081 Float32[Array, ''] | None, 

2082]: 

2083 """ 

2084 Accept or reject a proposed move and sample the new leaf values. 

2085 

2086 Parameters 

2087 ---------- 

2088 resid 

2089 The residuals (data minus forest value). 

2090 at 

2091 The inputs that are the same for all trees. 

2092 pt 

2093 The inputs that are separate for each tree. 

2094 

2095 Returns 

2096 ------- 

2097 resid : Float32[Array, 'n'] 

2098 The updated residuals (data minus forest value). 

2099 leaf_tree : Float32[Array, '2**d'] 

2100 The new leaf values of the tree. 

2101 acc : Bool[Array, ''] 

2102 Whether the move was accepted. 

2103 to_prune : Bool[Array, ''] 

2104 Whether, to reflect the acceptance status of the move, the state should 

2105 be updated by pruning the leaves involved in the move. 

2106 log_lk_ratio : Float32[Array, ''] | None 

2107 The logarithm of the likelihood ratio for the move. `None` if not to be 

2108 saved. 

2109 """ 

2110 # sum residuals in each leaf, in tree proposed by grow move 

2111 if at.prec_scale is None: 1ab

2112 scaled_resid = resid 1ab

2113 else: 

2114 scaled_resid = resid * at.prec_scale 1ab

2115 resid_tree = sum_resid( 1ab

2116 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size 

2117 ) 

2118 

2119 # subtract starting tree from function 

2120 resid_tree += pt.prec_tree * pt.leaf_tree 1ab

2121 

2122 # sum residuals in parent node modified by move 

2123 resid_left = resid_tree[pt.move.left] 1ab

2124 resid_right = resid_tree[pt.move.right] 1ab

2125 resid_total = resid_left + resid_right 1ab

2126 assert pt.move.node.dtype == jnp.int32 1ab

2127 resid_tree = resid_tree.at[pt.move.node].set(resid_total) 1ab

2128 

2129 # compute acceptance ratio 

2130 log_lk_ratio = compute_likelihood_ratio( 1ab

2131 resid_total, resid_left, resid_right, pt.prelkv, at.prelk 

2132 ) 

2133 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab

2134 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab

2135 if not at.save_ratios: 1ab

2136 log_lk_ratio = None 1ab

2137 

2138 # determine whether to accept the move 

2139 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab

2140 

2141 # compute leaves posterior and sample leaves 

2142 mean_post = resid_tree * pt.prelf.mean_factor 1ab

2143 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab

2144 

2145 # copy leaves around such that the leaf indices point to the correct leaf 

2146 to_prune = acc ^ pt.move.grow 1ab

2147 leaf_tree = ( 1ab

2148 leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)] 

2149 .set(leaf_tree[pt.move.node]) 

2150 .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)] 

2151 .set(leaf_tree[pt.move.node]) 

2152 ) 

2153 

2154 # replace old tree with new tree in function values 

2155 resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices] 1ab

2156 

2157 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab

2158 

2159 

2160def sum_resid( 1ab

2161 scaled_resid: Float32[Array, ' n'], 

2162 leaf_indices: UInt[Array, ' n'], 

2163 tree_size: int, 

2164 batch_size: int | None, 

2165) -> Float32[Array, ' {tree_size}']: 

2166 """ 

2167 Sum the residuals in each leaf. 

2168 

2169 Parameters 

2170 ---------- 

2171 scaled_resid 

2172 The residuals (data minus forest value) multiplied by the error 

2173 precision scale. 

2174 leaf_indices 

2175 The leaf indices of the tree (in which leaf each data point falls into). 

2176 tree_size 

2177 The size of the tree array (2 ** d). 

2178 batch_size 

2179 The data batch size for the aggregation. Batching increases numerical 

2180 accuracy and parallelism. 

2181 

2182 Returns 

2183 ------- 

2184 The sum of the residuals at data points in each leaf. 

2185 """ 

2186 if batch_size is None: 1ab

2187 aggr_func = _aggregate_scatter 1ab

2188 else: 

2189 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab

2190 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab

2191 

2192 

2193def _aggregate_batched_onetree( 1ab

2194 values: Shaped[Array, '*'], 

2195 indices: Integer[Array, '*'], 

2196 size: int, 

2197 dtype: jnp.dtype, 

2198 batch_size: int, 

2199) -> Float32[Array, ' {size}']: 

2200 (n,) = indices.shape 1ab

2201 nbatches = n // batch_size + bool(n % batch_size) 1ab

2202 batch_indices = jnp.arange(n) % nbatches 1ab

2203 return ( 1ab

2204 jnp.zeros((size, nbatches), dtype) 

2205 .at[indices, batch_indices] 

2206 .add(values) 

2207 .sum(axis=1) 

2208 ) 

2209 

2210 

2211def compute_likelihood_ratio( 1ab

2212 total_resid: Float32[Array, ''], 

2213 left_resid: Float32[Array, ''], 

2214 right_resid: Float32[Array, ''], 

2215 prelkv: PreLkV, 

2216 prelk: PreLk, 

2217) -> Float32[Array, '']: 

2218 """ 

2219 Compute the likelihood ratio of a grow move. 

2220 

2221 Parameters 

2222 ---------- 

2223 total_resid 

2224 left_resid 

2225 right_resid 

2226 The sum of the residuals (scaled by error precision scale) of the 

2227 datapoints falling in the nodes involved in the moves. 

2228 prelkv 

2229 prelk 

2230 The pre-computed terms of the likelihood ratio, see 

2231 `precompute_likelihood_terms`. 

2232 

2233 Returns 

2234 ------- 

2235 The likelihood ratio P(data | new tree) / P(data | old tree). 

2236 """ 

2237 exp_term = prelk.exp_factor * ( 1ab

2238 left_resid * left_resid / prelkv.sigma2_left 

2239 + right_resid * right_resid / prelkv.sigma2_right 

2240 - total_resid * total_resid / prelkv.sigma2_total 

2241 ) 

2242 return prelkv.sqrt_term + exp_term 1ab

2243 

2244 

2245def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab

2246 """ 

2247 Post-process the mcmc state after accepting/rejecting the moves. 

2248 

2249 This function is separate from `accept_moves_sequential_stage` to signal it 

2250 can work in parallel across trees. 

2251 

2252 Parameters 

2253 ---------- 

2254 bart 

2255 A partially updated BART mcmc state. 

2256 moves 

2257 The proposed moves (see `propose_moves`) as updated by 

2258 `accept_moves_sequential_stage`. 

2259 

2260 Returns 

2261 ------- 

2262 The fully updated BART mcmc state. 

2263 """ 

2264 return replace( 1ab

2265 bart, 

2266 forest=replace( 

2267 bart.forest, 

2268 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

2269 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

2270 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves), 

2271 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves), 

2272 ), 

2273 ) 

2274 

2275 

2276@vmap_nodoc 1ab

2277def apply_moves_to_leaf_indices( 1ab

2278 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

2279) -> UInt[Array, 'num_trees n']: 

2280 """ 

2281 Update the leaf indices to match the accepted move. 

2282 

2283 Parameters 

2284 ---------- 

2285 leaf_indices 

2286 The index of the leaf each datapoint falls into, if the grow move was 

2287 accepted. 

2288 moves 

2289 The proposed moves (see `propose_moves`), as updated by 

2290 `accept_moves_sequential_stage`. 

2291 

2292 Returns 

2293 ------- 

2294 The updated leaf indices. 

2295 """ 

2296 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab

2297 is_child = (leaf_indices & mask) == moves.left 1ab

2298 return jnp.where( 1ab

2299 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices 

2300 ) 

2301 

2302 

2303@vmap_nodoc 1ab

2304def apply_moves_to_split_trees( 1ab

2305 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves 

2306) -> UInt[Array, 'num_trees 2**(d-1)']: 

2307 """ 

2308 Update the split trees to match the accepted move. 

2309 

2310 Parameters 

2311 ---------- 

2312 split_tree 

2313 The cutpoints of the decision nodes in the initial trees. 

2314 moves 

2315 The proposed moves (see `propose_moves`), as updated by 

2316 `accept_moves_sequential_stage`. 

2317 

2318 Returns 

2319 ------- 

2320 The updated split trees. 

2321 """ 

2322 assert moves.to_prune is not None 1ab

2323 return ( 1ab

2324 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)] 

2325 .set(moves.grow_split.astype(split_tree.dtype)) 

2326 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)] 

2327 .set(0) 

2328 ) 

2329 

2330 

2331def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab

2332 """ 

2333 MCMC-update the error variance (factor). 

2334 

2335 Parameters 

2336 ---------- 

2337 key 

2338 A jax random key. 

2339 bart 

2340 A BART mcmc state. 

2341 

2342 Returns 

2343 ------- 

2344 The new BART mcmc state, with an updated `sigma2`. 

2345 """ 

2346 resid = bart.resid 1ab

2347 alpha = bart.sigma2_alpha + resid.size / 2 1ab

2348 if bart.prec_scale is None: 1ab

2349 scaled_resid = resid 1ab

2350 else: 

2351 scaled_resid = resid * bart.prec_scale 1ab

2352 norm2 = resid @ scaled_resid 1ab

2353 beta = bart.sigma2_beta + norm2 / 2 1ab

2354 

2355 sample = random.gamma(key, alpha) 1ab

2356 # random.gamma seems to be slow at compiling, maybe cdf inversion would 

2357 # be better, but it's not implemented in jax 

2358 return replace(bart, sigma2=beta / sample) 1ab

2359 

2360 

2361def step_z(key: Key[Array, ''], bart: State) -> State: 1ab

2362 """ 

2363 MCMC-update the latent variable for binary regression. 

2364 

2365 Parameters 

2366 ---------- 

2367 key 

2368 A jax random key. 

2369 bart 

2370 A BART MCMC state. 

2371 

2372 Returns 

2373 ------- 

2374 The updated BART MCMC state. 

2375 """ 

2376 trees_plus_offset = bart.z - bart.resid 1ab

2377 assert bart.y.dtype == bool 1ab

2378 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1ab

2379 z = trees_plus_offset + resid 1ab

2380 return replace(bart, z=z, resid=resid) 1ab