Coverage for src/bartz/mcmcstep.py: 91%

489 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-05-29 23:01 +0000

1# bartz/src/bartz/mcmcstep.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26Functions that implement the BART posterior MCMC initialization and update step. 

27 

28Functions that do MCMC steps operate by taking as input a bart state, and 

29outputting a new state. The inputs are not modified. 

30 

31The main entry points are: 

32 

33 - `State`: The dataclass that represents a BART MCMC state. 

34 - `init`: Creates an initial `State` from data and configurations. 

35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`. 

36""" 

37 

38import math 1ab

39from dataclasses import replace 1ab

40from functools import cache, partial 1ab

41from typing import Any 1ab

42 

43import jax 1ab

44from equinox import Module, field 1ab

45from jax import lax, random 1ab

46from jax import numpy as jnp 1ab

47from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab

48 

49from . import grove 1ab

50from .jaxext import minimal_unsigned_dtype, split, vmap_nodoc 1ab

51 

52 

53class Forest(Module): 1ab

54 """ 

55 Represents the MCMC state of a sum of trees. 

56 

57 Parameters 

58 ---------- 

59 leaf_trees 

60 The leaf values. 

61 var_trees 

62 The decision axes. 

63 split_trees 

64 The decision boundaries. 

65 p_nonterminal 

66 The probability of a nonterminal node at each depth, padded with a 

67 zero. 

68 p_propose_grow 

69 The unnormalized probability of picking a leaf for a grow proposal. 

70 leaf_indices 

71 The index of the leaf each datapoints falls into, for each tree. 

72 min_points_per_leaf 

73 The minimum number of data points in a leaf node. 

74 affluence_trees 

75 Whether a non-bottom leaf nodes contains twice `min_points_per_leaf` 

76 datapoints. If `min_points_per_leaf` is not specified, this is None. 

77 resid_batch_size 

78 count_batch_size 

79 The data batch sizes for computing the sufficient statistics. If `None`, 

80 they are computed with no batching. 

81 log_trans_prior 

82 The log transition and prior Metropolis-Hastings ratio for the 

83 proposed move on each tree. 

84 log_likelihood 

85 The log likelihood ratio. 

86 grow_prop_count 

87 prune_prop_count 

88 The number of grow/prune proposals made during one full MCMC cycle. 

89 grow_acc_count 

90 prune_acc_count 

91 The number of grow/prune moves accepted during one full MCMC cycle. 

92 sigma_mu2 

93 The prior variance of a leaf, conditional on the tree structure. 

94 """ 

95 

96 leaf_trees: Float32[Array, 'num_trees 2**d'] 1ab

97 var_trees: UInt[Array, 'num_trees 2**(d-1)'] 1ab

98 split_trees: UInt[Array, 'num_trees 2**(d-1)'] 1ab

99 p_nonterminal: Float32[Array, 'd'] 1ab

100 p_propose_grow: Float32[Array, '2**(d-1)'] 1ab

101 leaf_indices: UInt[Array, 'num_trees n'] 1ab

102 min_points_per_leaf: Int32[Array, ''] | None 1ab

103 affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None 1ab

104 resid_batch_size: int | None = field(static=True) 1ab

105 count_batch_size: int | None = field(static=True) 1ab

106 log_trans_prior: Float32[Array, 'num_trees'] | None 1ab

107 log_likelihood: Float32[Array, 'num_trees'] | None 1ab

108 grow_prop_count: Int32[Array, ''] 1ab

109 prune_prop_count: Int32[Array, ''] 1ab

110 grow_acc_count: Int32[Array, ''] 1ab

111 prune_acc_count: Int32[Array, ''] 1ab

112 sigma_mu2: Float32[Array, ''] 1ab

113 

114 

115class State(Module): 1ab

116 """ 

117 Represents the MCMC state of BART. 

118 

119 Parameters 

120 ---------- 

121 X 

122 The predictors. 

123 max_split 

124 The maximum split index for each predictor. 

125 y 

126 The response. If the data type is `bool`, the model is binary regression. 

127 resid 

128 The residuals (`y` or `z` minus sum of trees). 

129 z 

130 The latent variable for binary regression. `None` in continuous 

131 regression. 

132 offset 

133 Constant shift added to the sum of trees. 

134 sigma2 

135 The error variance. `None` in binary regression. 

136 prec_scale 

137 The scale on the error precision, i.e., ``1 / error_scale ** 2``. 

138 `None` in binary regression. 

139 sigma2_alpha 

140 sigma2_beta 

141 The shape and scale parameters of the inverse gamma prior on the noise 

142 variance. `None` in binary regression. 

143 forest 

144 The sum of trees model. 

145 """ 

146 

147 X: UInt[Array, 'p n'] 1ab

148 max_split: UInt[Array, 'p'] 1ab

149 y: Float32[Array, 'n'] | Bool[Array, 'n'] 1ab

150 z: None | Float32[Array, 'n'] 1ab

151 offset: Float32[Array, ''] 1ab

152 resid: Float32[Array, 'n'] 1ab

153 sigma2: Float32[Array, ''] | None 1ab

154 prec_scale: Float32[Array, 'n'] | None 1ab

155 sigma2_alpha: Float32[Array, ''] | None 1ab

156 sigma2_beta: Float32[Array, ''] | None 1ab

157 forest: Forest 1ab

158 

159 

160def init( 1ab

161 *, 

162 X: UInt[Any, 'p n'], 

163 y: Float32[Any, 'n'] | Bool[Any, 'n'], 

164 offset: float | Float32[Any, ''] = 0.0, 

165 max_split: UInt[Any, 'p'], 

166 num_trees: int, 

167 p_nonterminal: Float32[Any, 'd-1'], 

168 sigma_mu2: float | Float32[Any, ''], 

169 sigma2_alpha: float | Float32[Any, ''] | None = None, 

170 sigma2_beta: float | Float32[Any, ''] | None = None, 

171 error_scale: Float32[Any, 'n'] | None = None, 

172 min_points_per_leaf: int | None = None, 

173 resid_batch_size: int | None | str = 'auto', 

174 count_batch_size: int | None | str = 'auto', 

175 save_ratios: bool = False, 

176) -> State: 

177 """ 

178 Make a BART posterior sampling MCMC initial state. 

179 

180 Parameters 

181 ---------- 

182 X 

183 The predictors. Note this is trasposed compared to the usual convention. 

184 y 

185 The response. If the data type is `bool`, the regression model is binary 

186 regression with probit. 

187 offset 

188 Constant shift added to the sum of trees. 0 if not specified. 

189 max_split 

190 The maximum split index for each variable. All split ranges start at 1. 

191 num_trees 

192 The number of trees in the forest. 

193 p_nonterminal 

194 The probability of a nonterminal node at each depth. The maximum depth 

195 of trees is fixed by the length of this array. 

196 sigma_mu2 

197 The prior variance of a leaf, conditional on the tree structure. The 

198 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The 

199 prior mean of leaves is always zero. 

200 sigma2_alpha 

201 sigma2_beta 

202 The shape and scale parameters of the inverse gamma prior on the error 

203 variance. Leave unspecified for binary regression. 

204 error_scale 

205 Each error is scaled by the corresponding factor in `error_scale`, so 

206 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``. 

207 Not supported for binary regression. If not specified, defaults to 1 for 

208 all points, but potentially skipping calculations. 

209 min_points_per_leaf 

210 The minimum number of data points in a leaf node. 0 if not specified. 

211 resid_batch_size 

212 count_batch_size 

213 The batch sizes, along datapoints, for summing the residuals and 

214 counting the number of datapoints in each leaf. `None` for no batching. 

215 If 'auto', pick a value based on the device of `y`, or the default 

216 device. 

217 save_ratios 

218 Whether to save the Metropolis-Hastings ratios. 

219 

220 Returns 

221 ------- 

222 An initialized BART MCMC state. 

223 

224 Raises 

225 ------ 

226 ValueError 

227 If `y` is boolean and arguments unused in binary regression are set. 

228 """ 

229 p_nonterminal = jnp.asarray(p_nonterminal) 1ab

230 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab

231 max_depth = p_nonterminal.size 1ab

232 

233 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab

234 def make_forest(max_depth, dtype): 1ab

235 return grove.make_tree(max_depth, dtype) 1ab

236 

237 y = jnp.asarray(y) 1ab

238 offset = jnp.asarray(offset) 1ab

239 

240 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab

241 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees 

242 ) 

243 

244 is_binary = y.dtype == bool 1ab

245 if is_binary: 1ab

246 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true1ab

247 raise ValueError( 

248 'error_scale, sigma2_alpha, and sigma2_beta must be set ' 

249 ' to `None` for binary regression.' 

250 ) 

251 sigma2 = None 1ab

252 else: 

253 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab

254 sigma2_beta = jnp.asarray(sigma2_beta) 1ab

255 sigma2 = sigma2_beta / sigma2_alpha 1ab

256 # sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1) 

257 # TODO: I don't like this isfinite check, these functions should be 

258 # low-level and just do the thing. Why was it here? 

259 

260 bart = State( 1ab

261 X=jnp.asarray(X), 

262 max_split=jnp.asarray(max_split), 

263 y=y, 

264 z=jnp.full(y.shape, offset) if is_binary else None, 

265 offset=offset, 

266 resid=jnp.zeros(y.shape) if is_binary else y - offset, 

267 sigma2=sigma2, 

268 prec_scale=( 

269 None if error_scale is None else lax.reciprocal(jnp.square(error_scale)) 

270 ), 

271 sigma2_alpha=sigma2_alpha, 

272 sigma2_beta=sigma2_beta, 

273 forest=Forest( 

274 leaf_trees=make_forest(max_depth, jnp.float32), 

275 var_trees=make_forest( 

276 max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1) 

277 ), 

278 split_trees=make_forest(max_depth - 1, max_split.dtype), 

279 grow_prop_count=jnp.zeros((), int), 

280 grow_acc_count=jnp.zeros((), int), 

281 prune_prop_count=jnp.zeros((), int), 

282 prune_acc_count=jnp.zeros((), int), 

283 p_nonterminal=p_nonterminal, 

284 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))], 

285 leaf_indices=jnp.ones( 

286 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1) 

287 ), 

288 min_points_per_leaf=( 

289 None 

290 if min_points_per_leaf is None 

291 else jnp.asarray(min_points_per_leaf) 

292 ), 

293 affluence_trees=( 

294 None 

295 if min_points_per_leaf is None 

296 else make_forest(max_depth - 1, bool) 

297 .at[:, 1] 

298 .set(y.size >= 2 * min_points_per_leaf) 

299 ), 

300 resid_batch_size=resid_batch_size, 

301 count_batch_size=count_batch_size, 

302 log_trans_prior=jnp.full(num_trees, jnp.nan) if save_ratios else None, 

303 log_likelihood=jnp.full(num_trees, jnp.nan) if save_ratios else None, 

304 sigma_mu2=jnp.asarray(sigma_mu2), 

305 ), 

306 ) 

307 

308 return bart 1ab

309 

310 

311def _choose_suffstat_batch_size( 1ab

312 resid_batch_size, count_batch_size, y, forest_size 

313) -> tuple[int | None, ...]: 

314 @cache 1ab

315 def get_platform(): 1ab

316 try: 

317 device = y.devices().pop() 

318 except jax.errors.ConcretizationTypeError: 

319 device = jax.devices()[0] 

320 platform = device.platform 

321 if platform not in ('cpu', 'gpu'): 

322 raise KeyError(f'Unknown platform: {platform}') 

323 return platform 

324 

325 if resid_batch_size == 'auto': 325 ↛ 326line 325 didn't jump to line 326 because the condition on line 325 was never true1ab

326 platform = get_platform() 

327 n = max(1, y.size) 

328 if platform == 'cpu': 

329 resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6 

330 elif platform == 'gpu': 

331 resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3 

332 resid_batch_size = max(1, resid_batch_size) 

333 

334 if count_batch_size == 'auto': 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true1ab

335 platform = get_platform() 

336 if platform == 'cpu': 

337 count_batch_size = None 

338 elif platform == 'gpu': 

339 n = max(1, y.size) 

340 count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2 

341 # /4 is good on V100, /2 on L4/T4, still haven't tried A100 

342 max_memory = 2**29 

343 itemsize = 4 

344 min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory)) 

345 count_batch_size = max(count_batch_size, min_batch_size) 

346 count_batch_size = max(1, count_batch_size) 

347 

348 return resid_batch_size, count_batch_size 1ab

349 

350 

351@jax.jit 1ab

352def step(key: Key[Array, ''], bart: State) -> State: 1ab

353 """ 

354 Do one MCMC step. 

355 

356 Parameters 

357 ---------- 

358 key 

359 A jax random key. 

360 bart 

361 A BART mcmc state, as created by `init`. 

362 

363 Returns 

364 ------- 

365 The new BART mcmc state. 

366 """ 

367 keys = split(key) 1ab

368 

369 if bart.y.dtype == bool: # binary regression 1ab

370 bart = replace(bart, sigma2=jnp.float32(1)) 1ab

371 bart = step_trees(keys.pop(), bart) 1ab

372 bart = replace(bart, sigma2=None) 1ab

373 return step_z(keys.pop(), bart) 1ab

374 

375 else: # continuous regression 

376 bart = step_trees(keys.pop(), bart) 1ab

377 return step_sigma(keys.pop(), bart) 1ab

378 

379 

380def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab

381 """ 

382 Forest sampling step of BART MCMC. 

383 

384 Parameters 

385 ---------- 

386 key 

387 A jax random key. 

388 bart 

389 A BART mcmc state, as created by `init`. 

390 

391 Returns 

392 ------- 

393 The new BART mcmc state. 

394 

395 Notes 

396 ----- 

397 This function zeroes the proposal counters. 

398 """ 

399 keys = split(key) 1ab

400 moves = propose_moves(keys.pop(), bart.forest, bart.max_split) 1ab

401 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab

402 

403 

404class Moves(Module): 1ab

405 """ 

406 Moves proposed to modify each tree. 

407 

408 Parameters 

409 ---------- 

410 allowed 

411 Whether the move is possible in the first place. There are additional 

412 constraints that could forbid it, but they are computed at acceptance 

413 time. 

414 grow 

415 Whether the move is a grow move or a prune move. 

416 num_growable 

417 The number of growable leaves in the original tree. 

418 node 

419 The index of the leaf to grow or node to prune. 

420 left 

421 right 

422 The indices of the children of 'node'. 

423 partial_ratio 

424 A factor of the Metropolis-Hastings ratio of the move. It lacks 

425 the likelihood ratio and the probability of proposing the prune 

426 move. If the move is PRUNE, the ratio is inverted. `None` once 

427 `log_trans_prior_ratio` has been computed. 

428 log_trans_prior_ratio 

429 The logarithm of the product of the transition and prior terms of the 

430 Metropolis-Hastings ratio for the acceptance of the proposed move. 

431 `None` if not yet computed. 

432 grow_var 

433 The decision axes of the new rules. 

434 grow_split 

435 The decision boundaries of the new rules. 

436 var_trees 

437 The updated decision axes of the trees, valid whatever move. 

438 logu 

439 The logarithm of a uniform (0, 1] random variable to be used to 

440 accept the move. It's in (-oo, 0]. 

441 acc 

442 Whether the move was accepted. `None` if not yet computed. 

443 to_prune 

444 Whether the final operation to apply the move is pruning. This indicates 

445 an accepted prune move or a rejected grow move. `None` if not yet 

446 computed. 

447 """ 

448 

449 allowed: Bool[Array, 'num_trees'] 1ab

450 grow: Bool[Array, 'num_trees'] 1ab

451 num_growable: UInt[Array, 'num_trees'] 1ab

452 node: UInt[Array, 'num_trees'] 1ab

453 left: UInt[Array, 'num_trees'] 1ab

454 right: UInt[Array, 'num_trees'] 1ab

455 partial_ratio: Float32[Array, 'num_trees'] | None 1ab

456 log_trans_prior_ratio: None | Float32[Array, 'num_trees'] 1ab

457 grow_var: UInt[Array, 'num_trees'] 1ab

458 grow_split: UInt[Array, 'num_trees'] 1ab

459 var_trees: UInt[Array, 'num_trees 2**(d-1)'] 1ab

460 logu: Float32[Array, 'num_trees'] 1ab

461 acc: None | Bool[Array, 'num_trees'] 1ab

462 to_prune: None | Bool[Array, 'num_trees'] 1ab

463 

464 

465def propose_moves( 1ab

466 key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p'] 

467) -> Moves: 

468 """ 

469 Propose moves for all the trees. 

470 

471 There are two types of moves: GROW (convert a leaf to a decision node and 

472 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a 

473 leaf, deleting its children). 

474 

475 Parameters 

476 ---------- 

477 key 

478 A jax random key. 

479 forest 

480 The `forest` field of a BART MCMC state. 

481 max_split 

482 The maximum split index for each variable, found in `State`. 

483 

484 Returns 

485 ------- 

486 The proposed move for each tree. 

487 """ 

488 num_trees, _ = forest.leaf_trees.shape 1ab

489 keys = split(key, 1 + 2 * num_trees) 1ab

490 

491 # compute moves 

492 grow_moves = propose_grow_moves( 1ab

493 keys.pop(num_trees), 

494 forest.var_trees, 

495 forest.split_trees, 

496 forest.affluence_trees, 

497 max_split, 

498 forest.p_nonterminal, 

499 forest.p_propose_grow, 

500 ) 

501 prune_moves = propose_prune_moves( 1ab

502 keys.pop(num_trees), 

503 forest.split_trees, 

504 forest.affluence_trees, 

505 forest.p_nonterminal, 

506 forest.p_propose_grow, 

507 ) 

508 

509 u, logu = random.uniform(keys.pop(), (2, num_trees), jnp.float32) 1ab

510 

511 # choose between grow or prune 

512 grow_allowed = grow_moves.num_growable.astype(bool) 1ab

513 p_grow = jnp.where(grow_allowed & prune_moves.allowed, 0.5, grow_allowed) 1ab

514 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab

515 

516 # compute children indices 

517 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab

518 left = node << 1 1ab

519 right = left + 1 1ab

520 

521 return Moves( 1ab

522 allowed=grow | prune_moves.allowed, 

523 grow=grow, 

524 num_growable=grow_moves.num_growable, 

525 node=node, 

526 left=left, 

527 right=right, 

528 partial_ratio=jnp.where( 

529 grow, grow_moves.partial_ratio, prune_moves.partial_ratio 

530 ), 

531 log_trans_prior_ratio=None, # will be set in complete_ratio 

532 grow_var=grow_moves.var, 

533 grow_split=grow_moves.split, 

534 var_trees=grow_moves.var_tree, 

535 logu=jnp.log1p(-logu), 

536 acc=None, # will be set in accept_moves_sequential_stage 

537 to_prune=None, # will be set in accept_moves_sequential_stage 

538 ) 

539 

540 

541class GrowMoves(Module): 1ab

542 """ 

543 Represent a proposed grow move for each tree. 

544 

545 Parameters 

546 ---------- 

547 num_growable 

548 The number of growable leaves. 

549 node 

550 The index of the leaf to grow. ``2 ** d`` if there are no growable 

551 leaves. 

552 var 

553 split 

554 The decision axis and boundary of the new rule. 

555 partial_ratio 

556 A factor of the Metropolis-Hastings ratio of the move. It lacks 

557 the likelihood ratio and the probability of proposing the prune 

558 move. 

559 var_tree 

560 The updated decision axes of the tree. 

561 """ 

562 

563 num_growable: UInt[Array, 'num_trees'] 1ab

564 node: UInt[Array, 'num_trees'] 1ab

565 var: UInt[Array, 'num_trees'] 1ab

566 split: UInt[Array, 'num_trees'] 1ab

567 partial_ratio: Float32[Array, 'num_trees'] 1ab

568 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab

569 

570 

571@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None)) 1ab

572def propose_grow_moves( 1ab

573 key: Key[Array, ''], 

574 var_tree: UInt[Array, '2**(d-1)'], 

575 split_tree: UInt[Array, '2**(d-1)'], 

576 affluence_tree: Bool[Array, '2**(d-1)'] | None, 

577 max_split: UInt[Array, 'p'], 

578 p_nonterminal: Float32[Array, 'd'], 

579 p_propose_grow: Float32[Array, '2**(d-1)'], 

580) -> GrowMoves: 

581 """ 

582 Propose a GROW move for each tree. 

583 

584 A GROW move picks a leaf node and converts it to a non-terminal node with 

585 two leaf children. 

586 

587 Parameters 

588 ---------- 

589 key 

590 A jax random key. 

591 var_tree 

592 The splitting axes of the tree. 

593 split_tree 

594 The splitting points of the tree. 

595 affluence_tree 

596 Whether a leaf has enough points to be grown. 

597 max_split 

598 The maximum split index for each variable. 

599 p_nonterminal 

600 The probability of a nonterminal node at each depth. 

601 p_propose_grow 

602 The unnormalized probability of choosing a leaf to grow. 

603 

604 Returns 

605 ------- 

606 An object representing the proposed move. 

607 

608 Notes 

609 ----- 

610 The move is not proposed if a leaf is already at maximum depth, or if a leaf 

611 has less than twice the requested minimum number of datapoints per leaf. 

612 This is marked by returning `num_growable` set to 0. 

613 

614 The move is also not be possible if the ancestors of a leaf have 

615 exhausted the possible decision rules that lead to a non-empty selection. 

616 This is marked by returning `var` set to `p` and `split` set to 0. But this 

617 does not block the move from counting as "proposed", even though it is 

618 predictably going to be rejected. This simplifies the MCMC and should not 

619 reduce efficiency if not in unrealistic corner cases. 

620 """ 

621 keys = split(key, 3) 1ab

622 

623 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab

624 keys.pop(), split_tree, affluence_tree, p_propose_grow 

625 ) 

626 

627 var = choose_variable(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow) 1ab

628 var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)) 1ab

629 

630 split_idx = choose_split(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow) 1ab

631 

632 ratio = compute_partial_ratio( 1ab

633 prob_choose, num_prunable, p_nonterminal, leaf_to_grow 

634 ) 

635 

636 return GrowMoves( 1ab

637 num_growable=num_growable, 

638 node=leaf_to_grow, 

639 var=var, 

640 split=split_idx, 

641 partial_ratio=ratio, 

642 var_tree=var_tree, 

643 ) 

644 

645 # TODO it is not clear to me how var=p and split=0 when the move is not 

646 # possible lead to corrent behavior downstream. Like, the move is proposed, 

647 # but then it's a noop? And since it's a noop, it makes no difference if 

648 # it's "accepted" or "rejected", it's like it's always rejected, so who 

649 # cares if the likelihood ratio or a lot of other numbers are wrong? Uhm. 

650 

651 

652def choose_leaf( 1ab

653 key: Key[Array, ''], 

654 split_tree: UInt[Array, '2**(d-1)'], 

655 affluence_tree: Bool[Array, '2**(d-1)'] | None, 

656 p_propose_grow: Float32[Array, '2**(d-1)'], 

657) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]: 

658 """ 

659 Choose a leaf node to grow in a tree. 

660 

661 Parameters 

662 ---------- 

663 key 

664 A jax random key. 

665 split_tree 

666 The splitting points of the tree. 

667 affluence_tree 

668 Whether a leaf has enough points that it could be split into two leaves 

669 satisfying the `min_points_per_leaf` requirement. 

670 p_propose_grow 

671 The unnormalized probability of choosing a leaf to grow. 

672 

673 Returns 

674 ------- 

675 leaf_to_grow : int 

676 The index of the leaf to grow. If ``num_growable == 0``, return 

677 ``2 ** d``. 

678 num_growable : int 

679 The number of leaf nodes that can be grown, i.e., are nonterminal 

680 and have at least twice `min_points_per_leaf` if set. 

681 prob_choose : float 

682 The (normalized) probability that this function had to choose that 

683 specific leaf, given the arguments. 

684 num_prunable : int 

685 The number of leaf parents that could be pruned, after converting the 

686 selected leaf to a non-terminal node. 

687 """ 

688 is_growable = growable_leaves(split_tree, affluence_tree) 1ab

689 num_growable = jnp.count_nonzero(is_growable) 1ab

690 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab

691 leaf_to_grow, distr_norm = categorical(key, distr) 1ab

692 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab

693 prob_choose = distr[leaf_to_grow] / distr_norm 1ab

694 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab

695 num_prunable = jnp.count_nonzero(is_parent) 1ab

696 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab

697 

698 

699def growable_leaves( 1ab

700 split_tree: UInt[Array, '2**(d-1)'], 

701 affluence_tree: Bool[Array, '2**(d-1)'] | None, 

702) -> Bool[Array, '2**(d-1)']: 

703 """ 

704 Return a mask indicating the leaf nodes that can be proposed for growth. 

705 

706 The condition is that a leaf is not at the bottom level and has at least two 

707 times the number of minimum points per leaf. 

708 

709 Parameters 

710 ---------- 

711 split_tree 

712 The splitting points of the tree. 

713 affluence_tree 

714 Whether a leaf has enough points to be grown. 

715 

716 Returns 

717 ------- 

718 The mask indicating the leaf nodes that can be proposed to grow. 

719 """ 

720 is_growable = grove.is_actual_leaf(split_tree) 1ab

721 if affluence_tree is not None: 1ab

722 is_growable &= affluence_tree 1ab

723 return is_growable 1ab

724 

725 

726def categorical( 1ab

727 key: Key[Array, ''], distr: Float32[Array, 'n'] 

728) -> tuple[Int32[Array, ''], Float32[Array, '']]: 

729 """ 

730 Return a random integer from an arbitrary distribution. 

731 

732 Parameters 

733 ---------- 

734 key 

735 A jax random key. 

736 distr 

737 An unnormalized probability distribution. 

738 

739 Returns 

740 ------- 

741 u : Int32[Array, ''] 

742 A random integer in the range ``[0, n)``. If all probabilities are zero, 

743 return ``n``. 

744 norm : Float32[Array, ''] 

745 The sum of `distr`. 

746 """ 

747 ecdf = jnp.cumsum(distr) 1ab

748 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab

749 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab

750 

751 

752def choose_variable( 1ab

753 key: Key[Array, ''], 

754 var_tree: UInt[Array, '2**(d-1)'], 

755 split_tree: UInt[Array, '2**(d-1)'], 

756 max_split: UInt[Array, 'p'], 

757 leaf_index: Int32[Array, ''], 

758) -> Int32[Array, '']: 

759 """ 

760 Choose a variable to split on for a new non-terminal node. 

761 

762 Parameters 

763 ---------- 

764 key 

765 A jax random key. 

766 var_tree 

767 The variable indices of the tree. 

768 split_tree 

769 The splitting points of the tree. 

770 max_split 

771 The maximum split index for each variable. 

772 leaf_index 

773 The index of the leaf to grow. 

774 

775 Returns 

776 ------- 

777 The index of the variable to split on. 

778 

779 Notes 

780 ----- 

781 The variable is chosen among the variables that have a non-empty range of 

782 allowed splits. If no variable has a non-empty range, return `p`. 

783 """ 

784 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab

785 return randint_exclude(key, max_split.size, var_to_ignore) 1ab

786 

787 

788def fully_used_variables( 1ab

789 var_tree: UInt[Array, '2**(d-1)'], 

790 split_tree: UInt[Array, '2**(d-1)'], 

791 max_split: UInt[Array, 'p'], 

792 leaf_index: Int32[Array, ''], 

793) -> UInt[Array, 'd-2']: 

794 """ 

795 Return a list of variables that have an empty split range at a given node. 

796 

797 Parameters 

798 ---------- 

799 var_tree 

800 The variable indices of the tree. 

801 split_tree 

802 The splitting points of the tree. 

803 max_split 

804 The maximum split index for each variable. 

805 leaf_index 

806 The index of the node, assumed to be valid for `var_tree`. 

807 

808 Returns 

809 ------- 

810 The indices of the variables that have an empty split range. 

811 

812 Notes 

813 ----- 

814 The number of unused variables is not known in advance. Unused values in the 

815 array are filled with `p`. The fill values are not guaranteed to be placed 

816 in any particular order, and variables may appear more than once. 

817 """ 

818 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab

819 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab

820 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab

821 num_split = r - l 1ab

822 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab

823 

824 

825def ancestor_variables( 1ab

826 var_tree: UInt[Array, '2**(d-1)'], 

827 max_split: UInt[Array, 'p'], 

828 node_index: Int32[Array, ''], 

829) -> UInt[Array, 'd-2']: 

830 """ 

831 Return the list of variables in the ancestors of a node. 

832 

833 Parameters 

834 ---------- 

835 var_tree : int array (2 ** (d - 1),) 

836 The variable indices of the tree. 

837 max_split : int array (p,) 

838 The maximum split index for each variable. Used only to get `p`. 

839 node_index : int 

840 The index of the node, assumed to be valid for `var_tree`. 

841 

842 Returns 

843 ------- 

844 The variable indices of the ancestors of the node. 

845 

846 Notes 

847 ----- 

848 The ancestors are the nodes going from the root to the parent of the node. 

849 The number of ancestors is not known at tracing time; unused spots in the 

850 output array are filled with `p`. 

851 """ 

852 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

853 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab

854 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab

855 

856 def loop(carry, _): 1ab

857 i, index, ancestor_vars = carry 1ab

858 index >>= 1 1ab

859 var = var_tree[index] 1ab

860 var = jnp.where(index, var, max_split.size) 1ab

861 ancestor_vars = ancestor_vars.at[i].set(var) 1ab

862 return (i - 1, index, ancestor_vars), None 1ab

863 

864 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab

865 return ancestor_vars 1ab

866 

867 

868def split_range( 1ab

869 var_tree: UInt[Array, '2**(d-1)'], 

870 split_tree: UInt[Array, '2**(d-1)'], 

871 max_split: UInt[Array, 'p'], 

872 node_index: Int32[Array, ''], 

873 ref_var: Int32[Array, ''], 

874) -> tuple[Int32[Array, ''], Int32[Array, '']]: 

875 """ 

876 Return the range of allowed splits for a variable at a given node. 

877 

878 Parameters 

879 ---------- 

880 var_tree 

881 The variable indices of the tree. 

882 split_tree 

883 The splitting points of the tree. 

884 max_split 

885 The maximum split index for each variable. 

886 node_index 

887 The index of the node, assumed to be valid for `var_tree`. 

888 ref_var 

889 The variable for which to measure the split range. 

890 

891 Returns 

892 ------- 

893 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=0. 

894 """ 

895 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab

896 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab

897 jnp.int32 

898 ) 

899 carry = 0, initial_r, node_index 1ab

900 

901 def loop(carry, _): 1ab

902 l, r, index = carry 1ab

903 right_child = (index & 1).astype(bool) 1ab

904 index >>= 1 1ab

905 split = split_tree[index] 1ab

906 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab

907 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab

908 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab

909 return (l, r, index), None 1ab

910 

911 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab

912 return l + 1, r 1ab

913 

914 

915def randint_exclude( 1ab

916 key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n'] 

917) -> Int32[Array, '']: 

918 """ 

919 Return a random integer in a range, excluding some values. 

920 

921 Parameters 

922 ---------- 

923 key 

924 A jax random key. 

925 sup 

926 The exclusive upper bound of the range. 

927 exclude 

928 The values to exclude from the range. Values greater than or equal to 

929 `sup` are ignored. Values can appear more than once. 

930 

931 Returns 

932 ------- 

933 A random integer `u` in the range ``[0, sup)`` such that ``u not in exclude``. 

934 

935 Notes 

936 ----- 

937 If all values in the range are excluded, return `sup`. 

938 """ 

939 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab

940 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab

941 u = random.randint(key, (), 0, num_allowed) 1ab

942 

943 def loop(u, i): 1ab

944 return jnp.where(i <= u, u + 1, u), None 1ab

945 

946 u, _ = lax.scan(loop, u, exclude) 1ab

947 return u 1ab

948 

949 

950def choose_split( 1ab

951 key: Key[Array, ''], 

952 var_tree: UInt[Array, '2**(d-1)'], 

953 split_tree: UInt[Array, '2**(d-1)'], 

954 max_split: UInt[Array, 'p'], 

955 leaf_index: Int32[Array, ''], 

956) -> Int32[Array, '']: 

957 """ 

958 Choose a split point for a new non-terminal node. 

959 

960 Parameters 

961 ---------- 

962 key 

963 A jax random key. 

964 var_tree 

965 The splitting axes of the tree. 

966 split_tree 

967 The splitting points of the tree. 

968 max_split 

969 The maximum split index for each variable. 

970 leaf_index 

971 The index of the leaf to grow. It is assumed that `var_tree` already 

972 contains the target variable at this index. 

973 

974 Returns 

975 ------- 

976 The cutpoint. If ``var_tree[leaf_index]`` is out of bounds, return 0. 

977 """ 

978 var = var_tree[leaf_index] 1ab

979 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab

980 return random.randint(key, (), l, r) 1ab

981 

982 # TODO what happens if leaf_index is out of bounds? And is the value used 

983 # in that case? 

984 

985 

986def compute_partial_ratio( 1ab

987 prob_choose: Float32[Array, ''], 

988 num_prunable: Int32[Array, ''], 

989 p_nonterminal: Float32[Array, 'd'], 

990 leaf_to_grow: Int32[Array, ''], 

991) -> Float32[Array, '']: 

992 """ 

993 Compute the product of the transition and prior ratios of a grow move. 

994 

995 Parameters 

996 ---------- 

997 prob_choose 

998 The probability that the leaf had to be chosen amongst the growable 

999 leaves. 

1000 num_prunable 

1001 The number of leaf parents that could be pruned, after converting the 

1002 leaf to be grown to a non-terminal node. 

1003 p_nonterminal 

1004 The probability of a nonterminal node at each depth. 

1005 leaf_to_grow 

1006 The index of the leaf to grow. 

1007 

1008 Returns 

1009 ------- 

1010 The partial transition ratio times the prior ratio. 

1011 

1012 Notes 

1013 ----- 

1014 The transition ratio is P(new tree => old tree) / P(old tree => new tree). 

1015 The "partial" transition ratio returned is missing the factor P(propose 

1016 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). 

1017 """ 

1018 # the two ratios also contain factors num_available_split * 

1019 # num_available_var, but they cancel out 

1020 

1021 # p_prune can't be computed here because it needs the count trees, which are 

1022 # computed in the acceptance phase 

1023 

1024 prune_allowed = leaf_to_grow != 1 1ab

1025 # prune allowed <---> the initial tree is not a root 

1026 # leaf to grow is root --> the tree can only be a root 

1027 # tree is a root --> the only leaf I can grow is root 

1028 

1029 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab

1030 

1031 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab

1032 

1033 depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow] 1ab

1034 p_parent = p_nonterminal[depth] 1ab

1035 cp_children = 1 - p_nonterminal[depth + 1] 1ab

1036 tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent) 1ab

1037 

1038 return tree_ratio / inv_trans_ratio 1ab

1039 

1040 

1041class PruneMoves(Module): 1ab

1042 """ 

1043 Represent a proposed prune move for each tree. 

1044 

1045 Parameters 

1046 ---------- 

1047 allowed 

1048 Whether the move is possible. 

1049 node 

1050 The index of the node to prune. ``2 ** d`` if no node can be pruned. 

1051 partial_ratio 

1052 A factor of the Metropolis-Hastings ratio of the move. It lacks 

1053 the likelihood ratio and the probability of proposing the prune 

1054 move. This ratio is inverted, and is meant to be inverted back in 

1055 `accept_move_and_sample_leaves`. 

1056 """ 

1057 

1058 allowed: Bool[Array, 'num_trees'] 1ab

1059 node: UInt[Array, 'num_trees'] 1ab

1060 partial_ratio: Float32[Array, 'num_trees'] 1ab

1061 

1062 

1063@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab

1064def propose_prune_moves( 1ab

1065 key: Key[Array, ''], 

1066 split_tree: UInt[Array, '2**(d-1)'], 

1067 affluence_tree: Bool[Array, '2**(d-1)'] | None, 

1068 p_nonterminal: Float32[Array, 'd'], 

1069 p_propose_grow: Float32[Array, '2**(d-1)'], 

1070) -> PruneMoves: 

1071 """ 

1072 Tree structure prune move proposal of BART MCMC. 

1073 

1074 Parameters 

1075 ---------- 

1076 key 

1077 A jax random key. 

1078 split_tree 

1079 The splitting points of the tree. 

1080 affluence_tree 

1081 Whether a leaf has enough points to be grown. 

1082 p_nonterminal 

1083 The probability of a nonterminal node at each depth. 

1084 p_propose_grow 

1085 The unnormalized probability of choosing a leaf to grow. 

1086 

1087 Returns 

1088 ------- 

1089 An object representing the proposed moves. 

1090 """ 

1091 node_to_prune, num_prunable, prob_choose = choose_leaf_parent( 1ab

1092 key, split_tree, affluence_tree, p_propose_grow 

1093 ) 

1094 allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root 1ab

1095 

1096 ratio = compute_partial_ratio( 1ab

1097 prob_choose, num_prunable, p_nonterminal, node_to_prune 

1098 ) 

1099 

1100 return PruneMoves( 1ab

1101 allowed=allowed, 

1102 node=node_to_prune, 

1103 partial_ratio=ratio, 

1104 ) 

1105 

1106 

1107def choose_leaf_parent( 1ab

1108 key: Key[Array, ''], 

1109 split_tree: UInt[Array, '2**(d-1)'], 

1110 affluence_tree: Bool[Array, '2**(d-1)'] | None, 

1111 p_propose_grow: Float32[Array, '2**(d-1)'], 

1112) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, '']]: 

1113 """ 

1114 Pick a non-terminal node with leaf children to prune in a tree. 

1115 

1116 Parameters 

1117 ---------- 

1118 key 

1119 A jax random key. 

1120 split_tree 

1121 The splitting points of the tree. 

1122 affluence_tree 

1123 Whether a leaf has enough points to be grown. 

1124 p_propose_grow 

1125 The unnormalized probability of choosing a leaf to grow. 

1126 

1127 Returns 

1128 ------- 

1129 node_to_prune : Int32[Array, ''] 

1130 The index of the node to prune. If ``num_prunable == 0``, return 

1131 ``2 ** d``. 

1132 num_prunable : Int32[Array, ''] 

1133 The number of leaf parents that could be pruned. 

1134 prob_choose : Float32[Array, ''] 

1135 The (normalized) probability that `choose_leaf` would chose 

1136 `node_to_prune` as leaf to grow, if passed the tree where 

1137 `node_to_prune` had been pruned. 

1138 """ 

1139 is_prunable = grove.is_leaves_parent(split_tree) 1ab

1140 num_prunable = jnp.count_nonzero(is_prunable) 1ab

1141 node_to_prune = randint_masked(key, is_prunable) 1ab

1142 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab

1143 

1144 split_tree = split_tree.at[node_to_prune].set(0) 1ab

1145 if affluence_tree is not None: 1ab

1146 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab

1147 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab

1148 prob_choose = p_propose_grow[node_to_prune] 1ab

1149 prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab

1150 

1151 return node_to_prune, num_prunable, prob_choose 1ab

1152 

1153 

1154def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']: 1ab

1155 """ 

1156 Return a random integer in a range, including only some values. 

1157 

1158 Parameters 

1159 ---------- 

1160 key 

1161 A jax random key. 

1162 mask 

1163 The mask indicating the allowed values. 

1164 

1165 Returns 

1166 ------- 

1167 A random integer in the range ``[0, n)`` such that ``mask[u] == True``. 

1168 

1169 Notes 

1170 ----- 

1171 If all values in the mask are `False`, return `n`. 

1172 """ 

1173 ecdf = jnp.cumsum(mask) 1ab

1174 u = random.randint(key, (), 0, ecdf[-1]) 1ab

1175 return jnp.searchsorted(ecdf, u, 'right') 1ab

1176 

1177 

1178def accept_moves_and_sample_leaves( 1ab

1179 key: Key[Array, ''], bart: State, moves: Moves 

1180) -> State: 

1181 """ 

1182 Accept or reject the proposed moves and sample the new leaf values. 

1183 

1184 Parameters 

1185 ---------- 

1186 key 

1187 A jax random key. 

1188 bart 

1189 A valid BART mcmc state. 

1190 moves 

1191 The proposed moves, see `propose_moves`. 

1192 

1193 Returns 

1194 ------- 

1195 A new (valid) BART mcmc state. 

1196 """ 

1197 pso = accept_moves_parallel_stage(key, bart, moves) 1ab

1198 bart, moves = accept_moves_sequential_stage(pso) 1ab

1199 return accept_moves_final_stage(bart, moves) 1ab

1200 

1201 

1202class Counts(Module): 1ab

1203 """ 

1204 Number of datapoints in the nodes involved in proposed moves for each tree. 

1205 

1206 Parameters 

1207 ---------- 

1208 left 

1209 Number of datapoints in the left child. 

1210 right 

1211 Number of datapoints in the right child. 

1212 total 

1213 Number of datapoints in the parent (``= left + right``). 

1214 """ 

1215 

1216 left: UInt[Array, 'num_trees'] 1ab

1217 right: UInt[Array, 'num_trees'] 1ab

1218 total: UInt[Array, 'num_trees'] 1ab

1219 

1220 

1221class Precs(Module): 1ab

1222 """ 

1223 Likelihood precision scale in the nodes involved in proposed moves for each tree. 

1224 

1225 The "likelihood precision scale" of a tree node is the sum of the inverse 

1226 squared error scales of the datapoints selected by the node. 

1227 

1228 Parameters 

1229 ---------- 

1230 left 

1231 Likelihood precision scale in the left child. 

1232 right 

1233 Likelihood precision scale in the right child. 

1234 total 

1235 Likelihood precision scale in the parent (``= left + right``). 

1236 """ 

1237 

1238 left: Float32[Array, 'num_trees'] 1ab

1239 right: Float32[Array, 'num_trees'] 1ab

1240 total: Float32[Array, 'num_trees'] 1ab

1241 

1242 

1243class PreLkV(Module): 1ab

1244 """ 

1245 Non-sequential terms of the likelihood ratio for each tree. 

1246 

1247 These terms can be computed in parallel across trees. 

1248 

1249 Parameters 

1250 ---------- 

1251 sigma2_left 

1252 The noise variance in the left child of the leaves grown or pruned by 

1253 the moves. 

1254 sigma2_right 

1255 The noise variance in the right child of the leaves grown or pruned by 

1256 the moves. 

1257 sigma2_total 

1258 The noise variance in the total of the leaves grown or pruned by the 

1259 moves. 

1260 sqrt_term 

1261 The **logarithm** of the square root term of the likelihood ratio. 

1262 """ 

1263 

1264 sigma2_left: Float32[Array, 'num_trees'] 1ab

1265 sigma2_right: Float32[Array, 'num_trees'] 1ab

1266 sigma2_total: Float32[Array, 'num_trees'] 1ab

1267 sqrt_term: Float32[Array, 'num_trees'] 1ab

1268 

1269 

1270class PreLk(Module): 1ab

1271 """ 

1272 Non-sequential terms of the likelihood ratio shared by all trees. 

1273 

1274 Parameters 

1275 ---------- 

1276 exp_factor 

1277 The factor to multiply the likelihood ratio by, shared by all trees. 

1278 """ 

1279 

1280 exp_factor: Float32[Array, ''] 1ab

1281 

1282 

1283class PreLf(Module): 1ab

1284 """ 

1285 Pre-computed terms used to sample leaves from their posterior. 

1286 

1287 These terms can be computed in parallel across trees. 

1288 

1289 Parameters 

1290 ---------- 

1291 mean_factor 

1292 The factor to be multiplied by the sum of the scaled residuals to 

1293 obtain the posterior mean. 

1294 centered_leaves 

1295 The mean-zero normal values to be added to the posterior mean to 

1296 obtain the posterior leaf samples. 

1297 """ 

1298 

1299 mean_factor: Float32[Array, 'num_trees 2**d'] 1ab

1300 centered_leaves: Float32[Array, 'num_trees 2**d'] 1ab

1301 

1302 

1303class ParallelStageOut(Module): 1ab

1304 """ 

1305 The output of `accept_moves_parallel_stage`. 

1306 

1307 Parameters 

1308 ---------- 

1309 bart 

1310 A partially updated BART mcmc state. 

1311 moves 

1312 The proposed moves, with `partial_ratio` set to `None` and 

1313 `log_trans_prior_ratio` set to its final value. 

1314 prec_trees 

1315 The likelihood precision scale in each potential or actual leaf node. If 

1316 there is no precision scale, this is the number of points in each leaf. 

1317 move_counts 

1318 The counts of the number of points in the the nodes modified by the 

1319 moves. If `bart.min_points_per_leaf` is not set and 

1320 `bart.prec_scale` is set, they are not computed. 

1321 move_precs 

1322 The likelihood precision scale in each node modified by the moves. If 

1323 `bart.prec_scale` is not set, this is set to `move_counts`. 

1324 prelkv 

1325 prelk 

1326 prelf 

1327 Objects with pre-computed terms of the likelihood ratios and leaf 

1328 samples. 

1329 """ 

1330 

1331 bart: State 1ab

1332 moves: Moves 1ab

1333 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab

1334 move_counts: Counts | None 1ab

1335 move_precs: Precs | Counts 1ab

1336 prelkv: PreLkV 1ab

1337 prelk: PreLk 1ab

1338 prelf: PreLf 1ab

1339 

1340 

1341def accept_moves_parallel_stage( 1ab

1342 key: Key[Array, ''], bart: State, moves: Moves 

1343) -> ParallelStageOut: 

1344 """ 

1345 Pre-computes quantities used to accept moves, in parallel across trees. 

1346 

1347 Parameters 

1348 ---------- 

1349 key : jax.dtypes.prng_key array 

1350 A jax random key. 

1351 bart : dict 

1352 A BART mcmc state. 

1353 moves : dict 

1354 The proposed moves, see `propose_moves`. 

1355 

1356 Returns 

1357 ------- 

1358 An object with all that could be done in parallel. 

1359 """ 

1360 # where the move is grow, modify the state like the move was accepted 

1361 bart = replace( 1ab

1362 bart, 

1363 forest=replace( 

1364 bart.forest, 

1365 var_trees=moves.var_trees, 

1366 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X), 

1367 leaf_trees=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_trees, moves), 

1368 ), 

1369 ) 

1370 

1371 # count number of datapoints per leaf 

1372 if bart.forest.min_points_per_leaf is not None or bart.prec_scale is None: 1ab

1373 count_trees, move_counts = compute_count_trees( 1ab

1374 bart.forest.leaf_indices, moves, bart.forest.count_batch_size 

1375 ) 

1376 else: 

1377 # move_counts is passed later to a function, but then is unused under 

1378 # this condition 

1379 move_counts = None 1ab

1380 

1381 # Check if some nodes can't surely be grown because they don't have enough 

1382 # datapoints. This check is not actually used now, it will be used at the 

1383 # beginning of the next step to propose moves. 

1384 if bart.forest.min_points_per_leaf is not None: 1ab

1385 count_half_trees = count_trees[:, : bart.forest.var_trees.shape[1]] 1ab

1386 bart = replace( 1ab

1387 bart, 

1388 forest=replace( 

1389 bart.forest, 

1390 affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf, 

1391 ), 

1392 ) 

1393 

1394 # count number of datapoints per leaf, weighted by error precision scale 

1395 if bart.prec_scale is None: 1ab

1396 prec_trees = count_trees 1ab

1397 move_precs = move_counts 1ab

1398 else: 

1399 prec_trees, move_precs = compute_prec_trees( 1ab

1400 bart.prec_scale, 

1401 bart.forest.leaf_indices, 

1402 moves, 

1403 bart.forest.count_batch_size, 

1404 ) 

1405 

1406 # compute some missing information about moves 

1407 moves = complete_ratio(moves, move_counts, bart.forest.min_points_per_leaf) 1ab

1408 bart = replace( 1ab

1409 bart, 

1410 forest=replace( 

1411 bart.forest, 

1412 grow_prop_count=jnp.sum(moves.grow), 

1413 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow), 

1414 ), 

1415 ) 

1416 

1417 prelkv, prelk = precompute_likelihood_terms( 1ab

1418 bart.sigma2, bart.forest.sigma_mu2, move_precs 

1419 ) 

1420 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab

1421 

1422 return ParallelStageOut( 1ab

1423 bart=bart, 

1424 moves=moves, 

1425 prec_trees=prec_trees, 

1426 move_counts=move_counts, 

1427 move_precs=move_precs, 

1428 prelkv=prelkv, 

1429 prelk=prelk, 

1430 prelf=prelf, 

1431 ) 

1432 

1433 

1434@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab

1435def apply_grow_to_indices( 1ab

1436 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n'] 

1437) -> UInt[Array, 'num_trees n']: 

1438 """ 

1439 Update the leaf indices to apply a grow move. 

1440 

1441 Parameters 

1442 ---------- 

1443 moves 

1444 The proposed moves, see `propose_moves`. 

1445 leaf_indices 

1446 The index of the leaf each datapoint falls into. 

1447 X 

1448 The predictors matrix. 

1449 

1450 Returns 

1451 ------- 

1452 The updated leaf indices. 

1453 """ 

1454 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab

1455 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab

1456 tree_size = jnp.array(2 * moves.var_trees.size) 1ab

1457 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab

1458 return jnp.where( 1ab

1459 leaf_indices == node_to_update, 

1460 left_child + go_right, 

1461 leaf_indices, 

1462 ) 

1463 

1464 

1465def compute_count_trees( 1ab

1466 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None 

1467) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]: 

1468 """ 

1469 Count the number of datapoints in each leaf. 

1470 

1471 Parameters 

1472 ---------- 

1473 leaf_indices 

1474 The index of the leaf each datapoint falls into, with the deeper version 

1475 of the tree (post-GROW, pre-PRUNE). 

1476 moves 

1477 The proposed moves, see `propose_moves`. 

1478 batch_size 

1479 The data batch size to use for the summation. 

1480 

1481 Returns 

1482 ------- 

1483 count_trees : Int32[Array, 'num_trees 2**d'] 

1484 The number of points in each potential or actual leaf node. 

1485 counts : Counts 

1486 The counts of the number of points in the leaves grown or pruned by the 

1487 moves. 

1488 """ 

1489 num_trees, tree_size = moves.var_trees.shape 1ab

1490 tree_size *= 2 1ab

1491 tree_indices = jnp.arange(num_trees) 1ab

1492 

1493 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab

1494 

1495 # count datapoints in nodes modified by move 

1496 left = count_trees[tree_indices, moves.left] 1ab

1497 right = count_trees[tree_indices, moves.right] 1ab

1498 counts = Counts(left=left, right=right, total=left + right) 1ab

1499 

1500 # write count into non-leaf node 

1501 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab

1502 

1503 return count_trees, counts 1ab

1504 

1505 

1506def count_datapoints_per_leaf( 1ab

1507 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None 

1508) -> Int32[Array, 'num_trees 2**(d-1)']: 

1509 """ 

1510 Count the number of datapoints in each leaf. 

1511 

1512 Parameters 

1513 ---------- 

1514 leaf_indices 

1515 The index of the leaf each datapoint falls into. 

1516 tree_size 

1517 The size of the leaf tree array (2 ** d). 

1518 batch_size 

1519 The data batch size to use for the summation. 

1520 

1521 Returns 

1522 ------- 

1523 The number of points in each leaf node. 

1524 """ 

1525 if batch_size is None: 1ab

1526 return _count_scan(leaf_indices, tree_size) 1ab

1527 else: 

1528 return _count_vec(leaf_indices, tree_size, batch_size) 1ab

1529 

1530 

1531def _count_scan( 1ab

1532 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int 

1533) -> Int32[Array, 'num_trees {tree_size}']: 

1534 def loop(_, leaf_indices): 1ab

1535 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab

1536 

1537 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab

1538 return count_trees 1ab

1539 

1540 

1541def _aggregate_scatter( 1ab

1542 values: Shaped[Array, '*'], 

1543 indices: Integer[Array, '*'], 

1544 size: int, 

1545 dtype: jnp.dtype, 

1546) -> Shaped[Array, '{size}']: 

1547 return jnp.zeros(size, dtype).at[indices].add(values) 1ab

1548 

1549 

1550def _count_vec( 1ab

1551 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int 

1552) -> Int32[Array, 'num_trees 2**(d-1)']: 

1553 return _aggregate_batched_alltrees( 1ab

1554 1, leaf_indices, tree_size, jnp.uint32, batch_size 

1555 ) 

1556 # uint16 is super-slow on gpu, don't use it even if n < 2^16 

1557 

1558 

1559def _aggregate_batched_alltrees( 1ab

1560 values: Shaped[Array, '*'], 

1561 indices: UInt[Array, 'num_trees n'], 

1562 size: int, 

1563 dtype: jnp.dtype, 

1564 batch_size: int, 

1565) -> Shaped[Array, 'num_trees {size}']: 

1566 num_trees, n = indices.shape 1ab

1567 tree_indices = jnp.arange(num_trees) 1ab

1568 nbatches = n // batch_size + bool(n % batch_size) 1ab

1569 batch_indices = jnp.arange(n) % nbatches 1ab

1570 return ( 1ab

1571 jnp.zeros((num_trees, size, nbatches), dtype) 

1572 .at[tree_indices[:, None], indices, batch_indices] 

1573 .add(values) 

1574 .sum(axis=2) 

1575 ) 

1576 

1577 

1578def compute_prec_trees( 1ab

1579 prec_scale: Float32[Array, 'n'], 

1580 leaf_indices: UInt[Array, 'num_trees n'], 

1581 moves: Moves, 

1582 batch_size: int | None, 

1583) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]: 

1584 """ 

1585 Compute the likelihood precision scale in each leaf. 

1586 

1587 Parameters 

1588 ---------- 

1589 prec_scale 

1590 The scale of the precision of the error on each datapoint. 

1591 leaf_indices 

1592 The index of the leaf each datapoint falls into, with the deeper version 

1593 of the tree (post-GROW, pre-PRUNE). 

1594 moves 

1595 The proposed moves, see `propose_moves`. 

1596 batch_size 

1597 The data batch size to use for the summation. 

1598 

1599 Returns 

1600 ------- 

1601 prec_trees : Float32[Array, 'num_trees 2**d'] 

1602 The likelihood precision scale in each potential or actual leaf node. 

1603 precs : Precs 

1604 The likelihood precision scale in the nodes involved in the moves. 

1605 """ 

1606 num_trees, tree_size = moves.var_trees.shape 1ab

1607 tree_size *= 2 1ab

1608 tree_indices = jnp.arange(num_trees) 1ab

1609 

1610 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab

1611 

1612 # prec datapoints in nodes modified by move 

1613 left = prec_trees[tree_indices, moves.left] 1ab

1614 right = prec_trees[tree_indices, moves.right] 1ab

1615 precs = Precs(left=left, right=right, total=left + right) 1ab

1616 

1617 # write prec into non-leaf node 

1618 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab

1619 

1620 return prec_trees, precs 1ab

1621 

1622 

1623def prec_per_leaf( 1ab

1624 prec_scale: Float32[Array, 'n'], 

1625 leaf_indices: UInt[Array, 'num_trees n'], 

1626 tree_size: int, 

1627 batch_size: int | None, 

1628) -> Float32[Array, 'num_trees {tree_size}']: 

1629 """ 

1630 Compute the likelihood precision scale in each leaf. 

1631 

1632 Parameters 

1633 ---------- 

1634 prec_scale 

1635 The scale of the precision of the error on each datapoint. 

1636 leaf_indices 

1637 The index of the leaf each datapoint falls into. 

1638 tree_size 

1639 The size of the leaf tree array (2 ** d). 

1640 batch_size 

1641 The data batch size to use for the summation. 

1642 

1643 Returns 

1644 ------- 

1645 The likelihood precision scale in each leaf node. 

1646 """ 

1647 if batch_size is None: 1647 ↛ 1648line 1647 didn't jump to line 1648 because the condition on line 1647 was never true1ab

1648 return _prec_scan(prec_scale, leaf_indices, tree_size) 

1649 else: 

1650 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size) 1ab

1651 

1652 

1653def _prec_scan( 1ab

1654 prec_scale: Float32[Array, 'n'], 

1655 leaf_indices: UInt[Array, 'num_trees n'], 

1656 tree_size: int, 

1657) -> Float32[Array, 'num_trees {tree_size}']: 

1658 def loop(_, leaf_indices): 

1659 return None, _aggregate_scatter( 

1660 prec_scale, leaf_indices, tree_size, jnp.float32 

1661 ) 

1662 

1663 _, prec_trees = lax.scan(loop, None, leaf_indices) 

1664 return prec_trees 

1665 

1666 

1667def _prec_vec( 1ab

1668 prec_scale: Float32[Array, 'n'], 

1669 leaf_indices: UInt[Array, 'num_trees n'], 

1670 tree_size: int, 

1671 batch_size: int, 

1672) -> Float32[Array, 'num_trees {tree_size}']: 

1673 return _aggregate_batched_alltrees( 1ab

1674 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size 

1675 ) 

1676 

1677 

1678def complete_ratio( 1ab

1679 moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None 

1680) -> Moves: 

1681 """ 

1682 Complete non-likelihood MH ratio calculation. 

1683 

1684 This function adds the probability of choosing the prune move. 

1685 

1686 Parameters 

1687 ---------- 

1688 moves 

1689 The proposed moves, see `propose_moves`. 

1690 move_counts 

1691 The counts of the number of points in the the nodes modified by the 

1692 moves. 

1693 min_points_per_leaf 

1694 The minimum number of data points in a leaf node. 

1695 

1696 Returns 

1697 ------- 

1698 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set. 

1699 """ 

1700 p_prune = compute_p_prune(moves, move_counts, min_points_per_leaf) 1ab

1701 return replace( 1ab

1702 moves, 

1703 log_trans_prior_ratio=jnp.log(moves.partial_ratio * p_prune), 

1704 partial_ratio=None, 

1705 ) 

1706 

1707 

1708def compute_p_prune( 1ab

1709 moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None 

1710) -> Float32[Array, 'num_trees']: 

1711 """ 

1712 Compute the probability of proposing a prune move for each tree. 

1713 

1714 Parameters 

1715 ---------- 

1716 moves 

1717 The proposed moves, see `propose_moves`. 

1718 move_counts 

1719 The number of datapoints in the proposed children of the leaf to grow. 

1720 Not used if `min_points_per_leaf` is `None`. 

1721 min_points_per_leaf 

1722 The minimum number of data points in a leaf node. 

1723 

1724 Returns 

1725 ------- 

1726 The probability of proposing a prune move. 

1727 

1728 Notes 

1729 ----- 

1730 This probability is computed for going from the state with the deeper tree 

1731 to the one with the shallower one. This means, if grow: after accepting the 

1732 grow move, if prune: right away. 

1733 """ 

1734 # calculation in case the move is grow 

1735 other_growable_leaves = moves.num_growable >= 2 1ab

1736 new_leaves_growable = moves.node < moves.var_trees.shape[1] // 2 1ab

1737 if min_points_per_leaf is not None: 1ab

1738 assert move_counts is not None 1ab

1739 any_above_threshold = move_counts.left >= 2 * min_points_per_leaf 1ab

1740 any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf 1ab

1741 new_leaves_growable &= any_above_threshold 1ab

1742 grow_again_allowed = other_growable_leaves | new_leaves_growable 1ab

1743 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab

1744 

1745 # calculation in case the move is prune 

1746 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab

1747 

1748 return jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab

1749 

1750 

1751@vmap_nodoc 1ab

1752def adapt_leaf_trees_to_grow_indices( 1ab

1753 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves 

1754) -> Float32[Array, 'num_trees 2**d']: 

1755 """ 

1756 Modify leaves such that post-grow indices work on the original tree. 

1757 

1758 The value of the leaf to grow is copied to what would be its children if the 

1759 grow move was accepted. 

1760 

1761 Parameters 

1762 ---------- 

1763 leaf_trees 

1764 The leaf values. 

1765 moves 

1766 The proposed moves, see `propose_moves`. 

1767 

1768 Returns 

1769 ------- 

1770 The modified leaf values. 

1771 """ 

1772 values_at_node = leaf_trees[moves.node] 1ab

1773 return ( 1ab

1774 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)] 

1775 .set(values_at_node) 

1776 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)] 

1777 .set(values_at_node) 

1778 ) 

1779 

1780 

1781def precompute_likelihood_terms( 1ab

1782 sigma2: Float32[Array, ''], 

1783 sigma_mu2: Float32[Array, ''], 

1784 move_precs: Precs | Counts, 

1785) -> tuple[PreLkV, PreLk]: 

1786 """ 

1787 Pre-compute terms used in the likelihood ratio of the acceptance step. 

1788 

1789 Parameters 

1790 ---------- 

1791 sigma2 

1792 The error variance, or the global error variance factor is `prec_scale` 

1793 is set. 

1794 sigma_mu2 

1795 The prior variance of each leaf. 

1796 move_precs 

1797 The likelihood precision scale in the leaves grown or pruned by the 

1798 moves, under keys 'left', 'right', and 'total' (left + right). 

1799 

1800 Returns 

1801 ------- 

1802 prelkv : PreLkV 

1803 Dictionary with pre-computed terms of the likelihood ratio, one per 

1804 tree. 

1805 prelk : PreLk 

1806 Dictionary with pre-computed terms of the likelihood ratio, shared by 

1807 all trees. 

1808 """ 

1809 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab

1810 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab

1811 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab

1812 prelkv = PreLkV( 1ab

1813 sigma2_left=sigma2_left, 

1814 sigma2_right=sigma2_right, 

1815 sigma2_total=sigma2_total, 

1816 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2, 

1817 ) 

1818 return prelkv, PreLk( 1ab

1819 exp_factor=sigma_mu2 / (2 * sigma2), 

1820 ) 

1821 

1822 

1823def precompute_leaf_terms( 1ab

1824 key: Key[Array, ''], 

1825 prec_trees: Float32[Array, 'num_trees 2**d'], 

1826 sigma2: Float32[Array, ''], 

1827 sigma_mu2: Float32[Array, ''], 

1828) -> PreLf: 

1829 """ 

1830 Pre-compute terms used to sample leaves from their posterior. 

1831 

1832 Parameters 

1833 ---------- 

1834 key 

1835 A jax random key. 

1836 prec_trees 

1837 The likelihood precision scale in each potential or actual leaf node. 

1838 sigma2 

1839 The error variance, or the global error variance factor if `prec_scale` 

1840 is set. 

1841 sigma_mu2 

1842 The prior variance of each leaf. 

1843 

1844 Returns 

1845 ------- 

1846 Pre-computed terms for leaf sampling. 

1847 """ 

1848 prec_lk = prec_trees / sigma2 1ab

1849 prec_prior = lax.reciprocal(sigma_mu2) 1ab

1850 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab

1851 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab

1852 return PreLf( 1ab

1853 mean_factor=var_post / sigma2, 

1854 # mean = mean_lk * prec_lk * var_post 

1855 # resid_tree = mean_lk * prec_tree --> 

1856 # --> mean_lk = resid_tree / prec_tree (kind of) 

1857 # mean_factor = 

1858 # = mean / resid_tree = 

1859 # = resid_tree / prec_tree * prec_lk * var_post / resid_tree = 

1860 # = 1 / prec_tree * prec_tree / sigma2 * var_post = 

1861 # = var_post / sigma2 

1862 centered_leaves=z * jnp.sqrt(var_post), 

1863 ) 

1864 

1865 

1866def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab

1867 """ 

1868 Accept/reject the moves one tree at a time. 

1869 

1870 This is the most performance-sensitive function because it contains all and 

1871 only the parts of the algorithm that can not be parallelized across trees. 

1872 

1873 Parameters 

1874 ---------- 

1875 pso 

1876 The output of `accept_moves_parallel_stage`. 

1877 

1878 Returns 

1879 ------- 

1880 bart : State 

1881 A partially updated BART mcmc state. 

1882 moves : Moves 

1883 The accepted/rejected moves, with `acc` and `to_prune` set. 

1884 """ 

1885 

1886 def loop(resid, pt): 1ab

1887 resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves( 1ab

1888 resid, 

1889 SeqStageInAllTrees( 

1890 pso.bart.X, 

1891 pso.bart.forest.resid_batch_size, 

1892 pso.bart.prec_scale, 

1893 pso.bart.forest.min_points_per_leaf, 

1894 pso.bart.forest.log_likelihood is not None, 

1895 pso.prelk, 

1896 ), 

1897 pt, 

1898 ) 

1899 return resid, (leaf_tree, acc, to_prune, ratios) 1ab

1900 

1901 pts = SeqStageInPerTree( 1ab

1902 pso.bart.forest.leaf_trees, 

1903 pso.prec_trees, 

1904 pso.moves, 

1905 pso.move_counts, 

1906 pso.move_precs, 

1907 pso.bart.forest.leaf_indices, 

1908 pso.prelkv, 

1909 pso.prelf, 

1910 ) 

1911 resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, pso.bart.resid, pts) 1ab

1912 

1913 save_ratios = pso.bart.forest.log_likelihood is not None 1ab

1914 bart = replace( 1ab

1915 pso.bart, 

1916 resid=resid, 

1917 forest=replace( 

1918 pso.bart.forest, 

1919 leaf_trees=leaf_trees, 

1920 log_likelihood=ratios['log_likelihood'] if save_ratios else None, 

1921 log_trans_prior=ratios['log_trans_prior'] if save_ratios else None, 

1922 ), 

1923 ) 

1924 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab

1925 

1926 return bart, moves 1ab

1927 

1928 

1929class SeqStageInAllTrees(Module): 1ab

1930 """ 

1931 The inputs to `accept_move_and_sample_leaves` that are the same for all trees. 

1932 

1933 Parameters 

1934 ---------- 

1935 X 

1936 The predictors. 

1937 resid_batch_size 

1938 The batch size for computing the sum of residuals in each leaf. 

1939 prec_scale 

1940 The scale of the precision of the error on each datapoint. If None, it 

1941 is assumed to be 1. 

1942 min_points_per_leaf 

1943 The minimum number of data points in a leaf node. 

1944 save_ratios 

1945 Whether to save the acceptance ratios. 

1946 prelk 

1947 The pre-computed terms of the likelihood ratio which are shared across 

1948 trees. 

1949 """ 

1950 

1951 X: UInt[Array, 'p n'] 1ab

1952 resid_batch_size: int | None 1ab

1953 prec_scale: Float32[Array, 'n'] | None 1ab

1954 min_points_per_leaf: Int32[Array, ''] | None 1ab

1955 save_ratios: bool 1ab

1956 prelk: PreLk 1ab

1957 

1958 

1959class SeqStageInPerTree(Module): 1ab

1960 """ 

1961 The inputs to `accept_move_and_sample_leaves` that are separate for each tree. 

1962 

1963 Parameters 

1964 ---------- 

1965 leaf_tree 

1966 The leaf values of the tree. 

1967 prec_tree 

1968 The likelihood precision scale in each potential or actual leaf node. 

1969 move 

1970 The proposed move, see `propose_moves`. 

1971 move_counts 

1972 The counts of the number of points in the the nodes modified by the 

1973 moves. 

1974 move_precs 

1975 The likelihood precision scale in each node modified by the moves. 

1976 leaf_indices 

1977 The leaf indices for the largest version of the tree compatible with 

1978 the move. 

1979 prelkv 

1980 prelf 

1981 The pre-computed terms of the likelihood ratio and leaf sampling which 

1982 are specific to the tree. 

1983 """ 

1984 

1985 leaf_tree: Float32[Array, '2**d'] 1ab

1986 prec_tree: Float32[Array, '2**d'] 1ab

1987 move: Moves 1ab

1988 move_counts: Counts | None 1ab

1989 move_precs: Precs | Counts 1ab

1990 leaf_indices: UInt[Array, 'n'] 1ab

1991 prelkv: PreLkV 1ab

1992 prelf: PreLf 1ab

1993 

1994 

1995def accept_move_and_sample_leaves( 1ab

1996 resid: Float32[Array, 'n'], 

1997 at: SeqStageInAllTrees, 

1998 pt: SeqStageInPerTree, 

1999) -> tuple[ 

2000 Float32[Array, 'n'], 

2001 Float32[Array, '2**d'], 

2002 Bool[Array, ''], 

2003 Bool[Array, ''], 

2004 dict[str, Float32[Array, '']], 

2005]: 

2006 """ 

2007 Accept or reject a proposed move and sample the new leaf values. 

2008 

2009 Parameters 

2010 ---------- 

2011 resid 

2012 The residuals (data minus forest value). 

2013 at 

2014 The inputs that are the same for all trees. 

2015 pt 

2016 The inputs that are separate for each tree. 

2017 

2018 Returns 

2019 ------- 

2020 resid : Float32[Array, 'n'] 

2021 The updated residuals (data minus forest value). 

2022 leaf_tree : Float32[Array, '2**d'] 

2023 The new leaf values of the tree. 

2024 acc : Bool[Array, ''] 

2025 Whether the move was accepted. 

2026 to_prune : Bool[Array, ''] 

2027 Whether, to reflect the acceptance status of the move, the state should 

2028 be updated by pruning the leaves involved in the move. 

2029 ratios : dict[str, Float32[Array, '']] 

2030 The acceptance ratios for the moves. Empty if not to be saved. 

2031 """ 

2032 # sum residuals in each leaf, in tree proposed by grow move 

2033 if at.prec_scale is None: 1ab

2034 scaled_resid = resid 1ab

2035 else: 

2036 scaled_resid = resid * at.prec_scale 1ab

2037 resid_tree = sum_resid( 1ab

2038 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size 

2039 ) 

2040 

2041 # subtract starting tree from function 

2042 resid_tree += pt.prec_tree * pt.leaf_tree 1ab

2043 

2044 # get indices of move 

2045 node = pt.move.node 1ab

2046 assert node.dtype == jnp.int32 1ab

2047 left = pt.move.left 1ab

2048 right = pt.move.right 1ab

2049 

2050 # sum residuals in parent node modified by move 

2051 resid_left = resid_tree[left] 1ab

2052 resid_right = resid_tree[right] 1ab

2053 resid_total = resid_left + resid_right 1ab

2054 resid_tree = resid_tree.at[node].set(resid_total) 1ab

2055 

2056 # compute acceptance ratio 

2057 log_lk_ratio = compute_likelihood_ratio( 1ab

2058 resid_total, resid_left, resid_right, pt.prelkv, at.prelk 

2059 ) 

2060 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab

2061 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab

2062 ratios = {} 1ab

2063 if at.save_ratios: 1ab

2064 ratios.update( 1ab

2065 log_trans_prior=pt.move.log_trans_prior_ratio, 

2066 # TODO save log_trans_prior_ratio as a vector outside of this loop, 

2067 # then change the option everywhere to `save_likelihood_ratio`. 

2068 log_likelihood=log_lk_ratio, 

2069 ) 

2070 

2071 # determine whether to accept the move 

2072 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab

2073 if at.min_points_per_leaf is not None: 1ab

2074 assert pt.move_counts is not None 1ab

2075 acc &= pt.move_counts.left >= at.min_points_per_leaf 1ab

2076 acc &= pt.move_counts.right >= at.min_points_per_leaf 1ab

2077 

2078 # compute leaves posterior and sample leaves 

2079 initial_leaf_tree = pt.leaf_tree 1ab

2080 mean_post = resid_tree * pt.prelf.mean_factor 1ab

2081 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab

2082 

2083 # copy leaves around such that the leaf indices point to the correct leaf 

2084 to_prune = acc ^ pt.move.grow 1ab

2085 leaf_tree = ( 1ab

2086 leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)] 

2087 .set(leaf_tree[node]) 

2088 .at[jnp.where(to_prune, right, leaf_tree.size)] 

2089 .set(leaf_tree[node]) 

2090 ) 

2091 

2092 # replace old tree with new tree in function values 

2093 resid += (initial_leaf_tree - leaf_tree)[pt.leaf_indices] 1ab

2094 

2095 return resid, leaf_tree, acc, to_prune, ratios 1ab

2096 

2097 

2098def sum_resid( 1ab

2099 scaled_resid: Float32[Array, 'n'], 

2100 leaf_indices: UInt[Array, 'n'], 

2101 tree_size: int, 

2102 batch_size: int | None, 

2103) -> Float32[Array, '{tree_size}']: 

2104 """ 

2105 Sum the residuals in each leaf. 

2106 

2107 Parameters 

2108 ---------- 

2109 scaled_resid 

2110 The residuals (data minus forest value) multiplied by the error 

2111 precision scale. 

2112 leaf_indices 

2113 The leaf indices of the tree (in which leaf each data point falls into). 

2114 tree_size 

2115 The size of the tree array (2 ** d). 

2116 batch_size 

2117 The data batch size for the aggregation. Batching increases numerical 

2118 accuracy and parallelism. 

2119 

2120 Returns 

2121 ------- 

2122 The sum of the residuals at data points in each leaf. 

2123 """ 

2124 if batch_size is None: 1ab

2125 aggr_func = _aggregate_scatter 1ab

2126 else: 

2127 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab

2128 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab

2129 

2130 

2131def _aggregate_batched_onetree( 1ab

2132 values: Shaped[Array, '*'], 

2133 indices: Integer[Array, '*'], 

2134 size: int, 

2135 dtype: jnp.dtype, 

2136 batch_size: int, 

2137) -> Float32[Array, '{size}']: 

2138 (n,) = indices.shape 1ab

2139 nbatches = n // batch_size + bool(n % batch_size) 1ab

2140 batch_indices = jnp.arange(n) % nbatches 1ab

2141 return ( 1ab

2142 jnp.zeros((size, nbatches), dtype) 

2143 .at[indices, batch_indices] 

2144 .add(values) 

2145 .sum(axis=1) 

2146 ) 

2147 

2148 

2149def compute_likelihood_ratio( 1ab

2150 total_resid: Float32[Array, ''], 

2151 left_resid: Float32[Array, ''], 

2152 right_resid: Float32[Array, ''], 

2153 prelkv: PreLkV, 

2154 prelk: PreLk, 

2155) -> Float32[Array, '']: 

2156 """ 

2157 Compute the likelihood ratio of a grow move. 

2158 

2159 Parameters 

2160 ---------- 

2161 total_resid 

2162 left_resid 

2163 right_resid 

2164 The sum of the residuals (scaled by error precision scale) of the 

2165 datapoints falling in the nodes involved in the moves. 

2166 prelkv 

2167 prelk 

2168 The pre-computed terms of the likelihood ratio, see 

2169 `precompute_likelihood_terms`. 

2170 

2171 Returns 

2172 ------- 

2173 The likelihood ratio P(data | new tree) / P(data | old tree). 

2174 """ 

2175 exp_term = prelk.exp_factor * ( 1ab

2176 left_resid * left_resid / prelkv.sigma2_left 

2177 + right_resid * right_resid / prelkv.sigma2_right 

2178 - total_resid * total_resid / prelkv.sigma2_total 

2179 ) 

2180 return prelkv.sqrt_term + exp_term 1ab

2181 

2182 

2183def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab

2184 """ 

2185 Post-process the mcmc state after accepting/rejecting the moves. 

2186 

2187 This function is separate from `accept_moves_sequential_stage` to signal it 

2188 can work in parallel across trees. 

2189 

2190 Parameters 

2191 ---------- 

2192 bart 

2193 A partially updated BART mcmc state. 

2194 moves 

2195 The proposed moves (see `propose_moves`) as updated by 

2196 `accept_moves_sequential_stage`. 

2197 

2198 Returns 

2199 ------- 

2200 The fully updated BART mcmc state. 

2201 """ 

2202 return replace( 1ab

2203 bart, 

2204 forest=replace( 

2205 bart.forest, 

2206 grow_acc_count=jnp.sum(moves.acc & moves.grow), 

2207 prune_acc_count=jnp.sum(moves.acc & ~moves.grow), 

2208 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves), 

2209 split_trees=apply_moves_to_split_trees(bart.forest.split_trees, moves), 

2210 ), 

2211 ) 

2212 

2213 

2214@vmap_nodoc 1ab

2215def apply_moves_to_leaf_indices( 1ab

2216 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves 

2217) -> UInt[Array, 'num_trees n']: 

2218 """ 

2219 Update the leaf indices to match the accepted move. 

2220 

2221 Parameters 

2222 ---------- 

2223 leaf_indices 

2224 The index of the leaf each datapoint falls into, if the grow move was 

2225 accepted. 

2226 moves 

2227 The proposed moves (see `propose_moves`), as updated by 

2228 `accept_moves_sequential_stage`. 

2229 

2230 Returns 

2231 ------- 

2232 The updated leaf indices. 

2233 """ 

2234 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab

2235 is_child = (leaf_indices & mask) == moves.left 1ab

2236 return jnp.where( 1ab

2237 is_child & moves.to_prune, 

2238 moves.node.astype(leaf_indices.dtype), 

2239 leaf_indices, 

2240 ) 

2241 

2242 

2243@vmap_nodoc 1ab

2244def apply_moves_to_split_trees( 1ab

2245 split_trees: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves 

2246) -> UInt[Array, 'num_trees 2**(d-1)']: 

2247 """ 

2248 Update the split trees to match the accepted move. 

2249 

2250 Parameters 

2251 ---------- 

2252 split_trees 

2253 The cutpoints of the decision nodes in the initial trees. 

2254 moves 

2255 The proposed moves (see `propose_moves`), as updated by 

2256 `accept_moves_sequential_stage`. 

2257 

2258 Returns 

2259 ------- 

2260 The updated split trees. 

2261 """ 

2262 assert moves.to_prune is not None 1ab

2263 return ( 1ab

2264 split_trees.at[ 

2265 jnp.where( 

2266 moves.grow, 

2267 moves.node, 

2268 split_trees.size, 

2269 ) 

2270 ] 

2271 .set(moves.grow_split.astype(split_trees.dtype)) 

2272 .at[ 

2273 jnp.where( 

2274 moves.to_prune, 

2275 moves.node, 

2276 split_trees.size, 

2277 ) 

2278 ] 

2279 .set(0) 

2280 ) 

2281 

2282 

2283def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab

2284 """ 

2285 MCMC-update the error variance (factor). 

2286 

2287 Parameters 

2288 ---------- 

2289 key 

2290 A jax random key. 

2291 bart 

2292 A BART mcmc state. 

2293 

2294 Returns 

2295 ------- 

2296 The new BART mcmc state, with an updated `sigma2`. 

2297 """ 

2298 resid = bart.resid 1ab

2299 alpha = bart.sigma2_alpha + resid.size / 2 1ab

2300 if bart.prec_scale is None: 1ab

2301 scaled_resid = resid 1ab

2302 else: 

2303 scaled_resid = resid * bart.prec_scale 1ab

2304 norm2 = resid @ scaled_resid 1ab

2305 beta = bart.sigma2_beta + norm2 / 2 1ab

2306 

2307 sample = random.gamma(key, alpha) 1ab

2308 return replace(bart, sigma2=beta / sample) 1ab

2309 

2310 

2311def step_z(key: Key[Array, ''], bart: State) -> State: 1ab

2312 """ 

2313 MCMC-update the latent variable for binary regression. 

2314 

2315 Parameters 

2316 ---------- 

2317 key 

2318 A jax random key. 

2319 bart 

2320 A BART MCMC state. 

2321 

2322 Returns 

2323 ------- 

2324 The updated BART MCMC state. 

2325 """ 

2326 trees_plus_offset = bart.z - bart.resid 1ab

2327 lower = jnp.where(bart.y, -trees_plus_offset, -jnp.inf) 1ab

2328 upper = jnp.where(bart.y, jnp.inf, -trees_plus_offset) 1ab

2329 resid = random.truncated_normal(key, lower, upper) 1ab

2330 # TODO jax's implementation of truncated_normal is not good, it just does 

2331 # cdf inversion with erf and erf_inv. I can do better, at least avoiding to 

2332 # compute one of the boundaries, and maybe also flipping and using ndtr 

2333 # instead of erf for numerical stability (open an issue in jax?) 

2334 z = trees_plus_offset + resid 1ab

2335 return replace(bart, z=z, resid=resid) 1ab