Coverage for src/bartz/mcmcloop.py: 82%

161 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 14:46 +0000

1# bartz/src/bartz/mcmcloop.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions that implement the full BART posterior MCMC loop. 

26 

27The main entry point is `run_mcmc`. 

28""" 

29 

30from collections.abc import Callable 1ab

31from dataclasses import fields, replace 1ab

32from functools import partial, wraps 1ab

33from typing import Any, Protocol 1ab

34 

35import jax 1ab

36import numpy 1ab

37from equinox import Module 1ab

38from jax import debug, lax, tree 1ab

39from jax import numpy as jnp 1ab

40from jaxtyping import ( 1ab

41 Array, 

42 Bool, 

43 Float32, 

44 Int32, 

45 Integer, 

46 Key, 

47 PyTree, 

48 Real, 

49 Shaped, 

50 UInt, 

51) 

52 

53from bartz import grove, jaxext, mcmcstep 1ab

54from bartz.mcmcstep import State 1ab

55 

56 

57class BurninTrace(Module): 1ab

58 """MCMC trace with only diagnostic values.""" 

59 

60 sigma2: Float32[Array, '*trace_length'] | None 1ab

61 grow_prop_count: Int32[Array, '*trace_length'] 1ab

62 grow_acc_count: Int32[Array, '*trace_length'] 1ab

63 prune_prop_count: Int32[Array, '*trace_length'] 1ab

64 prune_acc_count: Int32[Array, '*trace_length'] 1ab

65 log_likelihood: Float32[Array, '*trace_length'] | None 1ab

66 log_trans_prior: Float32[Array, '*trace_length'] | None 1ab

67 

68 @classmethod 1ab

69 def from_state(cls, state: State) -> 'BurninTrace': 1ab

70 """Create a single-item burn-in trace from a MCMC state.""" 

71 return cls( 1ab

72 sigma2=state.sigma2, 

73 grow_prop_count=state.forest.grow_prop_count, 

74 grow_acc_count=state.forest.grow_acc_count, 

75 prune_prop_count=state.forest.prune_prop_count, 

76 prune_acc_count=state.forest.prune_acc_count, 

77 log_likelihood=state.forest.log_likelihood, 

78 log_trans_prior=state.forest.log_trans_prior, 

79 ) 

80 

81 

82class MainTrace(BurninTrace): 1ab

83 """MCMC trace with trees and diagnostic values.""" 

84 

85 leaf_tree: Real[Array, '*trace_length 2**d'] 1ab

86 var_tree: Real[Array, '*trace_length 2**(d-1)'] 1ab

87 split_tree: Real[Array, '*trace_length 2**(d-1)'] 1ab

88 offset: Float32[Array, '*trace_length'] 1ab

89 

90 @classmethod 1ab

91 def from_state(cls, state: State) -> 'MainTrace': 1ab

92 """Create a single-item main trace from a MCMC state.""" 

93 return cls( 1ab

94 leaf_tree=state.forest.leaf_tree, 

95 var_tree=state.forest.var_tree, 

96 split_tree=state.forest.split_tree, 

97 offset=state.offset, 

98 **vars(BurninTrace.from_state(state)), 

99 ) 

100 

101 

102CallbackState = PyTree[Any, 'T'] 1ab

103 

104 

105class Callback(Protocol): 1ab

106 """Callback type for `run_mcmc`.""" 

107 

108 def __call__( 1ab

109 self, 

110 *, 

111 bart: State, 

112 burnin: Bool[Array, ''], 

113 i_total: Int32[Array, ''], 

114 i_skip: Int32[Array, ''], 

115 callback_state: CallbackState, 

116 n_burn: Int32[Array, ''], 

117 n_save: Int32[Array, ''], 

118 n_skip: Int32[Array, ''], 

119 i_outer: Int32[Array, ''], 

120 inner_loop_length: int, 

121 ) -> tuple[State, CallbackState] | None: 

122 """Do an arbitrary action after an iteration of the MCMC. 

123 

124 Parameters 

125 ---------- 

126 bart 

127 The MCMC state just after updating it. 

128 burnin 

129 Whether the last iteration was in the burn-in phase. 

130 i_total 

131 The index of the last MCMC iteration (0-based). 

132 i_skip 

133 The number of MCMC updates from the last saved state. The initial 

134 state counts as saved, even if it's not copied into the trace. 

135 callback_state 

136 The callback state, initially set to the argument passed to 

137 `run_mcmc`, afterwards to the value returned by the last invocation 

138 of the callback. 

139 n_burn 

140 n_save 

141 n_skip 

142 The corresponding `run_mcmc` arguments as-is. 

143 i_outer 

144 The index of the last outer loop iteration (0-based). 

145 inner_loop_length 

146 The number of MCMC iterations in the inner loop. 

147 

148 Returns 

149 ------- 

150 bart : State 

151 A possibly modified MCMC state. To avoid modifying the state, 

152 return the `bart` argument passed to the callback as-is. 

153 callback_state : CallbackState 

154 The new state to be passed on the next callback invocation. 

155 

156 Notes 

157 ----- 

158 For convenience, the callback may return `None`, and the states won't 

159 be updated. 

160 """ 

161 ... 

162 

163 

164class _Carry(Module): 1ab

165 """Carry used in the loop in `run_mcmc`.""" 

166 

167 bart: State 1ab

168 i_total: Int32[Array, ''] 1ab

169 key: Key[Array, ''] 1ab

170 burnin_trace: PyTree[Shaped[Array, 'n_burn *']] 1ab

171 main_trace: PyTree[Shaped[Array, 'n_save *']] 1ab

172 callback_state: CallbackState 1ab

173 

174 

175def run_mcmc( 1ab

176 key: Key[Array, ''], 

177 bart: State, 

178 n_save: int, 

179 *, 

180 n_burn: int = 0, 

181 n_skip: int = 1, 

182 inner_loop_length: int | None = None, 

183 callback: Callback | None = None, 

184 callback_state: CallbackState = None, 

185 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state, 

186 main_extractor: Callable[[State], PyTree] = MainTrace.from_state, 

187) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]: 

188 """ 

189 Run the MCMC for the BART posterior. 

190 

191 Parameters 

192 ---------- 

193 key 

194 A key for random number generation. 

195 bart 

196 The initial MCMC state, as created and updated by the functions in 

197 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies, 

198 so this variable is invalidated after running `run_mcmc`. Make a copy 

199 beforehand to use it again. 

200 n_save 

201 The number of iterations to save. 

202 n_burn 

203 The number of initial iterations which are not saved. 

204 n_skip 

205 The number of iterations to skip between each saved iteration, plus 1. 

206 The effective burn-in is ``n_burn + n_skip - 1``. 

207 inner_loop_length 

208 The MCMC loop is split into an outer and an inner loop. The outer loop 

209 is in Python, while the inner loop is in JAX. `inner_loop_length` is the 

210 number of iterations of the inner loop to run for each iteration of the 

211 outer loop. If not specified, the outer loop will iterate just once, 

212 with all iterations done in a single inner loop run. The inner stride is 

213 unrelated to the stride used for saving the trace. 

214 callback 

215 An arbitrary function run during the loop after updating the state. For 

216 the signature, see `Callback`. The callback is called under the jax jit, 

217 so the argument values are not available at the time the Python code is 

218 executed. Use the utilities in `jax.debug` to access the values at 

219 actual runtime. The callback may return new values for the MCMC state 

220 and the callback state. 

221 callback_state 

222 The initial custom state for the callback. 

223 burnin_extractor 

224 main_extractor 

225 Functions that extract the variables to be saved respectively only in 

226 the main trace and in both traces, given the MCMC state as argument. 

227 Must return a pytree, and must be vmappable. 

228 

229 Returns 

230 ------- 

231 bart : State 

232 The final MCMC state. 

233 burnin_trace : PyTree[Shaped[Array, 'n_burn *']] 

234 The trace of the burn-in phase. For the default layout, see `BurninTrace`. 

235 main_trace : PyTree[Shaped[Array, 'n_save *']] 

236 The trace of the main phase. For the default layout, see `MainTrace`. 

237 

238 Notes 

239 ----- 

240 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do 

241 not include the initial state, and include the final state. 

242 """ 

243 

244 def empty_trace(length, bart, extractor): 1ab

245 return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart) 1ab

246 

247 burnin_trace = empty_trace(n_burn, bart, burnin_extractor) 1ab

248 main_trace = empty_trace(n_save, bart, main_extractor) 1ab

249 

250 # determine number of iterations for inner and outer loops 

251 n_iters = n_burn + n_skip * n_save 1ab

252 if inner_loop_length is None: 1ab

253 inner_loop_length = n_iters 1ab

254 if inner_loop_length: 1ab

255 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 1ab

256 else: 

257 n_outer = 1 1ab

258 # setting to 0 would make for a clean noop, but it's useful to keep the 

259 # same code path for benchmarking and testing 

260 

261 carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state) 1ab

262 for i_outer in range(n_outer): 1ab

263 carry = _run_mcmc_inner_loop( 1ab

264 carry, 

265 inner_loop_length, 

266 callback, 

267 burnin_extractor, 

268 main_extractor, 

269 n_burn, 

270 n_save, 

271 n_skip, 

272 i_outer, 

273 n_iters, 

274 ) 

275 

276 return carry.bart, carry.burnin_trace, carry.main_trace 1ab

277 

278 

279def _compute_i_skip( 1ab

280 i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, ''] 

281) -> Int32[Array, '']: 

282 """Compute the `i_skip` argument passed to `callback`.""" 

283 burnin = i_total < n_burn 1ab

284 return jnp.where( 1ab

285 burnin, 

286 i_total + 1, 

287 (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0), 

288 ) 

289 

290 

291@partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) 1ab

292def _run_mcmc_inner_loop( 1ab

293 carry: _Carry, 

294 inner_loop_length: int, 

295 callback: Callback | None, 

296 burnin_extractor: Callable[[State], PyTree], 

297 main_extractor: Callable[[State], PyTree], 

298 n_burn: Int32[Array, ''], 

299 n_save: Int32[Array, ''], 

300 n_skip: Int32[Array, ''], 

301 i_outer: Int32[Array, ''], 

302 n_iters: Int32[Array, ''], 

303): 

304 def loop_impl(carry: _Carry) -> _Carry: 1ab

305 """Loop body to run if i_total < n_iters.""" 

306 keys = jaxext.split(carry.key) 1ab

307 carry = replace(carry, key=keys.pop()) 1ab

308 carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart)) 1ab

309 

310 burnin = carry.i_total < n_burn 1ab

311 

312 if callback is not None: 1ab

313 i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip) 1ab

314 rt = callback( 1ab

315 bart=carry.bart, 

316 burnin=burnin, 

317 i_total=carry.i_total, 

318 i_skip=i_skip, 

319 callback_state=carry.callback_state, 

320 n_burn=n_burn, 

321 n_save=n_save, 

322 n_skip=n_skip, 

323 i_outer=i_outer, 

324 inner_loop_length=inner_loop_length, 

325 ) 

326 if rt is not None: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true1ab

327 bart, callback_state = rt 

328 carry = replace(carry, bart=bart, callback_state=callback_state) 

329 

330 def save_to_burnin_trace( 1ab

331 burnin_trace: PyTree, main_trace: PyTree 

332 ) -> tuple[PyTree, PyTree]: 

333 return pytree_at_set( 1ab

334 burnin_trace, carry.i_total, burnin_extractor(carry.bart) 

335 ), main_trace 

336 

337 def save_to_main_trace( 1ab

338 burnin_trace: PyTree, main_trace: PyTree 

339 ) -> tuple[PyTree, PyTree]: 

340 idx = (carry.i_total - n_burn) // n_skip 1ab

341 return burnin_trace, pytree_at_set( 1ab

342 main_trace, idx, main_extractor(carry.bart) 

343 ) 

344 

345 burnin_trace, main_trace = lax.cond( 1ab

346 burnin, 

347 save_to_burnin_trace, 

348 save_to_main_trace, 

349 carry.burnin_trace, 

350 carry.main_trace, 

351 ) 

352 return replace( 1ab

353 carry, 

354 i_total=carry.i_total + 1, 

355 burnin_trace=burnin_trace, 

356 main_trace=main_trace, 

357 ) 

358 

359 def loop_noop(carry: _Carry) -> _Carry: 1ab

360 """Loop body to run if i_total >= n_iters; it does nothing.""" 

361 return carry 1ab

362 

363 def loop(carry: _Carry, _) -> tuple[_Carry, None]: 1ab

364 carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry) 1ab

365 return carry, None 1ab

366 

367 carry, _ = lax.scan(loop, carry, None, inner_loop_length) 1ab

368 return carry 1ab

369 

370 

371def pytree_at_set(dest: PyTree, index: Int32[Array, ''], val: PyTree) -> PyTree: 1ab

372 """Map ``dest.at[index].set(val)`` over pytrees.""" 

373 

374 def at_set(dest, val): 1ab

375 if dest.size: 1ab

376 return dest.at[index, ...].set(val) 1ab

377 else: 

378 # this handles the case where an array is empty because jax refuses 

379 # to index into an array of length 0, even if just in the abstract 

380 return dest 1ab

381 

382 return tree.map(at_set, dest, val) 1ab

383 

384 

385class _PrintCallbackState(Module): 1ab

386 """State used by `_print_callback`.""" 

387 

388 dot_every: Int32[Array, ''] | None 1ab

389 report_every: Int32[Array, ''] | None 1ab

390 

391 

392def make_print_callback( 1ab

393 dot_every: int | Integer[Array, ''] | None = 1, 

394 report_every: int | Integer[Array, ''] | None = 100, 

395) -> dict[str, Any]: 

396 """ 

397 Prepare a logging callback for `run_mcmc`. 

398 

399 The callback prints a dot on every iteration, and a longer 

400 report outer loop iteration. 

401 

402 Parameters 

403 ---------- 

404 dot_every 

405 A dot is printed every `dot_every` MCMC iterations, `None` to disable. 

406 report_every 

407 A one line report is printed every `report_every` MCMC iterations, 

408 `None` to disable. 

409 

410 Returns 

411 ------- 

412 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback. 

413 

414 Examples 

415 -------- 

416 >>> run_mcmc(..., **make_print_callback()) 

417 """ 

418 

419 def asarray_or_none(val: None | Any) -> None | Array: 1ab

420 return None if val is None else jnp.asarray(val) 1ab

421 

422 return dict( 1ab

423 callback=_print_callback, 

424 callback_state=_PrintCallbackState( 

425 asarray_or_none(dot_every), asarray_or_none(report_every) 

426 ), 

427 ) 

428 

429 

430def _print_callback( 1ab

431 *, 

432 bart: State, 

433 burnin: Bool[Array, ''], 

434 i_total: Int32[Array, ''], 

435 n_burn: Int32[Array, ''], 

436 n_save: Int32[Array, ''], 

437 n_skip: Int32[Array, ''], 

438 callback_state: _PrintCallbackState, 

439 **_, 

440): 

441 """Print a dot and/or a report periodically during the MCMC.""" 

442 if callback_state.dot_every is not None: 442 ↛ 451line 442 didn't jump to line 451 because the condition on line 442 was always true1ab

443 cond = (i_total + 1) % callback_state.dot_every == 0 1ab

444 lax.cond( 1ab

445 cond, 

446 lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201 

447 # logging can't do in-line printing so I'll stick to print 

448 lambda: None, 

449 ) 

450 

451 if callback_state.report_every is not None: 451 ↛ exitline 451 didn't return from function '_print_callback' because the condition on line 451 was always true1ab

452 

453 def print_report(): 1ab

454 debug.callback( 1ab

455 _print_report, 

456 newline=callback_state.dot_every is not None, 

457 burnin=burnin, 

458 i_total=i_total, 

459 n_iters=n_burn + n_save * n_skip, 

460 grow_prop_count=bart.forest.grow_prop_count, 

461 grow_acc_count=bart.forest.grow_acc_count, 

462 prune_prop_count=bart.forest.prune_prop_count, 

463 prune_acc_count=bart.forest.prune_acc_count, 

464 prop_total=len(bart.forest.leaf_tree), 

465 fill=grove.forest_fill(bart.forest.split_tree), 

466 ) 

467 

468 cond = (i_total + 1) % callback_state.report_every == 0 1ab

469 lax.cond(cond, print_report, lambda: None) 1ab

470 

471 

472def _convert_jax_arrays_in_args(func: Callable) -> Callable: 1ab

473 """Remove jax arrays from a function arguments. 

474 

475 Converts all `jax.Array` instances in the arguments to either Python scalars 

476 or numpy arrays. 

477 """ 

478 

479 def convert_jax_arrays(pytree: PyTree) -> PyTree: 1ab

480 def convert_jax_arrays(val: Any) -> Any: 

481 if not isinstance(val, jax.Array): 

482 return val 

483 elif val.shape: 

484 return numpy.array(val) 

485 else: 

486 return val.item() 

487 

488 return tree.map(convert_jax_arrays, pytree) 

489 

490 @wraps(func) 1ab

491 def new_func(*args, **kw): 1ab

492 args = convert_jax_arrays(args) 

493 kw = convert_jax_arrays(kw) 

494 return func(*args, **kw) 

495 

496 return new_func 1ab

497 

498 

499@_convert_jax_arrays_in_args 1ab

500# convert all jax arrays in arguments because operations on them could lead to 

501# deadlock with the main thread 

502def _print_report( 1ab

503 *, 

504 newline: bool, 

505 burnin: bool, 

506 i_total: int, 

507 n_iters: int, 

508 grow_prop_count: int, 

509 grow_acc_count: int, 

510 prune_prop_count: int, 

511 prune_acc_count: int, 

512 prop_total: int, 

513 fill: float, 

514): 

515 """Print the report for `_print_callback`.""" 

516 

517 def acc_string(acc_count, prop_count): 

518 if prop_count: 

519 return f'{acc_count / prop_count:.0%}' 

520 else: 

521 return 'n/d' 

522 

523 grow_prop = grow_prop_count / prop_total 

524 prune_prop = prune_prop_count / prop_total 

525 grow_acc = acc_string(grow_acc_count, grow_prop_count) 

526 prune_acc = acc_string(prune_acc_count, prune_prop_count) 

527 

528 prefix = '\n' if newline else '' 

529 suffix = ' (burnin)' if burnin else '' 

530 

531 print( # noqa: T201, see _print_callback for why not logging 

532 f'{prefix}It {i_total + 1}/{n_iters} ' 

533 f'grow P={grow_prop:.0%} A={grow_acc}, ' 

534 f'prune P={prune_prop:.0%} A={prune_acc}, ' 

535 f'fill={fill:.0%}{suffix}' 

536 ) 

537 

538 

539class Trace(grove.TreeHeaps, Protocol): 1ab

540 """Protocol for a MCMC trace.""" 

541 

542 offset: Float32[Array, ' trace_length'] 1ab

543 

544 

545class TreesTrace(Module): 1ab

546 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace.""" 

547 

548 leaf_tree: Float32[Array, 'trace_length num_trees 2**d'] 1ab

549 var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab

550 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab

551 

552 @classmethod 1ab

553 def from_dataclass(cls, obj: grove.TreeHeaps): 1ab

554 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`.""" 

555 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 1ab

556 

557 

558@jax.jit 1ab

559def evaluate_trace( 1ab

560 trace: Trace, X: UInt[Array, 'p n'] 

561) -> Float32[Array, 'trace_length n']: 

562 """ 

563 Compute predictions for all iterations of the BART MCMC. 

564 

565 Parameters 

566 ---------- 

567 trace 

568 A trace of the BART MCMC, as returned by `run_mcmc`. 

569 X 

570 The predictors matrix, with `p` predictors and `n` observations. 

571 

572 Returns 

573 ------- 

574 The predictions for each iteration of the MCMC. 

575 """ 

576 evaluate_trees = partial(grove.evaluate_forest, sum_trees=False) 1ab

577 evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0)) 1ab

578 trees = TreesTrace.from_dataclass(trace) 1ab

579 

580 def loop(_, item): 1ab

581 offset, trees = item 1ab

582 values = evaluate_trees(X, trees) 1ab

583 return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32) 1ab

584 

585 _, y = lax.scan(loop, None, (trace.offset, trees)) 1ab

586 return y 1ab

587 

588 

589@partial(jax.jit, static_argnums=(0,)) 1ab

590def compute_varcount( 1ab

591 p: int, trace: grove.TreeHeaps 

592) -> Int32[Array, 'trace_length {p}']: 

593 """ 

594 Count how many times each predictor is used in each MCMC state. 

595 

596 Parameters 

597 ---------- 

598 p 

599 The number of predictors. 

600 trace 

601 A trace of the BART MCMC, as returned by `run_mcmc`. 

602 

603 Returns 

604 ------- 

605 Histogram of predictor usage in each MCMC state. 

606 """ 

607 vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0)) 1ab

608 return vmapped_var_histogram(p, trace.var_tree, trace.split_tree) 1ab