Coverage for src/bartz/mcmcloop.py: 90%

173 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-07-31 16:09 +0000

1# bartz/src/bartz/mcmcloop.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions that implement the full BART posterior MCMC loop. 

26 

27The entry points are `run_mcmc` and `make_default_callback`. 

28""" 

29 

30from collections.abc import Callable 1ab

31from dataclasses import fields, replace 1ab

32from functools import partial, wraps 1ab

33from typing import Any, Protocol 1ab

34 

35import jax 1ab

36import numpy 1ab

37from equinox import Module 1ab

38from jax import debug, lax, tree 1ab

39from jax import numpy as jnp 1ab

40from jax.nn import softmax 1ab

41from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt 1ab

42 

43from bartz import grove, jaxext, mcmcstep 1ab

44from bartz.mcmcstep import State 1ab

45 

46 

47class BurninTrace(Module): 1ab

48 """MCMC trace with only diagnostic values.""" 

49 

50 sigma2: Float32[Array, '*trace_length'] | None 1ab

51 theta: Float32[Array, '*trace_length'] | None 1ab

52 grow_prop_count: Int32[Array, '*trace_length'] 1ab

53 grow_acc_count: Int32[Array, '*trace_length'] 1ab

54 prune_prop_count: Int32[Array, '*trace_length'] 1ab

55 prune_acc_count: Int32[Array, '*trace_length'] 1ab

56 log_likelihood: Float32[Array, '*trace_length'] | None 1ab

57 log_trans_prior: Float32[Array, '*trace_length'] | None 1ab

58 

59 @classmethod 1ab

60 def from_state(cls, state: State) -> 'BurninTrace': 1ab

61 """Create a single-item burn-in trace from a MCMC state.""" 

62 return cls( 1ab

63 sigma2=state.sigma2, 

64 theta=state.forest.theta, 

65 grow_prop_count=state.forest.grow_prop_count, 

66 grow_acc_count=state.forest.grow_acc_count, 

67 prune_prop_count=state.forest.prune_prop_count, 

68 prune_acc_count=state.forest.prune_acc_count, 

69 log_likelihood=state.forest.log_likelihood, 

70 log_trans_prior=state.forest.log_trans_prior, 

71 ) 

72 

73 

74class MainTrace(BurninTrace): 1ab

75 """MCMC trace with trees and diagnostic values.""" 

76 

77 leaf_tree: Float32[Array, '*trace_length 2**d'] 1ab

78 var_tree: UInt[Array, '*trace_length 2**(d-1)'] 1ab

79 split_tree: UInt[Array, '*trace_length 2**(d-1)'] 1ab

80 offset: Float32[Array, '*trace_length'] 1ab

81 varprob: Float32[Array, '*trace_length p'] | None 1ab

82 

83 @classmethod 1ab

84 def from_state(cls, state: State) -> 'MainTrace': 1ab

85 """Create a single-item main trace from a MCMC state.""" 

86 # compute varprob 

87 log_s = state.forest.log_s 1ab

88 if log_s is None: 1ab

89 varprob = None 1ab

90 else: 

91 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1ab

92 

93 return cls( 1ab

94 leaf_tree=state.forest.leaf_tree, 

95 var_tree=state.forest.var_tree, 

96 split_tree=state.forest.split_tree, 

97 offset=state.offset, 

98 varprob=varprob, 

99 **vars(BurninTrace.from_state(state)), 

100 ) 

101 

102 

103CallbackState = PyTree[Any, 'T'] 1ab

104 

105 

106class Callback(Protocol): 1ab

107 """Callback type for `run_mcmc`.""" 

108 

109 def __call__( 1ab

110 self, 

111 *, 

112 key: Key[Array, ''], 

113 bart: State, 

114 burnin: Bool[Array, ''], 

115 i_total: Int32[Array, ''], 

116 i_skip: Int32[Array, ''], 

117 callback_state: CallbackState, 

118 n_burn: Int32[Array, ''], 

119 n_save: Int32[Array, ''], 

120 n_skip: Int32[Array, ''], 

121 i_outer: Int32[Array, ''], 

122 inner_loop_length: int, 

123 ) -> tuple[State, CallbackState] | None: 

124 """Do an arbitrary action after an iteration of the MCMC. 

125 

126 Parameters 

127 ---------- 

128 key 

129 A key for random number generation. 

130 bart 

131 The MCMC state just after updating it. 

132 burnin 

133 Whether the last iteration was in the burn-in phase. 

134 i_total 

135 The index of the last MCMC iteration (0-based). 

136 i_skip 

137 The number of MCMC updates from the last saved state. The initial 

138 state counts as saved, even if it's not copied into the trace. 

139 callback_state 

140 The callback state, initially set to the argument passed to 

141 `run_mcmc`, afterwards to the value returned by the last invocation 

142 of the callback. 

143 n_burn 

144 n_save 

145 n_skip 

146 The corresponding `run_mcmc` arguments as-is. 

147 i_outer 

148 The index of the last outer loop iteration (0-based). 

149 inner_loop_length 

150 The number of MCMC iterations in the inner loop. 

151 

152 Returns 

153 ------- 

154 bart : State 

155 A possibly modified MCMC state. To avoid modifying the state, 

156 return the `bart` argument passed to the callback as-is. 

157 callback_state : CallbackState 

158 The new state to be passed on the next callback invocation. 

159 

160 Notes 

161 ----- 

162 For convenience, the callback may return `None`, and the states won't 

163 be updated. 

164 """ 

165 ... 

166 

167 

168class _Carry(Module): 1ab

169 """Carry used in the loop in `run_mcmc`.""" 

170 

171 bart: State 1ab

172 i_total: Int32[Array, ''] 1ab

173 key: Key[Array, ''] 1ab

174 burnin_trace: PyTree[Shaped[Array, 'n_burn *']] 1ab

175 main_trace: PyTree[Shaped[Array, 'n_save *']] 1ab

176 callback_state: CallbackState 1ab

177 

178 

179def run_mcmc( 1ab

180 key: Key[Array, ''], 

181 bart: State, 

182 n_save: int, 

183 *, 

184 n_burn: int = 0, 

185 n_skip: int = 1, 

186 inner_loop_length: int | None = None, 

187 callback: Callback | None = None, 

188 callback_state: CallbackState = None, 

189 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state, 

190 main_extractor: Callable[[State], PyTree] = MainTrace.from_state, 

191) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]: 

192 """ 

193 Run the MCMC for the BART posterior. 

194 

195 Parameters 

196 ---------- 

197 key 

198 A key for random number generation. 

199 bart 

200 The initial MCMC state, as created and updated by the functions in 

201 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies, 

202 so this variable is invalidated after running `run_mcmc`. Make a copy 

203 beforehand to use it again. 

204 n_save 

205 The number of iterations to save. 

206 n_burn 

207 The number of initial iterations which are not saved. 

208 n_skip 

209 The number of iterations to skip between each saved iteration, plus 1. 

210 The effective burn-in is ``n_burn + n_skip - 1``. 

211 inner_loop_length 

212 The MCMC loop is split into an outer and an inner loop. The outer loop 

213 is in Python, while the inner loop is in JAX. `inner_loop_length` is the 

214 number of iterations of the inner loop to run for each iteration of the 

215 outer loop. If not specified, the outer loop will iterate just once, 

216 with all iterations done in a single inner loop run. The inner stride is 

217 unrelated to the stride used for saving the trace. 

218 callback 

219 An arbitrary function run during the loop after updating the state. For 

220 the signature, see `Callback`. The callback is called under the jax jit, 

221 so the argument values are not available at the time the Python code is 

222 executed. Use the utilities in `jax.debug` to access the values at 

223 actual runtime. The callback may return new values for the MCMC state 

224 and the callback state. 

225 callback_state 

226 The initial custom state for the callback. 

227 burnin_extractor 

228 main_extractor 

229 Functions that extract the variables to be saved respectively only in 

230 the main trace and in both traces, given the MCMC state as argument. 

231 Must return a pytree, and must be vmappable. 

232 

233 Returns 

234 ------- 

235 bart : State 

236 The final MCMC state. 

237 burnin_trace : PyTree[Shaped[Array, 'n_burn *']] 

238 The trace of the burn-in phase. For the default layout, see `BurninTrace`. 

239 main_trace : PyTree[Shaped[Array, 'n_save *']] 

240 The trace of the main phase. For the default layout, see `MainTrace`. 

241 

242 Notes 

243 ----- 

244 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do 

245 not include the initial state, and include the final state. 

246 """ 

247 

248 def empty_trace(length, bart, extractor): 1ab

249 return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart) 1ab

250 

251 burnin_trace = empty_trace(n_burn, bart, burnin_extractor) 1ab

252 main_trace = empty_trace(n_save, bart, main_extractor) 1ab

253 

254 # determine number of iterations for inner and outer loops 

255 n_iters = n_burn + n_skip * n_save 1ab

256 if inner_loop_length is None: 1ab

257 inner_loop_length = n_iters 1ab

258 if inner_loop_length: 1ab

259 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 1ab

260 else: 

261 n_outer = 1 1ab

262 # setting to 0 would make for a clean noop, but it's useful to keep the 

263 # same code path for benchmarking and testing 

264 

265 carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state) 1ab

266 for i_outer in range(n_outer): 1ab

267 carry = _run_mcmc_inner_loop( 1ab

268 carry, 

269 inner_loop_length, 

270 callback, 

271 burnin_extractor, 

272 main_extractor, 

273 n_burn, 

274 n_save, 

275 n_skip, 

276 i_outer, 

277 n_iters, 

278 ) 

279 

280 return carry.bart, carry.burnin_trace, carry.main_trace 1ab

281 

282 

283def _compute_i_skip( 1ab

284 i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, ''] 

285) -> Int32[Array, '']: 

286 """Compute the `i_skip` argument passed to `callback`.""" 

287 burnin = i_total < n_burn 1ab

288 return jnp.where( 1ab

289 burnin, 

290 i_total + 1, 

291 (i_total - n_burn + 1) % n_skip 

292 + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0), 

293 ) 

294 

295 

296@partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) 1ab

297def _run_mcmc_inner_loop( 1ab

298 carry: _Carry, 

299 inner_loop_length: int, 

300 callback: Callback | None, 

301 burnin_extractor: Callable[[State], PyTree], 

302 main_extractor: Callable[[State], PyTree], 

303 n_burn: Int32[Array, ''], 

304 n_save: Int32[Array, ''], 

305 n_skip: Int32[Array, ''], 

306 i_outer: Int32[Array, ''], 

307 n_iters: Int32[Array, ''], 

308): 

309 def loop_impl(carry: _Carry) -> _Carry: 1ab

310 """Loop body to run if i_total < n_iters.""" 

311 # split random key 

312 keys = jaxext.split(carry.key, 3) 1ab

313 carry = replace(carry, key=keys.pop()) 1ab

314 

315 # update state 

316 carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart)) 1ab

317 

318 burnin = carry.i_total < n_burn 1ab

319 

320 # invoke callback 

321 if callback is not None: 1ab

322 i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip) 1ab

323 rt = callback( 1ab

324 key=keys.pop(), 

325 bart=carry.bart, 

326 burnin=burnin, 

327 i_total=carry.i_total, 

328 i_skip=i_skip, 

329 callback_state=carry.callback_state, 

330 n_burn=n_burn, 

331 n_save=n_save, 

332 n_skip=n_skip, 

333 i_outer=i_outer, 

334 inner_loop_length=inner_loop_length, 

335 ) 

336 if rt is not None: 336 ↛ 340line 336 didn't jump to line 340 because the condition on line 336 was always true1ab

337 bart, callback_state = rt 1ab

338 carry = replace(carry, bart=bart, callback_state=callback_state) 1ab

339 

340 def save_to_burnin_trace() -> tuple[PyTree, PyTree]: 1ab

341 return _pytree_at_set( 1ab

342 carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart) 

343 ), carry.main_trace 

344 

345 def save_to_main_trace() -> tuple[PyTree, PyTree]: 1ab

346 idx = (carry.i_total - n_burn) // n_skip 1ab

347 return carry.burnin_trace, _pytree_at_set( 1ab

348 carry.main_trace, idx, main_extractor(carry.bart) 

349 ) 

350 

351 # save state to trace 

352 burnin_trace, main_trace = lax.cond( 1ab

353 burnin, save_to_burnin_trace, save_to_main_trace 

354 ) 

355 return replace( 1ab

356 carry, 

357 i_total=carry.i_total + 1, 

358 burnin_trace=burnin_trace, 

359 main_trace=main_trace, 

360 ) 

361 

362 def loop_noop(carry: _Carry) -> _Carry: 1ab

363 """Loop body to run if i_total >= n_iters; it does nothing.""" 

364 return carry 1ab

365 

366 def loop(carry: _Carry, _) -> tuple[_Carry, None]: 1ab

367 carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry) 1ab

368 return carry, None 1ab

369 

370 carry, _ = lax.scan(loop, carry, None, inner_loop_length) 1ab

371 return carry 1ab

372 

373 

374def _pytree_at_set( 1ab

375 dest: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array] 

376) -> PyTree[Array, ' T']: 

377 """Map ``dest.at[index].set(val)`` over pytrees.""" 

378 

379 def at_set(dest, val): 1ab

380 if dest.size: 1ab

381 return dest.at[index, ...].set(val) 1ab

382 else: 

383 # this handles the case where an array is empty because jax refuses 

384 # to index into an array of length 0, even if just in the abstract 

385 return dest 1ab

386 

387 return tree.map(at_set, dest, val) 1ab

388 

389 

390def make_default_callback( 1ab

391 *, 

392 dot_every: int | Integer[Array, ''] | None = 1, 

393 report_every: int | Integer[Array, ''] | None = 100, 

394 sparse_on_at: int | Integer[Array, ''] | None = None, 

395) -> dict[str, Any]: 

396 """ 

397 Prepare a default callback for `run_mcmc`. 

398 

399 The callback prints a dot on every iteration, and a longer 

400 report outer loop iteration, and can do variable selection. 

401 

402 Parameters 

403 ---------- 

404 dot_every 

405 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

406 report_every 

407 A one line report is printed every `report_every` MCMC iterations, 

408 `None` to disable. 

409 sparse_on_at 

410 If specified, variable selection is activated starting from this 

411 iteration. If `None`, variable selection is not used. 

412 

413 Returns 

414 ------- 

415 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback. 

416 

417 Examples 

418 -------- 

419 >>> run_mcmc(..., **make_default_callback()) 

420 """ 

421 

422 def asarray_or_none(val: None | Any) -> None | Array: 1ab

423 return None if val is None else jnp.asarray(val) 1ab

424 

425 def callback(*, bart, callback_state, **kwargs): 1ab

426 print_state, sparse_state = callback_state 1ab

427 bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs) 1ab

428 print_callback(callback_state=print_state, bart=bart, **kwargs) 1ab

429 return bart, callback_state 1ab

430 # here I assume that the callbacks don't update their states 

431 

432 return dict( 1ab

433 callback=callback, 

434 callback_state=( 

435 PrintCallbackState( 

436 asarray_or_none(dot_every), asarray_or_none(report_every) 

437 ), 

438 SparseCallbackState(asarray_or_none(sparse_on_at)), 

439 ), 

440 ) 

441 

442 

443class PrintCallbackState(Module): 1ab

444 """State for `print_callback`. 

445 

446 Parameters 

447 ---------- 

448 dot_every 

449 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

450 report_every 

451 A one line report is printed every `report_every` MCMC iterations, 

452 `None` to disable. 

453 """ 

454 

455 dot_every: Int32[Array, ''] | None 1ab

456 report_every: Int32[Array, ''] | None 1ab

457 

458 

459def print_callback( 1ab

460 *, 

461 bart: State, 

462 burnin: Bool[Array, ''], 

463 i_total: Int32[Array, ''], 

464 n_burn: Int32[Array, ''], 

465 n_save: Int32[Array, ''], 

466 n_skip: Int32[Array, ''], 

467 callback_state: PrintCallbackState, 

468 **_, 

469): 

470 """Print a dot and/or a report periodically during the MCMC.""" 

471 if callback_state.dot_every is not None: 1ab

472 dot_cond = (i_total + 1) % callback_state.dot_every == 0 1ab

473 lax.cond( 1ab

474 dot_cond, 

475 lambda: debug.callback( 

476 lambda: print('.', end='', flush=True), # noqa: T201 

477 ordered=True, 

478 ), 

479 # logging can't do in-line printing so I'll stick to print 

480 lambda: None, 

481 ) 

482 

483 if callback_state.report_every is not None: 1ab

484 

485 def print_report(): 1ab

486 debug.callback( 1ab

487 _print_report, 

488 burnin=burnin, 

489 i_total=i_total, 

490 n_iters=n_burn + n_save * n_skip, 

491 grow_prop_count=bart.forest.grow_prop_count, 

492 grow_acc_count=bart.forest.grow_acc_count, 

493 prune_acc_count=bart.forest.prune_acc_count, 

494 prop_total=len(bart.forest.leaf_tree), 

495 fill=grove.forest_fill(bart.forest.split_tree), 

496 ) 

497 

498 report_cond = (i_total + 1) % callback_state.report_every == 0 1ab

499 

500 # print a newline after dots 

501 if callback_state.dot_every is not None: 501 ↛ 508line 501 didn't jump to line 508 because the condition on line 501 was always true1ab

502 lax.cond( 1ab

503 report_cond & dot_cond, 

504 lambda: debug.callback(lambda: print(), ordered=True), # noqa: T201 

505 lambda: None, 

506 ) 

507 

508 lax.cond(report_cond, print_report, lambda: None) 1ab

509 

510 

511def _convert_jax_arrays_in_args(func: Callable) -> Callable: 1ab

512 """Remove jax arrays from a function arguments. 

513 

514 Converts all `jax.Array` instances in the arguments to either Python scalars 

515 or numpy arrays. 

516 """ 

517 

518 def convert_jax_arrays(pytree: PyTree) -> PyTree: 1ab

519 def convert_jax_array(val: Any) -> Any: 

520 if not isinstance(val, jax.Array): 

521 return val 

522 elif val.shape: 

523 return numpy.array(val) 

524 else: 

525 return val.item() 

526 

527 return tree.map(convert_jax_array, pytree) 

528 

529 @wraps(func) 1ab

530 def new_func(*args, **kw): 1ab

531 args = convert_jax_arrays(args) 

532 kw = convert_jax_arrays(kw) 

533 return func(*args, **kw) 

534 

535 return new_func 1ab

536 

537 

538@_convert_jax_arrays_in_args 1ab

539# convert all jax arrays in arguments because operations on them could lead to 

540# deadlock with the main thread 

541def _print_report( 1ab

542 *, 

543 burnin: bool, 

544 i_total: int, 

545 n_iters: int, 

546 grow_prop_count: int, 

547 grow_acc_count: int, 

548 prune_acc_count: int, 

549 prop_total: int, 

550 fill: float, 

551): 

552 """Print the report for `print_callback`.""" 

553 grow_prop = grow_prop_count / prop_total 

554 move_acc = (grow_acc_count + prune_acc_count) / prop_total 

555 

556 suffix = ' (burnin)' if burnin else '' 

557 

558 print( # noqa: T201, see print_callback for why not logging 

559 f'Iteration {i_total + 1}/{n_iters}, ' 

560 f'grow prob: {grow_prop:.0%}, ' 

561 f'move acc: {move_acc:.0%}, ' 

562 f'fill: {fill:.0%}{suffix}' 

563 ) 

564 

565 

566class SparseCallbackState(Module): 1ab

567 """State for `sparse_callback`. 

568 

569 Parameters 

570 ---------- 

571 sparse_on_at 

572 If specified, variable selection is activated starting from this 

573 iteration. If `None`, variable selection is not used. 

574 """ 

575 

576 sparse_on_at: Int32[Array, ''] | None 1ab

577 

578 

579def sparse_callback( 1ab

580 *, 

581 key: Key[Array, ''], 

582 bart: State, 

583 i_total: Int32[Array, ''], 

584 callback_state: SparseCallbackState, 

585 **_, 

586): 

587 """Perform variable selection, see `mcmcstep.step_sparse`.""" 

588 if callback_state.sparse_on_at is not None: 1ab

589 bart = lax.cond( 1ab

590 i_total < callback_state.sparse_on_at, 

591 lambda: bart, 

592 lambda: mcmcstep.step_sparse(key, bart), 

593 ) 

594 return bart, callback_state 1ab

595 

596 

597class Trace(grove.TreeHeaps, Protocol): 1ab

598 """Protocol for a MCMC trace.""" 

599 

600 offset: Float32[Array, ' trace_length'] 1ab

601 

602 

603class TreesTrace(Module): 1ab

604 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace.""" 

605 

606 leaf_tree: Float32[Array, 'trace_length num_trees 2**d'] 1ab

607 var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab

608 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab

609 

610 @classmethod 1ab

611 def from_dataclass(cls, obj: grove.TreeHeaps): 1ab

612 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`.""" 

613 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 1ab

614 

615 

616@jax.jit 1ab

617def evaluate_trace( 1ab

618 trace: Trace, X: UInt[Array, 'p n'] 

619) -> Float32[Array, 'trace_length n']: 

620 """ 

621 Compute predictions for all iterations of the BART MCMC. 

622 

623 Parameters 

624 ---------- 

625 trace 

626 A trace of the BART MCMC, as returned by `run_mcmc`. 

627 X 

628 The predictors matrix, with `p` predictors and `n` observations. 

629 

630 Returns 

631 ------- 

632 The predictions for each iteration of the MCMC. 

633 """ 

634 evaluate_trees = partial(grove.evaluate_forest, sum_trees=False) 1ab

635 evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0)) 1ab

636 trees = TreesTrace.from_dataclass(trace) 1ab

637 

638 def loop(_, item): 1ab

639 offset, trees = item 1ab

640 values = evaluate_trees(X, trees) 1ab

641 return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32) 1ab

642 

643 _, y = lax.scan(loop, None, (trace.offset, trees)) 1ab

644 return y 1ab

645 

646 

647@partial(jax.jit, static_argnums=(0,)) 1ab

648def compute_varcount( 1ab

649 p: int, trace: grove.TreeHeaps 

650) -> Int32[Array, 'trace_length {p}']: 

651 """ 

652 Count how many times each predictor is used in each MCMC state. 

653 

654 Parameters 

655 ---------- 

656 p 

657 The number of predictors. 

658 trace 

659 A trace of the BART MCMC, as returned by `run_mcmc`. 

660 

661 Returns 

662 ------- 

663 Histogram of predictor usage in each MCMC state. 

664 """ 

665 vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0)) 1ab

666 return vmapped_var_histogram(p, trace.var_tree, trace.split_tree) 1ab