Coverage for src/bartz/mcmcstep.py: 95%

551 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-07 22:47 +0000

1# bartz/src/bartz/mcmcstep.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26Functions that implement the BART posterior MCMC initialization and update step. 

27 

28Functions that do MCMC steps operate by taking as input a bart state, and 

29outputting a new state. The inputs are not modified. 

30 

31The entry points are: 

32 

33 - `State`: The dataclass that represents a BART MCMC state. 

34 - `init`: Creates an initial `State` from data and configurations. 

35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`. 

36 - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`. 

37""" 

38 

39import math 1ab

40from dataclasses import replace 1ab

41from functools import cache, partial 1ab

42from typing import Any, Literal 1ab

43 

44import jax 1ab

45from equinox import Module, field, tree_at 1ab

46from jax import lax, random 1ab

47from jax import numpy as jnp 1ab

48from jax.scipy.special import gammaln, logsumexp 1ab

49from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab

50 

51from bartz import grove 1ab

52from bartz.jaxext import ( 1ab

53 minimal_unsigned_dtype, 

54 split, 

55 truncated_normal_onesided, 

56 vmap_nodoc, 

57) 

58 

59 

60class Forest(Module): 1ab

61 """ 

62 Represents the MCMC state of a sum of trees. 

63 

64 Parameters 

65 ---------- 

66 leaf_tree 

67 The leaf values. 

68 var_tree 

69 The decision axes. 

70 split_tree 

71 The decision boundaries. 

72 affluence_tree 

73 Marks leaves that can be grown. 

74 max_split 

75 The maximum split index for each predictor. 

76 blocked_vars 

77 Indices of variables that are not used. This shall include at least 

78 the `i` such that ``max_split[i] == 0``, otherwise behavior is 

79 undefined. 

80 p_nonterminal 

81 The prior probability of each node being nonterminal, conditional on 

82 its ancestors. Includes the nodes at maximum depth which should be set 

83 to 0. 

84 p_propose_grow 

85 The unnormalized probability of picking a leaf for a grow proposal. 

86 leaf_indices 

87 The index of the leaf each datapoints falls into, for each tree. 

88 min_points_per_decision_node 

89 The minimum number of data points in a decision node. 

90 min_points_per_leaf 

91 The minimum number of data points in a leaf node. 

92 resid_batch_size 

93 count_batch_size 

94 The data batch sizes for computing the sufficient statistics. If `None`, 

95 they are computed with no batching. 

96 log_trans_prior 

97 The log transition and prior Metropolis-Hastings ratio for the 

98 proposed move on each tree. 

99 log_likelihood 

100 The log likelihood ratio. 

101 grow_prop_count 

102 prune_prop_count 

103 The number of grow/prune proposals made during one full MCMC cycle. 

104 grow_acc_count 

105 prune_acc_count 

106 The number of grow/prune moves accepted during one full MCMC cycle. 

107 sigma_mu2 

108 The prior variance of a leaf, conditional on the tree structure. 

109 log_s 

110 The logarithm of the prior probability for choosing a variable to split 

111 along in a decision rule, conditional on the ancestors. Not normalized. 

112 If `None`, use a uniform distribution. 

113 theta 

114 The concentration parameter for the Dirichlet prior on the variable 

115 distribution `s`. Required only to update `s`. 

116 a 

117 b 

118 rho 

119 Parameters of the prior on `theta`. Required only to sample `theta`. 

120 See `step_theta`. 

121 """ 

122 

123 leaf_tree: Float32[Array, 'num_trees 2**d'] 1ab

124 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

125 split_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

126 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

127 max_split: UInt[Array, ' p'] 1ab

128 blocked_vars: UInt[Array, ' k'] | None 1ab

129 p_nonterminal: Float32[Array, ' 2**d'] 1ab

130 p_propose_grow: Float32[Array, ' 2**(d-1)'] 1ab

131 leaf_indices: UInt[Array, 'num_trees n'] 1ab

132 min_points_per_decision_node: Int32[Array, ''] | None 1ab

133 min_points_per_leaf: Int32[Array, ''] | None 1ab

134 resid_batch_size: int | None = field(static=True) 1ab

135 count_batch_size: int | None = field(static=True) 1ab

136 log_trans_prior: Float32[Array, ' num_trees'] | None 1ab

137 log_likelihood: Float32[Array, ' num_trees'] | None 1ab

138 grow_prop_count: Int32[Array, ''] 1ab

139 prune_prop_count: Int32[Array, ''] 1ab

140 grow_acc_count: Int32[Array, ''] 1ab

141 prune_acc_count: Int32[Array, ''] 1ab

142 sigma_mu2: Float32[Array, ''] 1ab

143 log_s: Float32[Array, ' p'] | None 1ab

144 theta: Float32[Array, ''] | None 1ab

145 a: Float32[Array, ''] | None 1ab

146 b: Float32[Array, ''] | None 1ab

147 rho: Float32[Array, ''] | None 1ab

148 

149 

150class State(Module): 1ab

151 """ 

152 Represents the MCMC state of BART. 

153 

154 Parameters 

155 ---------- 

156 X 

157 The predictors. 

158 y 

159 The response. If the data type is `bool`, the model is binary regression. 

160 resid 

161 The residuals (`y` or `z` minus sum of trees). 

162 z 

163 The latent variable for binary regression. `None` in continuous 

164 regression. 

165 offset 

166 Constant shift added to the sum of trees. 

167 sigma2 

168 The error variance. `None` in binary regression. 

169 prec_scale 

170 The scale on the error precision, i.e., ``1 / error_scale ** 2``. 

171 `None` in binary regression. 

172 sigma2_alpha 

173 sigma2_beta 

174 The shape and scale parameters of the inverse gamma prior on the noise 

175 variance. `None` in binary regression. 

176 forest 

177 The sum of trees model. 

178 """ 

179 

180 X: UInt[Array, 'p n'] 1ab

181 y: Float32[Array, ' n'] | Bool[Array, ' n'] 1ab

182 z: None | Float32[Array, ' n'] 1ab

183 offset: Float32[Array, ''] 1ab

184 resid: Float32[Array, ' n'] 1ab

185 sigma2: Float32[Array, ''] | None 1ab

186 prec_scale: Float32[Array, ' n'] | None 1ab

187 sigma2_alpha: Float32[Array, ''] | None 1ab

188 sigma2_beta: Float32[Array, ''] | None 1ab

189 forest: Forest 1ab

190 

191 

192def init( 1ab

193 *, 

194 X: UInt[Any, 'p n'], 

195 y: Float32[Any, ' n'] | Bool[Any, ' n'], 

196 offset: float | Float32[Any, ''] = 0.0, 

197 max_split: UInt[Any, ' p'], 

198 num_trees: int, 

199 p_nonterminal: Float32[Any, ' d-1'], 

200 sigma_mu2: float | Float32[Any, ''], 

201 sigma2_alpha: float | Float32[Any, ''] | None = None, 

202 sigma2_beta: float | Float32[Any, ''] | None = None, 

203 error_scale: Float32[Any, ' n'] | None = None, 

204 min_points_per_decision_node: int | Integer[Any, ''] | None = None, 

205 resid_batch_size: int | None | Literal['auto'] = 'auto', 

206 count_batch_size: int | None | Literal['auto'] = 'auto', 

207 save_ratios: bool = False, 

208 filter_splitless_vars: bool = True, 

209 min_points_per_leaf: int | Integer[Any, ''] | None = None, 

210 log_s: Float32[Any, ' p'] | None = None, 

211 theta: float | Float32[Any, ''] | None = None, 

212 a: float | Float32[Any, ''] | None = None, 

213 b: float | Float32[Any, ''] | None = None, 

214 rho: float | Float32[Any, ''] | None = None, 

215) -> State: 

216 """ 

217 Make a BART posterior sampling MCMC initial state. 

218 

219 Parameters 

220 ---------- 

221 X 

222 The predictors. Note this is trasposed compared to the usual convention. 

223 y 

224 The response. If the data type is `bool`, the regression model is binary 

225 regression with probit. 

226 offset 

227 Constant shift added to the sum of trees. 0 if not specified. 

228 max_split 

229 The maximum split index for each variable. All split ranges start at 1. 

230 num_trees 

231 The number of trees in the forest. 

232 p_nonterminal 

233 The probability of a nonterminal node at each depth. The maximum depth 

234 of trees is fixed by the length of this array. 

235 sigma_mu2 

236 The prior variance of a leaf, conditional on the tree structure. The 

237 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The 

238 prior mean of leaves is always zero. 

239 sigma2_alpha 

240 sigma2_beta 

241 The shape and scale parameters of the inverse gamma prior on the error 

242 variance. Leave unspecified for binary regression. 

243 error_scale 

244 Each error is scaled by the corresponding factor in `error_scale`, so 

245 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``. 

246 Not supported for binary regression. If not specified, defaults to 1 for 

247 all points, but potentially skipping calculations. 

248 min_points_per_decision_node 

249 The minimum number of data points in a decision node. 0 if not 

250 specified. 

251 resid_batch_size 

252 count_batch_size 

253 The batch sizes, along datapoints, for summing the residuals and 

254 counting the number of datapoints in each leaf. `None` for no batching. 

255 If 'auto', pick a value based on the device of `y`, or the default 

256 device. 

257 save_ratios 

258 Whether to save the Metropolis-Hastings ratios. 

259 filter_splitless_vars 

260 Whether to check `max_split` for variables without available cutpoints. 

261 If any are found, they are put into a list of variables to exclude from 

262 the MCMC. If `False`, no check is performed, but the results may be 

263 wrong if any variable is blocked. The function is jax-traceable only 

264 if this is set to `False`. 

265 min_points_per_leaf 

266 The minimum number of datapoints in a leaf node. 0 if not specified. 

267 Unlike `min_points_per_decision_node`, this constraint is not taken into 

268 account in the Metropolis-Hastings ratio because it would be expensive 

269 to compute. Grow moves that would violate this constraint are vetoed. 

270 This parameter is independent of `min_points_per_decision_node` and 

271 there is no check that they are coherent. It makes sense to set 

272 ``min_points_per_decision_node >= 2 * min_points_per_leaf``. 

273 log_s 

274 The logarithm of the prior probability for choosing a variable to split 

275 along in a decision rule, conditional on the ancestors. Not normalized. 

276 If not specified, use a uniform distribution. If not specified and 

277 `theta` or `rho`, `a`, `b` are, it's initialized automatically. 

278 theta 

279 The concentration parameter for the Dirichlet prior on `s`. Required 

280 only to update `log_s`. If not specified, and `rho`, `a`, `b` are 

281 specified, it's initialized automatically. 

282 a 

283 b 

284 rho 

285 Parameters of the prior on `theta`. Required only to sample `theta`. 

286 

287 Returns 

288 ------- 

289 An initialized BART MCMC state. 

290 

291 Raises 

292 ------ 

293 ValueError 

294 If `y` is boolean and arguments unused in binary regression are set. 

295 

296 Notes 

297 ----- 

298 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out 

299 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left 

300 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be 

301 integers in the range ``[0, 1, ..., max_split[i]]``. 

302 """ 

303 p_nonterminal = jnp.asarray(p_nonterminal) 1ab

304 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab

305 max_depth = p_nonterminal.size 1ab

306 

307 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab

308 def make_forest(max_depth, dtype): 1ab

309 return grove.make_tree(max_depth, dtype) 1ab

310 

311 y = jnp.asarray(y) 1ab

312 offset = jnp.asarray(offset) 1ab

313 

314 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab

315 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees 

316 ) 

317 

318 is_binary = y.dtype == bool 1ab

319 if is_binary: 1ab

320 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true1ab

321 msg = ( 

322 'error_scale, sigma2_alpha, and sigma2_beta must be set ' 

323 ' to `None` for binary regression.' 

324 ) 

325 raise ValueError(msg) 

326 sigma2 = None 1ab

327 else: 

328 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab

329 sigma2_beta = jnp.asarray(sigma2_beta) 1ab

330 sigma2 = sigma2_beta / sigma2_alpha 1ab

331 

332 max_split = jnp.asarray(max_split) 1ab

333 

334 if filter_splitless_vars: 1ab

335 (blocked_vars,) = jnp.nonzero(max_split == 0) 1ab

336 blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size)) 1ab

337 # see `fully_used_variables` for the type cast 

338 else: 

339 blocked_vars = None 1ab

340 

341 # check and initialize sparsity parameters 

342 if not _all_none_or_not_none(rho, a, b): 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true1ab

343 msg = 'rho, a, b are not either all `None` or all set' 

344 raise ValueError(msg) 

345 if theta is None and rho is not None: 1ab

346 theta = rho 1ab

347 if log_s is None and theta is not None: 1ab

348 log_s = jnp.zeros(max_split.size) 1ab

349 

350 return State( 1ab

351 X=jnp.asarray(X), 

352 y=y, 

353 z=jnp.full(y.shape, offset) if is_binary else None, 

354 offset=offset, 

355 resid=jnp.zeros(y.shape) if is_binary else y - offset, 

356 sigma2=sigma2, 

357 prec_scale=( 

358 None if error_scale is None else lax.reciprocal(jnp.square(error_scale)) 

359 ), 

360 sigma2_alpha=sigma2_alpha, 

361 sigma2_beta=sigma2_beta, 

362 forest=Forest( 

363 leaf_tree=make_forest(max_depth, jnp.float32), 

364 var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)), 

365 split_tree=make_forest(max_depth - 1, max_split.dtype), 

366 affluence_tree=( 

367 make_forest(max_depth - 1, bool) 

368 .at[:, 1] 

369 .set( 

370 True 

371 if min_points_per_decision_node is None 

372 else y.size >= min_points_per_decision_node 

373 ) 

374 ), 

375 blocked_vars=blocked_vars, 

376 max_split=max_split, 

377 grow_prop_count=jnp.zeros((), int), 

378 grow_acc_count=jnp.zeros((), int), 

379 prune_prop_count=jnp.zeros((), int), 

380 prune_acc_count=jnp.zeros((), int), 

381 p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)], 

382 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))], 

383 leaf_indices=jnp.ones( 

384 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1) 

385 ), 

386 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node), 

387 min_points_per_leaf=_asarray_or_none(min_points_per_leaf), 

388 resid_batch_size=resid_batch_size, 

389 count_batch_size=count_batch_size, 

390 log_trans_prior=jnp.zeros(num_trees) if save_ratios else None, 

391 log_likelihood=jnp.zeros(num_trees) if save_ratios else None, 

392 sigma_mu2=jnp.asarray(sigma_mu2), 

393 log_s=_asarray_or_none(log_s), 

394 theta=_asarray_or_none(theta), 

395 rho=_asarray_or_none(rho), 

396 a=_asarray_or_none(a), 

397 b=_asarray_or_none(b), 

398 ), 

399 ) 

400 

401 

402def _all_none_or_not_none(*args): 1ab

403 is_none = [x is None for x in args] 1ab

404 return all(is_none) or not any(is_none) 1ab

405 

406 

407def _asarray_or_none(x): 1ab

408 if x is None: 1ab

409 return None 1ab

410 return jnp.asarray(x) 1ab

411 

412 

413def _choose_suffstat_batch_size( 1ab

414 resid_batch_size, count_batch_size, y, forest_size 

415) -> tuple[int | None, ...]: 

416 @cache 1ab

417 def get_platform(): 1ab

418 try: 1ab

419 device = y.devices().pop() 1ab

420 except jax.errors.ConcretizationTypeError: 1ab

421 device = jax.devices()[0] 1ab

422 platform = device.platform 1ab

423 if platform not in ('cpu', 'gpu'): 423 ↛ 424line 423 didn't jump to line 424 because the condition on line 423 was never true1ab

424 msg = f'Unknown platform: {platform}' 

425 raise KeyError(msg) 

426 return platform 1ab

427 

428 if resid_batch_size == 'auto': 1ab

429 platform = get_platform() 1ab

430 n = max(1, y.size) 1ab

431 if platform == 'cpu': 431 ↛ 433line 431 didn't jump to line 433 because the condition on line 431 was always true1ab

432 resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6 1ab

433 elif platform == 'gpu': 

434 resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3 

435 resid_batch_size = max(1, resid_batch_size) 1ab

436 

437 if count_batch_size == 'auto': 1ab

438 platform = get_platform() 1ab

439 if platform == 'cpu': 439 ↛ 441line 439 didn't jump to line 441 because the condition on line 439 was always true1ab

440 count_batch_size = None 1ab

441 elif platform == 'gpu': 

442 n = max(1, y.size) 

443 count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2 

444 # /4 is good on V100, /2 on L4/T4, still haven't tried A100 

445 max_memory = 2**29 

446 itemsize = 4 

447 min_batch_size = math.ceil(forest_size * itemsize * n / max_memory) 

448 count_batch_size = max(count_batch_size, min_batch_size) 

449 count_batch_size = max(1, count_batch_size) 

450 

451 return resid_batch_size, count_batch_size 1ab

452 

453 

454@jax.jit 1ab

455def step(key: Key[Array, ''], bart: State) -> State: 1ab

456 """ 

457 Do one MCMC step. 

458 

459 Parameters 

460 ---------- 

461 key 

462 A jax random key. 

463 bart 

464 A BART mcmc state, as created by `init`. 

465 

466 Returns 

467 ------- 

468 The new BART mcmc state. 

469 """ 

470 keys = split(key) 1ab

471 

472 if bart.y.dtype == bool: # binary regression 1ab

473 bart = replace(bart, sigma2=jnp.float32(1)) 1ab

474 bart = step_trees(keys.pop(), bart) 1ab

475 bart = replace(bart, sigma2=None) 1ab

476 return step_z(keys.pop(), bart) 1ab

477 

478 else: # continuous regression 

479 bart = step_trees(keys.pop(), bart) 1ab

480 return step_sigma(keys.pop(), bart) 1ab

481 

482 

483def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab

484 """ 

485 Forest sampling step of BART MCMC. 

486 

487 Parameters 

488 ---------- 

489 key 

490 A jax random key. 

491 bart 

492 A BART mcmc state, as created by `init`. 

493 

494 Returns 

495 ------- 

496 The new BART mcmc state. 

497 

498 Notes 

499 ----- 

500 This function zeroes the proposal counters. 

501 """ 

502 keys = split(key) 1ab

503 moves = propose_moves(keys.pop(), bart.forest) 1ab

504 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab

505 

506 

507class Moves(Module): 1ab

508 """ 

509 Moves proposed to modify each tree. 

510 

511 Parameters 

512 ---------- 

513 allowed 

514 Whether there is a possible move. If `False`, the other values may not 

515 make sense. The only case in which a move is marked as allowed but is 

516 then vetoed is if it does not satisfy `min_points_per_leaf`, which for 

517 efficiency is implemented post-hoc without changing the rest of the 

518 MCMC logic. 

519 grow 

520 Whether the move is a grow move or a prune move. 

521 num_growable 

522 The number of growable leaves in the original tree. 

523 node 

524 The index of the leaf to grow or node to prune. 

525 left 

526 right 

527 The indices of the children of 'node'. 

528 partial_ratio 

529 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

530 likelihood ratio, the probability of proposing the prune move, and the 

531 probability that the children of the modified node are terminal. If the 

532 move is PRUNE, the ratio is inverted. `None` once 

533 `log_trans_prior_ratio` has been computed. 

534 log_trans_prior_ratio 

535 The logarithm of the product of the transition and prior terms of the 

536 Metropolis-Hastings ratio for the acceptance of the proposed move. 

537 `None` if not yet computed. If PRUNE, the log-ratio is negated. 

538 grow_var 

539 The decision axes of the new rules. 

540 grow_split 

541 The decision boundaries of the new rules. 

542 var_tree 

543 The updated decision axes of the trees, valid whatever move. 

544 affluence_tree 

545 A partially updated `affluence_tree`, marking non-leaf nodes that would 

546 become leaves if the move was accepted. This mark initially (out of 

547 `propose_moves`) takes into account if there would be available decision 

548 rules to grow the leaf, and whether there are enough datapoints in the 

549 node is marked in `accept_moves_parallel_stage`. 

550 logu 

551 The logarithm of a uniform (0, 1] random variable to be used to 

552 accept the move. It's in (-oo, 0]. 

553 acc 

554 Whether the move was accepted. `None` if not yet computed. 

555 to_prune 

556 Whether the final operation to apply the move is pruning. This indicates 

557 an accepted prune move or a rejected grow move. `None` if not yet 

558 computed. 

559 """ 

560 

561 allowed: Bool[Array, ' num_trees'] 1ab

562 grow: Bool[Array, ' num_trees'] 1ab

563 num_growable: UInt[Array, ' num_trees'] 1ab

564 node: UInt[Array, ' num_trees'] 1ab

565 left: UInt[Array, ' num_trees'] 1ab

566 right: UInt[Array, ' num_trees'] 1ab

567 partial_ratio: Float32[Array, ' num_trees'] | None 1ab

568 log_trans_prior_ratio: None | Float32[Array, ' num_trees'] 1ab

569 grow_var: UInt[Array, ' num_trees'] 1ab

570 grow_split: UInt[Array, ' num_trees'] 1ab

571 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

572 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

573 logu: Float32[Array, ' num_trees'] 1ab

574 acc: None | Bool[Array, ' num_trees'] 1ab

575 to_prune: None | Bool[Array, ' num_trees'] 1ab

576 

577 

578def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 1ab

579 """ 

580 Propose moves for all the trees. 

581 

582 There are two types of moves: GROW (convert a leaf to a decision node and 

583 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a 

584 leaf, deleting its children). 

585 

586 Parameters 

587 ---------- 

588 key 

589 A jax random key. 

590 forest 

591 The `forest` field of a BART MCMC state. 

592 

593 Returns 

594 ------- 

595 The proposed move for each tree. 

596 """ 

597 num_trees, _ = forest.leaf_tree.shape 1ab

598 keys = split(key, 1 + 2 * num_trees) 1ab

599 

600 # compute moves 

601 grow_moves = propose_grow_moves( 1ab

602 keys.pop(num_trees), 

603 forest.var_tree, 

604 forest.split_tree, 

605 forest.affluence_tree, 

606 forest.max_split, 

607 forest.blocked_vars, 

608 forest.p_nonterminal, 

609 forest.p_propose_grow, 

610 forest.log_s, 

611 ) 

612 prune_moves = propose_prune_moves( 1ab

613 keys.pop(num_trees), 

614 forest.split_tree, 

615 grow_moves.affluence_tree, 

616 forest.p_nonterminal, 

617 forest.p_propose_grow, 

618 ) 

619 

620 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1ab

621 

622 # choose between grow or prune 

623 p_grow = jnp.where( 1ab

624 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed 

625 ) 

626 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab

627 

628 # compute children indices 

629 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab

630 left = node << 1 1ab

631 right = left + 1 1ab

632 

633 return Moves( 1ab

634 allowed=grow_moves.allowed | prune_moves.allowed, 

635 grow=grow, 

636 num_growable=grow_moves.num_growable, 

637 node=node, 

638 left=left, 

639 right=right, 

640 partial_ratio=jnp.where( 

641 grow, grow_moves.partial_ratio, prune_moves.partial_ratio 

642 ), 

643 log_trans_prior_ratio=None, # will be set in complete_ratio 

644 grow_var=grow_moves.var, 

645 grow_split=grow_moves.split, 

646 # var_tree does not need to be updated if prune 

647 var_tree=grow_moves.var_tree, 

648 # affluence_tree is updated for both moves unconditionally, prune last 

649 affluence_tree=prune_moves.affluence_tree, 

650 logu=jnp.log1p(-exp1mlogu), 

651 acc=None, # will be set in accept_moves_sequential_stage 

652 to_prune=None, # will be set in accept_moves_sequential_stage 

653 ) 

654 

655 

656class GrowMoves(Module): 1ab

657 """ 

658 Represent a proposed grow move for each tree. 

659 

660 Parameters 

661 ---------- 

662 allowed 

663 Whether the move is allowed for proposal. 

664 num_growable 

665 The number of leaves that can be proposed for grow. 

666 node 

667 The index of the leaf to grow. ``2 ** d`` if there are no growable 

668 leaves. 

669 var 

670 split 

671 The decision axis and boundary of the new rule. 

672 partial_ratio 

673 A factor of the Metropolis-Hastings ratio of the move. It lacks 

674 the likelihood ratio and the probability of proposing the prune 

675 move. 

676 var_tree 

677 The updated decision axes of the tree. 

678 affluence_tree 

679 A partially updated `affluence_tree` that marks each new leaf that 

680 would be produced as `True` if it would have available decision rules. 

681 """ 

682 

683 allowed: Bool[Array, ' num_trees'] 1ab

684 num_growable: UInt[Array, ' num_trees'] 1ab

685 node: UInt[Array, ' num_trees'] 1ab

686 var: UInt[Array, ' num_trees'] 1ab

687 split: UInt[Array, ' num_trees'] 1ab

688 partial_ratio: Float32[Array, ' num_trees'] 1ab

689 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

690 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

691 

692 

693@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) 1ab

694def propose_grow_moves( 1ab

695 key: Key[Array, ' num_trees'], 

696 var_tree: UInt[Array, 'num_trees 2**(d-1)'], 

697 split_tree: UInt[Array, 'num_trees 2**(d-1)'], 

698 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'], 

699 max_split: UInt[Array, ' p'], 

700 blocked_vars: Int32[Array, ' k'] | None, 

701 p_nonterminal: Float32[Array, ' 2**d'], 

702 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

703 log_s: Float32[Array, ' p'] | None, 

704) -> GrowMoves: 

705 """ 

706 Propose a GROW move for each tree. 

707 

708 A GROW move picks a leaf node and converts it to a non-terminal node with 

709 two leaf children. 

710 

711 Parameters 

712 ---------- 

713 key 

714 A jax random key. 

715 var_tree 

716 The splitting axes of the tree. 

717 split_tree 

718 The splitting points of the tree. 

719 affluence_tree 

720 Whether each leaf has enough points to be grown. 

721 max_split 

722 The maximum split index for each variable. 

723 blocked_vars 

724 The indices of the variables that have no available cutpoints. 

725 p_nonterminal 

726 The a priori probability of a node to be nonterminal conditional on the 

727 ancestors, including at the maximum depth where it should be zero. 

728 p_propose_grow 

729 The unnormalized probability of choosing a leaf to grow. 

730 log_s 

731 Unnormalized log-probability used to choose a variable to split on 

732 amongst the available ones. 

733 

734 Returns 

735 ------- 

736 An object representing the proposed move. 

737 

738 Notes 

739 ----- 

740 The move is not proposed if each leaf is already at maximum depth, or has 

741 less datapoints than the requested threshold `min_points_per_decision_node`, 

742 or it does not have any available decision rules given its ancestors. This 

743 is marked by setting `allowed` to `False` and `num_growable` to 0. 

744 """ 

745 keys = split(key, 3) 1ab

746 

747 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab

748 keys.pop(), split_tree, affluence_tree, p_propose_grow 

749 ) 

750 

751 # sample a decision rule 

752 var, num_available_var = choose_variable( 1ab

753 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s 

754 ) 

755 split_idx, l, r = choose_split( 1ab

756 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow 

757 ) 

758 

759 # determine if the new leaves would have available decision rules; if the 

760 # move is blocked, these values may not make sense 

761 left_growable = right_growable = num_available_var > 1 1ab

762 left_growable |= l < split_idx 1ab

763 right_growable |= split_idx + 1 < r 1ab

764 left = leaf_to_grow << 1 1ab

765 right = left + 1 1ab

766 affluence_tree = affluence_tree.at[left].set(left_growable) 1ab

767 affluence_tree = affluence_tree.at[right].set(right_growable) 1ab

768 

769 ratio = compute_partial_ratio( 1ab

770 prob_choose, num_prunable, p_nonterminal, leaf_to_grow 

771 ) 

772 

773 return GrowMoves( 1ab

774 allowed=num_growable > 0, 

775 num_growable=num_growable, 

776 node=leaf_to_grow, 

777 var=var, 

778 split=split_idx, 

779 partial_ratio=ratio, 

780 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)), 

781 affluence_tree=affluence_tree, 

782 ) 

783 

784 

785def choose_leaf( 1ab

786 key: Key[Array, ''], 

787 split_tree: UInt[Array, ' 2**(d-1)'], 

788 affluence_tree: Bool[Array, ' 2**(d-1)'], 

789 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

790) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]: 

791 """ 

792 Choose a leaf node to grow in a tree. 

793 

794 Parameters 

795 ---------- 

796 key 

797 A jax random key. 

798 split_tree 

799 The splitting points of the tree. 

800 affluence_tree 

801 Whether a leaf has enough points that it could be split into two leaves 

802 satisfying the `min_points_per_leaf` requirement. 

803 p_propose_grow 

804 The unnormalized probability of choosing a leaf to grow. 

805 

806 Returns 

807 ------- 

808 leaf_to_grow : Int32[Array, ''] 

809 The index of the leaf to grow. If ``num_growable == 0``, return 

810 ``2 ** d``. 

811 num_growable : Int32[Array, ''] 

812 The number of leaf nodes that can be grown, i.e., are nonterminal 

813 and have at least twice `min_points_per_leaf`. 

814 prob_choose : Float32[Array, ''] 

815 The (normalized) probability that this function had to choose that 

816 specific leaf, given the arguments. 

817 num_prunable : Int32[Array, ''] 

818 The number of leaf parents that could be pruned, after converting the 

819 selected leaf to a non-terminal node. 

820 """ 

821 is_growable = growable_leaves(split_tree, affluence_tree) 1ab

822 num_growable = jnp.count_nonzero(is_growable) 1ab

823 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab

824 leaf_to_grow, distr_norm = categorical(key, distr) 1ab

825 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab

826 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1ab

827 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab

828 num_prunable = jnp.count_nonzero(is_parent) 1ab

829 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab

830 

831 

832def growable_leaves( 1ab

833 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)'] 

834) -> Bool[Array, ' 2**(d-1)']: 

835 """ 

836 Return a mask indicating the leaf nodes that can be proposed for growth. 

837 

838 The condition is that a leaf is not at the bottom level, has available 

839 decision rules given its ancestors, and has at least 

840 `min_points_per_decision_node` points. 

841 

842 Parameters 

843 ---------- 

844 split_tree 

845 The splitting points of the tree. 

846 affluence_tree 

847 Marks leaves that can be grown. 

848 

849 Returns 

850 ------- 

851 The mask indicating the leaf nodes that can be proposed to grow. 

852 

853 Notes 

854 ----- 

855 This function needs `split_tree` and not just `affluence_tree` because 

856 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`. 

857 """ 

858 return grove.is_actual_leaf(split_tree) & affluence_tree 1ab

859 

860 

861def categorical( 1ab

862 key: Key[Array, ''], distr: Float32[Array, ' n'] 

863) -> tuple[Int32[Array, ''], Float32[Array, '']]: 

864 """ 

865 Return a random integer from an arbitrary distribution. 

866 

867 Parameters 

868 ---------- 

869 key 

870 A jax random key. 

871 distr 

872 An unnormalized probability distribution. 

873 

874 Returns 

875 ------- 

876 u : Int32[Array, ''] 

877 A random integer in the range ``[0, n)``. If all probabilities are zero, 

878 return ``n``. 

879 norm : Float32[Array, ''] 

880 The sum of `distr`. 

881 

882 Notes 

883 ----- 

884 This function uses a cumsum instead of the Gumbel trick, so it's ok only 

885 for small ranges with probabilities well greater than 0. 

886 """ 

887 ecdf = jnp.cumsum(distr) 1ab

888 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab

889 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab

890 

891 

892def choose_variable( 1ab

893 key: Key[Array, ''], 

894 var_tree: UInt[Array, ' 2**(d-1)'], 

895 split_tree: UInt[Array, ' 2**(d-1)'], 

896 max_split: UInt[Array, ' p'], 

897 leaf_index: Int32[Array, ''], 

898 blocked_vars: Int32[Array, ' k'] | None, 

899 log_s: Float32[Array, ' p'] | None, 

900) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

901 """ 

902 Choose a variable to split on for a new non-terminal node. 

903 

904 Parameters 

905 ---------- 

906 key 

907 A jax random key. 

908 var_tree 

909 The variable indices of the tree. 

910 split_tree 

911 The splitting points of the tree. 

912 max_split 

913 The maximum split index for each variable. 

914 leaf_index 

915 The index of the leaf to grow. 

916 blocked_vars 

917 The indices of the variables that have no available cutpoints. If 

918 `None`, all variables are assumed unblocked. 

919 log_s 

920 The logarithm of the prior probability for choosing a variable. If 

921 `None`, use a uniform distribution. 

922 

923 Returns 

924 ------- 

925 var : Int32[Array, ''] 

926 The index of the variable to split on. 

927 num_available_var : Int32[Array, ''] 

928 The number of variables with available decision rules `var` was chosen 

929 from. 

930 """ 

931 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab

932 if blocked_vars is not None: 1ab

933 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab

934 

935 if log_s is None: 1ab

936 return randint_exclude(key, max_split.size, var_to_ignore) 1ab

937 else: 

938 return categorical_exclude(key, log_s, var_to_ignore) 1ab

939 

940 

941def fully_used_variables( 1ab

942 var_tree: UInt[Array, ' 2**(d-1)'], 

943 split_tree: UInt[Array, ' 2**(d-1)'], 

944 max_split: UInt[Array, ' p'], 

945 leaf_index: Int32[Array, ''], 

946) -> UInt[Array, ' d-2']: 

947 """ 

948 Find variables in the ancestors of a node that have an empty split range. 

949 

950 Parameters 

951 ---------- 

952 var_tree 

953 The variable indices of the tree. 

954 split_tree 

955 The splitting points of the tree. 

956 max_split 

957 The maximum split index for each variable. 

958 leaf_index 

959 The index of the node, assumed to be valid for `var_tree`. 

960 

961 Returns 

962 ------- 

963 The indices of the variables that have an empty split range. 

964 

965 Notes 

966 ----- 

967 The number of unused variables is not known in advance. Unused values in the 

968 array are filled with `p`. The fill values are not guaranteed to be placed 

969 in any particular order, and variables may appear more than once. 

970 """ 

971 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab

972 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab

973 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab

974 num_split = r - l 1ab

975 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab

976 # the type of var_to_ignore is already sufficient to hold max_split.size, 

977 # see ancestor_variables() 

978 

979 

980def ancestor_variables( 1ab

981 var_tree: UInt[Array, ' 2**(d-1)'], 

982 max_split: UInt[Array, ' p'], 

983 node_index: Int32[Array, ''], 

984) -> UInt[Array, ' d-2']: 

985 """ 

986 Return the list of variables in the ancestors of a node. 

987 

988 Parameters 

989 ---------- 

990 var_tree 

991 The variable indices of the tree. 

992 max_split 

993 The maximum split index for each variable. Used only to get `p`. 

994 node_index 

995 The index of the node, assumed to be valid for `var_tree`. 

996 

997 Returns 

998 ------- 

999 The variable indices of the ancestors of the node. 

1000 

1001 Notes 

1002 ----- 

1003 The ancestors are the nodes going from the root to the parent of the node. 

1004 The number of ancestors is not known at tracing time; unused spots in the 

1005 output array are filled with `p`. 

1006 """ 

1007 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

1008 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab

1009 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab

1010 

1011 def loop(carry, _): 1ab

1012 i, index, ancestor_vars = carry 1ab

1013 index >>= 1 1ab

1014 var = var_tree[index] 1ab

1015 var = jnp.where(index, var, max_split.size) 1ab

1016 ancestor_vars = ancestor_vars.at[i].set(var) 1ab

1017 return (i - 1, index, ancestor_vars), None 1ab

1018 

1019 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab

1020 return ancestor_vars 1ab

1021 

1022 

1023def split_range( 1ab

1024 var_tree: UInt[Array, ' 2**(d-1)'], 

1025 split_tree: UInt[Array, ' 2**(d-1)'], 

1026 max_split: UInt[Array, ' p'], 

1027 node_index: Int32[Array, ''], 

1028 ref_var: Int32[Array, ''], 

1029) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1030 """ 

1031 Return the range of allowed splits for a variable at a given node. 

1032 

1033 Parameters 

1034 ---------- 

1035 var_tree 

1036 The variable indices of the tree. 

1037 split_tree 

1038 The splitting points of the tree. 

1039 max_split 

1040 The maximum split index for each variable. 

1041 node_index 

1042 The index of the node, assumed to be valid for `var_tree`. 

1043 ref_var 

1044 The variable for which to measure the split range. 

1045 

1046 Returns 

1047 ------- 

1048 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1. 

1049 """ 

1050 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

1051 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab

1052 jnp.int32 

1053 ) 

1054 carry = jnp.int32(0), initial_r, node_index 1ab

1055 

1056 def loop(carry, _): 1ab

1057 l, r, index = carry 1ab

1058 right_child = (index & 1).astype(bool) 1ab

1059 index >>= 1 1ab

1060 split = split_tree[index] 1ab

1061 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab

1062 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab

1063 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab

1064 return (l, r, index), None 1ab

1065 

1066 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab

1067 return l + 1, r 1ab

1068 

1069 

1070def randint_exclude( 1ab

1071 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n'] 

1072) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1073 """ 

1074 Return a random integer in a range, excluding some values. 

1075 

1076 Parameters 

1077 ---------- 

1078 key 

1079 A jax random key. 

1080 sup 

1081 The exclusive upper bound of the range. 

1082 exclude 

1083 The values to exclude from the range. Values greater than or equal to 

1084 `sup` are ignored. Values can appear more than once. 

1085 

1086 Returns 

1087 ------- 

1088 u : Int32[Array, ''] 

1089 A random integer `u` in the range ``[0, sup)`` such that ``u not in 

1090 exclude``. 

1091 num_allowed : Int32[Array, ''] 

1092 The number of integers in the range that were not excluded. 

1093 

1094 Notes 

1095 ----- 

1096 If all values in the range are excluded, return `sup`. 

1097 """ 

1098 exclude, num_allowed = _process_exclude(sup, exclude) 1ab

1099 u = random.randint(key, (), 0, num_allowed) 1ab

1100 

1101 def loop(u, i_excluded): 1ab

1102 return jnp.where(i_excluded <= u, u + 1, u), None 1ab

1103 

1104 u, _ = lax.scan(loop, u, exclude) 1ab

1105 return u, num_allowed 1ab

1106 

1107 

1108def _process_exclude(sup, exclude): 1ab

1109 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab

1110 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab

1111 return exclude, num_allowed 1ab

1112 

1113 

1114def categorical_exclude( 1ab

1115 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n'] 

1116) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

1117 """ 

1118 Draw from a categorical distribution, excluding a set of values. 

1119 

1120 Parameters 

1121 ---------- 

1122 key 

1123 A jax random key. 

1124 logits 

1125 The unnormalized log-probabilities of each category. 

1126 exclude 

1127 The values to exclude from the range [0, k). Values greater than or 

1128 equal to `logits.size` are ignored. Values can appear more than once. 

1129 

1130 Returns 

1131 ------- 

1132 u : Int32[Array, ''] 

1133 A random integer in the range ``[0, k)`` such that ``u not in exclude``. 

1134 num_allowed : Int32[Array, ''] 

1135 The number of integers in the range that were not excluded. 

1136 

1137 Notes 

1138 ----- 

1139 If all values in the range are excluded, the result is unspecified. 

1140 """ 

1141 exclude, num_allowed = _process_exclude(logits.size, exclude) 1ab

1142 kinda_neg_inf = jnp.finfo(logits.dtype).min 1ab

1143 logits = logits.at[exclude].set(kinda_neg_inf) 1ab

1144 u = random.categorical(key, logits) 1ab

1145 return u, num_allowed 1ab

1146 

1147 

1148def choose_split( 1ab

1149 key: Key[Array, ''], 

1150 var: Int32[Array, ''], 

1151 var_tree: UInt[Array, ' 2**(d-1)'], 

1152 split_tree: UInt[Array, ' 2**(d-1)'], 

1153 max_split: UInt[Array, ' p'], 

1154 leaf_index: Int32[Array, ''], 

1155) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]: 

1156 """ 

1157 Choose a split point for a new non-terminal node. 

1158 

1159 Parameters 

1160 ---------- 

1161 key 

1162 A jax random key. 

1163 var 

1164 The variable to split on. 

1165 var_tree 

1166 The splitting axes of the tree. Does not need to already contain `var` 

1167 at `leaf_index`. 

1168 split_tree 

1169 The splitting points of the tree. 

1170 max_split 

1171 The maximum split index for each variable. 

1172 leaf_index 

1173 The index of the leaf to grow. 

1174 

1175 Returns 

1176 ------- 

1177 split : Int32[Array, ''] 

1178 The cutpoint. 

1179 l : Int32[Array, ''] 

1180 r : Int32[Array, ''] 

1181 The integer range `split` was drawn from is [l, r). 

1182 

1183 Notes 

1184 ----- 

1185 If `var` is out of bounds, or if the available split range on that variable 

1186 is empty, return 0. 

1187 """ 

1188 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab

1189 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1ab

1190 

1191 

1192def compute_partial_ratio( 1ab

1193 prob_choose: Float32[Array, ''], 

1194 num_prunable: Int32[Array, ''], 

1195 p_nonterminal: Float32[Array, ' 2**d'], 

1196 leaf_to_grow: Int32[Array, ''], 

1197) -> Float32[Array, '']: 

1198 """ 

1199 Compute the product of the transition and prior ratios of a grow move. 

1200 

1201 Parameters 

1202 ---------- 

1203 prob_choose 

1204 The probability that the leaf had to be chosen amongst the growable 

1205 leaves. 

1206 num_prunable 

1207 The number of leaf parents that could be pruned, after converting the 

1208 leaf to be grown to a non-terminal node. 

1209 p_nonterminal 

1210 The a priori probability of each node being nonterminal conditional on 

1211 its ancestors. 

1212 leaf_to_grow 

1213 The index of the leaf to grow. 

1214 

1215 Returns 

1216 ------- 

1217 The partial transition ratio times the prior ratio. 

1218 

1219 Notes 

1220 ----- 

1221 The transition ratio is P(new tree => old tree) / P(old tree => new tree). 

1222 The "partial" transition ratio returned is missing the factor P(propose 

1223 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The 

1224 "partial" prior ratio is missing the factor P(children are leaves). 

1225 """ 

1226 # the two ratios also contain factors num_available_split * 

1227 # num_available_var * s[var], but they cancel out 

1228 

1229 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be 

1230 # computed here because they need the count trees, which are computed in the 

1231 # acceptance phase 

1232 

1233 prune_allowed = leaf_to_grow != 1 1ab

1234 # prune allowed <---> the initial tree is not a root 

1235 # leaf to grow is root --> the tree can only be a root 

1236 # tree is a root --> the only leaf I can grow is root 

1237 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab

1238 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab

1239 

1240 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this 

1241 # would produce a 0 and then an inf when `complete_ratio` takes the log 

1242 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1ab

1243 tree_ratio = pnt / (1 - pnt) 1ab

1244 

1245 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1ab

1246 

1247 

1248class PruneMoves(Module): 1ab

1249 """ 

1250 Represent a proposed prune move for each tree. 

1251 

1252 Parameters 

1253 ---------- 

1254 allowed 

1255 Whether the move is possible. 

1256 node 

1257 The index of the node to prune. ``2 ** d`` if no node can be pruned. 

1258 partial_ratio 

1259 A factor of the Metropolis-Hastings ratio of the move. It lacks the 

1260 likelihood ratio, the probability of proposing the prune move, and the 

1261 prior probability that the children of the node to prune are leaves. 

1262 This ratio is inverted, and is meant to be inverted back in 

1263 `accept_move_and_sample_leaves`. 

1264 """ 

1265 

1266 allowed: Bool[Array, ' num_trees'] 1ab

1267 node: UInt[Array, ' num_trees'] 1ab

1268 partial_ratio: Float32[Array, ' num_trees'] 1ab

1269 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab

1270 

1271 

1272@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab

1273def propose_prune_moves( 1ab

1274 key: Key[Array, ''], 

1275 split_tree: UInt[Array, ' 2**(d-1)'], 

1276 affluence_tree: Bool[Array, ' 2**(d-1)'], 

1277 p_nonterminal: Float32[Array, ' 2**d'], 

1278 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

1279) -> PruneMoves: 

1280 """ 

1281 Tree structure prune move proposal of BART MCMC. 

1282 

1283 Parameters 

1284 ---------- 

1285 key 

1286 A jax random key. 

1287 split_tree 

1288 The splitting points of the tree. 

1289 affluence_tree 

1290 Whether each leaf can be grown. 

1291 p_nonterminal 

1292 The a priori probability of a node to be nonterminal conditional on 

1293 the ancestors, including at the maximum depth where it should be zero. 

1294 p_propose_grow 

1295 The unnormalized probability of choosing a leaf to grow. 

1296 

1297 Returns 

1298 ------- 

1299 An object representing the proposed moves. 

1300 """ 

1301 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1ab

1302 key, split_tree, affluence_tree, p_propose_grow 

1303 ) 

1304 

1305 ratio = compute_partial_ratio( 1ab

1306 prob_choose, num_prunable, p_nonterminal, node_to_prune 

1307 ) 

1308 

1309 return PruneMoves( 1ab

1310 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root 

1311 node=node_to_prune, 

1312 partial_ratio=ratio, 

1313 affluence_tree=affluence_tree, 

1314 ) 

1315 

1316 

1317def choose_leaf_parent( 1ab

1318 key: Key[Array, ''], 

1319 split_tree: UInt[Array, ' 2**(d-1)'], 

1320 affluence_tree: Bool[Array, ' 2**(d-1)'], 

1321 p_propose_grow: Float32[Array, ' 2**(d-1)'], 

1322) -> tuple[ 

1323 Int32[Array, ''], 

1324 Int32[Array, ''], 

1325 Float32[Array, ''], 

1326 Bool[Array, 'num_trees 2**(d-1)'], 

1327]: 

1328 """ 

1329 Pick a non-terminal node with leaf children to prune in a tree. 

1330 

1331 Parameters 

1332 ---------- 

1333 key 

1334 A jax random key. 

1335 split_tree 

1336 The splitting points of the tree. 

1337 affluence_tree 

1338 Whether a leaf has enough points to be grown. 

1339 p_propose_grow 

1340 The unnormalized probability of choosing a leaf to grow. 

1341 

1342 Returns 

1343 ------- 

1344 node_to_prune : Int32[Array, ''] 

1345 The index of the node to prune. If ``num_prunable == 0``, return 

1346 ``2 ** d``. 

1347 num_prunable : Int32[Array, ''] 

1348 The number of leaf parents that could be pruned. 

1349 prob_choose : Float32[Array, ''] 

1350 The (normalized) probability that `choose_leaf` would chose 

1351 `node_to_prune` as leaf to grow, if passed the tree where 

1352 `node_to_prune` had been pruned. 

1353 affluence_tree : Bool[Array, 'num_trees 2**(d-1)'] 

1354 A partially updated `affluence_tree`, marking the node to prune as 

1355 growable. 

1356 """ 

1357 # sample a node to prune 

1358 is_prunable = grove.is_leaves_parent(split_tree) 1ab

1359 num_prunable = jnp.count_nonzero(is_prunable) 1ab

1360 node_to_prune = randint_masked(key, is_prunable) 1ab

1361 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab

1362 

1363 # compute stuff for reverse move 

1364 split_tree = split_tree.at[node_to_prune].set(0) 1ab

1365 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab

1366 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab

1367 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab

1368 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1ab

1369 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1ab

1370 

1371 return node_to_prune, num_prunable, prob_choose, affluence_tree 1ab

1372 

1373 

1374def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 1ab

1375 """ 

1376 Return a random integer in a range, including only some values. 

1377 

1378 Parameters 

1379 ---------- 

1380 key 

1381 A jax random key. 

1382 mask 

1383 The mask indicating the allowed values. 

1384 

1385 Returns 

1386 ------- 

1387 A random integer in the range ``[0, n)`` such that ``mask[u] == True``. 

1388 

1389 Notes 

1390 ----- 

1391 If all values in the mask are `False`, return `n`. 

1392 """ 

1393 ecdf = jnp.cumsum(mask) 1ab

1394 u = random.randint(key, (), 0, ecdf[-1]) 1ab

1395 return jnp.searchsorted(ecdf, u, 'right') 1ab

1396 

1397 

1398def accept_moves_and_sample_leaves( 1ab

1399 key: Key[Array, ''], bart: State, moves: Moves 

1400) -> State: 

1401 """ 

1402 Accept or reject the proposed moves and sample the new leaf values. 

1403 

1404 Parameters 

1405 ---------- 

1406 key 

1407 A jax random key. 

1408 bart 

1409 A valid BART mcmc state. 

1410 moves 

1411 The proposed moves, see `propose_moves`. 

1412 

1413 Returns 

1414 ------- 

1415 A new (valid) BART mcmc state. 

1416 """ 

1417 pso = accept_moves_parallel_stage(key, bart, moves) 1ab

1418 bart, moves = accept_moves_sequential_stage(pso) 1ab

1419 return accept_moves_final_stage(bart, moves) 1ab

1420 

1421 

1422class Counts(Module): 1ab

1423 """ 

1424 Number of datapoints in the nodes involved in proposed moves for each tree. 

1425 

1426 Parameters 

1427 ---------- 

1428 left 

1429 Number of datapoints in the left child. 

1430 right 

1431 Number of datapoints in the right child. 

1432 total 

1433 Number of datapoints in the parent (``= left + right``). 

1434 """ 

1435 

1436 left: UInt[Array, ' num_trees'] 1ab

1437 right: UInt[Array, ' num_trees'] 1ab

1438 total: UInt[Array, ' num_trees'] 1ab

1439 

1440 

1441class Precs(Module): 1ab

1442 """ 

1443 Likelihood precision scale in the nodes involved in proposed moves for each tree. 

1444 

1445 The "likelihood precision scale" of a tree node is the sum of the inverse 

1446 squared error scales of the datapoints selected by the node. 

1447 

1448 Parameters 

1449 ---------- 

1450 left 

1451 Likelihood precision scale in the left child. 

1452 right 

1453 Likelihood precision scale in the right child. 

1454 total 

1455 Likelihood precision scale in the parent (``= left + right``). 

1456 """ 

1457 

1458 left: Float32[Array, ' num_trees'] 1ab

1459 right: Float32[Array, ' num_trees'] 1ab

1460 total: Float32[Array, ' num_trees'] 1ab

1461 

1462 

1463class PreLkV(Module): 1ab

1464 """ 

1465 Non-sequential terms of the likelihood ratio for each tree. 

1466 

1467 These terms can be computed in parallel across trees. 

1468 

1469 Parameters 

1470 ---------- 

1471 sigma2_left 

1472 The noise variance in the left child of the leaves grown or pruned by 

1473 the moves. 

1474 sigma2_right 

1475 The noise variance in the right child of the leaves grown or pruned by 

1476 the moves. 

1477 sigma2_total 

1478 The noise variance in the total of the leaves grown or pruned by the 

1479 moves. 

1480 sqrt_term 

1481 The **logarithm** of the square root term of the likelihood ratio. 

1482 """ 

1483 

1484 sigma2_left: Float32[Array, ' num_trees'] 1ab

1485 sigma2_right: Float32[Array, ' num_trees'] 1ab

1486 sigma2_total: Float32[Array, ' num_trees'] 1ab

1487 sqrt_term: Float32[Array, ' num_trees'] 1ab

1488 

1489 

1490class PreLk(Module): 1ab

1491 """ 

1492 Non-sequential terms of the likelihood ratio shared by all trees. 

1493 

1494 Parameters 

1495 ---------- 

1496 exp_factor 

1497 The factor to multiply the likelihood ratio by, shared by all trees. 

1498 """ 

1499 

1500 exp_factor: Float32[Array, ''] 1ab

1501 

1502 

1503class PreLf(Module): 1ab

1504 """ 

1505 Pre-computed terms used to sample leaves from their posterior. 

1506 

1507 These terms can be computed in parallel across trees. 

1508 

1509 Parameters 

1510 ---------- 

1511 mean_factor 

1512 The factor to be multiplied by the sum of the scaled residuals to 

1513 obtain the posterior mean. 

1514 centered_leaves 

1515 The mean-zero normal values to be added to the posterior mean to 

1516 obtain the posterior leaf samples. 

1517 """ 

1518 

1519 mean_factor: Float32[Array, 'num_trees 2**d'] 1ab

1520 centered_leaves: Float32[Array, 'num_trees 2**d'] 1ab

1521 

1522 

1523class ParallelStageOut(Module): 1ab

1524 """ 

1525 The output of `accept_moves_parallel_stage`. 

1526 

1527 Parameters 

1528 ---------- 

1529 bart 

1530 A partially updated BART mcmc state. 

1531 moves 

1532 The proposed moves, with `partial_ratio` set to `None` and 

1533 `log_trans_prior_ratio` set to its final value. 

1534 prec_trees 

1535 The likelihood precision scale in each potential or actual leaf node. If 

1536 there is no precision scale, this is the number of points in each leaf. 

1537 move_counts 

1538 The counts of the number of points in the the nodes modified by the 

1539 moves. If `bart.min_points_per_leaf` is not set and 

1540 `bart.prec_scale` is set, they are not computed. 

1541 move_precs 

1542 The likelihood precision scale in each node modified by the moves. If 

1543 `bart.prec_scale` is not set, this is set to `move_counts`. 

1544 prelkv 

1545 prelk 

1546 prelf 

1547 Objects with pre-computed terms of the likelihood ratios and leaf 

1548 samples. 

1549 """ 

1550 

1551 bart: State 1ab

1552 moves: Moves 1ab

1553 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab

1554 move_precs: Precs | Counts 1ab

1555 prelkv: PreLkV 1ab

1556 prelk: PreLk 1ab

1557 prelf: PreLf 1ab

1558 

1559 

1560def accept_moves_parallel_stage( 1ab

1561 key: Key[Array, ''], bart: State, moves: Moves 

1562) -> ParallelStageOut: 

1563 """ 

1564 Pre-compute quantities used to accept moves, in parallel across trees. 

1565 

1566 Parameters 

1567 ---------- 

1568 key : jax.dtypes.prng_key array 

1569 A jax random key. 

1570 bart : dict 

1571 A BART mcmc state. 

1572 moves : dict 

1573 The proposed moves, see `propose_moves`. 

1574 

1575 Returns 

1576 ------- 

1577 An object with all that could be done in parallel. 

1578 """ 

1579 # where the move is grow, modify the state like the move was accepted 

1580 bart = replace( 1ab

1581 bart, 

1582 forest=replace( 

1583 bart.forest, 

1584 var_tree=moves.var_tree, 

1585 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X), 

1586 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves), 

1587 ), 

1588 ) 

1589 

1590 # count number of datapoints per leaf 

1591 if ( 1591 ↛ 1601line 1591 didn't jump to line 1601 because the condition on line 1591 was always true

1592 bart.forest.min_points_per_decision_node is not None 

1593 or bart.forest.min_points_per_leaf is not None 

1594 or bart.prec_scale is None 

1595 ): 

1596 count_trees, move_counts = compute_count_trees( 1ab

1597 bart.forest.leaf_indices, moves, bart.forest.count_batch_size 

1598 ) 

1599 

1600 # mark which leaves & potential leaves have enough points to be grown 

1601 if bart.forest.min_points_per_decision_node is not None: 1ab

1602 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab

1603 moves = replace( 1ab

1604 moves, 

1605 affluence_tree=moves.affluence_tree 

1606 & (count_half_trees >= bart.forest.min_points_per_decision_node), 

1607 ) 

1608 

1609 # copy updated affluence_tree to state 

1610 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1ab

1611 

1612 # veto grove move if new leaves don't have enough datapoints 

1613 if bart.forest.min_points_per_leaf is not None: 1ab

1614 moves = replace( 1ab

1615 moves, 

1616 allowed=moves.allowed 

1617 & (move_counts.left >= bart.forest.min_points_per_leaf) 

1618 & (move_counts.right >= bart.forest.min_points_per_leaf), 

1619 ) 

1620 

1621 # count number of datapoints per leaf, weighted by error precision scale 

1622 if bart.prec_scale is None: 1ab

1623 prec_trees = count_trees 1ab

1624 move_precs = move_counts 1ab

1625 else: 

1626 prec_trees, move_precs = compute_prec_trees( 1ab

1627 bart.prec_scale, 

1628 bart.forest.leaf_indices, 

1629 moves, 

1630 bart.forest.count_batch_size, 

1631 ) 

1632 assert move_precs is not None 1ab

1633 

1634 # compute some missing information about moves 

1635 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab

1636 save_ratios = bart.forest.log_likelihood is not None 1ab

1637 bart = replace( 1ab

1638 bart, 

1639 forest=replace( 

1640 bart.forest, 

1641 grow_prop_count=jnp.sum(moves.grow), 

1642 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

1643 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None, 

1644 ), 

1645 ) 

1646 

1647 # pre-compute some likelihood ratio & posterior terms 

1648 assert bart.sigma2 is not None # `step` shall temporarily set it to 1 1ab

1649 prelkv, prelk = precompute_likelihood_terms( 1ab

1650 bart.sigma2, bart.forest.sigma_mu2, move_precs 

1651 ) 

1652 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab

1653 

1654 return ParallelStageOut( 1ab

1655 bart=bart, 

1656 moves=moves, 

1657 prec_trees=prec_trees, 

1658 move_precs=move_precs, 

1659 prelkv=prelkv, 

1660 prelk=prelk, 

1661 prelf=prelf, 

1662 ) 

1663 

1664 

1665@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab

1666def apply_grow_to_indices( 1ab

1667 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

1668) -> UInt[Array, 'num_trees n']: 

1669 """ 

1670 Update the leaf indices to apply a grow move. 

1671 

1672 Parameters 

1673 ---------- 

1674 moves 

1675 The proposed moves, see `propose_moves`. 

1676 leaf_indices 

1677 The index of the leaf each datapoint falls into. 

1678 X 

1679 The predictors matrix. 

1680 

1681 Returns 

1682 ------- 

1683 The updated leaf indices. 

1684 """ 

1685 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab

1686 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab

1687 tree_size = jnp.array(2 * moves.var_tree.size) 1ab

1688 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab

1689 return jnp.where( 1ab

1690 leaf_indices == node_to_update, left_child + go_right, leaf_indices 

1691 ) 

1692 

1693 

1694def compute_count_trees( 1ab

1695 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None 

1696) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]: 

1697 """ 

1698 Count the number of datapoints in each leaf. 

1699 

1700 Parameters 

1701 ---------- 

1702 leaf_indices 

1703 The index of the leaf each datapoint falls into, with the deeper version 

1704 of the tree (post-GROW, pre-PRUNE). 

1705 moves 

1706 The proposed moves, see `propose_moves`. 

1707 batch_size 

1708 The data batch size to use for the summation. 

1709 

1710 Returns 

1711 ------- 

1712 count_trees : Int32[Array, 'num_trees 2**d'] 

1713 The number of points in each potential or actual leaf node. 

1714 counts : Counts 

1715 The counts of the number of points in the leaves grown or pruned by the 

1716 moves. 

1717 """ 

1718 num_trees, tree_size = moves.var_tree.shape 1ab

1719 tree_size *= 2 1ab

1720 tree_indices = jnp.arange(num_trees) 1ab

1721 

1722 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab

1723 

1724 # count datapoints in nodes modified by move 

1725 left = count_trees[tree_indices, moves.left] 1ab

1726 right = count_trees[tree_indices, moves.right] 1ab

1727 counts = Counts(left=left, right=right, total=left + right) 1ab

1728 

1729 # write count into non-leaf node 

1730 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab

1731 

1732 return count_trees, counts 1ab

1733 

1734 

1735def count_datapoints_per_leaf( 1ab

1736 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None 

1737) -> Int32[Array, 'num_trees 2**(d-1)']: 

1738 """ 

1739 Count the number of datapoints in each leaf. 

1740 

1741 Parameters 

1742 ---------- 

1743 leaf_indices 

1744 The index of the leaf each datapoint falls into. 

1745 tree_size 

1746 The size of the leaf tree array (2 ** d). 

1747 batch_size 

1748 The data batch size to use for the summation. 

1749 

1750 Returns 

1751 ------- 

1752 The number of points in each leaf node. 

1753 """ 

1754 if batch_size is None: 1ab

1755 return _count_scan(leaf_indices, tree_size) 1ab

1756 else: 

1757 return _count_vec(leaf_indices, tree_size, batch_size) 1ab

1758 

1759 

1760def _count_scan( 1ab

1761 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int 

1762) -> Int32[Array, 'num_trees {tree_size}']: 

1763 def loop(_, leaf_indices): 1ab

1764 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab

1765 

1766 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab

1767 return count_trees 1ab

1768 

1769 

1770def _aggregate_scatter( 1ab

1771 values: Shaped[Array, '*'], 

1772 indices: Integer[Array, '*'], 

1773 size: int, 

1774 dtype: jnp.dtype, 

1775) -> Shaped[Array, ' {size}']: 

1776 return jnp.zeros(size, dtype).at[indices].add(values) 1ab

1777 

1778 

1779def _count_vec( 1ab

1780 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int 

1781) -> Int32[Array, 'num_trees 2**(d-1)']: 

1782 return _aggregate_batched_alltrees( 1ab

1783 1, leaf_indices, tree_size, jnp.uint32, batch_size 

1784 ) 

1785 # uint16 is super-slow on gpu, don't use it even if n < 2^16 

1786 

1787 

1788def _aggregate_batched_alltrees( 1ab

1789 values: Shaped[Array, '*'], 

1790 indices: UInt[Array, 'num_trees n'], 

1791 size: int, 

1792 dtype: jnp.dtype, 

1793 batch_size: int, 

1794) -> Shaped[Array, 'num_trees {size}']: 

1795 num_trees, n = indices.shape 1ab

1796 tree_indices = jnp.arange(num_trees) 1ab

1797 nbatches = n // batch_size + bool(n % batch_size) 1ab

1798 batch_indices = jnp.arange(n) % nbatches 1ab

1799 return ( 1ab

1800 jnp.zeros((num_trees, size, nbatches), dtype) 

1801 .at[tree_indices[:, None], indices, batch_indices] 

1802 .add(values) 

1803 .sum(axis=2) 

1804 ) 

1805 

1806 

1807def compute_prec_trees( 1ab

1808 prec_scale: Float32[Array, ' n'], 

1809 leaf_indices: UInt[Array, 'num_trees n'], 

1810 moves: Moves, 

1811 batch_size: int | None, 

1812) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]: 

1813 """ 

1814 Compute the likelihood precision scale in each leaf. 

1815 

1816 Parameters 

1817 ---------- 

1818 prec_scale 

1819 The scale of the precision of the error on each datapoint. 

1820 leaf_indices 

1821 The index of the leaf each datapoint falls into, with the deeper version 

1822 of the tree (post-GROW, pre-PRUNE). 

1823 moves 

1824 The proposed moves, see `propose_moves`. 

1825 batch_size 

1826 The data batch size to use for the summation. 

1827 

1828 Returns 

1829 ------- 

1830 prec_trees : Float32[Array, 'num_trees 2**d'] 

1831 The likelihood precision scale in each potential or actual leaf node. 

1832 precs : Precs 

1833 The likelihood precision scale in the nodes involved in the moves. 

1834 """ 

1835 num_trees, tree_size = moves.var_tree.shape 1ab

1836 tree_size *= 2 1ab

1837 tree_indices = jnp.arange(num_trees) 1ab

1838 

1839 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab

1840 

1841 # prec datapoints in nodes modified by move 

1842 left = prec_trees[tree_indices, moves.left] 1ab

1843 right = prec_trees[tree_indices, moves.right] 1ab

1844 precs = Precs(left=left, right=right, total=left + right) 1ab

1845 

1846 # write prec into non-leaf node 

1847 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab

1848 

1849 return prec_trees, precs 1ab

1850 

1851 

1852def prec_per_leaf( 1ab

1853 prec_scale: Float32[Array, ' n'], 

1854 leaf_indices: UInt[Array, 'num_trees n'], 

1855 tree_size: int, 

1856 batch_size: int | None, 

1857) -> Float32[Array, 'num_trees {tree_size}']: 

1858 """ 

1859 Compute the likelihood precision scale in each leaf. 

1860 

1861 Parameters 

1862 ---------- 

1863 prec_scale 

1864 The scale of the precision of the error on each datapoint. 

1865 leaf_indices 

1866 The index of the leaf each datapoint falls into. 

1867 tree_size 

1868 The size of the leaf tree array (2 ** d). 

1869 batch_size 

1870 The data batch size to use for the summation. 

1871 

1872 Returns 

1873 ------- 

1874 The likelihood precision scale in each leaf node. 

1875 """ 

1876 if batch_size is None: 1876 ↛ 1879line 1876 didn't jump to line 1879 because the condition on line 1876 was always true1ab

1877 return _prec_scan(prec_scale, leaf_indices, tree_size) 1ab

1878 else: 

1879 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size) 

1880 

1881 

1882def _prec_scan( 1ab

1883 prec_scale: Float32[Array, ' n'], 

1884 leaf_indices: UInt[Array, 'num_trees n'], 

1885 tree_size: int, 

1886) -> Float32[Array, 'num_trees {tree_size}']: 

1887 def loop(_, leaf_indices): 1ab

1888 return None, _aggregate_scatter( 1ab

1889 prec_scale, leaf_indices, tree_size, jnp.float32 

1890 ) 

1891 

1892 _, prec_trees = lax.scan(loop, None, leaf_indices) 1ab

1893 return prec_trees 1ab

1894 

1895 

1896def _prec_vec( 1ab

1897 prec_scale: Float32[Array, ' n'], 

1898 leaf_indices: UInt[Array, 'num_trees n'], 

1899 tree_size: int, 

1900 batch_size: int, 

1901) -> Float32[Array, 'num_trees {tree_size}']: 

1902 return _aggregate_batched_alltrees( 

1903 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size 

1904 ) 

1905 

1906 

1907def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 1ab

1908 """ 

1909 Complete non-likelihood MH ratio calculation. 

1910 

1911 This function adds the probability of choosing a prune move over the grow 

1912 move in the inverse transition, and the a priori probability that the 

1913 children nodes are leaves. 

1914 

1915 Parameters 

1916 ---------- 

1917 moves 

1918 The proposed moves. Must have already been updated to keep into account 

1919 the thresholds on the number of datapoints per node, this happens in 

1920 `accept_moves_parallel_stage`. 

1921 p_nonterminal 

1922 The a priori probability of each node being nonterminal conditional on 

1923 its ancestors, including at the maximum depth where it should be zero. 

1924 

1925 Returns 

1926 ------- 

1927 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

1928 """ 

1929 # can the leaves can be grown? 

1930 num_trees, _ = moves.affluence_tree.shape 1ab

1931 tree_indices = jnp.arange(num_trees) 1ab

1932 left_growable = moves.affluence_tree.at[tree_indices, moves.left].get( 1ab

1933 mode='fill', fill_value=False 

1934 ) 

1935 right_growable = moves.affluence_tree.at[tree_indices, moves.right].get( 1ab

1936 mode='fill', fill_value=False 

1937 ) 

1938 

1939 # p_prune if grow 

1940 other_growable_leaves = moves.num_growable >= 2 1ab

1941 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab

1942 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab

1943 

1944 # p_prune if prune 

1945 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab

1946 

1947 # select p_prune 

1948 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab

1949 

1950 # prior probability of both children being terminal 

1951 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab

1952 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab

1953 pt_children = pt_left * pt_right 1ab

1954 

1955 return replace( 1ab

1956 moves, 

1957 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune), 

1958 partial_ratio=None, 

1959 ) 

1960 

1961 

1962@vmap_nodoc 1ab

1963def adapt_leaf_trees_to_grow_indices( 1ab

1964 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves 

1965) -> Float32[Array, 'num_trees 2**d']: 

1966 """ 

1967 Modify leaves such that post-grow indices work on the original tree. 

1968 

1969 The value of the leaf to grow is copied to what would be its children if the 

1970 grow move was accepted. 

1971 

1972 Parameters 

1973 ---------- 

1974 leaf_trees 

1975 The leaf values. 

1976 moves 

1977 The proposed moves, see `propose_moves`. 

1978 

1979 Returns 

1980 ------- 

1981 The modified leaf values. 

1982 """ 

1983 values_at_node = leaf_trees[moves.node] 1ab

1984 return ( 1ab

1985 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)] 

1986 .set(values_at_node) 

1987 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)] 

1988 .set(values_at_node) 

1989 ) 

1990 

1991 

1992def precompute_likelihood_terms( 1ab

1993 sigma2: Float32[Array, ''], 

1994 sigma_mu2: Float32[Array, ''], 

1995 move_precs: Precs | Counts, 

1996) -> tuple[PreLkV, PreLk]: 

1997 """ 

1998 Pre-compute terms used in the likelihood ratio of the acceptance step. 

1999 

2000 Parameters 

2001 ---------- 

2002 sigma2 

2003 The error variance, or the global error variance factor is `prec_scale` 

2004 is set. 

2005 sigma_mu2 

2006 The prior variance of each leaf. 

2007 move_precs 

2008 The likelihood precision scale in the leaves grown or pruned by the 

2009 moves, under keys 'left', 'right', and 'total' (left + right). 

2010 

2011 Returns 

2012 ------- 

2013 prelkv : PreLkV 

2014 Dictionary with pre-computed terms of the likelihood ratio, one per 

2015 tree. 

2016 prelk : PreLk 

2017 Dictionary with pre-computed terms of the likelihood ratio, shared by 

2018 all trees. 

2019 """ 

2020 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab

2021 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab

2022 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab

2023 prelkv = PreLkV( 1ab

2024 sigma2_left=sigma2_left, 

2025 sigma2_right=sigma2_right, 

2026 sigma2_total=sigma2_total, 

2027 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2, 

2028 ) 

2029 return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2)) 1ab

2030 

2031 

2032def precompute_leaf_terms( 1ab

2033 key: Key[Array, ''], 

2034 prec_trees: Float32[Array, 'num_trees 2**d'], 

2035 sigma2: Float32[Array, ''], 

2036 sigma_mu2: Float32[Array, ''], 

2037) -> PreLf: 

2038 """ 

2039 Pre-compute terms used to sample leaves from their posterior. 

2040 

2041 Parameters 

2042 ---------- 

2043 key 

2044 A jax random key. 

2045 prec_trees 

2046 The likelihood precision scale in each potential or actual leaf node. 

2047 sigma2 

2048 The error variance, or the global error variance factor if `prec_scale` 

2049 is set. 

2050 sigma_mu2 

2051 The prior variance of each leaf. 

2052 

2053 Returns 

2054 ------- 

2055 Pre-computed terms for leaf sampling. 

2056 """ 

2057 prec_lk = prec_trees / sigma2 1ab

2058 prec_prior = lax.reciprocal(sigma_mu2) 1ab

2059 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab

2060 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab

2061 return PreLf( 1ab

2062 mean_factor=var_post / sigma2, 

2063 # | mean = mean_lk * prec_lk * var_post 

2064 # | resid_tree = mean_lk * prec_tree --> 

2065 # | --> mean_lk = resid_tree / prec_tree (kind of) 

2066 # | mean_factor = 

2067 # | = mean / resid_tree = 

2068 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

2069 # | = 1 / prec_tree * prec_tree / sigma2 * var_post = 

2070 # | = var_post / sigma2 

2071 centered_leaves=z * jnp.sqrt(var_post), 

2072 ) 

2073 

2074 

2075def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab

2076 """ 

2077 Accept/reject the moves one tree at a time. 

2078 

2079 This is the most performance-sensitive function because it contains all and 

2080 only the parts of the algorithm that can not be parallelized across trees. 

2081 

2082 Parameters 

2083 ---------- 

2084 pso 

2085 The output of `accept_moves_parallel_stage`. 

2086 

2087 Returns 

2088 ------- 

2089 bart : State 

2090 A partially updated BART mcmc state. 

2091 moves : Moves 

2092 The accepted/rejected moves, with `acc` and `to_prune` set. 

2093 """ 

2094 

2095 def loop(resid, pt): 1ab

2096 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab

2097 resid, 

2098 SeqStageInAllTrees( 

2099 pso.bart.X, 

2100 pso.bart.forest.resid_batch_size, 

2101 pso.bart.prec_scale, 

2102 pso.bart.forest.log_likelihood is not None, 

2103 pso.prelk, 

2104 ), 

2105 pt, 

2106 ) 

2107 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab

2108 

2109 pts = SeqStageInPerTree( 1ab

2110 pso.bart.forest.leaf_tree, 

2111 pso.prec_trees, 

2112 pso.moves, 

2113 pso.move_precs, 

2114 pso.bart.forest.leaf_indices, 

2115 pso.prelkv, 

2116 pso.prelf, 

2117 ) 

2118 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab

2119 

2120 bart = replace( 1ab

2121 pso.bart, 

2122 resid=resid, 

2123 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio), 

2124 ) 

2125 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab

2126 

2127 return bart, moves 1ab

2128 

2129 

2130class SeqStageInAllTrees(Module): 1ab

2131 """ 

2132 The inputs to `accept_move_and_sample_leaves` that are shared by all trees. 

2133 

2134 Parameters 

2135 ---------- 

2136 X 

2137 The predictors. 

2138 resid_batch_size 

2139 The batch size for computing the sum of residuals in each leaf. 

2140 prec_scale 

2141 The scale of the precision of the error on each datapoint. If None, it 

2142 is assumed to be 1. 

2143 save_ratios 

2144 Whether to save the acceptance ratios. 

2145 prelk 

2146 The pre-computed terms of the likelihood ratio which are shared across 

2147 trees. 

2148 """ 

2149 

2150 X: UInt[Array, 'p n'] 1ab

2151 resid_batch_size: int | None = field(static=True) 1ab

2152 prec_scale: Float32[Array, ' n'] | None 1ab

2153 save_ratios: bool = field(static=True) 1ab

2154 prelk: PreLk 1ab

2155 

2156 

2157class SeqStageInPerTree(Module): 1ab

2158 """ 

2159 The inputs to `accept_move_and_sample_leaves` that are separate for each tree. 

2160 

2161 Parameters 

2162 ---------- 

2163 leaf_tree 

2164 The leaf values of the tree. 

2165 prec_tree 

2166 The likelihood precision scale in each potential or actual leaf node. 

2167 move 

2168 The proposed move, see `propose_moves`. 

2169 move_precs 

2170 The likelihood precision scale in each node modified by the moves. 

2171 leaf_indices 

2172 The leaf indices for the largest version of the tree compatible with 

2173 the move. 

2174 prelkv 

2175 prelf 

2176 The pre-computed terms of the likelihood ratio and leaf sampling which 

2177 are specific to the tree. 

2178 """ 

2179 

2180 leaf_tree: Float32[Array, ' 2**d'] 1ab

2181 prec_tree: Float32[Array, ' 2**d'] 1ab

2182 move: Moves 1ab

2183 move_precs: Precs | Counts 1ab

2184 leaf_indices: UInt[Array, ' n'] 1ab

2185 prelkv: PreLkV 1ab

2186 prelf: PreLf 1ab

2187 

2188 

2189def accept_move_and_sample_leaves( 1ab

2190 resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree 

2191) -> tuple[ 

2192 Float32[Array, ' n'], 

2193 Float32[Array, ' 2**d'], 

2194 Bool[Array, ''], 

2195 Bool[Array, ''], 

2196 Float32[Array, ''] | None, 

2197]: 

2198 """ 

2199 Accept or reject a proposed move and sample the new leaf values. 

2200 

2201 Parameters 

2202 ---------- 

2203 resid 

2204 The residuals (data minus forest value). 

2205 at 

2206 The inputs that are the same for all trees. 

2207 pt 

2208 The inputs that are separate for each tree. 

2209 

2210 Returns 

2211 ------- 

2212 resid : Float32[Array, 'n'] 

2213 The updated residuals (data minus forest value). 

2214 leaf_tree : Float32[Array, '2**d'] 

2215 The new leaf values of the tree. 

2216 acc : Bool[Array, ''] 

2217 Whether the move was accepted. 

2218 to_prune : Bool[Array, ''] 

2219 Whether, to reflect the acceptance status of the move, the state should 

2220 be updated by pruning the leaves involved in the move. 

2221 log_lk_ratio : Float32[Array, ''] | None 

2222 The logarithm of the likelihood ratio for the move. `None` if not to be 

2223 saved. 

2224 """ 

2225 # sum residuals in each leaf, in tree proposed by grow move 

2226 if at.prec_scale is None: 1ab

2227 scaled_resid = resid 1ab

2228 else: 

2229 scaled_resid = resid * at.prec_scale 1ab

2230 resid_tree = sum_resid( 1ab

2231 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size 

2232 ) 

2233 

2234 # subtract starting tree from function 

2235 resid_tree += pt.prec_tree * pt.leaf_tree 1ab

2236 

2237 # sum residuals in parent node modified by move 

2238 resid_left = resid_tree[pt.move.left] 1ab

2239 resid_right = resid_tree[pt.move.right] 1ab

2240 resid_total = resid_left + resid_right 1ab

2241 assert pt.move.node.dtype == jnp.int32 1ab

2242 resid_tree = resid_tree.at[pt.move.node].set(resid_total) 1ab

2243 

2244 # compute acceptance ratio 

2245 log_lk_ratio = compute_likelihood_ratio( 1ab

2246 resid_total, resid_left, resid_right, pt.prelkv, at.prelk 

2247 ) 

2248 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab

2249 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab

2250 if not at.save_ratios: 1ab

2251 log_lk_ratio = None 1ab

2252 

2253 # determine whether to accept the move 

2254 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab

2255 

2256 # compute leaves posterior and sample leaves 

2257 mean_post = resid_tree * pt.prelf.mean_factor 1ab

2258 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab

2259 

2260 # copy leaves around such that the leaf indices point to the correct leaf 

2261 to_prune = acc ^ pt.move.grow 1ab

2262 leaf_tree = ( 1ab

2263 leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)] 

2264 .set(leaf_tree[pt.move.node]) 

2265 .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)] 

2266 .set(leaf_tree[pt.move.node]) 

2267 ) 

2268 

2269 # replace old tree with new tree in function values 

2270 resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices] 1ab

2271 

2272 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab

2273 

2274 

2275def sum_resid( 1ab

2276 scaled_resid: Float32[Array, ' n'], 

2277 leaf_indices: UInt[Array, ' n'], 

2278 tree_size: int, 

2279 batch_size: int | None, 

2280) -> Float32[Array, ' {tree_size}']: 

2281 """ 

2282 Sum the residuals in each leaf. 

2283 

2284 Parameters 

2285 ---------- 

2286 scaled_resid 

2287 The residuals (data minus forest value) multiplied by the error 

2288 precision scale. 

2289 leaf_indices 

2290 The leaf indices of the tree (in which leaf each data point falls into). 

2291 tree_size 

2292 The size of the tree array (2 ** d). 

2293 batch_size 

2294 The data batch size for the aggregation. Batching increases numerical 

2295 accuracy and parallelism. 

2296 

2297 Returns 

2298 ------- 

2299 The sum of the residuals at data points in each leaf. 

2300 """ 

2301 if batch_size is None: 1ab

2302 aggr_func = _aggregate_scatter 1ab

2303 else: 

2304 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab

2305 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab

2306 

2307 

2308def _aggregate_batched_onetree( 1ab

2309 values: Shaped[Array, '*'], 

2310 indices: Integer[Array, '*'], 

2311 size: int, 

2312 dtype: jnp.dtype, 

2313 batch_size: int, 

2314) -> Float32[Array, ' {size}']: 

2315 (n,) = indices.shape 1ab

2316 nbatches = n // batch_size + bool(n % batch_size) 1ab

2317 batch_indices = jnp.arange(n) % nbatches 1ab

2318 return ( 1ab

2319 jnp.zeros((size, nbatches), dtype) 

2320 .at[indices, batch_indices] 

2321 .add(values) 

2322 .sum(axis=1) 

2323 ) 

2324 

2325 

2326def compute_likelihood_ratio( 1ab

2327 total_resid: Float32[Array, ''], 

2328 left_resid: Float32[Array, ''], 

2329 right_resid: Float32[Array, ''], 

2330 prelkv: PreLkV, 

2331 prelk: PreLk, 

2332) -> Float32[Array, '']: 

2333 """ 

2334 Compute the likelihood ratio of a grow move. 

2335 

2336 Parameters 

2337 ---------- 

2338 total_resid 

2339 left_resid 

2340 right_resid 

2341 The sum of the residuals (scaled by error precision scale) of the 

2342 datapoints falling in the nodes involved in the moves. 

2343 prelkv 

2344 prelk 

2345 The pre-computed terms of the likelihood ratio, see 

2346 `precompute_likelihood_terms`. 

2347 

2348 Returns 

2349 ------- 

2350 The likelihood ratio P(data | new tree) / P(data | old tree). 

2351 """ 

2352 exp_term = prelk.exp_factor * ( 1ab

2353 left_resid * left_resid / prelkv.sigma2_left 

2354 + right_resid * right_resid / prelkv.sigma2_right 

2355 - total_resid * total_resid / prelkv.sigma2_total 

2356 ) 

2357 return prelkv.sqrt_term + exp_term 1ab

2358 

2359 

2360def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab

2361 """ 

2362 Post-process the mcmc state after accepting/rejecting the moves. 

2363 

2364 This function is separate from `accept_moves_sequential_stage` to signal it 

2365 can work in parallel across trees. 

2366 

2367 Parameters 

2368 ---------- 

2369 bart 

2370 A partially updated BART mcmc state. 

2371 moves 

2372 The proposed moves (see `propose_moves`) as updated by 

2373 `accept_moves_sequential_stage`. 

2374 

2375 Returns 

2376 ------- 

2377 The fully updated BART mcmc state. 

2378 """ 

2379 return replace( 1ab

2380 bart, 

2381 forest=replace( 

2382 bart.forest, 

2383 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

2384 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

2385 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves), 

2386 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves), 

2387 ), 

2388 ) 

2389 

2390 

2391@vmap_nodoc 1ab

2392def apply_moves_to_leaf_indices( 1ab

2393 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

2394) -> UInt[Array, 'num_trees n']: 

2395 """ 

2396 Update the leaf indices to match the accepted move. 

2397 

2398 Parameters 

2399 ---------- 

2400 leaf_indices 

2401 The index of the leaf each datapoint falls into, if the grow move was 

2402 accepted. 

2403 moves 

2404 The proposed moves (see `propose_moves`), as updated by 

2405 `accept_moves_sequential_stage`. 

2406 

2407 Returns 

2408 ------- 

2409 The updated leaf indices. 

2410 """ 

2411 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab

2412 is_child = (leaf_indices & mask) == moves.left 1ab

2413 return jnp.where( 1ab

2414 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices 

2415 ) 

2416 

2417 

2418@vmap_nodoc 1ab

2419def apply_moves_to_split_trees( 1ab

2420 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves 

2421) -> UInt[Array, 'num_trees 2**(d-1)']: 

2422 """ 

2423 Update the split trees to match the accepted move. 

2424 

2425 Parameters 

2426 ---------- 

2427 split_tree 

2428 The cutpoints of the decision nodes in the initial trees. 

2429 moves 

2430 The proposed moves (see `propose_moves`), as updated by 

2431 `accept_moves_sequential_stage`. 

2432 

2433 Returns 

2434 ------- 

2435 The updated split trees. 

2436 """ 

2437 assert moves.to_prune is not None 1ab

2438 return ( 1ab

2439 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)] 

2440 .set(moves.grow_split.astype(split_tree.dtype)) 

2441 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)] 

2442 .set(0) 

2443 ) 

2444 

2445 

2446def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab

2447 """ 

2448 MCMC-update the error variance (factor). 

2449 

2450 Parameters 

2451 ---------- 

2452 key 

2453 A jax random key. 

2454 bart 

2455 A BART mcmc state. 

2456 

2457 Returns 

2458 ------- 

2459 The new BART mcmc state, with an updated `sigma2`. 

2460 """ 

2461 resid = bart.resid 1ab

2462 alpha = bart.sigma2_alpha + resid.size / 2 1ab

2463 if bart.prec_scale is None: 1ab

2464 scaled_resid = resid 1ab

2465 else: 

2466 scaled_resid = resid * bart.prec_scale 1ab

2467 norm2 = resid @ scaled_resid 1ab

2468 beta = bart.sigma2_beta + norm2 / 2 1ab

2469 

2470 sample = random.gamma(key, alpha) 1ab

2471 # random.gamma seems to be slow at compiling, maybe cdf inversion would 

2472 # be better, but it's not implemented in jax 

2473 return replace(bart, sigma2=beta / sample) 1ab

2474 

2475 

2476def step_z(key: Key[Array, ''], bart: State) -> State: 1ab

2477 """ 

2478 MCMC-update the latent variable for binary regression. 

2479 

2480 Parameters 

2481 ---------- 

2482 key 

2483 A jax random key. 

2484 bart 

2485 A BART MCMC state. 

2486 

2487 Returns 

2488 ------- 

2489 The updated BART MCMC state. 

2490 """ 

2491 trees_plus_offset = bart.z - bart.resid 1ab

2492 assert bart.y.dtype == bool 1ab

2493 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1ab

2494 z = trees_plus_offset + resid 1ab

2495 return replace(bart, z=z, resid=resid) 1ab

2496 

2497 

2498def step_s(key: Key[Array, ''], bart: State) -> State: 1ab

2499 """ 

2500 Update `log_s` using Dirichlet sampling. 

2501 

2502 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior 

2503 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where 

2504 varcount is the count of how many times each variable is used in the 

2505 current forest. 

2506 

2507 Parameters 

2508 ---------- 

2509 key 

2510 Random key for sampling. 

2511 bart 

2512 The current BART state. 

2513 

2514 Returns 

2515 ------- 

2516 Updated BART state with re-sampled `log_s`. 

2517 

2518 """ 

2519 assert bart.forest.theta is not None 1ab

2520 

2521 # histogram current variable usage 

2522 p = bart.forest.max_split.size 1ab

2523 varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree) 1ab

2524 

2525 # sample from Dirichlet posterior 

2526 alpha = bart.forest.theta / p + varcount 1ab

2527 log_s = random.loggamma(key, alpha) 1ab

2528 

2529 # update forest with new s 

2530 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1ab

2531 

2532 

2533def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 1ab

2534 """ 

2535 Update `theta`. 

2536 

2537 The prior is theta / (theta + rho) ~ Beta(a, b). 

2538 

2539 Parameters 

2540 ---------- 

2541 key 

2542 Random key for sampling. 

2543 bart 

2544 The current BART state. 

2545 num_grid 

2546 The number of points in the evenly-spaced grid used to sample 

2547 theta / (theta + rho). 

2548 

2549 Returns 

2550 ------- 

2551 Updated BART state with re-sampled `theta`. 

2552 """ 

2553 assert bart.forest.log_s is not None 1ab

2554 assert bart.forest.rho is not None 1ab

2555 assert bart.forest.a is not None 1ab

2556 assert bart.forest.b is not None 1ab

2557 

2558 # the grid points are the midpoints of num_grid bins in (0, 1) 

2559 padding = 1 / (2 * num_grid) 1ab

2560 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1ab

2561 

2562 # normalize s 

2563 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1ab

2564 

2565 # sample lambda 

2566 logp, theta_grid = _log_p_lamda( 1ab

2567 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b 

2568 ) 

2569 i = random.categorical(key, logp) 1ab

2570 theta = theta_grid[i] 1ab

2571 

2572 return replace(bart, forest=replace(bart.forest, theta=theta)) 1ab

2573 

2574 

2575def _log_p_lamda( 1ab

2576 lamda: Float32[Array, ' num_grid'], 

2577 log_s: Float32[Array, ' p'], 

2578 rho: Float32[Array, ''], 

2579 a: Float32[Array, ''], 

2580 b: Float32[Array, ''], 

2581) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]: 

2582 # in the following I use lamda[::-1] == 1 - lamda 

2583 theta = rho * lamda / lamda[::-1] 1ab

2584 p = log_s.size 1ab

2585 return ( 1ab

2586 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda) 

2587 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda) 

2588 + gammaln(theta) 

2589 - p * gammaln(theta / p) 

2590 + theta / p * jnp.sum(log_s) 

2591 ), theta 

2592 

2593 

2594def step_sparse(key: Key[Array, ''], bart: State) -> State: 1ab

2595 """ 

2596 Update the sparsity parameters. 

2597 

2598 This invokes `step_s`, and then `step_theta` only if the parameters of 

2599 the theta prior are defined. 

2600 

2601 Parameters 

2602 ---------- 

2603 key 

2604 Random key for sampling. 

2605 bart 

2606 The current BART state. 

2607 

2608 Returns 

2609 ------- 

2610 Updated BART state with re-sampled `log_s` and `theta`. 

2611 """ 

2612 keys = split(key) 1ab

2613 bart = step_s(keys.pop(), bart) 1ab

2614 if bart.forest.rho is not None: 1ab

2615 bart = step_theta(keys.pop(), bart) 1ab

2616 return bart 1ab