Coverage for src/bartz/mcmcstep.py: 95%

567 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-09-06 16:14 +0000

1# bartz/src/bartz/mcmcstep.py 

2# 

3# Copyright (c) 2024-2025, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26Functions that implement the BART posterior MCMC initialization and update step. 

27 

28Functions that do MCMC steps operate by taking as input a bart state, and 

29outputting a new state. The inputs are not modified. 

30 

31The entry points are: 

32 

33 - `State`: The dataclass that represents a BART MCMC state. 

34 - `init`: Creates an initial `State` from data and configurations. 

35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`. 

36 - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`. 

37""" 

38 

39import math 1ab

40from dataclasses import replace 1ab

41from functools import cache, partial 1ab

42from typing import Any, Literal 1ab

43 

44import jax 1ab

45from equinox import Module, field, tree_at 1ab

46from jax import lax, random 1ab

47from jax import numpy as jnp 1ab

48from jax.scipy.linalg import solve_triangular 1ab

49from jax.scipy.special import gammaln, logsumexp 1ab

50from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab

51 

52from bartz import grove 1ab

53from bartz.jaxext import ( 1ab

54 minimal_unsigned_dtype, 

55 split, 

56 truncated_normal_onesided, 

57 vmap_nodoc, 

58) 

59 

60 

61class Forest(Module): 1ab

62 """ 

63 Represents the MCMC state of a sum of trees. 

64 

65 Parameters 

66 ---------- 

67 leaf_tree 

68 The leaf values. 

69 var_tree 

70 The decision axes. 

71 split_tree 

72 The decision boundaries. 

73 affluence_tree 

74 Marks leaves that can be grown. 

75 max_split 

76 The maximum split index for each predictor. 

77 blocked_vars 

78 Indices of variables that are not used. This shall include at least 

79 the `i` such that ``max_split[i] == 0``, otherwise behavior is 

80 undefined. 

81 p_nonterminal 

82 The prior probability of each node being nonterminal, conditional on 

83 its ancestors. Includes the nodes at maximum depth which should be set 

84 to 0. 

85 p_propose_grow 

86 The unnormalized probability of picking a leaf for a grow proposal. 

87 leaf_indices 

88 The index of the leaf each datapoints falls into, for each tree. 

89 min_points_per_decision_node 

90 The minimum number of data points in a decision node. 

91 min_points_per_leaf 

92 The minimum number of data points in a leaf node. 

93 resid_batch_size 

94 count_batch_size 

95 The data batch sizes for computing the sufficient statistics. If `None`, 

96 they are computed with no batching. 

97 log_trans_prior 

98 The log transition and prior Metropolis-Hastings ratio for the 

99 proposed move on each tree. 

100 log_likelihood 

101 The log likelihood ratio. 

102 grow_prop_count 

103 prune_prop_count 

104 The number of grow/prune proposals made during one full MCMC cycle. 

105 grow_acc_count 

106 prune_acc_count 

107 The number of grow/prune moves accepted during one full MCMC cycle. 

108 sigma_mu2 

109 The prior variance of a leaf, conditional on the tree structure. 

110 log_s 

111 The logarithm of the prior probability for choosing a variable to split 

112 along in a decision rule, conditional on the ancestors. Not normalized. 

113 If `None`, use a uniform distribution. 

114 theta 

115 The concentration parameter for the Dirichlet prior on the variable 

116 distribution `s`. Required only to update `s`. 

117 a 

118 b 

119 rho 

120 Parameters of the prior on `theta`. Required only to sample `theta`. 

121 See `step_theta`. 

122 """ 

123 

124 leaf_tree: Float32[Array, 'num_trees 2**d'] 1ab

125 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

126 split_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

127 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

128 max_split: UInt[Array, ' p'] 1ab

129 blocked_vars: UInt[Array, ' k'] | None 1ab

130 p_nonterminal: Float32[Array, ' 2**d'] 1ab

131 p_propose_grow: Float32[Array, ' 2**(d-1)'] 1ab

132 leaf_indices: UInt[Array, 'num_trees n'] 1ab

133 min_points_per_decision_node: Int32[Array, ''] | None 1ab

134 min_points_per_leaf: Int32[Array, ''] | None 1ab

135 resid_batch_size: int | None = field(static=True) 1ab

136 count_batch_size: int | None = field(static=True) 1ab

137 log_trans_prior: Float32[Array, ' num_trees'] | None 1ab

138 log_likelihood: Float32[Array, ' num_trees'] | None 1ab

139 grow_prop_count: Int32[Array, ''] 1ab

140 prune_prop_count: Int32[Array, ''] 1ab

141 grow_acc_count: Int32[Array, ''] 1ab

142 prune_acc_count: Int32[Array, ''] 1ab

143 sigma_mu2: Float32[Array, ''] 1ab

144 log_s: Float32[Array, ' p'] | None 1ab

145 theta: Float32[Array, ''] | None 1ab

146 a: Float32[Array, ''] | None 1ab

147 b: Float32[Array, ''] | None 1ab

148 rho: Float32[Array, ''] | None 1ab

149 

150 

151class State(Module): 1ab

152 """ 

153 Represents the MCMC state of BART. 

154 

155 Parameters 

156 ---------- 

157 X 

158 The predictors. 

159 y 

160 The response. If the data type is `bool`, the model is binary regression. 

161 resid 

162 The residuals (`y` or `z` minus sum of trees). 

163 z 

164 The latent variable for binary regression. `None` in continuous 

165 regression. 

166 offset 

167 Constant shift added to the sum of trees. 

168 sigma2 

169 The error variance. `None` in binary regression. 

170 prec_scale 

171 The scale on the error precision, i.e., ``1 / error_scale ** 2``. 

172 `None` in binary regression. 

173 sigma2_alpha 

174 sigma2_beta 

175 The shape and scale parameters of the inverse gamma prior on the noise 

176 variance. `None` in binary regression. 

177 forest 

178 The sum of trees model. 

179 """ 

180 

181 X: UInt[Array, 'p n'] 1ab

182 y: Float32[Array, ' n'] | Bool[Array, ' n'] 1ab

183 z: None | Float32[Array, ' n'] 1ab

184 offset: Float32[Array, ''] 1ab

185 resid: Float32[Array, ' n'] 1ab

186 sigma2: Float32[Array, ''] | None 1ab

187 prec_scale: Float32[Array, ' n'] | None 1ab

188 sigma2_alpha: Float32[Array, ''] | None 1ab

189 sigma2_beta: Float32[Array, ''] | None 1ab

190 forest: Forest 1ab

191 

192 

193def init( 1ab

194 *, 

195 X: UInt[Any, 'p n'], 

196 y: Float32[Any, ' n'] | Bool[Any, ' n'], 

197 offset: float | Float32[Any, ''] = 0.0, 

198 max_split: UInt[Any, ' p'], 

199 num_trees: int, 

200 p_nonterminal: Float32[Any, ' d-1'], 

201 sigma_mu2: float | Float32[Any, ''], 

202 sigma2_alpha: float | Float32[Any, ''] | None = None, 

203 sigma2_beta: float | Float32[Any, ''] | None = None, 

204 error_scale: Float32[Any, ' n'] | None = None, 

205 min_points_per_decision_node: int | Integer[Any, ''] | None = None, 

206 resid_batch_size: int | None | Literal['auto'] = 'auto', 

207 count_batch_size: int | None | Literal['auto'] = 'auto', 

208 save_ratios: bool = False, 

209 filter_splitless_vars: bool = True, 

210 min_points_per_leaf: int | Integer[Any, ''] | None = None, 

211 log_s: Float32[Any, ' p'] | None = None, 

212 theta: float | Float32[Any, ''] | None = None, 

213 a: float | Float32[Any, ''] | None = None, 

214 b: float | Float32[Any, ''] | None = None, 

215 rho: float | Float32[Any, ''] | None = None, 

216) -> State: 

217 """ 

218 Make a BART posterior sampling MCMC initial state. 

219 

220 Parameters 

221 ---------- 

222 X 

223 The predictors. Note this is trasposed compared to the usual convention. 

224 y 

225 The response. If the data type is `bool`, the regression model is binary 

226 regression with probit. 

227 offset 

228 Constant shift added to the sum of trees. 0 if not specified. 

229 max_split 

230 The maximum split index for each variable. All split ranges start at 1. 

231 num_trees 

232 The number of trees in the forest. 

233 p_nonterminal 

234 The probability of a nonterminal node at each depth. The maximum depth 

235 of trees is fixed by the length of this array. 

236 sigma_mu2 

237 The prior variance of a leaf, conditional on the tree structure. The 

238 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The 

239 prior mean of leaves is always zero. 

240 sigma2_alpha 

241 sigma2_beta 

242 The shape and scale parameters of the inverse gamma prior on the error 

243 variance. Leave unspecified for binary regression. 

244 error_scale 

245 Each error is scaled by the corresponding factor in `error_scale`, so 

246 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``. 

247 Not supported for binary regression. If not specified, defaults to 1 for 

248 all points, but potentially skipping calculations. 

249 min_points_per_decision_node 

250 The minimum number of data points in a decision node. 0 if not 

251 specified. 

252 resid_batch_size 

253 count_batch_size 

254 The batch sizes, along datapoints, for summing the residuals and 

255 counting the number of datapoints in each leaf. `None` for no batching. 

256 If 'auto', pick a value based on the device of `y`, or the default 

257 device. 

258 save_ratios 

259 Whether to save the Metropolis-Hastings ratios. 

260 filter_splitless_vars 

261 Whether to check `max_split` for variables without available cutpoints. 

262 If any are found, they are put into a list of variables to exclude from 

263 the MCMC. If `False`, no check is performed, but the results may be 

264 wrong if any variable is blocked. The function is jax-traceable only 

265 if this is set to `False`. 

266 min_points_per_leaf 

267 The minimum number of datapoints in a leaf node. 0 if not specified. 

268 Unlike `min_points_per_decision_node`, this constraint is not taken into 

269 account in the Metropolis-Hastings ratio because it would be expensive 

270 to compute. Grow moves that would violate this constraint are vetoed. 

271 This parameter is independent of `min_points_per_decision_node` and 

272 there is no check that they are coherent. It makes sense to set 

273 ``min_points_per_decision_node >= 2 * min_points_per_leaf``. 

274 log_s 

275 The logarithm of the prior probability for choosing a variable to split 

276 along in a decision rule, conditional on the ancestors. Not normalized. 

277 If not specified, use a uniform distribution. If not specified and 

278 `theta` or `rho`, `a`, `b` are, it's initialized automatically. 

279 theta 

280 The concentration parameter for the Dirichlet prior on `s`. Required 

281 only to update `log_s`. If not specified, and `rho`, `a`, `b` are 

282 specified, it's initialized automatically. 

283 a 

284 b 

285 rho 

286 Parameters of the prior on `theta`. Required only to sample `theta`. 

287 

288 Returns 

289 ------- 

290 An initialized BART MCMC state. 

291 

292 Raises 

293 ------ 

294 ValueError 

295 If `y` is boolean and arguments unused in binary regression are set. 

296 

297 Notes 

298 ----- 

299 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out 

300 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left 

301 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be 

302 integers in the range ``[0, 1, ..., max_split[i]]``. 

303 """ 

304 p_nonterminal = jnp.asarray(p_nonterminal) 1ab

305 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab

306 max_depth = p_nonterminal.size 1ab

307 

308 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab

309 def make_forest(max_depth, dtype): 1ab

310 return grove.make_tree(max_depth, dtype) 1ab

311 

312 y = jnp.asarray(y) 1ab

313 offset = jnp.asarray(offset) 1ab

314 

315 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab

316 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees 

317 ) 

318 

319 is_binary = y.dtype == bool 1ab

320 if is_binary: 1ab

321 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true1ab

322 msg = ( 

323 'error_scale, sigma2_alpha, and sigma2_beta must be set ' 

324 ' to `None` for binary regression.' 

325 ) 

326 raise ValueError(msg) 

327 sigma2 = None 1ab

328 else: 

329 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab

330 sigma2_beta = jnp.asarray(sigma2_beta) 1ab

331 sigma2 = sigma2_beta / sigma2_alpha 1ab

332 

333 max_split = jnp.asarray(max_split) 1ab

334 

335 if filter_splitless_vars: 1ab

336 (blocked_vars,) = jnp.nonzero(max_split == 0) 1ab

337 blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size)) 1ab

338 # see `fully_used_variables` for the type cast 

339 else: 

340 blocked_vars = None 1ab

341 

342 # check and initialize sparsity parameters 

343 if not _all_none_or_not_none(rho, a, b): 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true1ab

344 msg = 'rho, a, b are not either all `None` or all set' 

345 raise ValueError(msg) 

346 if theta is None and rho is not None: 1ab

347 theta = rho 1ab

348 if log_s is None and theta is not None: 1ab

349 log_s = jnp.zeros(max_split.size) 1ab

350 

351 return State( 1ab

352 X=jnp.asarray(X), 

353 y=y, 

354 z=jnp.full(y.shape, offset) if is_binary else None, 

355 offset=offset, 

356 resid=jnp.zeros(y.shape) if is_binary else y - offset, 

357 sigma2=sigma2, 

358 prec_scale=( 

359 None if error_scale is None else lax.reciprocal(jnp.square(error_scale)) 

360 ), 

361 sigma2_alpha=sigma2_alpha, 

362 sigma2_beta=sigma2_beta, 

363 forest=Forest( 

364 leaf_tree=make_forest(max_depth, jnp.float32), 

365 var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)), 

366 split_tree=make_forest(max_depth - 1, max_split.dtype), 

367 affluence_tree=( 

368 make_forest(max_depth - 1, bool) 

369 .at[:, 1] 

370 .set( 

371 True 

372 if min_points_per_decision_node is None 

373 else y.size >= min_points_per_decision_node 

374 ) 

375 ), 

376 blocked_vars=blocked_vars, 

377 max_split=max_split, 

378 grow_prop_count=jnp.zeros((), int), 

379 grow_acc_count=jnp.zeros((), int), 

380 prune_prop_count=jnp.zeros((), int), 

381 prune_acc_count=jnp.zeros((), int), 

382 p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)], 

383 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))], 

384 leaf_indices=jnp.ones( 

385 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1) 

386 ), 

387 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node), 

388 min_points_per_leaf=_asarray_or_none(min_points_per_leaf), 

389 resid_batch_size=resid_batch_size, 

390 count_batch_size=count_batch_size, 

391 log_trans_prior=jnp.zeros(num_trees) if save_ratios else None, 

392 log_likelihood=jnp.zeros(num_trees) if save_ratios else None, 

393 sigma_mu2=jnp.asarray(sigma_mu2), 

394 log_s=_asarray_or_none(log_s), 

395 theta=_asarray_or_none(theta), 

396 rho=_asarray_or_none(rho), 

397 a=_asarray_or_none(a), 

398 b=_asarray_or_none(b), 

399 ), 

400 ) 

401 

402 

403def _all_none_or_not_none(*args): 1ab

404 is_none = [x is None for x in args] 1ab

405 return all(is_none) or not any(is_none) 1ab

406 

407 

408def _asarray_or_none(x): 1ab

409 if x is None: 1ab

410 return None 1ab

411 return jnp.asarray(x) 1ab

412 

413 

414def _choose_suffstat_batch_size( 1ab

415 resid_batch_size, count_batch_size, y, forest_size 

416) -> tuple[int | None, ...]: 

417 @cache 1ab

418 def get_platform(): 1ab

419 try: 1ab

420 device = y.devices().pop() 1ab

421 except jax.errors.ConcretizationTypeError: 1ab

422 device = jax.devices()[0] 1ab

423 platform = device.platform 1ab

424 if platform not in ('cpu', 'gpu'): 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true1ab

425 msg = f'Unknown platform: {platform}' 

426 raise KeyError(msg) 

427 return platform 1ab

428 

429 if resid_batch_size == 'auto': 1ab

430 platform = get_platform() 1ab

431 n = max(1, y.size) 1ab

432 if platform == 'cpu': 432 ↛ 434line 432 didn't jump to line 434 because the condition on line 432 was always true1ab

433 resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6 1ab

434 elif platform == 'gpu': 

435 resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3 

436 resid_batch_size = max(1, resid_batch_size) 1ab

437 

438 if count_batch_size == 'auto': 1ab

439 platform = get_platform() 1ab

440 if platform == 'cpu': 440 ↛ 442line 440 didn't jump to line 442 because the condition on line 440 was always true1ab

441 count_batch_size = None 1ab

442 elif platform == 'gpu': 

443 n = max(1, y.size) 

444 count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2 

445 # /4 is good on V100, /2 on L4/T4, still haven't tried A100 

446 max_memory = 2**29 

447 itemsize = 4 

448 min_batch_size = math.ceil(forest_size * itemsize * n / max_memory) 

449 count_batch_size = max(count_batch_size, min_batch_size) 

450 count_batch_size = max(1, count_batch_size) 

451 

452 return resid_batch_size, count_batch_size 1ab

453 

454 

455@jax.jit 1ab

456def step(key: Key[Array, ''], bart: State) -> State: 1ab

457 """ 

458 Do one MCMC step. 

459 

460 Parameters 

461 ---------- 

462 key 

463 A jax random key. 

464 bart 

465 A BART mcmc state, as created by `init`. 

466 

467 Returns 

468 ------- 

469 The new BART mcmc state. 

470 """ 

471 keys = split(key) 1ab

472 

473 if bart.y.dtype == bool: # binary regression 1ab

474 bart = replace(bart, sigma2=jnp.float32(1)) 1ab

475 bart = step_trees(keys.pop(), bart) 1ab

476 bart = replace(bart, sigma2=None) 1ab

477 return step_z(keys.pop(), bart) 1ab

478 

479 else: # continuous regression 

480 bart = step_trees(keys.pop(), bart) 1ab

481 return step_sigma(keys.pop(), bart) 1ab

482 

483 

484def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab

485 """ 

486 Forest sampling step of BART MCMC. 

487 

488 Parameters 

489 ---------- 

490 key 

491 A jax random key. 

492 bart 

493 A BART mcmc state, as created by `init`. 

494 

495 Returns 

496 ------- 

497 The new BART mcmc state. 

498 

499 Notes 

500 ----- 

501 This function zeroes the proposal counters. 

502 """ 

503 keys = split(key) 1ab

504 moves = propose_moves(keys.pop(), bart.forest) 1ab

505 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab

506 

507 

508class Moves(Module): 1ab

509 """ 

510 Moves proposed to modify each tree. 

511 

512 Parameters 

513 ---------- 

514 allowed 

515 Whether there is a possible move. If `False`, the other values may not 

516 make sense. The only case in which a move is marked as allowed but is 

517 then vetoed is if it does not satisfy `min_points_per_leaf`, which for 

518 efficiency is implemented post-hoc without changing the rest of the 

519 MCMC logic. 

520 grow 

521 Whether the move is a grow move or a prune move. 

522 num_growable 

523 The number of growable leaves in the original tree. 

524 node 

525 The index of the leaf to grow or node to prune. 

526 left 

527 right 

528 The indices of the children of 'node'. 

529 partial_ratio 

530 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

531 likelihood ratio, the probability of proposing the prune move, and the 

532 probability that the children of the modified node are terminal. If the 

533 move is PRUNE, the ratio is inverted. `None` once 

534 `log_trans_prior_ratio` has been computed. 

535 log_trans_prior_ratio 

536 The logarithm of the product of the transition and prior terms of the 

537 Metropolis-Hastings ratio for the acceptance of the proposed move. 

538 `None` if not yet computed. If PRUNE, the log-ratio is negated. 

539 grow_var 

540 The decision axes of the new rules. 

541 grow_split 

542 The decision boundaries of the new rules. 

543 var_tree 

544 The updated decision axes of the trees, valid whatever move. 

545 affluence_tree 

546 A partially updated `affluence_tree`, marking non-leaf nodes that would 

547 become leaves if the move was accepted. This mark initially (out of 

548 `propose_moves`) takes into account if there would be available decision 

549 rules to grow the leaf, and whether there are enough datapoints in the 

550 node is marked in `accept_moves_parallel_stage`. 

551 logu 

552 The logarithm of a uniform (0, 1] random variable to be used to 

553 accept the move. It's in (-oo, 0]. 

554 acc 

555 Whether the move was accepted. `None` if not yet computed. 

556 to_prune 

557 Whether the final operation to apply the move is pruning. This indicates 

558 an accepted prune move or a rejected grow move. `None` if not yet 

559 computed. 

560 """ 

561 

562 allowed: Bool[Array, ' num_trees'] 1ab

563 grow: Bool[Array, ' num_trees'] 1ab

564 num_growable: UInt[Array, ' num_trees'] 1ab

565 node: UInt[Array, ' num_trees'] 1ab

566 left: UInt[Array, ' num_trees'] 1ab

567 right: UInt[Array, ' num_trees'] 1ab

568 partial_ratio: Float32[Array, ' num_trees'] | None 1ab

569 log_trans_prior_ratio: None | Float32[Array, ' num_trees'] 1ab

570 grow_var: UInt[Array, ' num_trees'] 1ab

571 grow_split: UInt[Array, ' num_trees'] 1ab

572 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

573 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

574 logu: Float32[Array, ' num_trees'] 1ab

575 acc: None | Bool[Array, ' num_trees'] 1ab

576 to_prune: None | Bool[Array, ' num_trees'] 1ab

577 

578 

579def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 1ab

580 """ 

581 Propose moves for all the trees. 

582 

583 There are two types of moves: GROW (convert a leaf to a decision node and 

584 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a 

585 leaf, deleting its children). 

586 

587 Parameters 

588 ---------- 

589 key 

590 A jax random key. 

591 forest 

592 The `forest` field of a BART MCMC state. 

593 

594 Returns 

595 ------- 

596 The proposed move for each tree. 

597 """ 

598 num_trees, _ = forest.leaf_tree.shape 1ab

599 keys = split(key, 1 + 2 * num_trees) 1ab

600 

601 # compute moves 

602 grow_moves = propose_grow_moves( 1ab

603 keys.pop(num_trees), 

604 forest.var_tree, 

605 forest.split_tree, 

606 forest.affluence_tree, 

607 forest.max_split, 

608 forest.blocked_vars, 

609 forest.p_nonterminal, 

610 forest.p_propose_grow, 

611 forest.log_s, 

612 ) 

613 prune_moves = propose_prune_moves( 1ab

614 keys.pop(num_trees), 

615 forest.split_tree, 

616 grow_moves.affluence_tree, 

617 forest.p_nonterminal, 

618 forest.p_propose_grow, 

619 ) 

620 

621 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1ab

622 

623 # choose between grow or prune 

624 p_grow = jnp.where( 1ab

625 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed 

626 ) 

627 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab

628 

629 # compute children indices 

630 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab

631 left = node << 1 1ab

632 right = left + 1 1ab

633 

634 return Moves( 1ab

635 allowed=grow_moves.allowed | prune_moves.allowed, 

636 grow=grow, 

637 num_growable=grow_moves.num_growable, 

638 node=node, 

639 left=left, 

640 right=right, 

641 partial_ratio=jnp.where( 

642 grow, grow_moves.partial_ratio, prune_moves.partial_ratio 

643 ), 

644 log_trans_prior_ratio=None, # will be set in complete_ratio 

645 grow_var=grow_moves.var, 

646 grow_split=grow_moves.split, 

647 # var_tree does not need to be updated if prune 

648 var_tree=grow_moves.var_tree, 

649 # affluence_tree is updated for both moves unconditionally, prune last 

650 affluence_tree=prune_moves.affluence_tree, 

651 logu=jnp.log1p(-exp1mlogu), 

652 acc=None, # will be set in accept_moves_sequential_stage 

653 to_prune=None, # will be set in accept_moves_sequential_stage 

654 ) 

655 

656 

657class GrowMoves(Module): 1ab

658 """ 

659 Represent a proposed grow move for each tree. 

660 

661 Parameters 

662 ---------- 

663 allowed 

664 Whether the move is allowed for proposal. 

665 num_growable 

666 The number of leaves that can be proposed for grow. 

667 node 

668 The index of the leaf to grow. ``2 ** d`` if there are no growable 

669 leaves. 

670 var 

671 split 

672 The decision axis and boundary of the new rule. 

673 partial_ratio 

674 A factor of the Metropolis-Hastings ratio of the move. It lacks 

675 the likelihood ratio and the probability of proposing the prune 

676 move. 

677 var_tree 

678 The updated decision axes of the tree. 

679 affluence_tree 

680 A partially updated `affluence_tree` that marks each new leaf that 

681 would be produced as `True` if it would have available decision rules. 

682 """ 

683 

684 allowed: Bool[Array, ' num_trees'] 1ab

685 num_growable: UInt[Array, ' num_trees'] 1ab

686 node: UInt[Array, ' num_trees'] 1ab

687 var: UInt[Array, ' num_trees'] 1ab

688 split: UInt[Array, ' num_trees'] 1ab

689 partial_ratio: Float32[Array, ' num_trees'] 1ab

690 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

691 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

692 

693 

694@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) 1ab

695def propose_grow_moves( 1ab

696 key: Key[Array, ' num_trees'], 

697 var_tree: UInt[Array, 'num_trees 2**(d-1)'], 

698 split_tree: UInt[Array, 'num_trees 2**(d-1)'], 

699 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'], 

700 max_split: UInt[Array, ' p'], 

701 blocked_vars: Int32[Array, ' k'] | None, 

702 p_nonterminal: Float32[Array, ' 2**d'], 

703 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

704 log_s: Float32[Array, ' p'] | None, 

705) -> GrowMoves: 

706 """ 

707 Propose a GROW move for each tree. 

708 

709 A GROW move picks a leaf node and converts it to a non-terminal node with 

710 two leaf children. 

711 

712 Parameters 

713 ---------- 

714 key 

715 A jax random key. 

716 var_tree 

717 The splitting axes of the tree. 

718 split_tree 

719 The splitting points of the tree. 

720 affluence_tree 

721 Whether each leaf has enough points to be grown. 

722 max_split 

723 The maximum split index for each variable. 

724 blocked_vars 

725 The indices of the variables that have no available cutpoints. 

726 p_nonterminal 

727 The a priori probability of a node to be nonterminal conditional on the 

728 ancestors, including at the maximum depth where it should be zero. 

729 p_propose_grow 

730 The unnormalized probability of choosing a leaf to grow. 

731 log_s 

732 Unnormalized log-probability used to choose a variable to split on 

733 amongst the available ones. 

734 

735 Returns 

736 ------- 

737 An object representing the proposed move. 

738 

739 Notes 

740 ----- 

741 The move is not proposed if each leaf is already at maximum depth, or has 

742 less datapoints than the requested threshold `min_points_per_decision_node`, 

743 or it does not have any available decision rules given its ancestors. This 

744 is marked by setting `allowed` to `False` and `num_growable` to 0. 

745 """ 

746 keys = split(key, 3) 1ab

747 

748 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab

749 keys.pop(), split_tree, affluence_tree, p_propose_grow 

750 ) 

751 

752 # sample a decision rule 

753 var, num_available_var = choose_variable( 1ab

754 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s 

755 ) 

756 split_idx, l, r = choose_split( 1ab

757 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow 

758 ) 

759 

760 # determine if the new leaves would have available decision rules; if the 

761 # move is blocked, these values may not make sense 

762 left_growable = right_growable = num_available_var > 1 1ab

763 left_growable |= l < split_idx 1ab

764 right_growable |= split_idx + 1 < r 1ab

765 left = leaf_to_grow << 1 1ab

766 right = left + 1 1ab

767 affluence_tree = affluence_tree.at[left].set(left_growable) 1ab

768 affluence_tree = affluence_tree.at[right].set(right_growable) 1ab

769 

770 ratio = compute_partial_ratio( 1ab

771 prob_choose, num_prunable, p_nonterminal, leaf_to_grow 

772 ) 

773 

774 return GrowMoves( 1ab

775 allowed=num_growable > 0, 

776 num_growable=num_growable, 

777 node=leaf_to_grow, 

778 var=var, 

779 split=split_idx, 

780 partial_ratio=ratio, 

781 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)), 

782 affluence_tree=affluence_tree, 

783 ) 

784 

785 

786def choose_leaf( 1ab

787 key: Key[Array, ''], 

788 split_tree: UInt[Array, ' 2**(d-1)'], 

789 affluence_tree: Bool[Array, ' 2**(d-1)'], 

790 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

791) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]: 

792 """ 

793 Choose a leaf node to grow in a tree. 

794 

795 Parameters 

796 ---------- 

797 key 

798 A jax random key. 

799 split_tree 

800 The splitting points of the tree. 

801 affluence_tree 

802 Whether a leaf has enough points that it could be split into two leaves 

803 satisfying the `min_points_per_leaf` requirement. 

804 p_propose_grow 

805 The unnormalized probability of choosing a leaf to grow. 

806 

807 Returns 

808 ------- 

809 leaf_to_grow : Int32[Array, ''] 

810 The index of the leaf to grow. If ``num_growable == 0``, return 

811 ``2 ** d``. 

812 num_growable : Int32[Array, ''] 

813 The number of leaf nodes that can be grown, i.e., are nonterminal 

814 and have at least twice `min_points_per_leaf`. 

815 prob_choose : Float32[Array, ''] 

816 The (normalized) probability that this function had to choose that 

817 specific leaf, given the arguments. 

818 num_prunable : Int32[Array, ''] 

819 The number of leaf parents that could be pruned, after converting the 

820 selected leaf to a non-terminal node. 

821 """ 

822 is_growable = growable_leaves(split_tree, affluence_tree) 1ab

823 num_growable = jnp.count_nonzero(is_growable) 1ab

824 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab

825 leaf_to_grow, distr_norm = categorical(key, distr) 1ab

826 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab

827 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1ab

828 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab

829 num_prunable = jnp.count_nonzero(is_parent) 1ab

830 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab

831 

832 

833def growable_leaves( 1ab

834 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)'] 

835) -> Bool[Array, ' 2**(d-1)']: 

836 """ 

837 Return a mask indicating the leaf nodes that can be proposed for growth. 

838 

839 The condition is that a leaf is not at the bottom level, has available 

840 decision rules given its ancestors, and has at least 

841 `min_points_per_decision_node` points. 

842 

843 Parameters 

844 ---------- 

845 split_tree 

846 The splitting points of the tree. 

847 affluence_tree 

848 Marks leaves that can be grown. 

849 

850 Returns 

851 ------- 

852 The mask indicating the leaf nodes that can be proposed to grow. 

853 

854 Notes 

855 ----- 

856 This function needs `split_tree` and not just `affluence_tree` because 

857 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`. 

858 """ 

859 return grove.is_actual_leaf(split_tree) & affluence_tree 1ab

860 

861 

862def categorical( 1ab

863 key: Key[Array, ''], distr: Float32[Array, ' n'] 

864) -> tuple[Int32[Array, ''], Float32[Array, '']]: 

865 """ 

866 Return a random integer from an arbitrary distribution. 

867 

868 Parameters 

869 ---------- 

870 key 

871 A jax random key. 

872 distr 

873 An unnormalized probability distribution. 

874 

875 Returns 

876 ------- 

877 u : Int32[Array, ''] 

878 A random integer in the range ``[0, n)``. If all probabilities are zero, 

879 return ``n``. 

880 norm : Float32[Array, ''] 

881 The sum of `distr`. 

882 

883 Notes 

884 ----- 

885 This function uses a cumsum instead of the Gumbel trick, so it's ok only 

886 for small ranges with probabilities well greater than 0. 

887 """ 

888 ecdf = jnp.cumsum(distr) 1ab

889 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab

890 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab

891 

892 

893def choose_variable( 1ab

894 key: Key[Array, ''], 

895 var_tree: UInt[Array, ' 2**(d-1)'], 

896 split_tree: UInt[Array, ' 2**(d-1)'], 

897 max_split: UInt[Array, ' p'], 

898 leaf_index: Int32[Array, ''], 

899 blocked_vars: Int32[Array, ' k'] | None, 

900 log_s: Float32[Array, ' p'] | None, 

901) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

902 """ 

903 Choose a variable to split on for a new non-terminal node. 

904 

905 Parameters 

906 ---------- 

907 key 

908 A jax random key. 

909 var_tree 

910 The variable indices of the tree. 

911 split_tree 

912 The splitting points of the tree. 

913 max_split 

914 The maximum split index for each variable. 

915 leaf_index 

916 The index of the leaf to grow. 

917 blocked_vars 

918 The indices of the variables that have no available cutpoints. If 

919 `None`, all variables are assumed unblocked. 

920 log_s 

921 The logarithm of the prior probability for choosing a variable. If 

922 `None`, use a uniform distribution. 

923 

924 Returns 

925 ------- 

926 var : Int32[Array, ''] 

927 The index of the variable to split on. 

928 num_available_var : Int32[Array, ''] 

929 The number of variables with available decision rules `var` was chosen 

930 from. 

931 """ 

932 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab

933 if blocked_vars is not None: 1ab

934 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab

935 

936 if log_s is None: 1ab

937 return randint_exclude(key, max_split.size, var_to_ignore) 1ab

938 else: 

939 return categorical_exclude(key, log_s, var_to_ignore) 1ab

940 

941 

942def fully_used_variables( 1ab

943 var_tree: UInt[Array, ' 2**(d-1)'], 

944 split_tree: UInt[Array, ' 2**(d-1)'], 

945 max_split: UInt[Array, ' p'], 

946 leaf_index: Int32[Array, ''], 

947) -> UInt[Array, ' d-2']: 

948 """ 

949 Find variables in the ancestors of a node that have an empty split range. 

950 

951 Parameters 

952 ---------- 

953 var_tree 

954 The variable indices of the tree. 

955 split_tree 

956 The splitting points of the tree. 

957 max_split 

958 The maximum split index for each variable. 

959 leaf_index 

960 The index of the node, assumed to be valid for `var_tree`. 

961 

962 Returns 

963 ------- 

964 The indices of the variables that have an empty split range. 

965 

966 Notes 

967 ----- 

968 The number of unused variables is not known in advance. Unused values in the 

969 array are filled with `p`. The fill values are not guaranteed to be placed 

970 in any particular order, and variables may appear more than once. 

971 """ 

972 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab

973 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab

974 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab

975 num_split = r - l 1ab

976 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab

977 # the type of var_to_ignore is already sufficient to hold max_split.size, 

978 # see ancestor_variables() 

979 

980 

981def ancestor_variables( 1ab

982 var_tree: UInt[Array, ' 2**(d-1)'], 

983 max_split: UInt[Array, ' p'], 

984 node_index: Int32[Array, ''], 

985) -> UInt[Array, ' d-2']: 

986 """ 

987 Return the list of variables in the ancestors of a node. 

988 

989 Parameters 

990 ---------- 

991 var_tree 

992 The variable indices of the tree. 

993 max_split 

994 The maximum split index for each variable. Used only to get `p`. 

995 node_index 

996 The index of the node, assumed to be valid for `var_tree`. 

997 

998 Returns 

999 ------- 

1000 The variable indices of the ancestors of the node. 

1001 

1002 Notes 

1003 ----- 

1004 The ancestors are the nodes going from the root to the parent of the node. 

1005 The number of ancestors is not known at tracing time; unused spots in the 

1006 output array are filled with `p`. 

1007 """ 

1008 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

1009 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab

1010 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab

1011 

1012 def loop(carry, _): 1ab

1013 i, index, ancestor_vars = carry 1ab

1014 index >>= 1 1ab

1015 var = var_tree[index] 1ab

1016 var = jnp.where(index, var, max_split.size) 1ab

1017 ancestor_vars = ancestor_vars.at[i].set(var) 1ab

1018 return (i - 1, index, ancestor_vars), None 1ab

1019 

1020 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab

1021 return ancestor_vars 1ab

1022 

1023 

1024def split_range( 1ab

1025 var_tree: UInt[Array, ' 2**(d-1)'], 

1026 split_tree: UInt[Array, ' 2**(d-1)'], 

1027 max_split: UInt[Array, ' p'], 

1028 node_index: Int32[Array, ''], 

1029 ref_var: Int32[Array, ''], 

1030) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1031 """ 

1032 Return the range of allowed splits for a variable at a given node. 

1033 

1034 Parameters 

1035 ---------- 

1036 var_tree 

1037 The variable indices of the tree. 

1038 split_tree 

1039 The splitting points of the tree. 

1040 max_split 

1041 The maximum split index for each variable. 

1042 node_index 

1043 The index of the node, assumed to be valid for `var_tree`. 

1044 ref_var 

1045 The variable for which to measure the split range. 

1046 

1047 Returns 

1048 ------- 

1049 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1. 

1050 """ 

1051 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

1052 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab

1053 jnp.int32 

1054 ) 

1055 carry = jnp.int32(0), initial_r, node_index 1ab

1056 

1057 def loop(carry, _): 1ab

1058 l, r, index = carry 1ab

1059 right_child = (index & 1).astype(bool) 1ab

1060 index >>= 1 1ab

1061 split = split_tree[index] 1ab

1062 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab

1063 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab

1064 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab

1065 return (l, r, index), None 1ab

1066 

1067 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab

1068 return l + 1, r 1ab

1069 

1070 

1071def randint_exclude( 1ab

1072 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n'] 

1073) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1074 """ 

1075 Return a random integer in a range, excluding some values. 

1076 

1077 Parameters 

1078 ---------- 

1079 key 

1080 A jax random key. 

1081 sup 

1082 The exclusive upper bound of the range. 

1083 exclude 

1084 The values to exclude from the range. Values greater than or equal to 

1085 `sup` are ignored. Values can appear more than once. 

1086 

1087 Returns 

1088 ------- 

1089 u : Int32[Array, ''] 

1090 A random integer `u` in the range ``[0, sup)`` such that ``u not in 

1091 exclude``. 

1092 num_allowed : Int32[Array, ''] 

1093 The number of integers in the range that were not excluded. 

1094 

1095 Notes 

1096 ----- 

1097 If all values in the range are excluded, return `sup`. 

1098 """ 

1099 exclude, num_allowed = _process_exclude(sup, exclude) 1ab

1100 u = random.randint(key, (), 0, num_allowed) 1ab

1101 

1102 def loop(u, i_excluded): 1ab

1103 return jnp.where(i_excluded <= u, u + 1, u), None 1ab

1104 

1105 u, _ = lax.scan(loop, u, exclude) 1ab

1106 return u, num_allowed 1ab

1107 

1108 

1109def _process_exclude(sup, exclude): 1ab

1110 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab

1111 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab

1112 return exclude, num_allowed 1ab

1113 

1114 

1115def categorical_exclude( 1ab

1116 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n'] 

1117) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1118 """ 

1119 Draw from a categorical distribution, excluding a set of values. 

1120 

1121 Parameters 

1122 ---------- 

1123 key 

1124 A jax random key. 

1125 logits 

1126 The unnormalized log-probabilities of each category. 

1127 exclude 

1128 The values to exclude from the range [0, k). Values greater than or 

1129 equal to `logits.size` are ignored. Values can appear more than once. 

1130 

1131 Returns 

1132 ------- 

1133 u : Int32[Array, ''] 

1134 A random integer in the range ``[0, k)`` such that ``u not in exclude``. 

1135 num_allowed : Int32[Array, ''] 

1136 The number of integers in the range that were not excluded. 

1137 

1138 Notes 

1139 ----- 

1140 If all values in the range are excluded, the result is unspecified. 

1141 """ 

1142 exclude, num_allowed = _process_exclude(logits.size, exclude) 1ab

1143 kinda_neg_inf = jnp.finfo(logits.dtype).min 1ab

1144 logits = logits.at[exclude].set(kinda_neg_inf) 1ab

1145 u = random.categorical(key, logits) 1ab

1146 return u, num_allowed 1ab

1147 

1148 

1149def choose_split( 1ab

1150 key: Key[Array, ''], 

1151 var: Int32[Array, ''], 

1152 var_tree: UInt[Array, ' 2**(d-1)'], 

1153 split_tree: UInt[Array, ' 2**(d-1)'], 

1154 max_split: UInt[Array, ' p'], 

1155 leaf_index: Int32[Array, ''], 

1156) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]: 

1157 """ 

1158 Choose a split point for a new non-terminal node. 

1159 

1160 Parameters 

1161 ---------- 

1162 key 

1163 A jax random key. 

1164 var 

1165 The variable to split on. 

1166 var_tree 

1167 The splitting axes of the tree. Does not need to already contain `var` 

1168 at `leaf_index`. 

1169 split_tree 

1170 The splitting points of the tree. 

1171 max_split 

1172 The maximum split index for each variable. 

1173 leaf_index 

1174 The index of the leaf to grow. 

1175 

1176 Returns 

1177 ------- 

1178 split : Int32[Array, ''] 

1179 The cutpoint. 

1180 l : Int32[Array, ''] 

1181 r : Int32[Array, ''] 

1182 The integer range `split` was drawn from is [l, r). 

1183 

1184 Notes 

1185 ----- 

1186 If `var` is out of bounds, or if the available split range on that variable 

1187 is empty, return 0. 

1188 """ 

1189 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab

1190 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1ab

1191 

1192 

1193def compute_partial_ratio( 1ab

1194 prob_choose: Float32[Array, ''], 

1195 num_prunable: Int32[Array, ''], 

1196 p_nonterminal: Float32[Array, ' 2**d'], 

1197 leaf_to_grow: Int32[Array, ''], 

1198) -> Float32[Array, '']: 

1199 """ 

1200 Compute the product of the transition and prior ratios of a grow move. 

1201 

1202 Parameters 

1203 ---------- 

1204 prob_choose 

1205 The probability that the leaf had to be chosen amongst the growable 

1206 leaves. 

1207 num_prunable 

1208 The number of leaf parents that could be pruned, after converting the 

1209 leaf to be grown to a non-terminal node. 

1210 p_nonterminal 

1211 The a priori probability of each node being nonterminal conditional on 

1212 its ancestors. 

1213 leaf_to_grow 

1214 The index of the leaf to grow. 

1215 

1216 Returns 

1217 ------- 

1218 The partial transition ratio times the prior ratio. 

1219 

1220 Notes 

1221 ----- 

1222 The transition ratio is P(new tree => old tree) / P(old tree => new tree). 

1223 The "partial" transition ratio returned is missing the factor P(propose 

1224 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The 

1225 "partial" prior ratio is missing the factor P(children are leaves). 

1226 """ 

1227 # the two ratios also contain factors num_available_split * 

1228 # num_available_var * s[var], but they cancel out 

1229 

1230 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be 

1231 # computed here because they need the count trees, which are computed in the 

1232 # acceptance phase 

1233 

1234 prune_allowed = leaf_to_grow != 1 1ab

1235 # prune allowed <---> the initial tree is not a root 

1236 # leaf to grow is root --> the tree can only be a root 

1237 # tree is a root --> the only leaf I can grow is root 

1238 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab

1239 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab

1240 

1241 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this 

1242 # would produce a 0 and then an inf when `complete_ratio` takes the log 

1243 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1ab

1244 tree_ratio = pnt / (1 - pnt) 1ab

1245 

1246 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1ab

1247 

1248 

1249class PruneMoves(Module): 1ab

1250 """ 

1251 Represent a proposed prune move for each tree. 

1252 

1253 Parameters 

1254 ---------- 

1255 allowed 

1256 Whether the move is possible. 

1257 node 

1258 The index of the node to prune. ``2 ** d`` if no node can be pruned. 

1259 partial_ratio 

1260 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

1261 likelihood ratio, the probability of proposing the prune move, and the 

1262 prior probability that the children of the node to prune are leaves. 

1263 This ratio is inverted, and is meant to be inverted back in 

1264 `accept_move_and_sample_leaves`. 

1265 """ 

1266 

1267 allowed: Bool[Array, ' num_trees'] 1ab

1268 node: UInt[Array, ' num_trees'] 1ab

1269 partial_ratio: Float32[Array, ' num_trees'] 1ab

1270 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

1271 

1272 

1273@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab

1274def propose_prune_moves( 1ab

1275 key: Key[Array, ''], 

1276 split_tree: UInt[Array, ' 2**(d-1)'], 

1277 affluence_tree: Bool[Array, ' 2**(d-1)'], 

1278 p_nonterminal: Float32[Array, ' 2**d'], 

1279 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

1280) -> PruneMoves: 

1281 """ 

1282 Tree structure prune move proposal of BART MCMC. 

1283 

1284 Parameters 

1285 ---------- 

1286 key 

1287 A jax random key. 

1288 split_tree 

1289 The splitting points of the tree. 

1290 affluence_tree 

1291 Whether each leaf can be grown. 

1292 p_nonterminal 

1293 The a priori probability of a node to be nonterminal conditional on 

1294 the ancestors, including at the maximum depth where it should be zero. 

1295 p_propose_grow 

1296 The unnormalized probability of choosing a leaf to grow. 

1297 

1298 Returns 

1299 ------- 

1300 An object representing the proposed moves. 

1301 """ 

1302 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1ab

1303 key, split_tree, affluence_tree, p_propose_grow 

1304 ) 

1305 

1306 ratio = compute_partial_ratio( 1ab

1307 prob_choose, num_prunable, p_nonterminal, node_to_prune 

1308 ) 

1309 

1310 return PruneMoves( 1ab

1311 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root 

1312 node=node_to_prune, 

1313 partial_ratio=ratio, 

1314 affluence_tree=affluence_tree, 

1315 ) 

1316 

1317 

1318def choose_leaf_parent( 1ab

1319 key: Key[Array, ''], 

1320 split_tree: UInt[Array, ' 2**(d-1)'], 

1321 affluence_tree: Bool[Array, ' 2**(d-1)'], 

1322 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

1323) -> tuple[ 

1324 Int32[Array, ''], 

1325 Int32[Array, ''], 

1326 Float32[Array, ''], 

1327 Bool[Array, 'num_trees 2**(d-1)'], 

1328]: 

1329 """ 

1330 Pick a non-terminal node with leaf children to prune in a tree. 

1331 

1332 Parameters 

1333 ---------- 

1334 key 

1335 A jax random key. 

1336 split_tree 

1337 The splitting points of the tree. 

1338 affluence_tree 

1339 Whether a leaf has enough points to be grown. 

1340 p_propose_grow 

1341 The unnormalized probability of choosing a leaf to grow. 

1342 

1343 Returns 

1344 ------- 

1345 node_to_prune : Int32[Array, ''] 

1346 The index of the node to prune. If ``num_prunable == 0``, return 

1347 ``2 ** d``. 

1348 num_prunable : Int32[Array, ''] 

1349 The number of leaf parents that could be pruned. 

1350 prob_choose : Float32[Array, ''] 

1351 The (normalized) probability that `choose_leaf` would chose 

1352 `node_to_prune` as leaf to grow, if passed the tree where 

1353 `node_to_prune` had been pruned. 

1354 affluence_tree : Bool[Array, 'num_trees 2**(d-1)'] 

1355 A partially updated `affluence_tree`, marking the node to prune as 

1356 growable. 

1357 """ 

1358 # sample a node to prune 

1359 is_prunable = grove.is_leaves_parent(split_tree) 1ab

1360 num_prunable = jnp.count_nonzero(is_prunable) 1ab

1361 node_to_prune = randint_masked(key, is_prunable) 1ab

1362 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab

1363 

1364 # compute stuff for reverse move 

1365 split_tree = split_tree.at[node_to_prune].set(0) 1ab

1366 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab

1367 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab

1368 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab

1369 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1ab

1370 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1ab

1371 

1372 return node_to_prune, num_prunable, prob_choose, affluence_tree 1ab

1373 

1374 

1375def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 1ab

1376 """ 

1377 Return a random integer in a range, including only some values. 

1378 

1379 Parameters 

1380 ---------- 

1381 key 

1382 A jax random key. 

1383 mask 

1384 The mask indicating the allowed values. 

1385 

1386 Returns 

1387 ------- 

1388 A random integer in the range ``[0, n)`` such that ``mask[u] == True``. 

1389 

1390 Notes 

1391 ----- 

1392 If all values in the mask are `False`, return `n`. 

1393 """ 

1394 ecdf = jnp.cumsum(mask) 1ab

1395 u = random.randint(key, (), 0, ecdf[-1]) 1ab

1396 return jnp.searchsorted(ecdf, u, 'right') 1ab

1397 

1398 

1399def accept_moves_and_sample_leaves( 1ab

1400 key: Key[Array, ''], bart: State, moves: Moves 

1401) -> State: 

1402 """ 

1403 Accept or reject the proposed moves and sample the new leaf values. 

1404 

1405 Parameters 

1406 ---------- 

1407 key 

1408 A jax random key. 

1409 bart 

1410 A valid BART mcmc state. 

1411 moves 

1412 The proposed moves, see `propose_moves`. 

1413 

1414 Returns 

1415 ------- 

1416 A new (valid) BART mcmc state. 

1417 """ 

1418 pso = accept_moves_parallel_stage(key, bart, moves) 1ab

1419 bart, moves = accept_moves_sequential_stage(pso) 1ab

1420 return accept_moves_final_stage(bart, moves) 1ab

1421 

1422 

1423class Counts(Module): 1ab

1424 """ 

1425 Number of datapoints in the nodes involved in proposed moves for each tree. 

1426 

1427 Parameters 

1428 ---------- 

1429 left 

1430 Number of datapoints in the left child. 

1431 right 

1432 Number of datapoints in the right child. 

1433 total 

1434 Number of datapoints in the parent (``= left + right``). 

1435 """ 

1436 

1437 left: UInt[Array, ' num_trees'] 1ab

1438 right: UInt[Array, ' num_trees'] 1ab

1439 total: UInt[Array, ' num_trees'] 1ab

1440 

1441 

1442class Precs(Module): 1ab

1443 """ 

1444 Likelihood precision scale in the nodes involved in proposed moves for each tree. 

1445 

1446 The "likelihood precision scale" of a tree node is the sum of the inverse 

1447 squared error scales of the datapoints selected by the node. 

1448 

1449 Parameters 

1450 ---------- 

1451 left 

1452 Likelihood precision scale in the left child. 

1453 right 

1454 Likelihood precision scale in the right child. 

1455 total 

1456 Likelihood precision scale in the parent (``= left + right``). 

1457 """ 

1458 

1459 left: Float32[Array, ' num_trees'] 1ab

1460 right: Float32[Array, ' num_trees'] 1ab

1461 total: Float32[Array, ' num_trees'] 1ab

1462 

1463 

1464class PreLkV(Module): 1ab

1465 """ 

1466 Non-sequential terms of the likelihood ratio for each tree. 

1467 

1468 These terms can be computed in parallel across trees. 

1469 

1470 Parameters 

1471 ---------- 

1472 sigma2_left 

1473 The noise variance in the left child of the leaves grown or pruned by 

1474 the moves. 

1475 sigma2_right 

1476 The noise variance in the right child of the leaves grown or pruned by 

1477 the moves. 

1478 sigma2_total 

1479 The noise variance in the total of the leaves grown or pruned by the 

1480 moves. 

1481 sqrt_term 

1482 The **logarithm** of the square root term of the likelihood ratio. 

1483 """ 

1484 

1485 sigma2_left: Float32[Array, ' num_trees'] 1ab

1486 sigma2_right: Float32[Array, ' num_trees'] 1ab

1487 sigma2_total: Float32[Array, ' num_trees'] 1ab

1488 sqrt_term: Float32[Array, ' num_trees'] 1ab

1489 

1490 

1491class PreLk(Module): 1ab

1492 """ 

1493 Non-sequential terms of the likelihood ratio shared by all trees. 

1494 

1495 Parameters 

1496 ---------- 

1497 exp_factor 

1498 The factor to multiply the likelihood ratio by, shared by all trees. 

1499 """ 

1500 

1501 exp_factor: Float32[Array, ''] 1ab

1502 

1503 

1504class PreLf(Module): 1ab

1505 """ 

1506 Pre-computed terms used to sample leaves from their posterior. 

1507 

1508 These terms can be computed in parallel across trees. 

1509 

1510 Parameters 

1511 ---------- 

1512 mean_factor 

1513 The factor to be multiplied by the sum of the scaled residuals to 

1514 obtain the posterior mean. 

1515 centered_leaves 

1516 The mean-zero normal values to be added to the posterior mean to 

1517 obtain the posterior leaf samples. 

1518 """ 

1519 

1520 mean_factor: Float32[Array, 'num_trees 2**d'] 1ab

1521 centered_leaves: Float32[Array, 'num_trees 2**d'] 1ab

1522 

1523 

1524class ParallelStageOut(Module): 1ab

1525 """ 

1526 The output of `accept_moves_parallel_stage`. 

1527 

1528 Parameters 

1529 ---------- 

1530 bart 

1531 A partially updated BART mcmc state. 

1532 moves 

1533 The proposed moves, with `partial_ratio` set to `None` and 

1534 `log_trans_prior_ratio` set to its final value. 

1535 prec_trees 

1536 The likelihood precision scale in each potential or actual leaf node. If 

1537 there is no precision scale, this is the number of points in each leaf. 

1538 move_counts 

1539 The counts of the number of points in the the nodes modified by the 

1540 moves. If `bart.min_points_per_leaf` is not set and 

1541 `bart.prec_scale` is set, they are not computed. 

1542 move_precs 

1543 The likelihood precision scale in each node modified by the moves. If 

1544 `bart.prec_scale` is not set, this is set to `move_counts`. 

1545 prelkv 

1546 prelk 

1547 prelf 

1548 Objects with pre-computed terms of the likelihood ratios and leaf 

1549 samples. 

1550 """ 

1551 

1552 bart: State 1ab

1553 moves: Moves 1ab

1554 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab

1555 move_precs: Precs | Counts 1ab

1556 prelkv: PreLkV 1ab

1557 prelk: PreLk 1ab

1558 prelf: PreLf 1ab

1559 

1560 

1561def accept_moves_parallel_stage( 1ab

1562 key: Key[Array, ''], bart: State, moves: Moves 

1563) -> ParallelStageOut: 

1564 """ 

1565 Pre-compute quantities used to accept moves, in parallel across trees. 

1566 

1567 Parameters 

1568 ---------- 

1569 key : jax.dtypes.prng_key array 

1570 A jax random key. 

1571 bart : dict 

1572 A BART mcmc state. 

1573 moves : dict 

1574 The proposed moves, see `propose_moves`. 

1575 

1576 Returns 

1577 ------- 

1578 An object with all that could be done in parallel. 

1579 """ 

1580 # where the move is grow, modify the state like the move was accepted 

1581 bart = replace( 1ab

1582 bart, 

1583 forest=replace( 

1584 bart.forest, 

1585 var_tree=moves.var_tree, 

1586 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X), 

1587 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves), 

1588 ), 

1589 ) 

1590 

1591 # count number of datapoints per leaf 

1592 if ( 1592 ↛ 1602line 1592 didn't jump to line 1602 because the condition on line 1592 was always true

1593 bart.forest.min_points_per_decision_node is not None 

1594 or bart.forest.min_points_per_leaf is not None 

1595 or bart.prec_scale is None 

1596 ): 

1597 count_trees, move_counts = compute_count_trees( 1ab

1598 bart.forest.leaf_indices, moves, bart.forest.count_batch_size 

1599 ) 

1600 

1601 # mark which leaves & potential leaves have enough points to be grown 

1602 if bart.forest.min_points_per_decision_node is not None: 1ab

1603 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab

1604 moves = replace( 1ab

1605 moves, 

1606 affluence_tree=moves.affluence_tree 

1607 & (count_half_trees >= bart.forest.min_points_per_decision_node), 

1608 ) 

1609 

1610 # copy updated affluence_tree to state 

1611 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1ab

1612 

1613 # veto grove move if new leaves don't have enough datapoints 

1614 if bart.forest.min_points_per_leaf is not None: 1ab

1615 moves = replace( 1ab

1616 moves, 

1617 allowed=moves.allowed 

1618 & (move_counts.left >= bart.forest.min_points_per_leaf) 

1619 & (move_counts.right >= bart.forest.min_points_per_leaf), 

1620 ) 

1621 

1622 # count number of datapoints per leaf, weighted by error precision scale 

1623 if bart.prec_scale is None: 1ab

1624 prec_trees = count_trees 1ab

1625 move_precs = move_counts 1ab

1626 else: 

1627 prec_trees, move_precs = compute_prec_trees( 1ab

1628 bart.prec_scale, 

1629 bart.forest.leaf_indices, 

1630 moves, 

1631 bart.forest.count_batch_size, 

1632 ) 

1633 assert move_precs is not None 1ab

1634 

1635 # compute some missing information about moves 

1636 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab

1637 save_ratios = bart.forest.log_likelihood is not None 1ab

1638 bart = replace( 1ab

1639 bart, 

1640 forest=replace( 

1641 bart.forest, 

1642 grow_prop_count=jnp.sum(moves.grow), 

1643 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

1644 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None, 

1645 ), 

1646 ) 

1647 

1648 # pre-compute some likelihood ratio & posterior terms 

1649 assert bart.sigma2 is not None # `step` shall temporarily set it to 1 1ab

1650 prelkv, prelk = precompute_likelihood_terms( 1ab

1651 bart.sigma2, bart.forest.sigma_mu2, move_precs 

1652 ) 

1653 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab

1654 

1655 return ParallelStageOut( 1ab

1656 bart=bart, 

1657 moves=moves, 

1658 prec_trees=prec_trees, 

1659 move_precs=move_precs, 

1660 prelkv=prelkv, 

1661 prelk=prelk, 

1662 prelf=prelf, 

1663 ) 

1664 

1665 

1666@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab

1667def apply_grow_to_indices( 1ab

1668 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

1669) -> UInt[Array, 'num_trees n']: 

1670 """ 

1671 Update the leaf indices to apply a grow move. 

1672 

1673 Parameters 

1674 ---------- 

1675 moves 

1676 The proposed moves, see `propose_moves`. 

1677 leaf_indices 

1678 The index of the leaf each datapoint falls into. 

1679 X 

1680 The predictors matrix. 

1681 

1682 Returns 

1683 ------- 

1684 The updated leaf indices. 

1685 """ 

1686 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab

1687 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab

1688 tree_size = jnp.array(2 * moves.var_tree.size) 1ab

1689 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab

1690 return jnp.where( 1ab

1691 leaf_indices == node_to_update, left_child + go_right, leaf_indices 

1692 ) 

1693 

1694 

1695def compute_count_trees( 1ab

1696 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None 

1697) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]: 

1698 """ 

1699 Count the number of datapoints in each leaf. 

1700 

1701 Parameters 

1702 ---------- 

1703 leaf_indices 

1704 The index of the leaf each datapoint falls into, with the deeper version 

1705 of the tree (post-GROW, pre-PRUNE). 

1706 moves 

1707 The proposed moves, see `propose_moves`. 

1708 batch_size 

1709 The data batch size to use for the summation. 

1710 

1711 Returns 

1712 ------- 

1713 count_trees : Int32[Array, 'num_trees 2**d'] 

1714 The number of points in each potential or actual leaf node. 

1715 counts : Counts 

1716 The counts of the number of points in the leaves grown or pruned by the 

1717 moves. 

1718 """ 

1719 num_trees, tree_size = moves.var_tree.shape 1ab

1720 tree_size *= 2 1ab

1721 tree_indices = jnp.arange(num_trees) 1ab

1722 

1723 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab

1724 

1725 # count datapoints in nodes modified by move 

1726 left = count_trees[tree_indices, moves.left] 1ab

1727 right = count_trees[tree_indices, moves.right] 1ab

1728 counts = Counts(left=left, right=right, total=left + right) 1ab

1729 

1730 # write count into non-leaf node 

1731 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab

1732 

1733 return count_trees, counts 1ab

1734 

1735 

1736def count_datapoints_per_leaf( 1ab

1737 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None 

1738) -> Int32[Array, 'num_trees 2**(d-1)']: 

1739 """ 

1740 Count the number of datapoints in each leaf. 

1741 

1742 Parameters 

1743 ---------- 

1744 leaf_indices 

1745 The index of the leaf each datapoint falls into. 

1746 tree_size 

1747 The size of the leaf tree array (2 ** d). 

1748 batch_size 

1749 The data batch size to use for the summation. 

1750 

1751 Returns 

1752 ------- 

1753 The number of points in each leaf node. 

1754 """ 

1755 if batch_size is None: 1ab

1756 return _count_scan(leaf_indices, tree_size) 1ab

1757 else: 

1758 return _count_vec(leaf_indices, tree_size, batch_size) 1ab

1759 

1760 

1761def _count_scan( 1ab

1762 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int 

1763) -> Int32[Array, 'num_trees {tree_size}']: 

1764 def loop(_, leaf_indices): 1ab

1765 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab

1766 

1767 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab

1768 return count_trees 1ab

1769 

1770 

1771def _aggregate_scatter( 1ab

1772 values: Shaped[Array, '*'], 

1773 indices: Integer[Array, '*'], 

1774 size: int, 

1775 dtype: jnp.dtype, 

1776) -> Shaped[Array, ' {size}']: 

1777 return jnp.zeros(size, dtype).at[indices].add(values) 1ab

1778 

1779 

1780def _count_vec( 1ab

1781 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int 

1782) -> Int32[Array, 'num_trees 2**(d-1)']: 

1783 return _aggregate_batched_alltrees( 1ab

1784 1, leaf_indices, tree_size, jnp.uint32, batch_size 

1785 ) 

1786 # uint16 is super-slow on gpu, don't use it even if n < 2^16 

1787 

1788 

1789def _aggregate_batched_alltrees( 1ab

1790 values: Shaped[Array, '*'], 

1791 indices: UInt[Array, 'num_trees n'], 

1792 size: int, 

1793 dtype: jnp.dtype, 

1794 batch_size: int, 

1795) -> Shaped[Array, 'num_trees {size}']: 

1796 num_trees, n = indices.shape 1ab

1797 tree_indices = jnp.arange(num_trees) 1ab

1798 nbatches = n // batch_size + bool(n % batch_size) 1ab

1799 batch_indices = jnp.arange(n) % nbatches 1ab

1800 return ( 1ab

1801 jnp.zeros((num_trees, size, nbatches), dtype) 

1802 .at[tree_indices[:, None], indices, batch_indices] 

1803 .add(values) 

1804 .sum(axis=2) 

1805 ) 

1806 

1807 

1808def compute_prec_trees( 1ab

1809 prec_scale: Float32[Array, ' n'], 

1810 leaf_indices: UInt[Array, 'num_trees n'], 

1811 moves: Moves, 

1812 batch_size: int | None, 

1813) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]: 

1814 """ 

1815 Compute the likelihood precision scale in each leaf. 

1816 

1817 Parameters 

1818 ---------- 

1819 prec_scale 

1820 The scale of the precision of the error on each datapoint. 

1821 leaf_indices 

1822 The index of the leaf each datapoint falls into, with the deeper version 

1823 of the tree (post-GROW, pre-PRUNE). 

1824 moves 

1825 The proposed moves, see `propose_moves`. 

1826 batch_size 

1827 The data batch size to use for the summation. 

1828 

1829 Returns 

1830 ------- 

1831 prec_trees : Float32[Array, 'num_trees 2**d'] 

1832 The likelihood precision scale in each potential or actual leaf node. 

1833 precs : Precs 

1834 The likelihood precision scale in the nodes involved in the moves. 

1835 """ 

1836 num_trees, tree_size = moves.var_tree.shape 1ab

1837 tree_size *= 2 1ab

1838 tree_indices = jnp.arange(num_trees) 1ab

1839 

1840 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab

1841 

1842 # prec datapoints in nodes modified by move 

1843 left = prec_trees[tree_indices, moves.left] 1ab

1844 right = prec_trees[tree_indices, moves.right] 1ab

1845 precs = Precs(left=left, right=right, total=left + right) 1ab

1846 

1847 # write prec into non-leaf node 

1848 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab

1849 

1850 return prec_trees, precs 1ab

1851 

1852 

1853def prec_per_leaf( 1ab

1854 prec_scale: Float32[Array, ' n'], 

1855 leaf_indices: UInt[Array, 'num_trees n'], 

1856 tree_size: int, 

1857 batch_size: int | None, 

1858) -> Float32[Array, 'num_trees {tree_size}']: 

1859 """ 

1860 Compute the likelihood precision scale in each leaf. 

1861 

1862 Parameters 

1863 ---------- 

1864 prec_scale 

1865 The scale of the precision of the error on each datapoint. 

1866 leaf_indices 

1867 The index of the leaf each datapoint falls into. 

1868 tree_size 

1869 The size of the leaf tree array (2 ** d). 

1870 batch_size 

1871 The data batch size to use for the summation. 

1872 

1873 Returns 

1874 ------- 

1875 The likelihood precision scale in each leaf node. 

1876 """ 

1877 if batch_size is None: 1877 ↛ 1880line 1877 didn't jump to line 1880 because the condition on line 1877 was always true1ab

1878 return _prec_scan(prec_scale, leaf_indices, tree_size) 1ab

1879 else: 

1880 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size) 

1881 

1882 

1883def _prec_scan( 1ab

1884 prec_scale: Float32[Array, ' n'], 

1885 leaf_indices: UInt[Array, 'num_trees n'], 

1886 tree_size: int, 

1887) -> Float32[Array, 'num_trees {tree_size}']: 

1888 def loop(_, leaf_indices): 1ab

1889 return None, _aggregate_scatter( 1ab

1890 prec_scale, leaf_indices, tree_size, jnp.float32 

1891 ) 

1892 

1893 _, prec_trees = lax.scan(loop, None, leaf_indices) 1ab

1894 return prec_trees 1ab

1895 

1896 

1897def _prec_vec( 1ab

1898 prec_scale: Float32[Array, ' n'], 

1899 leaf_indices: UInt[Array, 'num_trees n'], 

1900 tree_size: int, 

1901 batch_size: int, 

1902) -> Float32[Array, 'num_trees {tree_size}']: 

1903 return _aggregate_batched_alltrees( 

1904 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size 

1905 ) 

1906 

1907 

1908def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 1ab

1909 """ 

1910 Complete non-likelihood MH ratio calculation. 

1911 

1912 This function adds the probability of choosing a prune move over the grow 

1913 move in the inverse transition, and the a priori probability that the 

1914 children nodes are leaves. 

1915 

1916 Parameters 

1917 ---------- 

1918 moves 

1919 The proposed moves. Must have already been updated to keep into account 

1920 the thresholds on the number of datapoints per node, this happens in 

1921 `accept_moves_parallel_stage`. 

1922 p_nonterminal 

1923 The a priori probability of each node being nonterminal conditional on 

1924 its ancestors, including at the maximum depth where it should be zero. 

1925 

1926 Returns 

1927 ------- 

1928 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

1929 """ 

1930 # can the leaves can be grown? 

1931 num_trees, _ = moves.affluence_tree.shape 1ab

1932 tree_indices = jnp.arange(num_trees) 1ab

1933 left_growable = moves.affluence_tree.at[tree_indices, moves.left].get( 1ab

1934 mode='fill', fill_value=False 

1935 ) 

1936 right_growable = moves.affluence_tree.at[tree_indices, moves.right].get( 1ab

1937 mode='fill', fill_value=False 

1938 ) 

1939 

1940 # p_prune if grow 

1941 other_growable_leaves = moves.num_growable >= 2 1ab

1942 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab

1943 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab

1944 

1945 # p_prune if prune 

1946 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab

1947 

1948 # select p_prune 

1949 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab

1950 

1951 # prior probability of both children being terminal 

1952 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab

1953 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab

1954 pt_children = pt_left * pt_right 1ab

1955 

1956 return replace( 1ab

1957 moves, 

1958 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune), 

1959 partial_ratio=None, 

1960 ) 

1961 

1962 

1963@vmap_nodoc 1ab

1964def adapt_leaf_trees_to_grow_indices( 1ab

1965 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves 

1966) -> Float32[Array, 'num_trees 2**d']: 

1967 """ 

1968 Modify leaves such that post-grow indices work on the original tree. 

1969 

1970 The value of the leaf to grow is copied to what would be its children if the 

1971 grow move was accepted. 

1972 

1973 Parameters 

1974 ---------- 

1975 leaf_trees 

1976 The leaf values. 

1977 moves 

1978 The proposed moves, see `propose_moves`. 

1979 

1980 Returns 

1981 ------- 

1982 The modified leaf values. 

1983 """ 

1984 values_at_node = leaf_trees[moves.node] 1ab

1985 return ( 1ab

1986 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)] 

1987 .set(values_at_node) 

1988 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)] 

1989 .set(values_at_node) 

1990 ) 

1991 

1992 

1993def precompute_likelihood_terms( 1ab

1994 sigma2: Float32[Array, ''], 

1995 sigma_mu2: Float32[Array, ''], 

1996 move_precs: Precs | Counts, 

1997) -> tuple[PreLkV, PreLk]: 

1998 """ 

1999 Pre-compute terms used in the likelihood ratio of the acceptance step. 

2000 

2001 Parameters 

2002 ---------- 

2003 sigma2 

2004 The error variance, or the global error variance factor is `prec_scale` 

2005 is set. 

2006 sigma_mu2 

2007 The prior variance of each leaf. 

2008 move_precs 

2009 The likelihood precision scale in the leaves grown or pruned by the 

2010 moves, under keys 'left', 'right', and 'total' (left + right). 

2011 

2012 Returns 

2013 ------- 

2014 prelkv : PreLkV 

2015 Dictionary with pre-computed terms of the likelihood ratio, one per 

2016 tree. 

2017 prelk : PreLk 

2018 Dictionary with pre-computed terms of the likelihood ratio, shared by 

2019 all trees. 

2020 """ 

2021 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab

2022 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab

2023 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab

2024 prelkv = PreLkV( 1ab

2025 sigma2_left=sigma2_left, 

2026 sigma2_right=sigma2_right, 

2027 sigma2_total=sigma2_total, 

2028 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2, 

2029 ) 

2030 return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2)) 1ab

2031 

2032 

2033def precompute_leaf_terms( 1ab

2034 key: Key[Array, ''], 

2035 prec_trees: Float32[Array, 'num_trees 2**d'], 

2036 sigma2: Float32[Array, ''], 

2037 sigma_mu2: Float32[Array, ''], 

2038) -> PreLf: 

2039 """ 

2040 Pre-compute terms used to sample leaves from their posterior. 

2041 

2042 Parameters 

2043 ---------- 

2044 key 

2045 A jax random key. 

2046 prec_trees 

2047 The likelihood precision scale in each potential or actual leaf node. 

2048 sigma2 

2049 The error variance, or the global error variance factor if `prec_scale` 

2050 is set. 

2051 sigma_mu2 

2052 The prior variance of each leaf. 

2053 

2054 Returns 

2055 ------- 

2056 Pre-computed terms for leaf sampling. 

2057 """ 

2058 prec_lk = prec_trees / sigma2 1ab

2059 prec_prior = lax.reciprocal(sigma_mu2) 1ab

2060 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab

2061 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab

2062 return PreLf( 1ab

2063 mean_factor=var_post / sigma2, 

2064 # | mean = mean_lk * prec_lk * var_post 

2065 # | resid_tree = mean_lk * prec_tree --> 

2066 # | --> mean_lk = resid_tree / prec_tree (kind of) 

2067 # | mean_factor = 

2068 # | = mean / resid_tree = 

2069 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

2070 # | = 1 / prec_tree * prec_tree / sigma2 * var_post = 

2071 # | = var_post / sigma2 

2072 centered_leaves=z * jnp.sqrt(var_post), 

2073 ) 

2074 

2075 

2076def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab

2077 """ 

2078 Accept/reject the moves one tree at a time. 

2079 

2080 This is the most performance-sensitive function because it contains all and 

2081 only the parts of the algorithm that can not be parallelized across trees. 

2082 

2083 Parameters 

2084 ---------- 

2085 pso 

2086 The output of `accept_moves_parallel_stage`. 

2087 

2088 Returns 

2089 ------- 

2090 bart : State 

2091 A partially updated BART mcmc state. 

2092 moves : Moves 

2093 The accepted/rejected moves, with `acc` and `to_prune` set. 

2094 """ 

2095 

2096 def loop(resid, pt): 1ab

2097 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab

2098 resid, 

2099 SeqStageInAllTrees( 

2100 pso.bart.X, 

2101 pso.bart.forest.resid_batch_size, 

2102 pso.bart.prec_scale, 

2103 pso.bart.forest.log_likelihood is not None, 

2104 pso.prelk, 

2105 ), 

2106 pt, 

2107 ) 

2108 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab

2109 

2110 pts = SeqStageInPerTree( 1ab

2111 pso.bart.forest.leaf_tree, 

2112 pso.prec_trees, 

2113 pso.moves, 

2114 pso.move_precs, 

2115 pso.bart.forest.leaf_indices, 

2116 pso.prelkv, 

2117 pso.prelf, 

2118 ) 

2119 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab

2120 

2121 bart = replace( 1ab

2122 pso.bart, 

2123 resid=resid, 

2124 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio), 

2125 ) 

2126 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab

2127 

2128 return bart, moves 1ab

2129 

2130 

2131class SeqStageInAllTrees(Module): 1ab

2132 """ 

2133 The inputs to `accept_move_and_sample_leaves` that are shared by all trees. 

2134 

2135 Parameters 

2136 ---------- 

2137 X 

2138 The predictors. 

2139 resid_batch_size 

2140 The batch size for computing the sum of residuals in each leaf. 

2141 prec_scale 

2142 The scale of the precision of the error on each datapoint. If None, it 

2143 is assumed to be 1. 

2144 save_ratios 

2145 Whether to save the acceptance ratios. 

2146 prelk 

2147 The pre-computed terms of the likelihood ratio which are shared across 

2148 trees. 

2149 """ 

2150 

2151 X: UInt[Array, 'p n'] 1ab

2152 resid_batch_size: int | None = field(static=True) 1ab

2153 prec_scale: Float32[Array, ' n'] | None 1ab

2154 save_ratios: bool = field(static=True) 1ab

2155 prelk: PreLk 1ab

2156 

2157 

2158class SeqStageInPerTree(Module): 1ab

2159 """ 

2160 The inputs to `accept_move_and_sample_leaves` that are separate for each tree. 

2161 

2162 Parameters 

2163 ---------- 

2164 leaf_tree 

2165 The leaf values of the tree. 

2166 prec_tree 

2167 The likelihood precision scale in each potential or actual leaf node. 

2168 move 

2169 The proposed move, see `propose_moves`. 

2170 move_precs 

2171 The likelihood precision scale in each node modified by the moves. 

2172 leaf_indices 

2173 The leaf indices for the largest version of the tree compatible with 

2174 the move. 

2175 prelkv 

2176 prelf 

2177 The pre-computed terms of the likelihood ratio and leaf sampling which 

2178 are specific to the tree. 

2179 """ 

2180 

2181 leaf_tree: Float32[Array, ' 2**d'] 1ab

2182 prec_tree: Float32[Array, ' 2**d'] 1ab

2183 move: Moves 1ab

2184 move_precs: Precs | Counts 1ab

2185 leaf_indices: UInt[Array, ' n'] 1ab

2186 prelkv: PreLkV 1ab

2187 prelf: PreLf 1ab

2188 

2189 

2190def accept_move_and_sample_leaves( 1ab

2191 resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree 

2192) -> tuple[ 

2193 Float32[Array, ' n'], 

2194 Float32[Array, ' 2**d'], 

2195 Bool[Array, ''], 

2196 Bool[Array, ''], 

2197 Float32[Array, ''] | None, 

2198]: 

2199 """ 

2200 Accept or reject a proposed move and sample the new leaf values. 

2201 

2202 Parameters 

2203 ---------- 

2204 resid 

2205 The residuals (data minus forest value). 

2206 at 

2207 The inputs that are the same for all trees. 

2208 pt 

2209 The inputs that are separate for each tree. 

2210 

2211 Returns 

2212 ------- 

2213 resid : Float32[Array, 'n'] 

2214 The updated residuals (data minus forest value). 

2215 leaf_tree : Float32[Array, '2**d'] 

2216 The new leaf values of the tree. 

2217 acc : Bool[Array, ''] 

2218 Whether the move was accepted. 

2219 to_prune : Bool[Array, ''] 

2220 Whether, to reflect the acceptance status of the move, the state should 

2221 be updated by pruning the leaves involved in the move. 

2222 log_lk_ratio : Float32[Array, ''] | None 

2223 The logarithm of the likelihood ratio for the move. `None` if not to be 

2224 saved. 

2225 """ 

2226 # sum residuals in each leaf, in tree proposed by grow move 

2227 if at.prec_scale is None: 1ab

2228 scaled_resid = resid 1ab

2229 else: 

2230 scaled_resid = resid * at.prec_scale 1ab

2231 resid_tree = sum_resid( 1ab

2232 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size 

2233 ) 

2234 

2235 # subtract starting tree from function 

2236 resid_tree += pt.prec_tree * pt.leaf_tree 1ab

2237 

2238 # sum residuals in parent node modified by move 

2239 resid_left = resid_tree[pt.move.left] 1ab

2240 resid_right = resid_tree[pt.move.right] 1ab

2241 resid_total = resid_left + resid_right 1ab

2242 assert pt.move.node.dtype == jnp.int32 1ab

2243 resid_tree = resid_tree.at[pt.move.node].set(resid_total) 1ab

2244 

2245 # compute acceptance ratio 

2246 log_lk_ratio = compute_likelihood_ratio( 1ab

2247 resid_total, resid_left, resid_right, pt.prelkv, at.prelk 

2248 ) 

2249 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab

2250 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab

2251 if not at.save_ratios: 1ab

2252 log_lk_ratio = None 1ab

2253 

2254 # determine whether to accept the move 

2255 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab

2256 

2257 # compute leaves posterior and sample leaves 

2258 mean_post = resid_tree * pt.prelf.mean_factor 1ab

2259 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab

2260 

2261 # copy leaves around such that the leaf indices point to the correct leaf 

2262 to_prune = acc ^ pt.move.grow 1ab

2263 leaf_tree = ( 1ab

2264 leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)] 

2265 .set(leaf_tree[pt.move.node]) 

2266 .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)] 

2267 .set(leaf_tree[pt.move.node]) 

2268 ) 

2269 

2270 # replace old tree with new tree in function values 

2271 resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices] 1ab

2272 

2273 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab

2274 

2275 

2276def sum_resid( 1ab

2277 scaled_resid: Float32[Array, ' n'], 

2278 leaf_indices: UInt[Array, ' n'], 

2279 tree_size: int, 

2280 batch_size: int | None, 

2281) -> Float32[Array, ' {tree_size}']: 

2282 """ 

2283 Sum the residuals in each leaf. 

2284 

2285 Parameters 

2286 ---------- 

2287 scaled_resid 

2288 The residuals (data minus forest value) multiplied by the error 

2289 precision scale. 

2290 leaf_indices 

2291 The leaf indices of the tree (in which leaf each data point falls into). 

2292 tree_size 

2293 The size of the tree array (2 ** d). 

2294 batch_size 

2295 The data batch size for the aggregation. Batching increases numerical 

2296 accuracy and parallelism. 

2297 

2298 Returns 

2299 ------- 

2300 The sum of the residuals at data points in each leaf. 

2301 """ 

2302 if batch_size is None: 1ab

2303 aggr_func = _aggregate_scatter 1ab

2304 else: 

2305 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab

2306 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab

2307 

2308 

2309def _aggregate_batched_onetree( 1ab

2310 values: Shaped[Array, '*'], 

2311 indices: Integer[Array, '*'], 

2312 size: int, 

2313 dtype: jnp.dtype, 

2314 batch_size: int, 

2315) -> Float32[Array, ' {size}']: 

2316 (n,) = indices.shape 1ab

2317 nbatches = n // batch_size + bool(n % batch_size) 1ab

2318 batch_indices = jnp.arange(n) % nbatches 1ab

2319 return ( 1ab

2320 jnp.zeros((size, nbatches), dtype) 

2321 .at[indices, batch_indices] 

2322 .add(values) 

2323 .sum(axis=1) 

2324 ) 

2325 

2326 

2327def compute_likelihood_ratio( 1ab

2328 total_resid: Float32[Array, ''], 

2329 left_resid: Float32[Array, ''], 

2330 right_resid: Float32[Array, ''], 

2331 prelkv: PreLkV, 

2332 prelk: PreLk, 

2333) -> Float32[Array, '']: 

2334 """ 

2335 Compute the likelihood ratio of a grow move. 

2336 

2337 Parameters 

2338 ---------- 

2339 total_resid 

2340 left_resid 

2341 right_resid 

2342 The sum of the residuals (scaled by error precision scale) of the 

2343 datapoints falling in the nodes involved in the moves. 

2344 prelkv 

2345 prelk 

2346 The pre-computed terms of the likelihood ratio, see 

2347 `precompute_likelihood_terms`. 

2348 

2349 Returns 

2350 ------- 

2351 The likelihood ratio P(data | new tree) / P(data | old tree). 

2352 """ 

2353 exp_term = prelk.exp_factor * ( 1ab

2354 left_resid * left_resid / prelkv.sigma2_left 

2355 + right_resid * right_resid / prelkv.sigma2_right 

2356 - total_resid * total_resid / prelkv.sigma2_total 

2357 ) 

2358 return prelkv.sqrt_term + exp_term 1ab

2359 

2360 

2361def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab

2362 """ 

2363 Post-process the mcmc state after accepting/rejecting the moves. 

2364 

2365 This function is separate from `accept_moves_sequential_stage` to signal it 

2366 can work in parallel across trees. 

2367 

2368 Parameters 

2369 ---------- 

2370 bart 

2371 A partially updated BART mcmc state. 

2372 moves 

2373 The proposed moves (see `propose_moves`) as updated by 

2374 `accept_moves_sequential_stage`. 

2375 

2376 Returns 

2377 ------- 

2378 The fully updated BART mcmc state. 

2379 """ 

2380 return replace( 1ab

2381 bart, 

2382 forest=replace( 

2383 bart.forest, 

2384 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

2385 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

2386 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves), 

2387 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves), 

2388 ), 

2389 ) 

2390 

2391 

2392@vmap_nodoc 1ab

2393def apply_moves_to_leaf_indices( 1ab

2394 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

2395) -> UInt[Array, 'num_trees n']: 

2396 """ 

2397 Update the leaf indices to match the accepted move. 

2398 

2399 Parameters 

2400 ---------- 

2401 leaf_indices 

2402 The index of the leaf each datapoint falls into, if the grow move was 

2403 accepted. 

2404 moves 

2405 The proposed moves (see `propose_moves`), as updated by 

2406 `accept_moves_sequential_stage`. 

2407 

2408 Returns 

2409 ------- 

2410 The updated leaf indices. 

2411 """ 

2412 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab

2413 is_child = (leaf_indices & mask) == moves.left 1ab

2414 return jnp.where( 1ab

2415 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices 

2416 ) 

2417 

2418 

2419@vmap_nodoc 1ab

2420def apply_moves_to_split_trees( 1ab

2421 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves 

2422) -> UInt[Array, 'num_trees 2**(d-1)']: 

2423 """ 

2424 Update the split trees to match the accepted move. 

2425 

2426 Parameters 

2427 ---------- 

2428 split_tree 

2429 The cutpoints of the decision nodes in the initial trees. 

2430 moves 

2431 The proposed moves (see `propose_moves`), as updated by 

2432 `accept_moves_sequential_stage`. 

2433 

2434 Returns 

2435 ------- 

2436 The updated split trees. 

2437 """ 

2438 assert moves.to_prune is not None 1ab

2439 return ( 1ab

2440 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)] 

2441 .set(moves.grow_split.astype(split_tree.dtype)) 

2442 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)] 

2443 .set(0) 

2444 ) 

2445 

2446 

2447def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab

2448 """ 

2449 MCMC-update the error variance (factor). 

2450 

2451 Parameters 

2452 ---------- 

2453 key 

2454 A jax random key. 

2455 bart 

2456 A BART mcmc state. 

2457 

2458 Returns 

2459 ------- 

2460 The new BART mcmc state, with an updated `sigma2`. 

2461 """ 

2462 resid = bart.resid 1ab

2463 alpha = bart.sigma2_alpha + resid.size / 2 1ab

2464 if bart.prec_scale is None: 1ab

2465 scaled_resid = resid 1ab

2466 else: 

2467 scaled_resid = resid * bart.prec_scale 1ab

2468 norm2 = resid @ scaled_resid 1ab

2469 beta = bart.sigma2_beta + norm2 / 2 1ab

2470 

2471 sample = random.gamma(key, alpha) 1ab

2472 # random.gamma seems to be slow at compiling, maybe cdf inversion would 

2473 # be better, but it's not implemented in jax 

2474 return replace(bart, sigma2=beta / sample) 1ab

2475 

2476 

2477@jax.jit 1ab

2478def _sample_wishart_bartlett( 1ab

2479 key: Key[Array, ''], df: Integer[Array, ''], scale_inv: Float32[Array, 'k k'] 

2480) -> Float32[Array, 'k k']: 

2481 """ 

2482 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition. 

2483 

2484 Parameters 

2485 ---------- 

2486 key 

2487 A JAX random key 

2488 df 

2489 Degrees of freedom 

2490 scale_inv 

2491 Scale matrix of the corresponding Inverse Wishart distribution 

2492 

2493 Returns 

2494 ------- 

2495 A sample from Wishart(df, scale) 

2496 """ 

2497 keys = split(key) 1ab

2498 

2499 k = scale_inv.shape[0] 1ab

2500 

2501 # Gershgorin estimate for max eigenvalue 

2502 rho = jnp.max(jnp.sum(jnp.abs(scale_inv), axis=1)) 1ab

2503 u = k * rho * jnp.finfo(scale_inv.dtype).eps + jnp.finfo(scale_inv.dtype).eps 1ab

2504 

2505 # Stabilize the matrix 

2506 scale_inv = scale_inv.at[jnp.diag_indices(k)].add(u) 1ab

2507 L = jnp.linalg.cholesky(scale_inv) 1ab

2508 

2509 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i)) 

2510 # chi^2(k) = Gamma(k/2, scale=2) 

2511 df_vector = df - jnp.arange(k) 1ab

2512 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 1ab

2513 diag_A = jnp.sqrt(chi2_samples) 1ab

2514 

2515 off_diag_A = random.normal(keys.pop(), (k, k)) 1ab

2516 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 1ab

2517 T = solve_triangular(L, A, lower=True, trans='T') 1ab

2518 

2519 return T @ T.T 1ab

2520 

2521 

2522def step_z(key: Key[Array, ''], bart: State) -> State: 1ab

2523 """ 

2524 MCMC-update the latent variable for binary regression. 

2525 

2526 Parameters 

2527 ---------- 

2528 key 

2529 A jax random key. 

2530 bart 

2531 A BART MCMC state. 

2532 

2533 Returns 

2534 ------- 

2535 The updated BART MCMC state. 

2536 """ 

2537 trees_plus_offset = bart.z - bart.resid 1ab

2538 assert bart.y.dtype == bool 1ab

2539 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1ab

2540 z = trees_plus_offset + resid 1ab

2541 return replace(bart, z=z, resid=resid) 1ab

2542 

2543 

2544def step_s(key: Key[Array, ''], bart: State) -> State: 1ab

2545 """ 

2546 Update `log_s` using Dirichlet sampling. 

2547 

2548 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior 

2549 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where 

2550 varcount is the count of how many times each variable is used in the 

2551 current forest. 

2552 

2553 Parameters 

2554 ---------- 

2555 key 

2556 Random key for sampling. 

2557 bart 

2558 The current BART state. 

2559 

2560 Returns 

2561 ------- 

2562 Updated BART state with re-sampled `log_s`. 

2563 

2564 Notes 

2565 ----- 

2566 This full conditional is approximated, because it does not take into account 

2567 that there are forbidden decision rules. 

2568 """ 

2569 assert bart.forest.theta is not None 1ab

2570 

2571 # histogram current variable usage 

2572 p = bart.forest.max_split.size 1ab

2573 varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree) 1ab

2574 

2575 # sample from Dirichlet posterior 

2576 alpha = bart.forest.theta / p + varcount 1ab

2577 log_s = random.loggamma(key, alpha) 1ab

2578 

2579 # update forest with new s 

2580 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1ab

2581 

2582 

2583def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 1ab

2584 """ 

2585 Update `theta`. 

2586 

2587 The prior is theta / (theta + rho) ~ Beta(a, b). 

2588 

2589 Parameters 

2590 ---------- 

2591 key 

2592 Random key for sampling. 

2593 bart 

2594 The current BART state. 

2595 num_grid 

2596 The number of points in the evenly-spaced grid used to sample 

2597 theta / (theta + rho). 

2598 

2599 Returns 

2600 ------- 

2601 Updated BART state with re-sampled `theta`. 

2602 """ 

2603 assert bart.forest.log_s is not None 1ab

2604 assert bart.forest.rho is not None 1ab

2605 assert bart.forest.a is not None 1ab

2606 assert bart.forest.b is not None 1ab

2607 

2608 # the grid points are the midpoints of num_grid bins in (0, 1) 

2609 padding = 1 / (2 * num_grid) 1ab

2610 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1ab

2611 

2612 # normalize s 

2613 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1ab

2614 

2615 # sample lambda 

2616 logp, theta_grid = _log_p_lamda( 1ab

2617 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b 

2618 ) 

2619 i = random.categorical(key, logp) 1ab

2620 theta = theta_grid[i] 1ab

2621 

2622 return replace(bart, forest=replace(bart.forest, theta=theta)) 1ab

2623 

2624 

2625def _log_p_lamda( 1ab

2626 lamda: Float32[Array, ' num_grid'], 

2627 log_s: Float32[Array, ' p'], 

2628 rho: Float32[Array, ''], 

2629 a: Float32[Array, ''], 

2630 b: Float32[Array, ''], 

2631) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]: 

2632 # in the following I use lamda[::-1] == 1 - lamda 

2633 theta = rho * lamda / lamda[::-1] 1ab

2634 p = log_s.size 1ab

2635 return ( 1ab

2636 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda) 

2637 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda) 

2638 + gammaln(theta) 

2639 - p * gammaln(theta / p) 

2640 + theta / p * jnp.sum(log_s) 

2641 ), theta 

2642 

2643 

2644def step_sparse(key: Key[Array, ''], bart: State) -> State: 1ab

2645 """ 

2646 Update the sparsity parameters. 

2647 

2648 This invokes `step_s`, and then `step_theta` only if the parameters of 

2649 the theta prior are defined. 

2650 

2651 Parameters 

2652 ---------- 

2653 key 

2654 Random key for sampling. 

2655 bart 

2656 The current BART state. 

2657 

2658 Returns 

2659 ------- 

2660 Updated BART state with re-sampled `log_s` and `theta`. 

2661 """ 

2662 keys = split(key) 1ab

2663 bart = step_s(keys.pop(), bart) 1ab

2664 if bart.forest.rho is not None: 1ab

2665 bart = step_theta(keys.pop(), bart) 1ab

2666 return bart 1ab