Coverage for src/bartz/mcmcstep.py: 96%

610 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 08:16 +0000

1# bartz/src/bartz/mcmcstep.py 

2# 

3# Copyright (c) 2024-2025, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26Functions that implement the BART posterior MCMC initialization and update step. 

27 

28Functions that do MCMC steps operate by taking as input a bart state, and 

29outputting a new state. The inputs are not modified. 

30 

31The entry points are: 

32 

33 - `State`: The dataclass that represents a BART MCMC state. 

34 - `init`: Creates an initial `State` from data and configurations. 

35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`. 

36 - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`. 

37""" 

38 

39import math 1ab

40from dataclasses import replace 1ab

41from functools import cache, partial 1ab

42from typing import Any, Literal 1ab

43 

44import jax 1ab

45from equinox import Module, field, tree_at 1ab

46from jax import lax, random 1ab

47from jax import numpy as jnp 1ab

48from jax.scipy.linalg import solve_triangular 1ab

49from jax.scipy.special import gammaln, logsumexp 1ab

50from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab

51 

52from bartz import grove 1ab

53from bartz.jaxext import ( 1ab

54 minimal_unsigned_dtype, 

55 split, 

56 truncated_normal_onesided, 

57 vmap_nodoc, 

58) 

59 

60 

61class Forest(Module): 1ab

62 """ 

63 Represents the MCMC state of a sum of trees. 

64 

65 Parameters 

66 ---------- 

67 leaf_tree 

68 The leaf values. 

69 var_tree 

70 The decision axes. 

71 split_tree 

72 The decision boundaries. 

73 affluence_tree 

74 Marks leaves that can be grown. 

75 max_split 

76 The maximum split index for each predictor. 

77 blocked_vars 

78 Indices of variables that are not used. This shall include at least 

79 the `i` such that ``max_split[i] == 0``, otherwise behavior is 

80 undefined. 

81 p_nonterminal 

82 The prior probability of each node being nonterminal, conditional on 

83 its ancestors. Includes the nodes at maximum depth which should be set 

84 to 0. 

85 p_propose_grow 

86 The unnormalized probability of picking a leaf for a grow proposal. 

87 leaf_indices 

88 The index of the leaf each datapoints falls into, for each tree. 

89 min_points_per_decision_node 

90 The minimum number of data points in a decision node. 

91 min_points_per_leaf 

92 The minimum number of data points in a leaf node. 

93 resid_batch_size 

94 count_batch_size 

95 The data batch sizes for computing the sufficient statistics. If `None`, 

96 they are computed with no batching. 

97 log_trans_prior 

98 The log transition and prior Metropolis-Hastings ratio for the 

99 proposed move on each tree. 

100 log_likelihood 

101 The log likelihood ratio. 

102 grow_prop_count 

103 prune_prop_count 

104 The number of grow/prune proposals made during one full MCMC cycle. 

105 grow_acc_count 

106 prune_acc_count 

107 The number of grow/prune moves accepted during one full MCMC cycle. 

108 sigma_mu2 

109 The prior variance of a leaf, conditional on the tree structure. 

110 log_s 

111 The logarithm of the prior probability for choosing a variable to split 

112 along in a decision rule, conditional on the ancestors. Not normalized. 

113 If `None`, use a uniform distribution. 

114 theta 

115 The concentration parameter for the Dirichlet prior on the variable 

116 distribution `s`. Required only to update `s`. 

117 a 

118 b 

119 rho 

120 Parameters of the prior on `theta`. Required only to sample `theta`. 

121 See `step_theta`. 

122 """ 

123 

124 leaf_tree: Float32[Array, 'num_trees 2**d'] 1ab

125 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

126 split_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

127 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

128 max_split: UInt[Array, ' p'] 1ab

129 blocked_vars: UInt[Array, ' k'] | None 1ab

130 p_nonterminal: Float32[Array, ' 2**d'] 1ab

131 p_propose_grow: Float32[Array, ' 2**(d-1)'] 1ab

132 leaf_indices: UInt[Array, 'num_trees n'] 1ab

133 min_points_per_decision_node: Int32[Array, ''] | None 1ab

134 min_points_per_leaf: Int32[Array, ''] | None 1ab

135 resid_batch_size: int | None = field(static=True) 1ab

136 count_batch_size: int | None = field(static=True) 1ab

137 log_trans_prior: Float32[Array, ' num_trees'] | None 1ab

138 log_likelihood: Float32[Array, ' num_trees'] | None 1ab

139 grow_prop_count: Int32[Array, ''] 1ab

140 prune_prop_count: Int32[Array, ''] 1ab

141 grow_acc_count: Int32[Array, ''] 1ab

142 prune_acc_count: Int32[Array, ''] 1ab

143 sigma_mu2: Float32[Array, ''] 1ab

144 log_s: Float32[Array, ' p'] | None 1ab

145 theta: Float32[Array, ''] | None 1ab

146 a: Float32[Array, ''] | None 1ab

147 b: Float32[Array, ''] | None 1ab

148 rho: Float32[Array, ''] | None 1ab

149 

150 

151class State(Module): 1ab

152 """ 

153 Represents the MCMC state of BART. 

154 

155 Parameters 

156 ---------- 

157 X 

158 The predictors. 

159 y 

160 The response. If the data type is `bool`, the model is binary regression. 

161 resid 

162 The residuals (`y` or `z` minus sum of trees). 

163 z 

164 The latent variable for binary regression. `None` in continuous 

165 regression. 

166 offset 

167 Constant shift added to the sum of trees. 

168 sigma2 

169 The error variance. `None` in binary regression. 

170 prec_scale 

171 The scale on the error precision, i.e., ``1 / error_scale ** 2``. 

172 `None` in binary regression. 

173 sigma2_alpha 

174 sigma2_beta 

175 The shape and scale parameters of the inverse gamma prior on the noise 

176 variance. `None` in binary regression. 

177 forest 

178 The sum of trees model. 

179 """ 

180 

181 X: UInt[Array, 'p n'] 1ab

182 y: Float32[Array, ' n'] | Bool[Array, ' n'] 1ab

183 z: None | Float32[Array, ' n'] 1ab

184 offset: Float32[Array, ''] 1ab

185 resid: Float32[Array, ' n'] 1ab

186 sigma2: Float32[Array, ''] | None 1ab

187 prec_scale: Float32[Array, ' n'] | None 1ab

188 sigma2_alpha: Float32[Array, ''] | None 1ab

189 sigma2_beta: Float32[Array, ''] | None 1ab

190 forest: Forest 1ab

191 

192 

193def init( 1ab

194 *, 

195 X: UInt[Any, 'p n'], 

196 y: Float32[Any, ' n'] | Bool[Any, ' n'], 

197 offset: float | Float32[Any, ''] = 0.0, 

198 max_split: UInt[Any, ' p'], 

199 num_trees: int, 

200 p_nonterminal: Float32[Any, ' d-1'], 

201 sigma_mu2: float | Float32[Any, ''], 

202 sigma2_alpha: float | Float32[Any, ''] | None = None, 

203 sigma2_beta: float | Float32[Any, ''] | None = None, 

204 error_scale: Float32[Any, ' n'] | None = None, 

205 min_points_per_decision_node: int | Integer[Any, ''] | None = None, 

206 resid_batch_size: int | None | Literal['auto'] = 'auto', 

207 count_batch_size: int | None | Literal['auto'] = 'auto', 

208 save_ratios: bool = False, 

209 filter_splitless_vars: bool = True, 

210 min_points_per_leaf: int | Integer[Any, ''] | None = None, 

211 log_s: Float32[Any, ' p'] | None = None, 

212 theta: float | Float32[Any, ''] | None = None, 

213 a: float | Float32[Any, ''] | None = None, 

214 b: float | Float32[Any, ''] | None = None, 

215 rho: float | Float32[Any, ''] | None = None, 

216) -> State: 

217 """ 

218 Make a BART posterior sampling MCMC initial state. 

219 

220 Parameters 

221 ---------- 

222 X 

223 The predictors. Note this is trasposed compared to the usual convention. 

224 y 

225 The response. If the data type is `bool`, the regression model is binary 

226 regression with probit. 

227 offset 

228 Constant shift added to the sum of trees. 0 if not specified. 

229 max_split 

230 The maximum split index for each variable. All split ranges start at 1. 

231 num_trees 

232 The number of trees in the forest. 

233 p_nonterminal 

234 The probability of a nonterminal node at each depth. The maximum depth 

235 of trees is fixed by the length of this array. 

236 sigma_mu2 

237 The prior variance of a leaf, conditional on the tree structure. The 

238 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The 

239 prior mean of leaves is always zero. 

240 sigma2_alpha 

241 sigma2_beta 

242 The shape and scale parameters of the inverse gamma prior on the error 

243 variance. Leave unspecified for binary regression. 

244 error_scale 

245 Each error is scaled by the corresponding factor in `error_scale`, so 

246 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``. 

247 Not supported for binary regression. If not specified, defaults to 1 for 

248 all points, but potentially skipping calculations. 

249 min_points_per_decision_node 

250 The minimum number of data points in a decision node. 0 if not 

251 specified. 

252 resid_batch_size 

253 count_batch_size 

254 The batch sizes, along datapoints, for summing the residuals and 

255 counting the number of datapoints in each leaf. `None` for no batching. 

256 If 'auto', pick a value based on the device of `y`, or the default 

257 device. 

258 save_ratios 

259 Whether to save the Metropolis-Hastings ratios. 

260 filter_splitless_vars 

261 Whether to check `max_split` for variables without available cutpoints. 

262 If any are found, they are put into a list of variables to exclude from 

263 the MCMC. If `False`, no check is performed, but the results may be 

264 wrong if any variable is blocked. The function is jax-traceable only 

265 if this is set to `False`. 

266 min_points_per_leaf 

267 The minimum number of datapoints in a leaf node. 0 if not specified. 

268 Unlike `min_points_per_decision_node`, this constraint is not taken into 

269 account in the Metropolis-Hastings ratio because it would be expensive 

270 to compute. Grow moves that would violate this constraint are vetoed. 

271 This parameter is independent of `min_points_per_decision_node` and 

272 there is no check that they are coherent. It makes sense to set 

273 ``min_points_per_decision_node >= 2 * min_points_per_leaf``. 

274 log_s 

275 The logarithm of the prior probability for choosing a variable to split 

276 along in a decision rule, conditional on the ancestors. Not normalized. 

277 If not specified, use a uniform distribution. If not specified and 

278 `theta` or `rho`, `a`, `b` are, it's initialized automatically. 

279 theta 

280 The concentration parameter for the Dirichlet prior on `s`. Required 

281 only to update `log_s`. If not specified, and `rho`, `a`, `b` are 

282 specified, it's initialized automatically. 

283 a 

284 b 

285 rho 

286 Parameters of the prior on `theta`. Required only to sample `theta`. 

287 

288 Returns 

289 ------- 

290 An initialized BART MCMC state. 

291 

292 Raises 

293 ------ 

294 ValueError 

295 If `y` is boolean and arguments unused in binary regression are set. 

296 

297 Notes 

298 ----- 

299 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out 

300 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left 

301 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be 

302 integers in the range ``[0, 1, ..., max_split[i]]``. 

303 """ 

304 p_nonterminal = jnp.asarray(p_nonterminal) 1ab

305 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab

306 max_depth = p_nonterminal.size 1ab

307 

308 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab

309 def make_forest(max_depth, dtype): 1ab

310 return grove.make_tree(max_depth, dtype) 1ab

311 

312 y = jnp.asarray(y) 1ab

313 offset = jnp.asarray(offset) 1ab

314 

315 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab

316 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees 

317 ) 

318 

319 is_binary = y.dtype == bool 1ab

320 if is_binary: 1ab

321 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true1ab

322 msg = ( 

323 'error_scale, sigma2_alpha, and sigma2_beta must be set ' 

324 ' to `None` for binary regression.' 

325 ) 

326 raise ValueError(msg) 

327 sigma2 = None 1ab

328 else: 

329 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab

330 sigma2_beta = jnp.asarray(sigma2_beta) 1ab

331 sigma2 = sigma2_beta / sigma2_alpha 1ab

332 

333 max_split = jnp.asarray(max_split) 1ab

334 

335 if filter_splitless_vars: 1ab

336 (blocked_vars,) = jnp.nonzero(max_split == 0) 1ab

337 blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size)) 1ab

338 # see `fully_used_variables` for the type cast 

339 else: 

340 blocked_vars = None 1ab

341 

342 # check and initialize sparsity parameters 

343 if not _all_none_or_not_none(rho, a, b): 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true1ab

344 msg = 'rho, a, b are not either all `None` or all set' 

345 raise ValueError(msg) 

346 if theta is None and rho is not None: 1ab

347 theta = rho 1ab

348 if log_s is None and theta is not None: 1ab

349 log_s = jnp.zeros(max_split.size) 1ab

350 

351 return State( 1ab

352 X=jnp.asarray(X), 

353 y=y, 

354 z=jnp.full(y.shape, offset) if is_binary else None, 

355 offset=offset, 

356 resid=jnp.zeros(y.shape) if is_binary else y - offset, 

357 sigma2=sigma2, 

358 prec_scale=( 

359 None if error_scale is None else lax.reciprocal(jnp.square(error_scale)) 

360 ), 

361 sigma2_alpha=sigma2_alpha, 

362 sigma2_beta=sigma2_beta, 

363 forest=Forest( 

364 leaf_tree=make_forest(max_depth, jnp.float32), 

365 var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)), 

366 split_tree=make_forest(max_depth - 1, max_split.dtype), 

367 affluence_tree=( 

368 make_forest(max_depth - 1, bool) 

369 .at[:, 1] 

370 .set( 

371 True 

372 if min_points_per_decision_node is None 

373 else y.size >= min_points_per_decision_node 

374 ) 

375 ), 

376 blocked_vars=blocked_vars, 

377 max_split=max_split, 

378 grow_prop_count=jnp.zeros((), int), 

379 grow_acc_count=jnp.zeros((), int), 

380 prune_prop_count=jnp.zeros((), int), 

381 prune_acc_count=jnp.zeros((), int), 

382 p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)], 

383 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))], 

384 leaf_indices=jnp.ones( 

385 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1) 

386 ), 

387 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node), 

388 min_points_per_leaf=_asarray_or_none(min_points_per_leaf), 

389 resid_batch_size=resid_batch_size, 

390 count_batch_size=count_batch_size, 

391 log_trans_prior=jnp.zeros(num_trees) if save_ratios else None, 

392 log_likelihood=jnp.zeros(num_trees) if save_ratios else None, 

393 sigma_mu2=jnp.asarray(sigma_mu2), 

394 log_s=_asarray_or_none(log_s), 

395 theta=_asarray_or_none(theta), 

396 rho=_asarray_or_none(rho), 

397 a=_asarray_or_none(a), 

398 b=_asarray_or_none(b), 

399 ), 

400 ) 

401 

402 

403def _all_none_or_not_none(*args): 1ab

404 is_none = [x is None for x in args] 1ab

405 return all(is_none) or not any(is_none) 1ab

406 

407 

408def _asarray_or_none(x): 1ab

409 if x is None: 1ab

410 return None 1ab

411 return jnp.asarray(x) 1ab

412 

413 

414def _choose_suffstat_batch_size( 1ab

415 resid_batch_size, count_batch_size, y, forest_size 

416) -> tuple[int | None, ...]: 

417 @cache 1ab

418 def get_platform(): 1ab

419 try: 1ab

420 device = y.devices().pop() 1ab

421 except jax.errors.ConcretizationTypeError: 1ab

422 device = jax.devices()[0] 1ab

423 platform = device.platform 1ab

424 if platform not in ('cpu', 'gpu'): 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true1ab

425 msg = f'Unknown platform: {platform}' 

426 raise KeyError(msg) 

427 return platform 1ab

428 

429 if resid_batch_size == 'auto': 1ab

430 platform = get_platform() 1ab

431 n = max(1, y.size) 1ab

432 if platform == 'cpu': 432 ↛ 434line 432 didn't jump to line 434 because the condition on line 432 was always true1ab

433 resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6 1ab

434 elif platform == 'gpu': 

435 resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3 

436 resid_batch_size = max(1, resid_batch_size) 1ab

437 

438 if count_batch_size == 'auto': 1ab

439 platform = get_platform() 1ab

440 if platform == 'cpu': 440 ↛ 442line 440 didn't jump to line 442 because the condition on line 440 was always true1ab

441 count_batch_size = None 1ab

442 elif platform == 'gpu': 

443 n = max(1, y.size) 

444 count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2 

445 # /4 is good on V100, /2 on L4/T4, still haven't tried A100 

446 max_memory = 2**29 

447 itemsize = 4 

448 min_batch_size = math.ceil(forest_size * itemsize * n / max_memory) 

449 count_batch_size = max(count_batch_size, min_batch_size) 

450 count_batch_size = max(1, count_batch_size) 

451 

452 return resid_batch_size, count_batch_size 1ab

453 

454 

455@jax.jit 1ab

456def step(key: Key[Array, ''], bart: State) -> State: 1ab

457 """ 

458 Do one MCMC step. 

459 

460 Parameters 

461 ---------- 

462 key 

463 A jax random key. 

464 bart 

465 A BART mcmc state, as created by `init`. 

466 

467 Returns 

468 ------- 

469 The new BART mcmc state. 

470 """ 

471 keys = split(key) 1ab

472 

473 if bart.y.dtype == bool: # binary regression 1ab

474 bart = replace(bart, sigma2=jnp.float32(1)) 1ab

475 bart = step_trees(keys.pop(), bart) 1ab

476 bart = replace(bart, sigma2=None) 1ab

477 return step_z(keys.pop(), bart) 1ab

478 

479 else: # continuous regression 

480 bart = step_trees(keys.pop(), bart) 1ab

481 return step_sigma(keys.pop(), bart) 1ab

482 

483 

484def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab

485 """ 

486 Forest sampling step of BART MCMC. 

487 

488 Parameters 

489 ---------- 

490 key 

491 A jax random key. 

492 bart 

493 A BART mcmc state, as created by `init`. 

494 

495 Returns 

496 ------- 

497 The new BART mcmc state. 

498 

499 Notes 

500 ----- 

501 This function zeroes the proposal counters. 

502 """ 

503 keys = split(key) 1ab

504 moves = propose_moves(keys.pop(), bart.forest) 1ab

505 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab

506 

507 

508class Moves(Module): 1ab

509 """ 

510 Moves proposed to modify each tree. 

511 

512 Parameters 

513 ---------- 

514 allowed 

515 Whether there is a possible move. If `False`, the other values may not 

516 make sense. The only case in which a move is marked as allowed but is 

517 then vetoed is if it does not satisfy `min_points_per_leaf`, which for 

518 efficiency is implemented post-hoc without changing the rest of the 

519 MCMC logic. 

520 grow 

521 Whether the move is a grow move or a prune move. 

522 num_growable 

523 The number of growable leaves in the original tree. 

524 node 

525 The index of the leaf to grow or node to prune. 

526 left 

527 right 

528 The indices of the children of 'node'. 

529 partial_ratio 

530 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

531 likelihood ratio, the probability of proposing the prune move, and the 

532 probability that the children of the modified node are terminal. If the 

533 move is PRUNE, the ratio is inverted. `None` once 

534 `log_trans_prior_ratio` has been computed. 

535 log_trans_prior_ratio 

536 The logarithm of the product of the transition and prior terms of the 

537 Metropolis-Hastings ratio for the acceptance of the proposed move. 

538 `None` if not yet computed. If PRUNE, the log-ratio is negated. 

539 grow_var 

540 The decision axes of the new rules. 

541 grow_split 

542 The decision boundaries of the new rules. 

543 var_tree 

544 The updated decision axes of the trees, valid whatever move. 

545 affluence_tree 

546 A partially updated `affluence_tree`, marking non-leaf nodes that would 

547 become leaves if the move was accepted. This mark initially (out of 

548 `propose_moves`) takes into account if there would be available decision 

549 rules to grow the leaf, and whether there are enough datapoints in the 

550 node is marked in `accept_moves_parallel_stage`. 

551 logu 

552 The logarithm of a uniform (0, 1] random variable to be used to 

553 accept the move. It's in (-oo, 0]. 

554 acc 

555 Whether the move was accepted. `None` if not yet computed. 

556 to_prune 

557 Whether the final operation to apply the move is pruning. This indicates 

558 an accepted prune move or a rejected grow move. `None` if not yet 

559 computed. 

560 """ 

561 

562 allowed: Bool[Array, ' num_trees'] 1ab

563 grow: Bool[Array, ' num_trees'] 1ab

564 num_growable: UInt[Array, ' num_trees'] 1ab

565 node: UInt[Array, ' num_trees'] 1ab

566 left: UInt[Array, ' num_trees'] 1ab

567 right: UInt[Array, ' num_trees'] 1ab

568 partial_ratio: Float32[Array, ' num_trees'] | None 1ab

569 log_trans_prior_ratio: None | Float32[Array, ' num_trees'] 1ab

570 grow_var: UInt[Array, ' num_trees'] 1ab

571 grow_split: UInt[Array, ' num_trees'] 1ab

572 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

573 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

574 logu: Float32[Array, ' num_trees'] 1ab

575 acc: None | Bool[Array, ' num_trees'] 1ab

576 to_prune: None | Bool[Array, ' num_trees'] 1ab

577 

578 

579def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 1ab

580 """ 

581 Propose moves for all the trees. 

582 

583 There are two types of moves: GROW (convert a leaf to a decision node and 

584 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a 

585 leaf, deleting its children). 

586 

587 Parameters 

588 ---------- 

589 key 

590 A jax random key. 

591 forest 

592 The `forest` field of a BART MCMC state. 

593 

594 Returns 

595 ------- 

596 The proposed move for each tree. 

597 """ 

598 num_trees, _ = forest.leaf_tree.shape 1ab

599 keys = split(key, 1 + 2 * num_trees) 1ab

600 

601 # compute moves 

602 grow_moves = propose_grow_moves( 1ab

603 keys.pop(num_trees), 

604 forest.var_tree, 

605 forest.split_tree, 

606 forest.affluence_tree, 

607 forest.max_split, 

608 forest.blocked_vars, 

609 forest.p_nonterminal, 

610 forest.p_propose_grow, 

611 forest.log_s, 

612 ) 

613 prune_moves = propose_prune_moves( 1ab

614 keys.pop(num_trees), 

615 forest.split_tree, 

616 grow_moves.affluence_tree, 

617 forest.p_nonterminal, 

618 forest.p_propose_grow, 

619 ) 

620 

621 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1ab

622 

623 # choose between grow or prune 

624 p_grow = jnp.where( 1ab

625 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed 

626 ) 

627 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab

628 

629 # compute children indices 

630 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab

631 left = node << 1 1ab

632 right = left + 1 1ab

633 

634 return Moves( 1ab

635 allowed=grow_moves.allowed | prune_moves.allowed, 

636 grow=grow, 

637 num_growable=grow_moves.num_growable, 

638 node=node, 

639 left=left, 

640 right=right, 

641 partial_ratio=jnp.where( 

642 grow, grow_moves.partial_ratio, prune_moves.partial_ratio 

643 ), 

644 log_trans_prior_ratio=None, # will be set in complete_ratio 

645 grow_var=grow_moves.var, 

646 grow_split=grow_moves.split, 

647 # var_tree does not need to be updated if prune 

648 var_tree=grow_moves.var_tree, 

649 # affluence_tree is updated for both moves unconditionally, prune last 

650 affluence_tree=prune_moves.affluence_tree, 

651 logu=jnp.log1p(-exp1mlogu), 

652 acc=None, # will be set in accept_moves_sequential_stage 

653 to_prune=None, # will be set in accept_moves_sequential_stage 

654 ) 

655 

656 

657class GrowMoves(Module): 1ab

658 """ 

659 Represent a proposed grow move for each tree. 

660 

661 Parameters 

662 ---------- 

663 allowed 

664 Whether the move is allowed for proposal. 

665 num_growable 

666 The number of leaves that can be proposed for grow. 

667 node 

668 The index of the leaf to grow. ``2 ** d`` if there are no growable 

669 leaves. 

670 var 

671 split 

672 The decision axis and boundary of the new rule. 

673 partial_ratio 

674 A factor of the Metropolis-Hastings ratio of the move. It lacks 

675 the likelihood ratio and the probability of proposing the prune 

676 move. 

677 var_tree 

678 The updated decision axes of the tree. 

679 affluence_tree 

680 A partially updated `affluence_tree` that marks each new leaf that 

681 would be produced as `True` if it would have available decision rules. 

682 """ 

683 

684 allowed: Bool[Array, ' num_trees'] 1ab

685 num_growable: UInt[Array, ' num_trees'] 1ab

686 node: UInt[Array, ' num_trees'] 1ab

687 var: UInt[Array, ' num_trees'] 1ab

688 split: UInt[Array, ' num_trees'] 1ab

689 partial_ratio: Float32[Array, ' num_trees'] 1ab

690 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

691 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

692 

693 

694@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) 1ab

695def propose_grow_moves( 1ab

696 key: Key[Array, ' num_trees'], 

697 var_tree: UInt[Array, 'num_trees 2**(d-1)'], 

698 split_tree: UInt[Array, 'num_trees 2**(d-1)'], 

699 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'], 

700 max_split: UInt[Array, ' p'], 

701 blocked_vars: Int32[Array, ' k'] | None, 

702 p_nonterminal: Float32[Array, ' 2**d'], 

703 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

704 log_s: Float32[Array, ' p'] | None, 

705) -> GrowMoves: 

706 """ 

707 Propose a GROW move for each tree. 

708 

709 A GROW move picks a leaf node and converts it to a non-terminal node with 

710 two leaf children. 

711 

712 Parameters 

713 ---------- 

714 key 

715 A jax random key. 

716 var_tree 

717 The splitting axes of the tree. 

718 split_tree 

719 The splitting points of the tree. 

720 affluence_tree 

721 Whether each leaf has enough points to be grown. 

722 max_split 

723 The maximum split index for each variable. 

724 blocked_vars 

725 The indices of the variables that have no available cutpoints. 

726 p_nonterminal 

727 The a priori probability of a node to be nonterminal conditional on the 

728 ancestors, including at the maximum depth where it should be zero. 

729 p_propose_grow 

730 The unnormalized probability of choosing a leaf to grow. 

731 log_s 

732 Unnormalized log-probability used to choose a variable to split on 

733 amongst the available ones. 

734 

735 Returns 

736 ------- 

737 An object representing the proposed move. 

738 

739 Notes 

740 ----- 

741 The move is not proposed if each leaf is already at maximum depth, or has 

742 less datapoints than the requested threshold `min_points_per_decision_node`, 

743 or it does not have any available decision rules given its ancestors. This 

744 is marked by setting `allowed` to `False` and `num_growable` to 0. 

745 """ 

746 keys = split(key, 3) 1ab

747 

748 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab

749 keys.pop(), split_tree, affluence_tree, p_propose_grow 

750 ) 

751 

752 # sample a decision rule 

753 var, num_available_var = choose_variable( 1ab

754 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s 

755 ) 

756 split_idx, l, r = choose_split( 1ab

757 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow 

758 ) 

759 

760 # determine if the new leaves would have available decision rules; if the 

761 # move is blocked, these values may not make sense 

762 left_growable = right_growable = num_available_var > 1 1ab

763 left_growable |= l < split_idx 1ab

764 right_growable |= split_idx + 1 < r 1ab

765 left = leaf_to_grow << 1 1ab

766 right = left + 1 1ab

767 affluence_tree = affluence_tree.at[left].set(left_growable) 1ab

768 affluence_tree = affluence_tree.at[right].set(right_growable) 1ab

769 

770 ratio = compute_partial_ratio( 1ab

771 prob_choose, num_prunable, p_nonterminal, leaf_to_grow 

772 ) 

773 

774 return GrowMoves( 1ab

775 allowed=num_growable > 0, 

776 num_growable=num_growable, 

777 node=leaf_to_grow, 

778 var=var, 

779 split=split_idx, 

780 partial_ratio=ratio, 

781 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)), 

782 affluence_tree=affluence_tree, 

783 ) 

784 

785 

786def choose_leaf( 1ab

787 key: Key[Array, ''], 

788 split_tree: UInt[Array, ' 2**(d-1)'], 

789 affluence_tree: Bool[Array, ' 2**(d-1)'], 

790 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

791) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]: 

792 """ 

793 Choose a leaf node to grow in a tree. 

794 

795 Parameters 

796 ---------- 

797 key 

798 A jax random key. 

799 split_tree 

800 The splitting points of the tree. 

801 affluence_tree 

802 Whether a leaf has enough points that it could be split into two leaves 

803 satisfying the `min_points_per_leaf` requirement. 

804 p_propose_grow 

805 The unnormalized probability of choosing a leaf to grow. 

806 

807 Returns 

808 ------- 

809 leaf_to_grow : Int32[Array, ''] 

810 The index of the leaf to grow. If ``num_growable == 0``, return 

811 ``2 ** d``. 

812 num_growable : Int32[Array, ''] 

813 The number of leaf nodes that can be grown, i.e., are nonterminal 

814 and have at least twice `min_points_per_leaf`. 

815 prob_choose : Float32[Array, ''] 

816 The (normalized) probability that this function had to choose that 

817 specific leaf, given the arguments. 

818 num_prunable : Int32[Array, ''] 

819 The number of leaf parents that could be pruned, after converting the 

820 selected leaf to a non-terminal node. 

821 """ 

822 is_growable = growable_leaves(split_tree, affluence_tree) 1ab

823 num_growable = jnp.count_nonzero(is_growable) 1ab

824 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab

825 leaf_to_grow, distr_norm = categorical(key, distr) 1ab

826 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab

827 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1ab

828 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab

829 num_prunable = jnp.count_nonzero(is_parent) 1ab

830 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab

831 

832 

833def growable_leaves( 1ab

834 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)'] 

835) -> Bool[Array, ' 2**(d-1)']: 

836 """ 

837 Return a mask indicating the leaf nodes that can be proposed for growth. 

838 

839 The condition is that a leaf is not at the bottom level, has available 

840 decision rules given its ancestors, and has at least 

841 `min_points_per_decision_node` points. 

842 

843 Parameters 

844 ---------- 

845 split_tree 

846 The splitting points of the tree. 

847 affluence_tree 

848 Marks leaves that can be grown. 

849 

850 Returns 

851 ------- 

852 The mask indicating the leaf nodes that can be proposed to grow. 

853 

854 Notes 

855 ----- 

856 This function needs `split_tree` and not just `affluence_tree` because 

857 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`. 

858 """ 

859 return grove.is_actual_leaf(split_tree) & affluence_tree 1ab

860 

861 

862def categorical( 1ab

863 key: Key[Array, ''], distr: Float32[Array, ' n'] 

864) -> tuple[Int32[Array, ''], Float32[Array, '']]: 

865 """ 

866 Return a random integer from an arbitrary distribution. 

867 

868 Parameters 

869 ---------- 

870 key 

871 A jax random key. 

872 distr 

873 An unnormalized probability distribution. 

874 

875 Returns 

876 ------- 

877 u : Int32[Array, ''] 

878 A random integer in the range ``[0, n)``. If all probabilities are zero, 

879 return ``n``. 

880 norm : Float32[Array, ''] 

881 The sum of `distr`. 

882 

883 Notes 

884 ----- 

885 This function uses a cumsum instead of the Gumbel trick, so it's ok only 

886 for small ranges with probabilities well greater than 0. 

887 """ 

888 ecdf = jnp.cumsum(distr) 1ab

889 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab

890 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab

891 

892 

893def choose_variable( 1ab

894 key: Key[Array, ''], 

895 var_tree: UInt[Array, ' 2**(d-1)'], 

896 split_tree: UInt[Array, ' 2**(d-1)'], 

897 max_split: UInt[Array, ' p'], 

898 leaf_index: Int32[Array, ''], 

899 blocked_vars: Int32[Array, ' k'] | None, 

900 log_s: Float32[Array, ' p'] | None, 

901) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

902 """ 

903 Choose a variable to split on for a new non-terminal node. 

904 

905 Parameters 

906 ---------- 

907 key 

908 A jax random key. 

909 var_tree 

910 The variable indices of the tree. 

911 split_tree 

912 The splitting points of the tree. 

913 max_split 

914 The maximum split index for each variable. 

915 leaf_index 

916 The index of the leaf to grow. 

917 blocked_vars 

918 The indices of the variables that have no available cutpoints. If 

919 `None`, all variables are assumed unblocked. 

920 log_s 

921 The logarithm of the prior probability for choosing a variable. If 

922 `None`, use a uniform distribution. 

923 

924 Returns 

925 ------- 

926 var : Int32[Array, ''] 

927 The index of the variable to split on. 

928 num_available_var : Int32[Array, ''] 

929 The number of variables with available decision rules `var` was chosen 

930 from. 

931 """ 

932 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab

933 if blocked_vars is not None: 1ab

934 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab

935 

936 if log_s is None: 1ab

937 return randint_exclude(key, max_split.size, var_to_ignore) 1ab

938 else: 

939 return categorical_exclude(key, log_s, var_to_ignore) 1ab

940 

941 

942def fully_used_variables( 1ab

943 var_tree: UInt[Array, ' 2**(d-1)'], 

944 split_tree: UInt[Array, ' 2**(d-1)'], 

945 max_split: UInt[Array, ' p'], 

946 leaf_index: Int32[Array, ''], 

947) -> UInt[Array, ' d-2']: 

948 """ 

949 Find variables in the ancestors of a node that have an empty split range. 

950 

951 Parameters 

952 ---------- 

953 var_tree 

954 The variable indices of the tree. 

955 split_tree 

956 The splitting points of the tree. 

957 max_split 

958 The maximum split index for each variable. 

959 leaf_index 

960 The index of the node, assumed to be valid for `var_tree`. 

961 

962 Returns 

963 ------- 

964 The indices of the variables that have an empty split range. 

965 

966 Notes 

967 ----- 

968 The number of unused variables is not known in advance. Unused values in the 

969 array are filled with `p`. The fill values are not guaranteed to be placed 

970 in any particular order, and variables may appear more than once. 

971 """ 

972 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab

973 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab

974 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab

975 num_split = r - l 1ab

976 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab

977 # the type of var_to_ignore is already sufficient to hold max_split.size, 

978 # see ancestor_variables() 

979 

980 

981def ancestor_variables( 1ab

982 var_tree: UInt[Array, ' 2**(d-1)'], 

983 max_split: UInt[Array, ' p'], 

984 node_index: Int32[Array, ''], 

985) -> UInt[Array, ' d-2']: 

986 """ 

987 Return the list of variables in the ancestors of a node. 

988 

989 Parameters 

990 ---------- 

991 var_tree 

992 The variable indices of the tree. 

993 max_split 

994 The maximum split index for each variable. Used only to get `p`. 

995 node_index 

996 The index of the node, assumed to be valid for `var_tree`. 

997 

998 Returns 

999 ------- 

1000 The variable indices of the ancestors of the node. 

1001 

1002 Notes 

1003 ----- 

1004 The ancestors are the nodes going from the root to the parent of the node. 

1005 The number of ancestors is not known at tracing time; unused spots in the 

1006 output array are filled with `p`. 

1007 """ 

1008 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

1009 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab

1010 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab

1011 

1012 def loop(carry, _): 1ab

1013 i, index, ancestor_vars = carry 1ab

1014 index >>= 1 1ab

1015 var = var_tree[index] 1ab

1016 var = jnp.where(index, var, max_split.size) 1ab

1017 ancestor_vars = ancestor_vars.at[i].set(var) 1ab

1018 return (i - 1, index, ancestor_vars), None 1ab

1019 

1020 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab

1021 return ancestor_vars 1ab

1022 

1023 

1024def split_range( 1ab

1025 var_tree: UInt[Array, ' 2**(d-1)'], 

1026 split_tree: UInt[Array, ' 2**(d-1)'], 

1027 max_split: UInt[Array, ' p'], 

1028 node_index: Int32[Array, ''], 

1029 ref_var: Int32[Array, ''], 

1030) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1031 """ 

1032 Return the range of allowed splits for a variable at a given node. 

1033 

1034 Parameters 

1035 ---------- 

1036 var_tree 

1037 The variable indices of the tree. 

1038 split_tree 

1039 The splitting points of the tree. 

1040 max_split 

1041 The maximum split index for each variable. 

1042 node_index 

1043 The index of the node, assumed to be valid for `var_tree`. 

1044 ref_var 

1045 The variable for which to measure the split range. 

1046 

1047 Returns 

1048 ------- 

1049 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1. 

1050 """ 

1051 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

1052 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab

1053 jnp.int32 

1054 ) 

1055 carry = jnp.int32(0), initial_r, node_index 1ab

1056 

1057 def loop(carry, _): 1ab

1058 l, r, index = carry 1ab

1059 right_child = (index & 1).astype(bool) 1ab

1060 index >>= 1 1ab

1061 split = split_tree[index] 1ab

1062 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab

1063 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab

1064 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab

1065 return (l, r, index), None 1ab

1066 

1067 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab

1068 return l + 1, r 1ab

1069 

1070 

1071def randint_exclude( 1ab

1072 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n'] 

1073) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1074 """ 

1075 Return a random integer in a range, excluding some values. 

1076 

1077 Parameters 

1078 ---------- 

1079 key 

1080 A jax random key. 

1081 sup 

1082 The exclusive upper bound of the range. 

1083 exclude 

1084 The values to exclude from the range. Values greater than or equal to 

1085 `sup` are ignored. Values can appear more than once. 

1086 

1087 Returns 

1088 ------- 

1089 u : Int32[Array, ''] 

1090 A random integer `u` in the range ``[0, sup)`` such that ``u not in 

1091 exclude``. 

1092 num_allowed : Int32[Array, ''] 

1093 The number of integers in the range that were not excluded. 

1094 

1095 Notes 

1096 ----- 

1097 If all values in the range are excluded, return `sup`. 

1098 """ 

1099 exclude, num_allowed = _process_exclude(sup, exclude) 1ab

1100 u = random.randint(key, (), 0, num_allowed) 1ab

1101 

1102 def loop(u, i_excluded): 1ab

1103 return jnp.where(i_excluded <= u, u + 1, u), None 1ab

1104 

1105 u, _ = lax.scan(loop, u, exclude) 1ab

1106 return u, num_allowed 1ab

1107 

1108 

1109def _process_exclude(sup, exclude): 1ab

1110 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab

1111 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab

1112 return exclude, num_allowed 1ab

1113 

1114 

1115def categorical_exclude( 1ab

1116 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n'] 

1117) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1118 """ 

1119 Draw from a categorical distribution, excluding a set of values. 

1120 

1121 Parameters 

1122 ---------- 

1123 key 

1124 A jax random key. 

1125 logits 

1126 The unnormalized log-probabilities of each category. 

1127 exclude 

1128 The values to exclude from the range [0, k). Values greater than or 

1129 equal to `logits.size` are ignored. Values can appear more than once. 

1130 

1131 Returns 

1132 ------- 

1133 u : Int32[Array, ''] 

1134 A random integer in the range ``[0, k)`` such that ``u not in exclude``. 

1135 num_allowed : Int32[Array, ''] 

1136 The number of integers in the range that were not excluded. 

1137 

1138 Notes 

1139 ----- 

1140 If all values in the range are excluded, the result is unspecified. 

1141 """ 

1142 exclude, num_allowed = _process_exclude(logits.size, exclude) 1ab

1143 kinda_neg_inf = jnp.finfo(logits.dtype).min 1ab

1144 logits = logits.at[exclude].set(kinda_neg_inf) 1ab

1145 u = random.categorical(key, logits) 1ab

1146 return u, num_allowed 1ab

1147 

1148 

1149def choose_split( 1ab

1150 key: Key[Array, ''], 

1151 var: Int32[Array, ''], 

1152 var_tree: UInt[Array, ' 2**(d-1)'], 

1153 split_tree: UInt[Array, ' 2**(d-1)'], 

1154 max_split: UInt[Array, ' p'], 

1155 leaf_index: Int32[Array, ''], 

1156) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]: 

1157 """ 

1158 Choose a split point for a new non-terminal node. 

1159 

1160 Parameters 

1161 ---------- 

1162 key 

1163 A jax random key. 

1164 var 

1165 The variable to split on. 

1166 var_tree 

1167 The splitting axes of the tree. Does not need to already contain `var` 

1168 at `leaf_index`. 

1169 split_tree 

1170 The splitting points of the tree. 

1171 max_split 

1172 The maximum split index for each variable. 

1173 leaf_index 

1174 The index of the leaf to grow. 

1175 

1176 Returns 

1177 ------- 

1178 split : Int32[Array, ''] 

1179 The cutpoint. 

1180 l : Int32[Array, ''] 

1181 r : Int32[Array, ''] 

1182 The integer range `split` was drawn from is [l, r). 

1183 

1184 Notes 

1185 ----- 

1186 If `var` is out of bounds, or if the available split range on that variable 

1187 is empty, return 0. 

1188 """ 

1189 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab

1190 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1ab

1191 

1192 

1193def compute_partial_ratio( 1ab

1194 prob_choose: Float32[Array, ''], 

1195 num_prunable: Int32[Array, ''], 

1196 p_nonterminal: Float32[Array, ' 2**d'], 

1197 leaf_to_grow: Int32[Array, ''], 

1198) -> Float32[Array, '']: 

1199 """ 

1200 Compute the product of the transition and prior ratios of a grow move. 

1201 

1202 Parameters 

1203 ---------- 

1204 prob_choose 

1205 The probability that the leaf had to be chosen amongst the growable 

1206 leaves. 

1207 num_prunable 

1208 The number of leaf parents that could be pruned, after converting the 

1209 leaf to be grown to a non-terminal node. 

1210 p_nonterminal 

1211 The a priori probability of each node being nonterminal conditional on 

1212 its ancestors. 

1213 leaf_to_grow 

1214 The index of the leaf to grow. 

1215 

1216 Returns 

1217 ------- 

1218 The partial transition ratio times the prior ratio. 

1219 

1220 Notes 

1221 ----- 

1222 The transition ratio is P(new tree => old tree) / P(old tree => new tree). 

1223 The "partial" transition ratio returned is missing the factor P(propose 

1224 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The 

1225 "partial" prior ratio is missing the factor P(children are leaves). 

1226 """ 

1227 # the two ratios also contain factors num_available_split * 

1228 # num_available_var * s[var], but they cancel out 

1229 

1230 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be 

1231 # computed here because they need the count trees, which are computed in the 

1232 # acceptance phase 

1233 

1234 prune_allowed = leaf_to_grow != 1 1ab

1235 # prune allowed <---> the initial tree is not a root 

1236 # leaf to grow is root --> the tree can only be a root 

1237 # tree is a root --> the only leaf I can grow is root 

1238 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab

1239 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab

1240 

1241 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this 

1242 # would produce a 0 and then an inf when `complete_ratio` takes the log 

1243 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1ab

1244 tree_ratio = pnt / (1 - pnt) 1ab

1245 

1246 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1ab

1247 

1248 

1249class PruneMoves(Module): 1ab

1250 """ 

1251 Represent a proposed prune move for each tree. 

1252 

1253 Parameters 

1254 ---------- 

1255 allowed 

1256 Whether the move is possible. 

1257 node 

1258 The index of the node to prune. ``2 ** d`` if no node can be pruned. 

1259 partial_ratio 

1260 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

1261 likelihood ratio, the probability of proposing the prune move, and the 

1262 prior probability that the children of the node to prune are leaves. 

1263 This ratio is inverted, and is meant to be inverted back in 

1264 `accept_move_and_sample_leaves`. 

1265 """ 

1266 

1267 allowed: Bool[Array, ' num_trees'] 1ab

1268 node: UInt[Array, ' num_trees'] 1ab

1269 partial_ratio: Float32[Array, ' num_trees'] 1ab

1270 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

1271 

1272 

1273@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab

1274def propose_prune_moves( 1ab

1275 key: Key[Array, ''], 

1276 split_tree: UInt[Array, ' 2**(d-1)'], 

1277 affluence_tree: Bool[Array, ' 2**(d-1)'], 

1278 p_nonterminal: Float32[Array, ' 2**d'], 

1279 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

1280) -> PruneMoves: 

1281 """ 

1282 Tree structure prune move proposal of BART MCMC. 

1283 

1284 Parameters 

1285 ---------- 

1286 key 

1287 A jax random key. 

1288 split_tree 

1289 The splitting points of the tree. 

1290 affluence_tree 

1291 Whether each leaf can be grown. 

1292 p_nonterminal 

1293 The a priori probability of a node to be nonterminal conditional on 

1294 the ancestors, including at the maximum depth where it should be zero. 

1295 p_propose_grow 

1296 The unnormalized probability of choosing a leaf to grow. 

1297 

1298 Returns 

1299 ------- 

1300 An object representing the proposed moves. 

1301 """ 

1302 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1ab

1303 key, split_tree, affluence_tree, p_propose_grow 

1304 ) 

1305 

1306 ratio = compute_partial_ratio( 1ab

1307 prob_choose, num_prunable, p_nonterminal, node_to_prune 

1308 ) 

1309 

1310 return PruneMoves( 1ab

1311 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root 

1312 node=node_to_prune, 

1313 partial_ratio=ratio, 

1314 affluence_tree=affluence_tree, 

1315 ) 

1316 

1317 

1318def choose_leaf_parent( 1ab

1319 key: Key[Array, ''], 

1320 split_tree: UInt[Array, ' 2**(d-1)'], 

1321 affluence_tree: Bool[Array, ' 2**(d-1)'], 

1322 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

1323) -> tuple[ 

1324 Int32[Array, ''], 

1325 Int32[Array, ''], 

1326 Float32[Array, ''], 

1327 Bool[Array, 'num_trees 2**(d-1)'], 

1328]: 

1329 """ 

1330 Pick a non-terminal node with leaf children to prune in a tree. 

1331 

1332 Parameters 

1333 ---------- 

1334 key 

1335 A jax random key. 

1336 split_tree 

1337 The splitting points of the tree. 

1338 affluence_tree 

1339 Whether a leaf has enough points to be grown. 

1340 p_propose_grow 

1341 The unnormalized probability of choosing a leaf to grow. 

1342 

1343 Returns 

1344 ------- 

1345 node_to_prune : Int32[Array, ''] 

1346 The index of the node to prune. If ``num_prunable == 0``, return 

1347 ``2 ** d``. 

1348 num_prunable : Int32[Array, ''] 

1349 The number of leaf parents that could be pruned. 

1350 prob_choose : Float32[Array, ''] 

1351 The (normalized) probability that `choose_leaf` would chose 

1352 `node_to_prune` as leaf to grow, if passed the tree where 

1353 `node_to_prune` had been pruned. 

1354 affluence_tree : Bool[Array, 'num_trees 2**(d-1)'] 

1355 A partially updated `affluence_tree`, marking the node to prune as 

1356 growable. 

1357 """ 

1358 # sample a node to prune 

1359 is_prunable = grove.is_leaves_parent(split_tree) 1ab

1360 num_prunable = jnp.count_nonzero(is_prunable) 1ab

1361 node_to_prune = randint_masked(key, is_prunable) 1ab

1362 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab

1363 

1364 # compute stuff for reverse move 

1365 split_tree = split_tree.at[node_to_prune].set(0) 1ab

1366 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab

1367 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab

1368 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab

1369 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1ab

1370 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1ab

1371 

1372 return node_to_prune, num_prunable, prob_choose, affluence_tree 1ab

1373 

1374 

1375def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 1ab

1376 """ 

1377 Return a random integer in a range, including only some values. 

1378 

1379 Parameters 

1380 ---------- 

1381 key 

1382 A jax random key. 

1383 mask 

1384 The mask indicating the allowed values. 

1385 

1386 Returns 

1387 ------- 

1388 A random integer in the range ``[0, n)`` such that ``mask[u] == True``. 

1389 

1390 Notes 

1391 ----- 

1392 If all values in the mask are `False`, return `n`. 

1393 """ 

1394 ecdf = jnp.cumsum(mask) 1ab

1395 u = random.randint(key, (), 0, ecdf[-1]) 1ab

1396 return jnp.searchsorted(ecdf, u, 'right') 1ab

1397 

1398 

1399def accept_moves_and_sample_leaves( 1ab

1400 key: Key[Array, ''], bart: State, moves: Moves 

1401) -> State: 

1402 """ 

1403 Accept or reject the proposed moves and sample the new leaf values. 

1404 

1405 Parameters 

1406 ---------- 

1407 key 

1408 A jax random key. 

1409 bart 

1410 A valid BART mcmc state. 

1411 moves 

1412 The proposed moves, see `propose_moves`. 

1413 

1414 Returns 

1415 ------- 

1416 A new (valid) BART mcmc state. 

1417 """ 

1418 pso = accept_moves_parallel_stage(key, bart, moves) 1ab

1419 bart, moves = accept_moves_sequential_stage(pso) 1ab

1420 return accept_moves_final_stage(bart, moves) 1ab

1421 

1422 

1423class Counts(Module): 1ab

1424 """ 

1425 Number of datapoints in the nodes involved in proposed moves for each tree. 

1426 

1427 Parameters 

1428 ---------- 

1429 left 

1430 Number of datapoints in the left child. 

1431 right 

1432 Number of datapoints in the right child. 

1433 total 

1434 Number of datapoints in the parent (``= left + right``). 

1435 """ 

1436 

1437 left: UInt[Array, ' num_trees'] 1ab

1438 right: UInt[Array, ' num_trees'] 1ab

1439 total: UInt[Array, ' num_trees'] 1ab

1440 

1441 

1442class Precs(Module): 1ab

1443 """ 

1444 Likelihood precision scale in the nodes involved in proposed moves for each tree. 

1445 

1446 The "likelihood precision scale" of a tree node is the sum of the inverse 

1447 squared error scales of the datapoints selected by the node. 

1448 

1449 Parameters 

1450 ---------- 

1451 left 

1452 Likelihood precision scale in the left child. 

1453 right 

1454 Likelihood precision scale in the right child. 

1455 total 

1456 Likelihood precision scale in the parent (``= left + right``). 

1457 """ 

1458 

1459 left: Float32[Array, ' num_trees'] 1ab

1460 right: Float32[Array, ' num_trees'] 1ab

1461 total: Float32[Array, ' num_trees'] 1ab

1462 

1463 

1464class PreLkV(Module): 1ab

1465 """ 

1466 Non-sequential terms of the likelihood ratio for each tree. 

1467 

1468 These terms can be computed in parallel across trees. 

1469 

1470 Supports both scalar and multivariate models. In the scalar case, variance 

1471 terms are 1D arrays of shape (num_trees,); In the multivariate case, they are 

1472 arrays of covariance matrices with shape (num_trees, k, k). 

1473 

1474 Parameters 

1475 ---------- 

1476 sigma2_left 

1477 In the scalar case, this is the noise variance in the left child of the leaves 

1478 grown or pruned by the moves. 

1479 In the multivariate case, this is the intermediate matrix in the quadratic form 

1480 representing the contribution of the left child to the exponential term. 

1481 sigma2_right 

1482 In the scalar case, this is the noise variance in the right child of the leaves 

1483 grown or pruned by the moves. 

1484 In the multivariate case, this is the intermediate matrix in the quadratic form 

1485 representing the contribution of the right child to the exponential term. 

1486 sigma2_total 

1487 In the scalar case, this is the noise variance in the total of the leaves 

1488 grown or pruned by the moves. 

1489 In the multivariate case, this is the intermediate matrix in the quadratic form 

1490 representing the contribution of the parent node to the exponential term. 

1491 sqrt_term 

1492 The **logarithm** of the square root term of the likelihood ratio. 

1493 """ 

1494 

1495 sigma2_left: Float32[Array, ' num_trees'] | Float32[Array, 'num_trees k k'] 1ab

1496 sigma2_right: Float32[Array, ' num_trees'] | Float32[Array, 'num_trees k k'] 1ab

1497 sigma2_total: Float32[Array, ' num_trees'] | Float32[Array, 'num_trees k k'] 1ab

1498 sqrt_term: Float32[Array, ' num_trees'] 1ab

1499 

1500 

1501class PreLk(Module): 1ab

1502 """ 

1503 Non-sequential terms of the likelihood ratio shared by all trees. 

1504 

1505 Parameters 

1506 ---------- 

1507 exp_factor 

1508 The factor to multiply the likelihood ratio by, shared by all trees. 

1509 """ 

1510 

1511 exp_factor: Float32[Array, ''] 1ab

1512 

1513 

1514class PreLf(Module): 1ab

1515 """ 

1516 Pre-computed terms used to sample leaves from their posterior. 

1517 

1518 These terms can be computed in parallel across trees. 

1519 

1520 Supports both scalar and multivariate models. In the scalara case, the arrays have 

1521 shape (num_trees, 2**d); In the multivariate case, mean_factor has shape (num_trees, 2**d, k, k) and 

1522 centered_leaves has shape (num_trees, 2**d, k). 

1523 

1524 Parameters 

1525 ---------- 

1526 mean_factor 

1527 The factor to be multiplied by the sum of the scaled residuals to 

1528 obtain the posterior mean. 

1529 centered_leaves 

1530 The mean-zero normal values to be added to the posterior mean to 

1531 obtain the posterior leaf samples. 

1532 """ 

1533 

1534 mean_factor: Float32[Array, 'num_trees 2**d'] | Float32[Array, 'num_trees 2**d k k'] 1ab

1535 centered_leaves: ( 1ab

1536 Float32[Array, 'num_trees 2**d'] | Float32[Array, 'num_trees 2**d k'] 

1537 ) 

1538 

1539 

1540class ParallelStageOut(Module): 1ab

1541 """ 

1542 The output of `accept_moves_parallel_stage`. 

1543 

1544 Parameters 

1545 ---------- 

1546 bart 

1547 A partially updated BART mcmc state. 

1548 moves 

1549 The proposed moves, with `partial_ratio` set to `None` and 

1550 `log_trans_prior_ratio` set to its final value. 

1551 prec_trees 

1552 The likelihood precision scale in each potential or actual leaf node. If 

1553 there is no precision scale, this is the number of points in each leaf. 

1554 move_counts 

1555 The counts of the number of points in the the nodes modified by the 

1556 moves. If `bart.min_points_per_leaf` is not set and 

1557 `bart.prec_scale` is set, they are not computed. 

1558 move_precs 

1559 The likelihood precision scale in each node modified by the moves. If 

1560 `bart.prec_scale` is not set, this is set to `move_counts`. 

1561 prelkv 

1562 prelk 

1563 prelf 

1564 Objects with pre-computed terms of the likelihood ratios and leaf 

1565 samples. 

1566 """ 

1567 

1568 bart: State 1ab

1569 moves: Moves 1ab

1570 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab

1571 move_precs: Precs | Counts 1ab

1572 prelkv: PreLkV 1ab

1573 prelk: PreLk 1ab

1574 prelf: PreLf 1ab

1575 

1576 

1577def accept_moves_parallel_stage( 1ab

1578 key: Key[Array, ''], bart: State, moves: Moves 

1579) -> ParallelStageOut: 

1580 """ 

1581 Pre-compute quantities used to accept moves, in parallel across trees. 

1582 

1583 Parameters 

1584 ---------- 

1585 key : jax.dtypes.prng_key array 

1586 A jax random key. 

1587 bart : dict 

1588 A BART mcmc state. 

1589 moves : dict 

1590 The proposed moves, see `propose_moves`. 

1591 

1592 Returns 

1593 ------- 

1594 An object with all that could be done in parallel. 

1595 """ 

1596 # where the move is grow, modify the state like the move was accepted 

1597 bart = replace( 1ab

1598 bart, 

1599 forest=replace( 

1600 bart.forest, 

1601 var_tree=moves.var_tree, 

1602 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X), 

1603 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves), 

1604 ), 

1605 ) 

1606 

1607 # count number of datapoints per leaf 

1608 if ( 1608 ↛ 1618line 1608 didn't jump to line 1618 because the condition on line 1608 was always true

1609 bart.forest.min_points_per_decision_node is not None 

1610 or bart.forest.min_points_per_leaf is not None 

1611 or bart.prec_scale is None 

1612 ): 

1613 count_trees, move_counts = compute_count_trees( 1ab

1614 bart.forest.leaf_indices, moves, bart.forest.count_batch_size 

1615 ) 

1616 

1617 # mark which leaves & potential leaves have enough points to be grown 

1618 if bart.forest.min_points_per_decision_node is not None: 1ab

1619 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab

1620 moves = replace( 1ab

1621 moves, 

1622 affluence_tree=moves.affluence_tree 

1623 & (count_half_trees >= bart.forest.min_points_per_decision_node), 

1624 ) 

1625 

1626 # copy updated affluence_tree to state 

1627 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1ab

1628 

1629 # veto grove move if new leaves don't have enough datapoints 

1630 if bart.forest.min_points_per_leaf is not None: 1ab

1631 moves = replace( 1ab

1632 moves, 

1633 allowed=moves.allowed 

1634 & (move_counts.left >= bart.forest.min_points_per_leaf) 

1635 & (move_counts.right >= bart.forest.min_points_per_leaf), 

1636 ) 

1637 

1638 # count number of datapoints per leaf, weighted by error precision scale 

1639 if bart.prec_scale is None: 1ab

1640 prec_trees = count_trees 1ab

1641 move_precs = move_counts 1ab

1642 else: 

1643 prec_trees, move_precs = compute_prec_trees( 1ab

1644 bart.prec_scale, 

1645 bart.forest.leaf_indices, 

1646 moves, 

1647 bart.forest.count_batch_size, 

1648 ) 

1649 assert move_precs is not None 1ab

1650 

1651 # compute some missing information about moves 

1652 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab

1653 save_ratios = bart.forest.log_likelihood is not None 1ab

1654 bart = replace( 1ab

1655 bart, 

1656 forest=replace( 

1657 bart.forest, 

1658 grow_prop_count=jnp.sum(moves.grow), 

1659 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

1660 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None, 

1661 ), 

1662 ) 

1663 

1664 # pre-compute some likelihood ratio & posterior terms 

1665 assert bart.sigma2 is not None # `step` shall temporarily set it to 1 1ab

1666 prelkv, prelk = precompute_likelihood_terms( 1ab

1667 bart.sigma2, bart.forest.sigma_mu2, move_precs 

1668 ) 

1669 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab

1670 

1671 return ParallelStageOut( 1ab

1672 bart=bart, 

1673 moves=moves, 

1674 prec_trees=prec_trees, 

1675 move_precs=move_precs, 

1676 prelkv=prelkv, 

1677 prelk=prelk, 

1678 prelf=prelf, 

1679 ) 

1680 

1681 

1682@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab

1683def apply_grow_to_indices( 1ab

1684 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

1685) -> UInt[Array, 'num_trees n']: 

1686 """ 

1687 Update the leaf indices to apply a grow move. 

1688 

1689 Parameters 

1690 ---------- 

1691 moves 

1692 The proposed moves, see `propose_moves`. 

1693 leaf_indices 

1694 The index of the leaf each datapoint falls into. 

1695 X 

1696 The predictors matrix. 

1697 

1698 Returns 

1699 ------- 

1700 The updated leaf indices. 

1701 """ 

1702 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab

1703 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab

1704 tree_size = jnp.array(2 * moves.var_tree.size) 1ab

1705 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab

1706 return jnp.where( 1ab

1707 leaf_indices == node_to_update, left_child + go_right, leaf_indices 

1708 ) 

1709 

1710 

1711def compute_count_trees( 1ab

1712 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None 

1713) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]: 

1714 """ 

1715 Count the number of datapoints in each leaf. 

1716 

1717 Parameters 

1718 ---------- 

1719 leaf_indices 

1720 The index of the leaf each datapoint falls into, with the deeper version 

1721 of the tree (post-GROW, pre-PRUNE). 

1722 moves 

1723 The proposed moves, see `propose_moves`. 

1724 batch_size 

1725 The data batch size to use for the summation. 

1726 

1727 Returns 

1728 ------- 

1729 count_trees : Int32[Array, 'num_trees 2**d'] 

1730 The number of points in each potential or actual leaf node. 

1731 counts : Counts 

1732 The counts of the number of points in the leaves grown or pruned by the 

1733 moves. 

1734 """ 

1735 num_trees, tree_size = moves.var_tree.shape 1ab

1736 tree_size *= 2 1ab

1737 tree_indices = jnp.arange(num_trees) 1ab

1738 

1739 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab

1740 

1741 # count datapoints in nodes modified by move 

1742 left = count_trees[tree_indices, moves.left] 1ab

1743 right = count_trees[tree_indices, moves.right] 1ab

1744 counts = Counts(left=left, right=right, total=left + right) 1ab

1745 

1746 # write count into non-leaf node 

1747 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab

1748 

1749 return count_trees, counts 1ab

1750 

1751 

1752def count_datapoints_per_leaf( 1ab

1753 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None 

1754) -> Int32[Array, 'num_trees 2**(d-1)']: 

1755 """ 

1756 Count the number of datapoints in each leaf. 

1757 

1758 Parameters 

1759 ---------- 

1760 leaf_indices 

1761 The index of the leaf each datapoint falls into. 

1762 tree_size 

1763 The size of the leaf tree array (2 ** d). 

1764 batch_size 

1765 The data batch size to use for the summation. 

1766 

1767 Returns 

1768 ------- 

1769 The number of points in each leaf node. 

1770 """ 

1771 if batch_size is None: 1ab

1772 return _count_scan(leaf_indices, tree_size) 1ab

1773 else: 

1774 return _count_vec(leaf_indices, tree_size, batch_size) 1ab

1775 

1776 

1777def _count_scan( 1ab

1778 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int 

1779) -> Int32[Array, 'num_trees {tree_size}']: 

1780 def loop(_, leaf_indices): 1ab

1781 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab

1782 

1783 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab

1784 return count_trees 1ab

1785 

1786 

1787def _aggregate_scatter( 1ab

1788 values: Shaped[Array, '*'], 

1789 indices: Integer[Array, '*'], 

1790 size: int, 

1791 dtype: jnp.dtype, 

1792) -> Shaped[Array, ' {size}']: 

1793 return jnp.zeros(size, dtype).at[indices].add(values) 1ab

1794 

1795 

1796def _count_vec( 1ab

1797 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int 

1798) -> Int32[Array, 'num_trees 2**(d-1)']: 

1799 return _aggregate_batched_alltrees( 1ab

1800 1, leaf_indices, tree_size, jnp.uint32, batch_size 

1801 ) 

1802 # uint16 is super-slow on gpu, don't use it even if n < 2^16 

1803 

1804 

1805def _aggregate_batched_alltrees( 1ab

1806 values: Shaped[Array, '*'], 

1807 indices: UInt[Array, 'num_trees n'], 

1808 size: int, 

1809 dtype: jnp.dtype, 

1810 batch_size: int, 

1811) -> Shaped[Array, 'num_trees {size}']: 

1812 num_trees, n = indices.shape 1ab

1813 tree_indices = jnp.arange(num_trees) 1ab

1814 nbatches = n // batch_size + bool(n % batch_size) 1ab

1815 batch_indices = jnp.arange(n) % nbatches 1ab

1816 return ( 1ab

1817 jnp.zeros((num_trees, size, nbatches), dtype) 

1818 .at[tree_indices[:, None], indices, batch_indices] 

1819 .add(values) 

1820 .sum(axis=2) 

1821 ) 

1822 

1823 

1824def compute_prec_trees( 1ab

1825 prec_scale: Float32[Array, ' n'], 

1826 leaf_indices: UInt[Array, 'num_trees n'], 

1827 moves: Moves, 

1828 batch_size: int | None, 

1829) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]: 

1830 """ 

1831 Compute the likelihood precision scale in each leaf. 

1832 

1833 Parameters 

1834 ---------- 

1835 prec_scale 

1836 The scale of the precision of the error on each datapoint. 

1837 leaf_indices 

1838 The index of the leaf each datapoint falls into, with the deeper version 

1839 of the tree (post-GROW, pre-PRUNE). 

1840 moves 

1841 The proposed moves, see `propose_moves`. 

1842 batch_size 

1843 The data batch size to use for the summation. 

1844 

1845 Returns 

1846 ------- 

1847 prec_trees : Float32[Array, 'num_trees 2**d'] 

1848 The likelihood precision scale in each potential or actual leaf node. 

1849 precs : Precs 

1850 The likelihood precision scale in the nodes involved in the moves. 

1851 """ 

1852 num_trees, tree_size = moves.var_tree.shape 1ab

1853 tree_size *= 2 1ab

1854 tree_indices = jnp.arange(num_trees) 1ab

1855 

1856 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab

1857 

1858 # prec datapoints in nodes modified by move 

1859 left = prec_trees[tree_indices, moves.left] 1ab

1860 right = prec_trees[tree_indices, moves.right] 1ab

1861 precs = Precs(left=left, right=right, total=left + right) 1ab

1862 

1863 # write prec into non-leaf node 

1864 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab

1865 

1866 return prec_trees, precs 1ab

1867 

1868 

1869def prec_per_leaf( 1ab

1870 prec_scale: Float32[Array, ' n'], 

1871 leaf_indices: UInt[Array, 'num_trees n'], 

1872 tree_size: int, 

1873 batch_size: int | None, 

1874) -> Float32[Array, 'num_trees {tree_size}']: 

1875 """ 

1876 Compute the likelihood precision scale in each leaf. 

1877 

1878 Parameters 

1879 ---------- 

1880 prec_scale 

1881 The scale of the precision of the error on each datapoint. 

1882 leaf_indices 

1883 The index of the leaf each datapoint falls into. 

1884 tree_size 

1885 The size of the leaf tree array (2 ** d). 

1886 batch_size 

1887 The data batch size to use for the summation. 

1888 

1889 Returns 

1890 ------- 

1891 The likelihood precision scale in each leaf node. 

1892 """ 

1893 if batch_size is None: 1893 ↛ 1896line 1893 didn't jump to line 1896 because the condition on line 1893 was always true1ab

1894 return _prec_scan(prec_scale, leaf_indices, tree_size) 1ab

1895 else: 

1896 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size) 

1897 

1898 

1899def _prec_scan( 1ab

1900 prec_scale: Float32[Array, ' n'], 

1901 leaf_indices: UInt[Array, 'num_trees n'], 

1902 tree_size: int, 

1903) -> Float32[Array, 'num_trees {tree_size}']: 

1904 def loop(_, leaf_indices): 1ab

1905 return None, _aggregate_scatter( 1ab

1906 prec_scale, leaf_indices, tree_size, jnp.float32 

1907 ) 

1908 

1909 _, prec_trees = lax.scan(loop, None, leaf_indices) 1ab

1910 return prec_trees 1ab

1911 

1912 

1913def _prec_vec( 1ab

1914 prec_scale: Float32[Array, ' n'], 

1915 leaf_indices: UInt[Array, 'num_trees n'], 

1916 tree_size: int, 

1917 batch_size: int, 

1918) -> Float32[Array, 'num_trees {tree_size}']: 

1919 return _aggregate_batched_alltrees( 

1920 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size 

1921 ) 

1922 

1923 

1924def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 1ab

1925 """ 

1926 Complete non-likelihood MH ratio calculation. 

1927 

1928 This function adds the probability of choosing a prune move over the grow 

1929 move in the inverse transition, and the a priori probability that the 

1930 children nodes are leaves. 

1931 

1932 Parameters 

1933 ---------- 

1934 moves 

1935 The proposed moves. Must have already been updated to keep into account 

1936 the thresholds on the number of datapoints per node, this happens in 

1937 `accept_moves_parallel_stage`. 

1938 p_nonterminal 

1939 The a priori probability of each node being nonterminal conditional on 

1940 its ancestors, including at the maximum depth where it should be zero. 

1941 

1942 Returns 

1943 ------- 

1944 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

1945 """ 

1946 # can the leaves can be grown? 

1947 num_trees, _ = moves.affluence_tree.shape 1ab

1948 tree_indices = jnp.arange(num_trees) 1ab

1949 left_growable = moves.affluence_tree.at[tree_indices, moves.left].get( 1ab

1950 mode='fill', fill_value=False 

1951 ) 

1952 right_growable = moves.affluence_tree.at[tree_indices, moves.right].get( 1ab

1953 mode='fill', fill_value=False 

1954 ) 

1955 

1956 # p_prune if grow 

1957 other_growable_leaves = moves.num_growable >= 2 1ab

1958 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab

1959 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab

1960 

1961 # p_prune if prune 

1962 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab

1963 

1964 # select p_prune 

1965 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab

1966 

1967 # prior probability of both children being terminal 

1968 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab

1969 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab

1970 pt_children = pt_left * pt_right 1ab

1971 

1972 return replace( 1ab

1973 moves, 

1974 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune), 

1975 partial_ratio=None, 

1976 ) 

1977 

1978 

1979@vmap_nodoc 1ab

1980def adapt_leaf_trees_to_grow_indices( 1ab

1981 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves 

1982) -> Float32[Array, 'num_trees 2**d']: 

1983 """ 

1984 Modify leaves such that post-grow indices work on the original tree. 

1985 

1986 The value of the leaf to grow is copied to what would be its children if the 

1987 grow move was accepted. 

1988 

1989 Parameters 

1990 ---------- 

1991 leaf_trees 

1992 The leaf values. 

1993 moves 

1994 The proposed moves, see `propose_moves`. 

1995 

1996 Returns 

1997 ------- 

1998 The modified leaf values. 

1999 """ 

2000 values_at_node = leaf_trees[moves.node] 1ab

2001 return ( 1ab

2002 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)] 

2003 .set(values_at_node) 

2004 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)] 

2005 .set(values_at_node) 

2006 ) 

2007 

2008 

2009def precompute_likelihood_terms( 1ab

2010 sigma2: Float32[Array, ''], 

2011 sigma_mu2: Float32[Array, ''], 

2012 move_precs: Precs | Counts, 

2013) -> tuple[PreLkV, PreLk]: 

2014 """ 

2015 Pre-compute terms used in the likelihood ratio of the acceptance step. 

2016 

2017 Parameters 

2018 ---------- 

2019 sigma2 

2020 The error variance, or the global error variance factor is `prec_scale` 

2021 is set. 

2022 sigma_mu2 

2023 The prior variance of each leaf. 

2024 move_precs 

2025 The likelihood precision scale in the leaves grown or pruned by the 

2026 moves, under keys 'left', 'right', and 'total' (left + right). 

2027 

2028 Returns 

2029 ------- 

2030 prelkv : PreLkV 

2031 Dictionary with pre-computed terms of the likelihood ratio, one per 

2032 tree. 

2033 prelk : PreLk 

2034 Dictionary with pre-computed terms of the likelihood ratio, shared by 

2035 all trees. 

2036 """ 

2037 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab

2038 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab

2039 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab

2040 prelkv = PreLkV( 1ab

2041 sigma2_left=sigma2_left, 

2042 sigma2_right=sigma2_right, 

2043 sigma2_total=sigma2_total, 

2044 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2, 

2045 ) 

2046 return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2)) 1ab

2047 

2048 

2049@partial(jnp.vectorize, signature='(k,k)->(k,k)') 1ab

2050def _chol_with_gersh(mat: Float32[Array, '... k k']) -> Float32[Array, '... k k']: 1ab

2051 """Cholesky with Gershgorin stabilization, supports batching.""" 

2052 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1)) 1ab

2053 u = mat.shape[0] * rho * jnp.finfo(mat.dtype).eps 1ab

2054 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 1ab

2055 return jnp.linalg.cholesky(mat) 1ab

2056 

2057 

2058def _logdet_from_chol(L): 1ab

2059 """Compute logdet of A = L'L via Cholesky (sum of log of diag^2).""" 

2060 return 2.0 * jnp.sum(jnp.log(jnp.diag(L))) 1ab

2061 

2062 

2063def precompute_likelihood_terms_mv( 1ab

2064 error_cov_inv: Float32[Array, 'k k'], 

2065 leaf_prior_cov_inv: Float32[Array, 'k k'], 

2066 move_precs: Counts, 

2067) -> tuple[PreLkV, PreLk]: 

2068 """ 

2069 Pre-compute terms used in the likelihood ratio of the acceptance step. 

2070 

2071 This implementation assumes a homoskedastic error model (i.e., the residual 

2072 covariance is the same for all observations). Support for heteroskedasticity 

2073 is planed for future updates. 

2074 

2075 Parameters 

2076 ---------- 

2077 error_cov_inv 

2078 The inverse of the error covariance matrix. 

2079 leaf_prior_cov_inv 

2080 The inverse of prior covariance matrix of each leaf. 

2081 move_precs 

2082 The likelihood precision scale in the leaves grown or pruned by the 

2083 moves, under keys 'left', 'right', and 'total' (left + right). 

2084 

2085 Returns 

2086 ------- 

2087 prelkv : PreLkV 

2088 Dictionary with pre-computed terms of the likelihood ratio, one per 

2089 tree. 

2090 prelk : PreLk 

2091 Dictionary with pre-computed terms of the likelihood ratio, shared by 

2092 all trees. 

2093 """ 

2094 nL = move_precs.left[..., None, None] 1ab

2095 nR = move_precs.right[..., None, None] 1ab

2096 nT = move_precs.total[..., None, None] 1ab

2097 

2098 L_left = _chol_with_gersh(error_cov_inv * nL + leaf_prior_cov_inv) 1ab

2099 L_right = _chol_with_gersh(error_cov_inv * nR + leaf_prior_cov_inv) 1ab

2100 L_total = _chol_with_gersh(error_cov_inv * nT + leaf_prior_cov_inv) 1ab

2101 

2102 sqrt_term = 0.5 * ( 1ab

2103 _logdet_from_chol(_chol_with_gersh(leaf_prior_cov_inv)) 

2104 + _logdet_from_chol(L_total) 

2105 - _logdet_from_chol(L_left) 

2106 - _logdet_from_chol(L_right) 

2107 ) 

2108 

2109 def _covariance_from_chol(L): 1ab

2110 Y = solve_triangular(L, error_cov_inv, lower=True) 1ab

2111 return Y.T @ Y 1ab

2112 

2113 prelkv = PreLkV( 1ab

2114 sigma2_left=_covariance_from_chol(L_left), 

2115 sigma2_right=_covariance_from_chol(L_right), 

2116 sigma2_total=_covariance_from_chol(L_total), 

2117 sqrt_term=sqrt_term, 

2118 ) 

2119 

2120 return prelkv, PreLk(exp_factor=0.5) 1ab

2121 

2122 

2123def precompute_leaf_terms( 1ab

2124 key: Key[Array, ''], 

2125 prec_trees: Float32[Array, 'num_trees 2**d'], 

2126 sigma2: Float32[Array, ''], 

2127 sigma_mu2: Float32[Array, ''], 

2128 z: Float32[Array, 'num_trees 2**d'] | None = None, 

2129) -> PreLf: 

2130 """ 

2131 Pre-compute terms used to sample leaves from their posterior. 

2132 

2133 Parameters 

2134 ---------- 

2135 key 

2136 A jax random key. 

2137 prec_trees 

2138 The likelihood precision scale in each potential or actual leaf node. 

2139 sigma2 

2140 The error variance, or the global error variance factor if `prec_scale` 

2141 is set. 

2142 sigma_mu2 

2143 The prior variance of each leaf. 

2144 z 

2145 Optional standard normal noise to use for sampling the centered leaves. 

2146 This is intended for testing purposes only. 

2147 

2148 Returns 

2149 ------- 

2150 Pre-computed terms for leaf sampling. 

2151 """ 

2152 prec_lk = prec_trees / sigma2 1ab

2153 prec_prior = lax.reciprocal(sigma_mu2) 1ab

2154 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab

2155 if z is None: 1ab

2156 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab

2157 return PreLf( 1ab

2158 mean_factor=var_post / sigma2, 

2159 # | mean = mean_lk * prec_lk * var_post 

2160 # | resid_tree = mean_lk * prec_tree --> 

2161 # | --> mean_lk = resid_tree / prec_tree (kind of) 

2162 # | mean_factor = 

2163 # | = mean / resid_tree = 

2164 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

2165 # | = 1 / prec_tree * prec_tree / sigma2 * var_post = 

2166 # | = var_post / sigma2 

2167 centered_leaves=z * jnp.sqrt(var_post), 

2168 ) 

2169 

2170 

2171def precompute_leaf_terms_mv( 1ab

2172 key: Key[Array, ''], 

2173 prec_trees: Float32[Array, 'num_trees 2**d'], 

2174 error_cov_inv: Float32[Array, 'k k'], 

2175 leaf_prior_cov_inv: Float32[Array, 'k k'], 

2176 z: Float32[Array, 'num_trees 2**d'] | None = None, 

2177) -> PreLf: 

2178 """ 

2179 Pre-compute terms used to sample leaves from their posterior. 

2180 

2181 Parameters 

2182 ---------- 

2183 key 

2184 A jax random key. 

2185 prec_trees 

2186 The likelihood precision scale in each potential or actual leaf node. 

2187 error_cov_inv 

2188 The inverse of error variance, or the global error variance factor if `prec_scale` 

2189 is set. 

2190 leaf_prior_cov_inv 

2191 The inverse of prior variance of each leaf. 

2192 z 

2193 Optional standard normal noise to use for sampling the centered leaves. 

2194 This is intended for testing purposes only. 

2195 

2196 Returns 

2197 ------- 

2198 Pre-computed terms for leaf sampling in multivariate case. 

2199 """ 

2200 num_trees, num_leaves = prec_trees.shape 1ab

2201 k = error_cov_inv.shape[0] 1ab

2202 n_k = prec_trees[..., None, None] # Shape: [num_trees, num_leaves, 1, 1] 1ab

2203 

2204 # Only broadcast the inverse of error covariance matrix to satisfy JAX's batching rules 

2205 # for `lax.linalg.solve_triangular`, which does not support implicit broadcasting. 

2206 error_cov_inv_batched = jnp.broadcast_to( 1ab

2207 error_cov_inv, (num_trees, num_leaves, k, k) 

2208 ) 

2209 

2210 posterior_precision = leaf_prior_cov_inv + n_k * error_cov_inv_batched 1ab

2211 

2212 L_prec = _chol_with_gersh(posterior_precision) 1ab

2213 Y = solve_triangular(L_prec, error_cov_inv_batched, lower=True) 1ab

2214 mean_factor = solve_triangular(L_prec, Y, trans='T', lower=True) 1ab

2215 

2216 if z is None: 1ab

2217 z = random.normal(key, (num_trees, num_leaves, k)) 1ab

2218 centered_leaves = solve_triangular(L_prec, z, trans='T') 1ab

2219 

2220 return PreLf( 1ab

2221 mean_factor=mean_factor, # Shape: [num_trees, num_leaves, k, k] 

2222 centered_leaves=centered_leaves, # Shape: [num_trees, num_leaves, k] 

2223 ) 

2224 

2225 

2226def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab

2227 """ 

2228 Accept/reject the moves one tree at a time. 

2229 

2230 This is the most performance-sensitive function because it contains all and 

2231 only the parts of the algorithm that can not be parallelized across trees. 

2232 

2233 Parameters 

2234 ---------- 

2235 pso 

2236 The output of `accept_moves_parallel_stage`. 

2237 

2238 Returns 

2239 ------- 

2240 bart : State 

2241 A partially updated BART mcmc state. 

2242 moves : Moves 

2243 The accepted/rejected moves, with `acc` and `to_prune` set. 

2244 """ 

2245 

2246 def loop(resid, pt): 1ab

2247 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab

2248 resid, 

2249 SeqStageInAllTrees( 

2250 pso.bart.X, 

2251 pso.bart.forest.resid_batch_size, 

2252 pso.bart.prec_scale, 

2253 pso.bart.forest.log_likelihood is not None, 

2254 pso.prelk, 

2255 ), 

2256 pt, 

2257 ) 

2258 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab

2259 

2260 pts = SeqStageInPerTree( 1ab

2261 pso.bart.forest.leaf_tree, 

2262 pso.prec_trees, 

2263 pso.moves, 

2264 pso.move_precs, 

2265 pso.bart.forest.leaf_indices, 

2266 pso.prelkv, 

2267 pso.prelf, 

2268 ) 

2269 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab

2270 

2271 bart = replace( 1ab

2272 pso.bart, 

2273 resid=resid, 

2274 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio), 

2275 ) 

2276 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab

2277 

2278 return bart, moves 1ab

2279 

2280 

2281class SeqStageInAllTrees(Module): 1ab

2282 """ 

2283 The inputs to `accept_move_and_sample_leaves` that are shared by all trees. 

2284 

2285 Parameters 

2286 ---------- 

2287 X 

2288 The predictors. 

2289 resid_batch_size 

2290 The batch size for computing the sum of residuals in each leaf. 

2291 prec_scale 

2292 The scale of the precision of the error on each datapoint. If None, it 

2293 is assumed to be 1. 

2294 save_ratios 

2295 Whether to save the acceptance ratios. 

2296 prelk 

2297 The pre-computed terms of the likelihood ratio which are shared across 

2298 trees. 

2299 """ 

2300 

2301 X: UInt[Array, 'p n'] 1ab

2302 resid_batch_size: int | None = field(static=True) 1ab

2303 prec_scale: Float32[Array, ' n'] | None 1ab

2304 save_ratios: bool = field(static=True) 1ab

2305 prelk: PreLk 1ab

2306 

2307 

2308class SeqStageInPerTree(Module): 1ab

2309 """ 

2310 The inputs to `accept_move_and_sample_leaves` that are separate for each tree. 

2311 

2312 Parameters 

2313 ---------- 

2314 leaf_tree 

2315 The leaf values of the tree. 

2316 prec_tree 

2317 The likelihood precision scale in each potential or actual leaf node. 

2318 move 

2319 The proposed move, see `propose_moves`. 

2320 move_precs 

2321 The likelihood precision scale in each node modified by the moves. 

2322 leaf_indices 

2323 The leaf indices for the largest version of the tree compatible with 

2324 the move. 

2325 prelkv 

2326 prelf 

2327 The pre-computed terms of the likelihood ratio and leaf sampling which 

2328 are specific to the tree. 

2329 """ 

2330 

2331 leaf_tree: Float32[Array, ' 2**d'] 1ab

2332 prec_tree: Float32[Array, ' 2**d'] 1ab

2333 move: Moves 1ab

2334 move_precs: Precs | Counts 1ab

2335 leaf_indices: UInt[Array, ' n'] 1ab

2336 prelkv: PreLkV 1ab

2337 prelf: PreLf 1ab

2338 

2339 

2340def accept_move_and_sample_leaves( 1ab

2341 resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree 

2342) -> tuple[ 

2343 Float32[Array, ' n'], 

2344 Float32[Array, ' 2**d'], 

2345 Bool[Array, ''], 

2346 Bool[Array, ''], 

2347 Float32[Array, ''] | None, 

2348]: 

2349 """ 

2350 Accept or reject a proposed move and sample the new leaf values. 

2351 

2352 Parameters 

2353 ---------- 

2354 resid 

2355 The residuals (data minus forest value). 

2356 at 

2357 The inputs that are the same for all trees. 

2358 pt 

2359 The inputs that are separate for each tree. 

2360 

2361 Returns 

2362 ------- 

2363 resid : Float32[Array, 'n'] 

2364 The updated residuals (data minus forest value). 

2365 leaf_tree : Float32[Array, '2**d'] 

2366 The new leaf values of the tree. 

2367 acc : Bool[Array, ''] 

2368 Whether the move was accepted. 

2369 to_prune : Bool[Array, ''] 

2370 Whether, to reflect the acceptance status of the move, the state should 

2371 be updated by pruning the leaves involved in the move. 

2372 log_lk_ratio : Float32[Array, ''] | None 

2373 The logarithm of the likelihood ratio for the move. `None` if not to be 

2374 saved. 

2375 """ 

2376 # sum residuals in each leaf, in tree proposed by grow move 

2377 if at.prec_scale is None: 1ab

2378 scaled_resid = resid 1ab

2379 else: 

2380 scaled_resid = resid * at.prec_scale 1ab

2381 resid_tree = sum_resid( 1ab

2382 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size 

2383 ) 

2384 

2385 # subtract starting tree from function 

2386 resid_tree += pt.prec_tree * pt.leaf_tree 1ab

2387 

2388 # sum residuals in parent node modified by move 

2389 resid_left = resid_tree[pt.move.left] 1ab

2390 resid_right = resid_tree[pt.move.right] 1ab

2391 resid_total = resid_left + resid_right 1ab

2392 assert pt.move.node.dtype == jnp.int32 1ab

2393 resid_tree = resid_tree.at[pt.move.node].set(resid_total) 1ab

2394 

2395 # compute acceptance ratio 

2396 log_lk_ratio = compute_likelihood_ratio( 1ab

2397 resid_total, resid_left, resid_right, pt.prelkv, at.prelk 

2398 ) 

2399 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab

2400 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab

2401 if not at.save_ratios: 1ab

2402 log_lk_ratio = None 1ab

2403 

2404 # determine whether to accept the move 

2405 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab

2406 

2407 # compute leaves posterior and sample leaves 

2408 mean_post = resid_tree * pt.prelf.mean_factor 1ab

2409 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab

2410 

2411 # copy leaves around such that the leaf indices point to the correct leaf 

2412 to_prune = acc ^ pt.move.grow 1ab

2413 leaf_tree = ( 1ab

2414 leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)] 

2415 .set(leaf_tree[pt.move.node]) 

2416 .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)] 

2417 .set(leaf_tree[pt.move.node]) 

2418 ) 

2419 

2420 # replace old tree with new tree in function values 

2421 resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices] 1ab

2422 

2423 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab

2424 

2425 

2426def sum_resid( 1ab

2427 scaled_resid: Float32[Array, ' n'], 

2428 leaf_indices: UInt[Array, ' n'], 

2429 tree_size: int, 

2430 batch_size: int | None, 

2431) -> Float32[Array, ' {tree_size}']: 

2432 """ 

2433 Sum the residuals in each leaf. 

2434 

2435 Parameters 

2436 ---------- 

2437 scaled_resid 

2438 The residuals (data minus forest value) multiplied by the error 

2439 precision scale. 

2440 leaf_indices 

2441 The leaf indices of the tree (in which leaf each data point falls into). 

2442 tree_size 

2443 The size of the tree array (2 ** d). 

2444 batch_size 

2445 The data batch size for the aggregation. Batching increases numerical 

2446 accuracy and parallelism. 

2447 

2448 Returns 

2449 ------- 

2450 The sum of the residuals at data points in each leaf. 

2451 """ 

2452 if batch_size is None: 1ab

2453 aggr_func = _aggregate_scatter 1ab

2454 else: 

2455 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab

2456 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab

2457 

2458 

2459def _aggregate_batched_onetree( 1ab

2460 values: Shaped[Array, '*'], 

2461 indices: Integer[Array, '*'], 

2462 size: int, 

2463 dtype: jnp.dtype, 

2464 batch_size: int, 

2465) -> Float32[Array, ' {size}']: 

2466 (n,) = indices.shape 1ab

2467 nbatches = n // batch_size + bool(n % batch_size) 1ab

2468 batch_indices = jnp.arange(n) % nbatches 1ab

2469 return ( 1ab

2470 jnp.zeros((size, nbatches), dtype) 

2471 .at[indices, batch_indices] 

2472 .add(values) 

2473 .sum(axis=1) 

2474 ) 

2475 

2476 

2477def compute_likelihood_ratio( 1ab

2478 total_resid: Float32[Array, ''], 

2479 left_resid: Float32[Array, ''], 

2480 right_resid: Float32[Array, ''], 

2481 prelkv: PreLkV, 

2482 prelk: PreLk, 

2483) -> Float32[Array, '']: 

2484 """ 

2485 Compute the likelihood ratio of a grow move. 

2486 

2487 Parameters 

2488 ---------- 

2489 total_resid 

2490 left_resid 

2491 right_resid 

2492 The sum of the residuals (scaled by error precision scale) of the 

2493 datapoints falling in the nodes involved in the moves. 

2494 prelkv 

2495 prelk 

2496 The pre-computed terms of the likelihood ratio, see 

2497 `precompute_likelihood_terms`. 

2498 

2499 Returns 

2500 ------- 

2501 The log-likelihood ratio log P(data | new tree) - log P(data | old tree). 

2502 """ 

2503 exp_term = prelk.exp_factor * ( 1ab

2504 left_resid * left_resid / prelkv.sigma2_left 

2505 + right_resid * right_resid / prelkv.sigma2_right 

2506 - total_resid * total_resid / prelkv.sigma2_total 

2507 ) 

2508 return prelkv.sqrt_term + exp_term 1ab

2509 

2510 

2511def compute_likelihood_ratio_mv( 1ab

2512 total_resid: Float32[Array, ' k'], 

2513 left_resid: Float32[Array, ' k'], 

2514 right_resid: Float32[Array, ' k'], 

2515 prelkv: PreLkV, 

2516 prelk: PreLk, # noqa: ARG001 

2517) -> Float32[Array, '']: 

2518 """ 

2519 Compute the likelihood ratio of a grow move, for multivariate case. 

2520 

2521 Parameters 

2522 ---------- 

2523 total_resid 

2524 left_resid 

2525 right_resid 

2526 The sum of the residuals (scaled by error precision scale) of the 

2527 datapoints falling in the nodes involved in the moves. 

2528 prelkv 

2529 prelk 

2530 The pre-computed terms of the likelihood ratio, see 

2531 `precompute_likelihood_terms_mv`. 

2532 

2533 Returns 

2534 ------- 

2535 The log-likelihood ratio log P(data | new tree) - log P(data | old tree). 

2536 """ 

2537 

2538 def _quadratic_form(r, cov): 1ab

2539 return r @ cov @ r 1ab

2540 

2541 qf_left = _quadratic_form(left_resid, prelkv.sigma2_left) 1ab

2542 qf_right = _quadratic_form(right_resid, prelkv.sigma2_right) 1ab

2543 qf_total = _quadratic_form(total_resid, prelkv.sigma2_total) 1ab

2544 exp_term = 0.5 * (qf_left + qf_right - qf_total) 1ab

2545 return prelkv.sqrt_term + exp_term 1ab

2546 

2547 

2548def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab

2549 """ 

2550 Post-process the mcmc state after accepting/rejecting the moves. 

2551 

2552 This function is separate from `accept_moves_sequential_stage` to signal it 

2553 can work in parallel across trees. 

2554 

2555 Parameters 

2556 ---------- 

2557 bart 

2558 A partially updated BART mcmc state. 

2559 moves 

2560 The proposed moves (see `propose_moves`) as updated by 

2561 `accept_moves_sequential_stage`. 

2562 

2563 Returns 

2564 ------- 

2565 The fully updated BART mcmc state. 

2566 """ 

2567 return replace( 1ab

2568 bart, 

2569 forest=replace( 

2570 bart.forest, 

2571 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

2572 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

2573 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves), 

2574 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves), 

2575 ), 

2576 ) 

2577 

2578 

2579@vmap_nodoc 1ab

2580def apply_moves_to_leaf_indices( 1ab

2581 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

2582) -> UInt[Array, 'num_trees n']: 

2583 """ 

2584 Update the leaf indices to match the accepted move. 

2585 

2586 Parameters 

2587 ---------- 

2588 leaf_indices 

2589 The index of the leaf each datapoint falls into, if the grow move was 

2590 accepted. 

2591 moves 

2592 The proposed moves (see `propose_moves`), as updated by 

2593 `accept_moves_sequential_stage`. 

2594 

2595 Returns 

2596 ------- 

2597 The updated leaf indices. 

2598 """ 

2599 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab

2600 is_child = (leaf_indices & mask) == moves.left 1ab

2601 return jnp.where( 1ab

2602 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices 

2603 ) 

2604 

2605 

2606@vmap_nodoc 1ab

2607def apply_moves_to_split_trees( 1ab

2608 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves 

2609) -> UInt[Array, 'num_trees 2**(d-1)']: 

2610 """ 

2611 Update the split trees to match the accepted move. 

2612 

2613 Parameters 

2614 ---------- 

2615 split_tree 

2616 The cutpoints of the decision nodes in the initial trees. 

2617 moves 

2618 The proposed moves (see `propose_moves`), as updated by 

2619 `accept_moves_sequential_stage`. 

2620 

2621 Returns 

2622 ------- 

2623 The updated split trees. 

2624 """ 

2625 assert moves.to_prune is not None 1ab

2626 return ( 1ab

2627 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)] 

2628 .set(moves.grow_split.astype(split_tree.dtype)) 

2629 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)] 

2630 .set(0) 

2631 ) 

2632 

2633 

2634def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab

2635 """ 

2636 MCMC-update the error variance (factor). 

2637 

2638 Parameters 

2639 ---------- 

2640 key 

2641 A jax random key. 

2642 bart 

2643 A BART mcmc state. 

2644 

2645 Returns 

2646 ------- 

2647 The new BART mcmc state, with an updated `sigma2`. 

2648 """ 

2649 resid = bart.resid 1ab

2650 alpha = bart.sigma2_alpha + resid.size / 2 1ab

2651 if bart.prec_scale is None: 1ab

2652 scaled_resid = resid 1ab

2653 else: 

2654 scaled_resid = resid * bart.prec_scale 1ab

2655 norm2 = resid @ scaled_resid 1ab

2656 beta = bart.sigma2_beta + norm2 / 2 1ab

2657 

2658 sample = random.gamma(key, alpha) 1ab

2659 # random.gamma seems to be slow at compiling, maybe cdf inversion would 

2660 # be better, but it's not implemented in jax 

2661 return replace(bart, sigma2=beta / sample) 1ab

2662 

2663 

2664@jax.jit 1ab

2665def _sample_wishart_bartlett( 1ab

2666 key: Key[Array, ''], df: Integer[Array, ''], scale_inv: Float32[Array, 'k k'] 

2667) -> Float32[Array, 'k k']: 

2668 """ 

2669 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition. 

2670 

2671 Parameters 

2672 ---------- 

2673 key 

2674 A JAX random key 

2675 df 

2676 Degrees of freedom 

2677 scale_inv 

2678 Scale matrix of the corresponding Inverse Wishart distribution 

2679 

2680 Returns 

2681 ------- 

2682 A sample from Wishart(df, scale) 

2683 """ 

2684 keys = split(key) 1ab

2685 

2686 k = scale_inv.shape[0] 1ab

2687 

2688 # Gershgorin estimate for max eigenvalue 

2689 rho = jnp.max(jnp.sum(jnp.abs(scale_inv), axis=1)) 1ab

2690 u = k * rho * jnp.finfo(scale_inv.dtype).eps + jnp.finfo(scale_inv.dtype).eps 1ab

2691 

2692 # Stabilize the matrix 

2693 scale_inv = scale_inv.at[jnp.diag_indices(k)].add(u) 1ab

2694 L = jnp.linalg.cholesky(scale_inv) 1ab

2695 

2696 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i)) 

2697 # chi^2(k) = Gamma(k/2, scale=2) 

2698 df_vector = df - jnp.arange(k) 1ab

2699 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 1ab

2700 diag_A = jnp.sqrt(chi2_samples) 1ab

2701 

2702 off_diag_A = random.normal(keys.pop(), (k, k)) 1ab

2703 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 1ab

2704 T = solve_triangular(L, A, lower=True, trans='T') 1ab

2705 

2706 return T @ T.T 1ab

2707 

2708 

2709def step_z(key: Key[Array, ''], bart: State) -> State: 1ab

2710 """ 

2711 MCMC-update the latent variable for binary regression. 

2712 

2713 Parameters 

2714 ---------- 

2715 key 

2716 A jax random key. 

2717 bart 

2718 A BART MCMC state. 

2719 

2720 Returns 

2721 ------- 

2722 The updated BART MCMC state. 

2723 """ 

2724 trees_plus_offset = bart.z - bart.resid 1ab

2725 assert bart.y.dtype == bool 1ab

2726 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1ab

2727 z = trees_plus_offset + resid 1ab

2728 return replace(bart, z=z, resid=resid) 1ab

2729 

2730 

2731def step_s(key: Key[Array, ''], bart: State) -> State: 1ab

2732 """ 

2733 Update `log_s` using Dirichlet sampling. 

2734 

2735 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior 

2736 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where 

2737 varcount is the count of how many times each variable is used in the 

2738 current forest. 

2739 

2740 Parameters 

2741 ---------- 

2742 key 

2743 Random key for sampling. 

2744 bart 

2745 The current BART state. 

2746 

2747 Returns 

2748 ------- 

2749 Updated BART state with re-sampled `log_s`. 

2750 

2751 Notes 

2752 ----- 

2753 This full conditional is approximated, because it does not take into account 

2754 that there are forbidden decision rules. 

2755 """ 

2756 assert bart.forest.theta is not None 1ab

2757 

2758 # histogram current variable usage 

2759 p = bart.forest.max_split.size 1ab

2760 varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree) 1ab

2761 

2762 # sample from Dirichlet posterior 

2763 alpha = bart.forest.theta / p + varcount 1ab

2764 log_s = random.loggamma(key, alpha) 1ab

2765 

2766 # update forest with new s 

2767 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1ab

2768 

2769 

2770def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 1ab

2771 """ 

2772 Update `theta`. 

2773 

2774 The prior is theta / (theta + rho) ~ Beta(a, b). 

2775 

2776 Parameters 

2777 ---------- 

2778 key 

2779 Random key for sampling. 

2780 bart 

2781 The current BART state. 

2782 num_grid 

2783 The number of points in the evenly-spaced grid used to sample 

2784 theta / (theta + rho). 

2785 

2786 Returns 

2787 ------- 

2788 Updated BART state with re-sampled `theta`. 

2789 """ 

2790 assert bart.forest.log_s is not None 1ab

2791 assert bart.forest.rho is not None 1ab

2792 assert bart.forest.a is not None 1ab

2793 assert bart.forest.b is not None 1ab

2794 

2795 # the grid points are the midpoints of num_grid bins in (0, 1) 

2796 padding = 1 / (2 * num_grid) 1ab

2797 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1ab

2798 

2799 # normalize s 

2800 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1ab

2801 

2802 # sample lambda 

2803 logp, theta_grid = _log_p_lamda( 1ab

2804 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b 

2805 ) 

2806 i = random.categorical(key, logp) 1ab

2807 theta = theta_grid[i] 1ab

2808 

2809 return replace(bart, forest=replace(bart.forest, theta=theta)) 1ab

2810 

2811 

2812def _log_p_lamda( 1ab

2813 lamda: Float32[Array, ' num_grid'], 

2814 log_s: Float32[Array, ' p'], 

2815 rho: Float32[Array, ''], 

2816 a: Float32[Array, ''], 

2817 b: Float32[Array, ''], 

2818) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]: 

2819 # in the following I use lamda[::-1] == 1 - lamda 

2820 theta = rho * lamda / lamda[::-1] 1ab

2821 p = log_s.size 1ab

2822 return ( 1ab

2823 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda) 

2824 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda) 

2825 + gammaln(theta) 

2826 - p * gammaln(theta / p) 

2827 + theta / p * jnp.sum(log_s) 

2828 ), theta 

2829 

2830 

2831def step_sparse(key: Key[Array, ''], bart: State) -> State: 1ab

2832 """ 

2833 Update the sparsity parameters. 

2834 

2835 This invokes `step_s`, and then `step_theta` only if the parameters of 

2836 the theta prior are defined. 

2837 

2838 Parameters 

2839 ---------- 

2840 key 

2841 Random key for sampling. 

2842 bart 

2843 The current BART state. 

2844 

2845 Returns 

2846 ------- 

2847 Updated BART state with re-sampled `log_s` and `theta`. 

2848 """ 

2849 keys = split(key) 1ab

2850 bart = step_s(keys.pop(), bart) 1ab

2851 if bart.forest.rho is not None: 1ab

2852 bart = step_theta(keys.pop(), bart) 1ab

2853 return bart 1ab