Coverage for src/bartz/mcmcloop.py: 86%

179 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-07 22:47 +0000

1# bartz/src/bartz/mcmcloop.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions that implement the full BART posterior MCMC loop. 

26 

27The entry points are `run_mcmc` and `make_default_callback`. 

28""" 

29 

30from collections.abc import Callable 1ab

31from dataclasses import fields, replace 1ab

32from functools import partial, wraps 1ab

33from typing import Any, Protocol 1ab

34 

35import jax 1ab

36import numpy 1ab

37from equinox import Module 1ab

38from jax import debug, lax, tree 1ab

39from jax import numpy as jnp 1ab

40from jax.nn import softmax 1ab

41from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt 1ab

42 

43from bartz import grove, jaxext, mcmcstep 1ab

44from bartz.mcmcstep import State 1ab

45 

46 

47class BurninTrace(Module): 1ab

48 """MCMC trace with only diagnostic values.""" 

49 

50 sigma2: Float32[Array, '*trace_length'] | None 1ab

51 theta: Float32[Array, '*trace_length'] | None 1ab

52 grow_prop_count: Int32[Array, '*trace_length'] 1ab

53 grow_acc_count: Int32[Array, '*trace_length'] 1ab

54 prune_prop_count: Int32[Array, '*trace_length'] 1ab

55 prune_acc_count: Int32[Array, '*trace_length'] 1ab

56 log_likelihood: Float32[Array, '*trace_length'] | None 1ab

57 log_trans_prior: Float32[Array, '*trace_length'] | None 1ab

58 

59 @classmethod 1ab

60 def from_state(cls, state: State) -> 'BurninTrace': 1ab

61 """Create a single-item burn-in trace from a MCMC state.""" 

62 return cls( 1ab

63 sigma2=state.sigma2, 

64 theta=state.forest.theta, 

65 grow_prop_count=state.forest.grow_prop_count, 

66 grow_acc_count=state.forest.grow_acc_count, 

67 prune_prop_count=state.forest.prune_prop_count, 

68 prune_acc_count=state.forest.prune_acc_count, 

69 log_likelihood=state.forest.log_likelihood, 

70 log_trans_prior=state.forest.log_trans_prior, 

71 ) 

72 

73 

74class MainTrace(BurninTrace): 1ab

75 """MCMC trace with trees and diagnostic values.""" 

76 

77 leaf_tree: Float32[Array, '*trace_length 2**d'] 1ab

78 var_tree: UInt[Array, '*trace_length 2**(d-1)'] 1ab

79 split_tree: UInt[Array, '*trace_length 2**(d-1)'] 1ab

80 offset: Float32[Array, '*trace_length'] 1ab

81 varprob: Float32[Array, '*trace_length p'] | None 1ab

82 

83 @classmethod 1ab

84 def from_state(cls, state: State) -> 'MainTrace': 1ab

85 """Create a single-item main trace from a MCMC state.""" 

86 # compute varprob 

87 log_s = state.forest.log_s 1ab

88 if log_s is None: 1ab

89 varprob = None 1ab

90 else: 

91 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1ab

92 

93 return cls( 1ab

94 leaf_tree=state.forest.leaf_tree, 

95 var_tree=state.forest.var_tree, 

96 split_tree=state.forest.split_tree, 

97 offset=state.offset, 

98 varprob=varprob, 

99 **vars(BurninTrace.from_state(state)), 

100 ) 

101 

102 

103CallbackState = PyTree[Any, 'T'] 1ab

104 

105 

106class Callback(Protocol): 1ab

107 """Callback type for `run_mcmc`.""" 

108 

109 def __call__( 1ab

110 self, 

111 *, 

112 key: Key[Array, ''], 

113 bart: State, 

114 burnin: Bool[Array, ''], 

115 i_total: Int32[Array, ''], 

116 i_skip: Int32[Array, ''], 

117 callback_state: CallbackState, 

118 n_burn: Int32[Array, ''], 

119 n_save: Int32[Array, ''], 

120 n_skip: Int32[Array, ''], 

121 i_outer: Int32[Array, ''], 

122 inner_loop_length: int, 

123 ) -> tuple[State, CallbackState] | None: 

124 """Do an arbitrary action after an iteration of the MCMC. 

125 

126 Parameters 

127 ---------- 

128 key 

129 A key for random number generation. 

130 bart 

131 The MCMC state just after updating it. 

132 burnin 

133 Whether the last iteration was in the burn-in phase. 

134 i_total 

135 The index of the last MCMC iteration (0-based). 

136 i_skip 

137 The number of MCMC updates from the last saved state. The initial 

138 state counts as saved, even if it's not copied into the trace. 

139 callback_state 

140 The callback state, initially set to the argument passed to 

141 `run_mcmc`, afterwards to the value returned by the last invocation 

142 of the callback. 

143 n_burn 

144 n_save 

145 n_skip 

146 The corresponding `run_mcmc` arguments as-is. 

147 i_outer 

148 The index of the last outer loop iteration (0-based). 

149 inner_loop_length 

150 The number of MCMC iterations in the inner loop. 

151 

152 Returns 

153 ------- 

154 bart : State 

155 A possibly modified MCMC state. To avoid modifying the state, 

156 return the `bart` argument passed to the callback as-is. 

157 callback_state : CallbackState 

158 The new state to be passed on the next callback invocation. 

159 

160 Notes 

161 ----- 

162 For convenience, the callback may return `None`, and the states won't 

163 be updated. 

164 """ 

165 ... 

166 

167 

168class _Carry(Module): 1ab

169 """Carry used in the loop in `run_mcmc`.""" 

170 

171 bart: State 1ab

172 i_total: Int32[Array, ''] 1ab

173 key: Key[Array, ''] 1ab

174 burnin_trace: PyTree[Shaped[Array, 'n_burn *']] 1ab

175 main_trace: PyTree[Shaped[Array, 'n_save *']] 1ab

176 callback_state: CallbackState 1ab

177 

178 

179def run_mcmc( 1ab

180 key: Key[Array, ''], 

181 bart: State, 

182 n_save: int, 

183 *, 

184 n_burn: int = 0, 

185 n_skip: int = 1, 

186 inner_loop_length: int | None = None, 

187 callback: Callback | None = None, 

188 callback_state: CallbackState = None, 

189 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state, 

190 main_extractor: Callable[[State], PyTree] = MainTrace.from_state, 

191) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]: 

192 """ 

193 Run the MCMC for the BART posterior. 

194 

195 Parameters 

196 ---------- 

197 key 

198 A key for random number generation. 

199 bart 

200 The initial MCMC state, as created and updated by the functions in 

201 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies, 

202 so this variable is invalidated after running `run_mcmc`. Make a copy 

203 beforehand to use it again. 

204 n_save 

205 The number of iterations to save. 

206 n_burn 

207 The number of initial iterations which are not saved. 

208 n_skip 

209 The number of iterations to skip between each saved iteration, plus 1. 

210 The effective burn-in is ``n_burn + n_skip - 1``. 

211 inner_loop_length 

212 The MCMC loop is split into an outer and an inner loop. The outer loop 

213 is in Python, while the inner loop is in JAX. `inner_loop_length` is the 

214 number of iterations of the inner loop to run for each iteration of the 

215 outer loop. If not specified, the outer loop will iterate just once, 

216 with all iterations done in a single inner loop run. The inner stride is 

217 unrelated to the stride used for saving the trace. 

218 callback 

219 An arbitrary function run during the loop after updating the state. For 

220 the signature, see `Callback`. The callback is called under the jax jit, 

221 so the argument values are not available at the time the Python code is 

222 executed. Use the utilities in `jax.debug` to access the values at 

223 actual runtime. The callback may return new values for the MCMC state 

224 and the callback state. 

225 callback_state 

226 The initial custom state for the callback. 

227 burnin_extractor 

228 main_extractor 

229 Functions that extract the variables to be saved respectively only in 

230 the main trace and in both traces, given the MCMC state as argument. 

231 Must return a pytree, and must be vmappable. 

232 

233 Returns 

234 ------- 

235 bart : State 

236 The final MCMC state. 

237 burnin_trace : PyTree[Shaped[Array, 'n_burn *']] 

238 The trace of the burn-in phase. For the default layout, see `BurninTrace`. 

239 main_trace : PyTree[Shaped[Array, 'n_save *']] 

240 The trace of the main phase. For the default layout, see `MainTrace`. 

241 

242 Notes 

243 ----- 

244 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do 

245 not include the initial state, and include the final state. 

246 """ 

247 

248 def empty_trace(length, bart, extractor): 1ab

249 return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart) 1ab

250 

251 burnin_trace = empty_trace(n_burn, bart, burnin_extractor) 1ab

252 main_trace = empty_trace(n_save, bart, main_extractor) 1ab

253 

254 # determine number of iterations for inner and outer loops 

255 n_iters = n_burn + n_skip * n_save 1ab

256 if inner_loop_length is None: 1ab

257 inner_loop_length = n_iters 1ab

258 if inner_loop_length: 1ab

259 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 1ab

260 else: 

261 n_outer = 1 1ab

262 # setting to 0 would make for a clean noop, but it's useful to keep the 

263 # same code path for benchmarking and testing 

264 

265 carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state) 1ab

266 for i_outer in range(n_outer): 1ab

267 carry = _run_mcmc_inner_loop( 1ab

268 carry, 

269 inner_loop_length, 

270 callback, 

271 burnin_extractor, 

272 main_extractor, 

273 n_burn, 

274 n_save, 

275 n_skip, 

276 i_outer, 

277 n_iters, 

278 ) 

279 

280 return carry.bart, carry.burnin_trace, carry.main_trace 1ab

281 

282 

283def _compute_i_skip( 1ab

284 i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, ''] 

285) -> Int32[Array, '']: 

286 """Compute the `i_skip` argument passed to `callback`.""" 

287 burnin = i_total < n_burn 1ab

288 return jnp.where( 1ab

289 burnin, 

290 i_total + 1, 

291 (i_total - n_burn + 1) % n_skip 

292 + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0), 

293 ) 

294 

295 

296@partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) 1ab

297def _run_mcmc_inner_loop( 1ab

298 carry: _Carry, 

299 inner_loop_length: int, 

300 callback: Callback | None, 

301 burnin_extractor: Callable[[State], PyTree], 

302 main_extractor: Callable[[State], PyTree], 

303 n_burn: Int32[Array, ''], 

304 n_save: Int32[Array, ''], 

305 n_skip: Int32[Array, ''], 

306 i_outer: Int32[Array, ''], 

307 n_iters: Int32[Array, ''], 

308): 

309 def loop_impl(carry: _Carry) -> _Carry: 1ab

310 """Loop body to run if i_total < n_iters.""" 

311 # split random key 

312 keys = jaxext.split(carry.key, 3) 1ab

313 carry = replace(carry, key=keys.pop()) 1ab

314 

315 # update state 

316 carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart)) 1ab

317 

318 burnin = carry.i_total < n_burn 1ab

319 

320 # invoke callback 

321 if callback is not None: 1ab

322 i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip) 1ab

323 rt = callback( 1ab

324 key=keys.pop(), 

325 bart=carry.bart, 

326 burnin=burnin, 

327 i_total=carry.i_total, 

328 i_skip=i_skip, 

329 callback_state=carry.callback_state, 

330 n_burn=n_burn, 

331 n_save=n_save, 

332 n_skip=n_skip, 

333 i_outer=i_outer, 

334 inner_loop_length=inner_loop_length, 

335 ) 

336 if rt is not None: 336 ↛ 340line 336 didn't jump to line 340 because the condition on line 336 was always true1ab

337 bart, callback_state = rt 1ab

338 carry = replace(carry, bart=bart, callback_state=callback_state) 1ab

339 

340 def save_to_burnin_trace() -> tuple[PyTree, PyTree]: 1ab

341 return _pytree_at_set( 1ab

342 carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart) 

343 ), carry.main_trace 

344 

345 def save_to_main_trace() -> tuple[PyTree, PyTree]: 1ab

346 idx = (carry.i_total - n_burn) // n_skip 1ab

347 return carry.burnin_trace, _pytree_at_set( 1ab

348 carry.main_trace, idx, main_extractor(carry.bart) 

349 ) 

350 

351 # save state to trace 

352 burnin_trace, main_trace = lax.cond( 1ab

353 burnin, save_to_burnin_trace, save_to_main_trace 

354 ) 

355 return replace( 1ab

356 carry, 

357 i_total=carry.i_total + 1, 

358 burnin_trace=burnin_trace, 

359 main_trace=main_trace, 

360 ) 

361 

362 def loop_noop(carry: _Carry) -> _Carry: 1ab

363 """Loop body to run if i_total >= n_iters; it does nothing.""" 

364 return carry 1ab

365 

366 def loop(carry: _Carry, _) -> tuple[_Carry, None]: 1ab

367 carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry) 1ab

368 return carry, None 1ab

369 

370 carry, _ = lax.scan(loop, carry, None, inner_loop_length) 1ab

371 return carry 1ab

372 

373 

374def _pytree_at_set( 1ab

375 dest: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array] 

376) -> PyTree[Array, ' T']: 

377 """Map ``dest.at[index].set(val)`` over pytrees.""" 

378 

379 def at_set(dest, val): 1ab

380 if dest.size: 1ab

381 return dest.at[index, ...].set(val) 1ab

382 else: 

383 # this handles the case where an array is empty because jax refuses 

384 # to index into an array of length 0, even if just in the abstract 

385 return dest 1ab

386 

387 return tree.map(at_set, dest, val) 1ab

388 

389 

390def make_default_callback( 1ab

391 *, 

392 dot_every: int | Integer[Array, ''] | None = 1, 

393 report_every: int | Integer[Array, ''] | None = 100, 

394 sparse_on_at: int | Integer[Array, ''] | None = None, 

395) -> dict[str, Any]: 

396 """ 

397 Prepare a default callback for `run_mcmc`. 

398 

399 The callback prints a dot on every iteration, and a longer 

400 report outer loop iteration, and can do variable selection. 

401 

402 Parameters 

403 ---------- 

404 dot_every 

405 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

406 report_every 

407 A one line report is printed every `report_every` MCMC iterations, 

408 `None` to disable. 

409 sparse_on_at 

410 If specified, variable selection is activated starting from this 

411 iteration. If `None`, variable selection is not used. 

412 

413 Returns 

414 ------- 

415 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback. 

416 

417 Examples 

418 -------- 

419 >>> run_mcmc(..., **make_default_callback()) 

420 """ 

421 

422 def asarray_or_none(val: None | Any) -> None | Array: 1ab

423 return None if val is None else jnp.asarray(val) 1ab

424 

425 def callback(*, bart, callback_state, **kwargs): 1ab

426 print_state, sparse_state = callback_state 1ab

427 bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs) 1ab

428 print_callback(callback_state=print_state, bart=bart, **kwargs) 1ab

429 return bart, callback_state 1ab

430 # here I assume that the callbacks don't update their states 

431 

432 return dict( 1ab

433 callback=callback, 

434 callback_state=( 

435 PrintCallbackState( 

436 asarray_or_none(dot_every), asarray_or_none(report_every) 

437 ), 

438 SparseCallbackState(asarray_or_none(sparse_on_at)), 

439 ), 

440 ) 

441 

442 

443class PrintCallbackState(Module): 1ab

444 """State for `print_callback`. 

445 

446 Parameters 

447 ---------- 

448 dot_every 

449 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

450 report_every 

451 A one line report is printed every `report_every` MCMC iterations, 

452 `None` to disable. 

453 """ 

454 

455 dot_every: Int32[Array, ''] | None 1ab

456 report_every: Int32[Array, ''] | None 1ab

457 

458 

459def print_callback( 1ab

460 *, 

461 bart: State, 

462 burnin: Bool[Array, ''], 

463 i_total: Int32[Array, ''], 

464 n_burn: Int32[Array, ''], 

465 n_save: Int32[Array, ''], 

466 n_skip: Int32[Array, ''], 

467 callback_state: PrintCallbackState, 

468 **_, 

469): 

470 """Print a dot and/or a report periodically during the MCMC.""" 

471 if callback_state.dot_every is not None: 1ab

472 cond = (i_total + 1) % callback_state.dot_every == 0 1ab

473 lax.cond( 1ab

474 cond, 

475 lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201 

476 # logging can't do in-line printing so I'll stick to print 

477 lambda: None, 

478 ) 

479 

480 if callback_state.report_every is not None: 1ab

481 

482 def print_report(): 1ab

483 debug.callback( 1ab

484 _print_report, 

485 newline=callback_state.dot_every is not None, 

486 burnin=burnin, 

487 i_total=i_total, 

488 n_iters=n_burn + n_save * n_skip, 

489 grow_prop_count=bart.forest.grow_prop_count, 

490 grow_acc_count=bart.forest.grow_acc_count, 

491 prune_prop_count=bart.forest.prune_prop_count, 

492 prune_acc_count=bart.forest.prune_acc_count, 

493 prop_total=len(bart.forest.leaf_tree), 

494 fill=grove.forest_fill(bart.forest.split_tree), 

495 ) 

496 

497 cond = (i_total + 1) % callback_state.report_every == 0 1ab

498 lax.cond(cond, print_report, lambda: None) 1ab

499 

500 

501def _convert_jax_arrays_in_args(func: Callable) -> Callable: 1ab

502 """Remove jax arrays from a function arguments. 

503 

504 Converts all `jax.Array` instances in the arguments to either Python scalars 

505 or numpy arrays. 

506 """ 

507 

508 def convert_jax_arrays(pytree: PyTree) -> PyTree: 1ab

509 def convert_jax_arrays(val: Any) -> Any: 

510 if not isinstance(val, jax.Array): 

511 return val 

512 elif val.shape: 

513 return numpy.array(val) 

514 else: 

515 return val.item() 

516 

517 return tree.map(convert_jax_arrays, pytree) 

518 

519 @wraps(func) 1ab

520 def new_func(*args, **kw): 1ab

521 args = convert_jax_arrays(args) 

522 kw = convert_jax_arrays(kw) 

523 return func(*args, **kw) 

524 

525 return new_func 1ab

526 

527 

528@_convert_jax_arrays_in_args 1ab

529# convert all jax arrays in arguments because operations on them could lead to 

530# deadlock with the main thread 

531def _print_report( 1ab

532 *, 

533 newline: bool, 

534 burnin: bool, 

535 i_total: int, 

536 n_iters: int, 

537 grow_prop_count: int, 

538 grow_acc_count: int, 

539 prune_prop_count: int, 

540 prune_acc_count: int, 

541 prop_total: int, 

542 fill: float, 

543): 

544 """Print the report for `print_callback`.""" 

545 

546 def acc_string(acc_count, prop_count): 

547 if prop_count: 

548 return f'{acc_count / prop_count:.0%}' 

549 else: 

550 return 'n/d' 

551 

552 grow_prop = grow_prop_count / prop_total 

553 prune_prop = prune_prop_count / prop_total 

554 grow_acc = acc_string(grow_acc_count, grow_prop_count) 

555 prune_acc = acc_string(prune_acc_count, prune_prop_count) 

556 

557 prefix = '\n' if newline else '' 

558 suffix = ' (burnin)' if burnin else '' 

559 

560 print( # noqa: T201, see print_callback for why not logging 

561 f'{prefix}It {i_total + 1}/{n_iters} ' 

562 f'grow P={grow_prop:.0%} A={grow_acc}, ' 

563 f'prune P={prune_prop:.0%} A={prune_acc}, ' 

564 f'fill={fill:.0%}{suffix}' 

565 ) 

566 

567 

568class SparseCallbackState(Module): 1ab

569 """State for `sparse_callback`. 

570 

571 Parameters 

572 ---------- 

573 sparse_on_at 

574 If specified, variable selection is activated starting from this 

575 iteration. If `None`, variable selection is not used. 

576 """ 

577 

578 sparse_on_at: Int32[Array, ''] | None 1ab

579 

580 

581def sparse_callback( 1ab

582 *, 

583 key: Key[Array, ''], 

584 bart: State, 

585 i_total: Int32[Array, ''], 

586 callback_state: SparseCallbackState, 

587 **_, 

588): 

589 """Perform variable selection, see `mcmcstep.step_sparse`.""" 

590 if callback_state.sparse_on_at is not None: 1ab

591 bart = lax.cond( 1ab

592 i_total < callback_state.sparse_on_at, 

593 lambda: bart, 

594 lambda: mcmcstep.step_sparse(key, bart), 

595 ) 

596 return bart, callback_state 1ab

597 

598 

599class Trace(grove.TreeHeaps, Protocol): 1ab

600 """Protocol for a MCMC trace.""" 

601 

602 offset: Float32[Array, ' trace_length'] 1ab

603 

604 

605class TreesTrace(Module): 1ab

606 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace.""" 

607 

608 leaf_tree: Float32[Array, 'trace_length num_trees 2**d'] 1ab

609 var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab

610 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab

611 

612 @classmethod 1ab

613 def from_dataclass(cls, obj: grove.TreeHeaps): 1ab

614 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`.""" 

615 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 1ab

616 

617 

618@jax.jit 1ab

619def evaluate_trace( 1ab

620 trace: Trace, X: UInt[Array, 'p n'] 

621) -> Float32[Array, 'trace_length n']: 

622 """ 

623 Compute predictions for all iterations of the BART MCMC. 

624 

625 Parameters 

626 ---------- 

627 trace 

628 A trace of the BART MCMC, as returned by `run_mcmc`. 

629 X 

630 The predictors matrix, with `p` predictors and `n` observations. 

631 

632 Returns 

633 ------- 

634 The predictions for each iteration of the MCMC. 

635 """ 

636 evaluate_trees = partial(grove.evaluate_forest, sum_trees=False) 1ab

637 evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0)) 1ab

638 trees = TreesTrace.from_dataclass(trace) 1ab

639 

640 def loop(_, item): 1ab

641 offset, trees = item 1ab

642 values = evaluate_trees(X, trees) 1ab

643 return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32) 1ab

644 

645 _, y = lax.scan(loop, None, (trace.offset, trees)) 1ab

646 return y 1ab

647 

648 

649@partial(jax.jit, static_argnums=(0,)) 1ab

650def compute_varcount( 1ab

651 p: int, trace: grove.TreeHeaps 

652) -> Int32[Array, 'trace_length {p}']: 

653 """ 

654 Count how many times each predictor is used in each MCMC state. 

655 

656 Parameters 

657 ---------- 

658 p 

659 The number of predictors. 

660 trace 

661 A trace of the BART MCMC, as returned by `run_mcmc`. 

662 

663 Returns 

664 ------- 

665 Histogram of predictor usage in each MCMC state. 

666 """ 

667 vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0)) 1ab

668 return vmapped_var_histogram(p, trace.var_tree, trace.split_tree) 1ab