Coverage for src/bartz/mcmcloop.py: 82%
161 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
1# bartz/src/bartz/mcmcloop.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions that implement the full BART posterior MCMC loop.
27The main entry point is `run_mcmc`.
28"""
30from collections.abc import Callable 1ab
31from dataclasses import fields, replace 1ab
32from functools import partial, wraps 1ab
33from typing import Any, Protocol 1ab
35import jax 1ab
36import numpy 1ab
37from equinox import Module 1ab
38from jax import debug, lax, tree 1ab
39from jax import numpy as jnp 1ab
40from jaxtyping import ( 1ab
41 Array,
42 Bool,
43 Float32,
44 Int32,
45 Integer,
46 Key,
47 PyTree,
48 Real,
49 Shaped,
50 UInt,
51)
53from bartz import grove, jaxext, mcmcstep 1ab
54from bartz.mcmcstep import State 1ab
57class BurninTrace(Module): 1ab
58 """MCMC trace with only diagnostic values."""
60 sigma2: Float32[Array, '*trace_length'] | None 1ab
61 grow_prop_count: Int32[Array, '*trace_length'] 1ab
62 grow_acc_count: Int32[Array, '*trace_length'] 1ab
63 prune_prop_count: Int32[Array, '*trace_length'] 1ab
64 prune_acc_count: Int32[Array, '*trace_length'] 1ab
65 log_likelihood: Float32[Array, '*trace_length'] | None 1ab
66 log_trans_prior: Float32[Array, '*trace_length'] | None 1ab
68 @classmethod 1ab
69 def from_state(cls, state: State) -> 'BurninTrace': 1ab
70 """Create a single-item burn-in trace from a MCMC state."""
71 return cls( 1ab
72 sigma2=state.sigma2,
73 grow_prop_count=state.forest.grow_prop_count,
74 grow_acc_count=state.forest.grow_acc_count,
75 prune_prop_count=state.forest.prune_prop_count,
76 prune_acc_count=state.forest.prune_acc_count,
77 log_likelihood=state.forest.log_likelihood,
78 log_trans_prior=state.forest.log_trans_prior,
79 )
82class MainTrace(BurninTrace): 1ab
83 """MCMC trace with trees and diagnostic values."""
85 leaf_tree: Real[Array, '*trace_length 2**d'] 1ab
86 var_tree: Real[Array, '*trace_length 2**(d-1)'] 1ab
87 split_tree: Real[Array, '*trace_length 2**(d-1)'] 1ab
88 offset: Float32[Array, '*trace_length'] 1ab
90 @classmethod 1ab
91 def from_state(cls, state: State) -> 'MainTrace': 1ab
92 """Create a single-item main trace from a MCMC state."""
93 return cls( 1ab
94 leaf_tree=state.forest.leaf_tree,
95 var_tree=state.forest.var_tree,
96 split_tree=state.forest.split_tree,
97 offset=state.offset,
98 **vars(BurninTrace.from_state(state)),
99 )
102CallbackState = PyTree[Any, 'T'] 1ab
105class Callback(Protocol): 1ab
106 """Callback type for `run_mcmc`."""
108 def __call__( 1ab
109 self,
110 *,
111 bart: State,
112 burnin: Bool[Array, ''],
113 i_total: Int32[Array, ''],
114 i_skip: Int32[Array, ''],
115 callback_state: CallbackState,
116 n_burn: Int32[Array, ''],
117 n_save: Int32[Array, ''],
118 n_skip: Int32[Array, ''],
119 i_outer: Int32[Array, ''],
120 inner_loop_length: int,
121 ) -> tuple[State, CallbackState] | None:
122 """Do an arbitrary action after an iteration of the MCMC.
124 Parameters
125 ----------
126 bart
127 The MCMC state just after updating it.
128 burnin
129 Whether the last iteration was in the burn-in phase.
130 i_total
131 The index of the last MCMC iteration (0-based).
132 i_skip
133 The number of MCMC updates from the last saved state. The initial
134 state counts as saved, even if it's not copied into the trace.
135 callback_state
136 The callback state, initially set to the argument passed to
137 `run_mcmc`, afterwards to the value returned by the last invocation
138 of the callback.
139 n_burn
140 n_save
141 n_skip
142 The corresponding `run_mcmc` arguments as-is.
143 i_outer
144 The index of the last outer loop iteration (0-based).
145 inner_loop_length
146 The number of MCMC iterations in the inner loop.
148 Returns
149 -------
150 bart : State
151 A possibly modified MCMC state. To avoid modifying the state,
152 return the `bart` argument passed to the callback as-is.
153 callback_state : CallbackState
154 The new state to be passed on the next callback invocation.
156 Notes
157 -----
158 For convenience, the callback may return `None`, and the states won't
159 be updated.
160 """
161 ...
164class _Carry(Module): 1ab
165 """Carry used in the loop in `run_mcmc`."""
167 bart: State 1ab
168 i_total: Int32[Array, ''] 1ab
169 key: Key[Array, ''] 1ab
170 burnin_trace: PyTree[Shaped[Array, 'n_burn *']] 1ab
171 main_trace: PyTree[Shaped[Array, 'n_save *']] 1ab
172 callback_state: CallbackState 1ab
175def run_mcmc( 1ab
176 key: Key[Array, ''],
177 bart: State,
178 n_save: int,
179 *,
180 n_burn: int = 0,
181 n_skip: int = 1,
182 inner_loop_length: int | None = None,
183 callback: Callback | None = None,
184 callback_state: CallbackState = None,
185 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
186 main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
187) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]:
188 """
189 Run the MCMC for the BART posterior.
191 Parameters
192 ----------
193 key
194 A key for random number generation.
195 bart
196 The initial MCMC state, as created and updated by the functions in
197 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
198 so this variable is invalidated after running `run_mcmc`. Make a copy
199 beforehand to use it again.
200 n_save
201 The number of iterations to save.
202 n_burn
203 The number of initial iterations which are not saved.
204 n_skip
205 The number of iterations to skip between each saved iteration, plus 1.
206 The effective burn-in is ``n_burn + n_skip - 1``.
207 inner_loop_length
208 The MCMC loop is split into an outer and an inner loop. The outer loop
209 is in Python, while the inner loop is in JAX. `inner_loop_length` is the
210 number of iterations of the inner loop to run for each iteration of the
211 outer loop. If not specified, the outer loop will iterate just once,
212 with all iterations done in a single inner loop run. The inner stride is
213 unrelated to the stride used for saving the trace.
214 callback
215 An arbitrary function run during the loop after updating the state. For
216 the signature, see `Callback`. The callback is called under the jax jit,
217 so the argument values are not available at the time the Python code is
218 executed. Use the utilities in `jax.debug` to access the values at
219 actual runtime. The callback may return new values for the MCMC state
220 and the callback state.
221 callback_state
222 The initial custom state for the callback.
223 burnin_extractor
224 main_extractor
225 Functions that extract the variables to be saved respectively only in
226 the main trace and in both traces, given the MCMC state as argument.
227 Must return a pytree, and must be vmappable.
229 Returns
230 -------
231 bart : State
232 The final MCMC state.
233 burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
234 The trace of the burn-in phase. For the default layout, see `BurninTrace`.
235 main_trace : PyTree[Shaped[Array, 'n_save *']]
236 The trace of the main phase. For the default layout, see `MainTrace`.
238 Notes
239 -----
240 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
241 not include the initial state, and include the final state.
242 """
244 def empty_trace(length, bart, extractor): 1ab
245 return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart) 1ab
247 burnin_trace = empty_trace(n_burn, bart, burnin_extractor) 1ab
248 main_trace = empty_trace(n_save, bart, main_extractor) 1ab
250 # determine number of iterations for inner and outer loops
251 n_iters = n_burn + n_skip * n_save 1ab
252 if inner_loop_length is None: 1ab
253 inner_loop_length = n_iters 1ab
254 if inner_loop_length: 1ab
255 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 1ab
256 else:
257 n_outer = 1 1ab
258 # setting to 0 would make for a clean noop, but it's useful to keep the
259 # same code path for benchmarking and testing
261 carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state) 1ab
262 for i_outer in range(n_outer): 1ab
263 carry = _run_mcmc_inner_loop( 1ab
264 carry,
265 inner_loop_length,
266 callback,
267 burnin_extractor,
268 main_extractor,
269 n_burn,
270 n_save,
271 n_skip,
272 i_outer,
273 n_iters,
274 )
276 return carry.bart, carry.burnin_trace, carry.main_trace 1ab
279def _compute_i_skip( 1ab
280 i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
281) -> Int32[Array, '']:
282 """Compute the `i_skip` argument passed to `callback`."""
283 burnin = i_total < n_burn 1ab
284 return jnp.where( 1ab
285 burnin,
286 i_total + 1,
287 (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
288 )
291@partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) 1ab
292def _run_mcmc_inner_loop( 1ab
293 carry: _Carry,
294 inner_loop_length: int,
295 callback: Callback | None,
296 burnin_extractor: Callable[[State], PyTree],
297 main_extractor: Callable[[State], PyTree],
298 n_burn: Int32[Array, ''],
299 n_save: Int32[Array, ''],
300 n_skip: Int32[Array, ''],
301 i_outer: Int32[Array, ''],
302 n_iters: Int32[Array, ''],
303):
304 def loop_impl(carry: _Carry) -> _Carry: 1ab
305 """Loop body to run if i_total < n_iters."""
306 keys = jaxext.split(carry.key) 1ab
307 carry = replace(carry, key=keys.pop()) 1ab
308 carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart)) 1ab
310 burnin = carry.i_total < n_burn 1ab
312 if callback is not None: 1ab
313 i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip) 1ab
314 rt = callback( 1ab
315 bart=carry.bart,
316 burnin=burnin,
317 i_total=carry.i_total,
318 i_skip=i_skip,
319 callback_state=carry.callback_state,
320 n_burn=n_burn,
321 n_save=n_save,
322 n_skip=n_skip,
323 i_outer=i_outer,
324 inner_loop_length=inner_loop_length,
325 )
326 if rt is not None: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true1ab
327 bart, callback_state = rt
328 carry = replace(carry, bart=bart, callback_state=callback_state)
330 def save_to_burnin_trace( 1ab
331 burnin_trace: PyTree, main_trace: PyTree
332 ) -> tuple[PyTree, PyTree]:
333 return pytree_at_set( 1ab
334 burnin_trace, carry.i_total, burnin_extractor(carry.bart)
335 ), main_trace
337 def save_to_main_trace( 1ab
338 burnin_trace: PyTree, main_trace: PyTree
339 ) -> tuple[PyTree, PyTree]:
340 idx = (carry.i_total - n_burn) // n_skip 1ab
341 return burnin_trace, pytree_at_set( 1ab
342 main_trace, idx, main_extractor(carry.bart)
343 )
345 burnin_trace, main_trace = lax.cond( 1ab
346 burnin,
347 save_to_burnin_trace,
348 save_to_main_trace,
349 carry.burnin_trace,
350 carry.main_trace,
351 )
352 return replace( 1ab
353 carry,
354 i_total=carry.i_total + 1,
355 burnin_trace=burnin_trace,
356 main_trace=main_trace,
357 )
359 def loop_noop(carry: _Carry) -> _Carry: 1ab
360 """Loop body to run if i_total >= n_iters; it does nothing."""
361 return carry 1ab
363 def loop(carry: _Carry, _) -> tuple[_Carry, None]: 1ab
364 carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry) 1ab
365 return carry, None 1ab
367 carry, _ = lax.scan(loop, carry, None, inner_loop_length) 1ab
368 return carry 1ab
371def pytree_at_set(dest: PyTree, index: Int32[Array, ''], val: PyTree) -> PyTree: 1ab
372 """Map ``dest.at[index].set(val)`` over pytrees."""
374 def at_set(dest, val): 1ab
375 if dest.size: 1ab
376 return dest.at[index, ...].set(val) 1ab
377 else:
378 # this handles the case where an array is empty because jax refuses
379 # to index into an array of length 0, even if just in the abstract
380 return dest 1ab
382 return tree.map(at_set, dest, val) 1ab
385class _PrintCallbackState(Module): 1ab
386 """State used by `_print_callback`."""
388 dot_every: Int32[Array, ''] | None 1ab
389 report_every: Int32[Array, ''] | None 1ab
392def make_print_callback( 1ab
393 dot_every: int | Integer[Array, ''] | None = 1,
394 report_every: int | Integer[Array, ''] | None = 100,
395) -> dict[str, Any]:
396 """
397 Prepare a logging callback for `run_mcmc`.
399 The callback prints a dot on every iteration, and a longer
400 report outer loop iteration.
402 Parameters
403 ----------
404 dot_every
405 A dot is printed every `dot_every` MCMC iterations, `None` to disable.
406 report_every
407 A one line report is printed every `report_every` MCMC iterations,
408 `None` to disable.
410 Returns
411 -------
412 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
414 Examples
415 --------
416 >>> run_mcmc(..., **make_print_callback())
417 """
419 def asarray_or_none(val: None | Any) -> None | Array: 1ab
420 return None if val is None else jnp.asarray(val) 1ab
422 return dict( 1ab
423 callback=_print_callback,
424 callback_state=_PrintCallbackState(
425 asarray_or_none(dot_every), asarray_or_none(report_every)
426 ),
427 )
430def _print_callback( 1ab
431 *,
432 bart: State,
433 burnin: Bool[Array, ''],
434 i_total: Int32[Array, ''],
435 n_burn: Int32[Array, ''],
436 n_save: Int32[Array, ''],
437 n_skip: Int32[Array, ''],
438 callback_state: _PrintCallbackState,
439 **_,
440):
441 """Print a dot and/or a report periodically during the MCMC."""
442 if callback_state.dot_every is not None: 442 ↛ 451line 442 didn't jump to line 451 because the condition on line 442 was always true1ab
443 cond = (i_total + 1) % callback_state.dot_every == 0 1ab
444 lax.cond( 1ab
445 cond,
446 lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201
447 # logging can't do in-line printing so I'll stick to print
448 lambda: None,
449 )
451 if callback_state.report_every is not None: 451 ↛ exitline 451 didn't return from function '_print_callback' because the condition on line 451 was always true1ab
453 def print_report(): 1ab
454 debug.callback( 1ab
455 _print_report,
456 newline=callback_state.dot_every is not None,
457 burnin=burnin,
458 i_total=i_total,
459 n_iters=n_burn + n_save * n_skip,
460 grow_prop_count=bart.forest.grow_prop_count,
461 grow_acc_count=bart.forest.grow_acc_count,
462 prune_prop_count=bart.forest.prune_prop_count,
463 prune_acc_count=bart.forest.prune_acc_count,
464 prop_total=len(bart.forest.leaf_tree),
465 fill=grove.forest_fill(bart.forest.split_tree),
466 )
468 cond = (i_total + 1) % callback_state.report_every == 0 1ab
469 lax.cond(cond, print_report, lambda: None) 1ab
472def _convert_jax_arrays_in_args(func: Callable) -> Callable: 1ab
473 """Remove jax arrays from a function arguments.
475 Converts all `jax.Array` instances in the arguments to either Python scalars
476 or numpy arrays.
477 """
479 def convert_jax_arrays(pytree: PyTree) -> PyTree: 1ab
480 def convert_jax_arrays(val: Any) -> Any:
481 if not isinstance(val, jax.Array):
482 return val
483 elif val.shape:
484 return numpy.array(val)
485 else:
486 return val.item()
488 return tree.map(convert_jax_arrays, pytree)
490 @wraps(func) 1ab
491 def new_func(*args, **kw): 1ab
492 args = convert_jax_arrays(args)
493 kw = convert_jax_arrays(kw)
494 return func(*args, **kw)
496 return new_func 1ab
499@_convert_jax_arrays_in_args 1ab
500# convert all jax arrays in arguments because operations on them could lead to
501# deadlock with the main thread
502def _print_report( 1ab
503 *,
504 newline: bool,
505 burnin: bool,
506 i_total: int,
507 n_iters: int,
508 grow_prop_count: int,
509 grow_acc_count: int,
510 prune_prop_count: int,
511 prune_acc_count: int,
512 prop_total: int,
513 fill: float,
514):
515 """Print the report for `_print_callback`."""
517 def acc_string(acc_count, prop_count):
518 if prop_count:
519 return f'{acc_count / prop_count:.0%}'
520 else:
521 return 'n/d'
523 grow_prop = grow_prop_count / prop_total
524 prune_prop = prune_prop_count / prop_total
525 grow_acc = acc_string(grow_acc_count, grow_prop_count)
526 prune_acc = acc_string(prune_acc_count, prune_prop_count)
528 prefix = '\n' if newline else ''
529 suffix = ' (burnin)' if burnin else ''
531 print( # noqa: T201, see _print_callback for why not logging
532 f'{prefix}It {i_total + 1}/{n_iters} '
533 f'grow P={grow_prop:.0%} A={grow_acc}, '
534 f'prune P={prune_prop:.0%} A={prune_acc}, '
535 f'fill={fill:.0%}{suffix}'
536 )
539class Trace(grove.TreeHeaps, Protocol): 1ab
540 """Protocol for a MCMC trace."""
542 offset: Float32[Array, ' trace_length'] 1ab
545class TreesTrace(Module): 1ab
546 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
548 leaf_tree: Float32[Array, 'trace_length num_trees 2**d'] 1ab
549 var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab
550 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab
552 @classmethod 1ab
553 def from_dataclass(cls, obj: grove.TreeHeaps): 1ab
554 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
555 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 1ab
558@jax.jit 1ab
559def evaluate_trace( 1ab
560 trace: Trace, X: UInt[Array, 'p n']
561) -> Float32[Array, 'trace_length n']:
562 """
563 Compute predictions for all iterations of the BART MCMC.
565 Parameters
566 ----------
567 trace
568 A trace of the BART MCMC, as returned by `run_mcmc`.
569 X
570 The predictors matrix, with `p` predictors and `n` observations.
572 Returns
573 -------
574 The predictions for each iteration of the MCMC.
575 """
576 evaluate_trees = partial(grove.evaluate_forest, sum_trees=False) 1ab
577 evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0)) 1ab
578 trees = TreesTrace.from_dataclass(trace) 1ab
580 def loop(_, item): 1ab
581 offset, trees = item 1ab
582 values = evaluate_trees(X, trees) 1ab
583 return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32) 1ab
585 _, y = lax.scan(loop, None, (trace.offset, trees)) 1ab
586 return y 1ab
589@partial(jax.jit, static_argnums=(0,)) 1ab
590def compute_varcount( 1ab
591 p: int, trace: grove.TreeHeaps
592) -> Int32[Array, 'trace_length {p}']:
593 """
594 Count how many times each predictor is used in each MCMC state.
596 Parameters
597 ----------
598 p
599 The number of predictors.
600 trace
601 A trace of the BART MCMC, as returned by `run_mcmc`.
603 Returns
604 -------
605 Histogram of predictor usage in each MCMC state.
606 """
607 vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0)) 1ab
608 return vmapped_var_histogram(p, trace.var_tree, trace.split_tree) 1ab