Coverage for src/bartz/mcmcloop.py: 86%
179 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-07 22:47 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-07 22:47 +0000
1# bartz/src/bartz/mcmcloop.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions that implement the full BART posterior MCMC loop.
27The entry points are `run_mcmc` and `make_default_callback`.
28"""
30from collections.abc import Callable 1ab
31from dataclasses import fields, replace 1ab
32from functools import partial, wraps 1ab
33from typing import Any, Protocol 1ab
35import jax 1ab
36import numpy 1ab
37from equinox import Module 1ab
38from jax import debug, lax, tree 1ab
39from jax import numpy as jnp 1ab
40from jax.nn import softmax 1ab
41from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, PyTree, Shaped, UInt 1ab
43from bartz import grove, jaxext, mcmcstep 1ab
44from bartz.mcmcstep import State 1ab
47class BurninTrace(Module): 1ab
48 """MCMC trace with only diagnostic values."""
50 sigma2: Float32[Array, '*trace_length'] | None 1ab
51 theta: Float32[Array, '*trace_length'] | None 1ab
52 grow_prop_count: Int32[Array, '*trace_length'] 1ab
53 grow_acc_count: Int32[Array, '*trace_length'] 1ab
54 prune_prop_count: Int32[Array, '*trace_length'] 1ab
55 prune_acc_count: Int32[Array, '*trace_length'] 1ab
56 log_likelihood: Float32[Array, '*trace_length'] | None 1ab
57 log_trans_prior: Float32[Array, '*trace_length'] | None 1ab
59 @classmethod 1ab
60 def from_state(cls, state: State) -> 'BurninTrace': 1ab
61 """Create a single-item burn-in trace from a MCMC state."""
62 return cls( 1ab
63 sigma2=state.sigma2,
64 theta=state.forest.theta,
65 grow_prop_count=state.forest.grow_prop_count,
66 grow_acc_count=state.forest.grow_acc_count,
67 prune_prop_count=state.forest.prune_prop_count,
68 prune_acc_count=state.forest.prune_acc_count,
69 log_likelihood=state.forest.log_likelihood,
70 log_trans_prior=state.forest.log_trans_prior,
71 )
74class MainTrace(BurninTrace): 1ab
75 """MCMC trace with trees and diagnostic values."""
77 leaf_tree: Float32[Array, '*trace_length 2**d'] 1ab
78 var_tree: UInt[Array, '*trace_length 2**(d-1)'] 1ab
79 split_tree: UInt[Array, '*trace_length 2**(d-1)'] 1ab
80 offset: Float32[Array, '*trace_length'] 1ab
81 varprob: Float32[Array, '*trace_length p'] | None 1ab
83 @classmethod 1ab
84 def from_state(cls, state: State) -> 'MainTrace': 1ab
85 """Create a single-item main trace from a MCMC state."""
86 # compute varprob
87 log_s = state.forest.log_s 1ab
88 if log_s is None: 1ab
89 varprob = None 1ab
90 else:
91 varprob = softmax(log_s, where=state.forest.max_split.astype(bool)) 1ab
93 return cls( 1ab
94 leaf_tree=state.forest.leaf_tree,
95 var_tree=state.forest.var_tree,
96 split_tree=state.forest.split_tree,
97 offset=state.offset,
98 varprob=varprob,
99 **vars(BurninTrace.from_state(state)),
100 )
103CallbackState = PyTree[Any, 'T'] 1ab
106class Callback(Protocol): 1ab
107 """Callback type for `run_mcmc`."""
109 def __call__( 1ab
110 self,
111 *,
112 key: Key[Array, ''],
113 bart: State,
114 burnin: Bool[Array, ''],
115 i_total: Int32[Array, ''],
116 i_skip: Int32[Array, ''],
117 callback_state: CallbackState,
118 n_burn: Int32[Array, ''],
119 n_save: Int32[Array, ''],
120 n_skip: Int32[Array, ''],
121 i_outer: Int32[Array, ''],
122 inner_loop_length: int,
123 ) -> tuple[State, CallbackState] | None:
124 """Do an arbitrary action after an iteration of the MCMC.
126 Parameters
127 ----------
128 key
129 A key for random number generation.
130 bart
131 The MCMC state just after updating it.
132 burnin
133 Whether the last iteration was in the burn-in phase.
134 i_total
135 The index of the last MCMC iteration (0-based).
136 i_skip
137 The number of MCMC updates from the last saved state. The initial
138 state counts as saved, even if it's not copied into the trace.
139 callback_state
140 The callback state, initially set to the argument passed to
141 `run_mcmc`, afterwards to the value returned by the last invocation
142 of the callback.
143 n_burn
144 n_save
145 n_skip
146 The corresponding `run_mcmc` arguments as-is.
147 i_outer
148 The index of the last outer loop iteration (0-based).
149 inner_loop_length
150 The number of MCMC iterations in the inner loop.
152 Returns
153 -------
154 bart : State
155 A possibly modified MCMC state. To avoid modifying the state,
156 return the `bart` argument passed to the callback as-is.
157 callback_state : CallbackState
158 The new state to be passed on the next callback invocation.
160 Notes
161 -----
162 For convenience, the callback may return `None`, and the states won't
163 be updated.
164 """
165 ...
168class _Carry(Module): 1ab
169 """Carry used in the loop in `run_mcmc`."""
171 bart: State 1ab
172 i_total: Int32[Array, ''] 1ab
173 key: Key[Array, ''] 1ab
174 burnin_trace: PyTree[Shaped[Array, 'n_burn *']] 1ab
175 main_trace: PyTree[Shaped[Array, 'n_save *']] 1ab
176 callback_state: CallbackState 1ab
179def run_mcmc( 1ab
180 key: Key[Array, ''],
181 bart: State,
182 n_save: int,
183 *,
184 n_burn: int = 0,
185 n_skip: int = 1,
186 inner_loop_length: int | None = None,
187 callback: Callback | None = None,
188 callback_state: CallbackState = None,
189 burnin_extractor: Callable[[State], PyTree] = BurninTrace.from_state,
190 main_extractor: Callable[[State], PyTree] = MainTrace.from_state,
191) -> tuple[State, PyTree[Shaped[Array, 'n_burn *']], PyTree[Shaped[Array, 'n_save *']]]:
192 """
193 Run the MCMC for the BART posterior.
195 Parameters
196 ----------
197 key
198 A key for random number generation.
199 bart
200 The initial MCMC state, as created and updated by the functions in
201 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
202 so this variable is invalidated after running `run_mcmc`. Make a copy
203 beforehand to use it again.
204 n_save
205 The number of iterations to save.
206 n_burn
207 The number of initial iterations which are not saved.
208 n_skip
209 The number of iterations to skip between each saved iteration, plus 1.
210 The effective burn-in is ``n_burn + n_skip - 1``.
211 inner_loop_length
212 The MCMC loop is split into an outer and an inner loop. The outer loop
213 is in Python, while the inner loop is in JAX. `inner_loop_length` is the
214 number of iterations of the inner loop to run for each iteration of the
215 outer loop. If not specified, the outer loop will iterate just once,
216 with all iterations done in a single inner loop run. The inner stride is
217 unrelated to the stride used for saving the trace.
218 callback
219 An arbitrary function run during the loop after updating the state. For
220 the signature, see `Callback`. The callback is called under the jax jit,
221 so the argument values are not available at the time the Python code is
222 executed. Use the utilities in `jax.debug` to access the values at
223 actual runtime. The callback may return new values for the MCMC state
224 and the callback state.
225 callback_state
226 The initial custom state for the callback.
227 burnin_extractor
228 main_extractor
229 Functions that extract the variables to be saved respectively only in
230 the main trace and in both traces, given the MCMC state as argument.
231 Must return a pytree, and must be vmappable.
233 Returns
234 -------
235 bart : State
236 The final MCMC state.
237 burnin_trace : PyTree[Shaped[Array, 'n_burn *']]
238 The trace of the burn-in phase. For the default layout, see `BurninTrace`.
239 main_trace : PyTree[Shaped[Array, 'n_save *']]
240 The trace of the main phase. For the default layout, see `MainTrace`.
242 Notes
243 -----
244 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
245 not include the initial state, and include the final state.
246 """
248 def empty_trace(length, bart, extractor): 1ab
249 return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart) 1ab
251 burnin_trace = empty_trace(n_burn, bart, burnin_extractor) 1ab
252 main_trace = empty_trace(n_save, bart, main_extractor) 1ab
254 # determine number of iterations for inner and outer loops
255 n_iters = n_burn + n_skip * n_save 1ab
256 if inner_loop_length is None: 1ab
257 inner_loop_length = n_iters 1ab
258 if inner_loop_length: 1ab
259 n_outer = n_iters // inner_loop_length + bool(n_iters % inner_loop_length) 1ab
260 else:
261 n_outer = 1 1ab
262 # setting to 0 would make for a clean noop, but it's useful to keep the
263 # same code path for benchmarking and testing
265 carry = _Carry(bart, jnp.int32(0), key, burnin_trace, main_trace, callback_state) 1ab
266 for i_outer in range(n_outer): 1ab
267 carry = _run_mcmc_inner_loop( 1ab
268 carry,
269 inner_loop_length,
270 callback,
271 burnin_extractor,
272 main_extractor,
273 n_burn,
274 n_save,
275 n_skip,
276 i_outer,
277 n_iters,
278 )
280 return carry.bart, carry.burnin_trace, carry.main_trace 1ab
283def _compute_i_skip( 1ab
284 i_total: Int32[Array, ''], n_burn: Int32[Array, ''], n_skip: Int32[Array, '']
285) -> Int32[Array, '']:
286 """Compute the `i_skip` argument passed to `callback`."""
287 burnin = i_total < n_burn 1ab
288 return jnp.where( 1ab
289 burnin,
290 i_total + 1,
291 (i_total - n_burn + 1) % n_skip
292 + jnp.where(i_total - n_burn + 1 < n_skip, n_burn, 0),
293 )
296@partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) 1ab
297def _run_mcmc_inner_loop( 1ab
298 carry: _Carry,
299 inner_loop_length: int,
300 callback: Callback | None,
301 burnin_extractor: Callable[[State], PyTree],
302 main_extractor: Callable[[State], PyTree],
303 n_burn: Int32[Array, ''],
304 n_save: Int32[Array, ''],
305 n_skip: Int32[Array, ''],
306 i_outer: Int32[Array, ''],
307 n_iters: Int32[Array, ''],
308):
309 def loop_impl(carry: _Carry) -> _Carry: 1ab
310 """Loop body to run if i_total < n_iters."""
311 # split random key
312 keys = jaxext.split(carry.key, 3) 1ab
313 carry = replace(carry, key=keys.pop()) 1ab
315 # update state
316 carry = replace(carry, bart=mcmcstep.step(keys.pop(), carry.bart)) 1ab
318 burnin = carry.i_total < n_burn 1ab
320 # invoke callback
321 if callback is not None: 1ab
322 i_skip = _compute_i_skip(carry.i_total, n_burn, n_skip) 1ab
323 rt = callback( 1ab
324 key=keys.pop(),
325 bart=carry.bart,
326 burnin=burnin,
327 i_total=carry.i_total,
328 i_skip=i_skip,
329 callback_state=carry.callback_state,
330 n_burn=n_burn,
331 n_save=n_save,
332 n_skip=n_skip,
333 i_outer=i_outer,
334 inner_loop_length=inner_loop_length,
335 )
336 if rt is not None: 336 ↛ 340line 336 didn't jump to line 340 because the condition on line 336 was always true1ab
337 bart, callback_state = rt 1ab
338 carry = replace(carry, bart=bart, callback_state=callback_state) 1ab
340 def save_to_burnin_trace() -> tuple[PyTree, PyTree]: 1ab
341 return _pytree_at_set( 1ab
342 carry.burnin_trace, carry.i_total, burnin_extractor(carry.bart)
343 ), carry.main_trace
345 def save_to_main_trace() -> tuple[PyTree, PyTree]: 1ab
346 idx = (carry.i_total - n_burn) // n_skip 1ab
347 return carry.burnin_trace, _pytree_at_set( 1ab
348 carry.main_trace, idx, main_extractor(carry.bart)
349 )
351 # save state to trace
352 burnin_trace, main_trace = lax.cond( 1ab
353 burnin, save_to_burnin_trace, save_to_main_trace
354 )
355 return replace( 1ab
356 carry,
357 i_total=carry.i_total + 1,
358 burnin_trace=burnin_trace,
359 main_trace=main_trace,
360 )
362 def loop_noop(carry: _Carry) -> _Carry: 1ab
363 """Loop body to run if i_total >= n_iters; it does nothing."""
364 return carry 1ab
366 def loop(carry: _Carry, _) -> tuple[_Carry, None]: 1ab
367 carry = lax.cond(carry.i_total < n_iters, loop_impl, loop_noop, carry) 1ab
368 return carry, None 1ab
370 carry, _ = lax.scan(loop, carry, None, inner_loop_length) 1ab
371 return carry 1ab
374def _pytree_at_set( 1ab
375 dest: PyTree[Array, ' T'], index: Int32[Array, ''], val: PyTree[Array]
376) -> PyTree[Array, ' T']:
377 """Map ``dest.at[index].set(val)`` over pytrees."""
379 def at_set(dest, val): 1ab
380 if dest.size: 1ab
381 return dest.at[index, ...].set(val) 1ab
382 else:
383 # this handles the case where an array is empty because jax refuses
384 # to index into an array of length 0, even if just in the abstract
385 return dest 1ab
387 return tree.map(at_set, dest, val) 1ab
390def make_default_callback( 1ab
391 *,
392 dot_every: int | Integer[Array, ''] | None = 1,
393 report_every: int | Integer[Array, ''] | None = 100,
394 sparse_on_at: int | Integer[Array, ''] | None = None,
395) -> dict[str, Any]:
396 """
397 Prepare a default callback for `run_mcmc`.
399 The callback prints a dot on every iteration, and a longer
400 report outer loop iteration, and can do variable selection.
402 Parameters
403 ----------
404 dot_every
405 A dot is printed every `dot_every` MCMC iterations, `None` to disable.
406 report_every
407 A one line report is printed every `report_every` MCMC iterations,
408 `None` to disable.
409 sparse_on_at
410 If specified, variable selection is activated starting from this
411 iteration. If `None`, variable selection is not used.
413 Returns
414 -------
415 A dictionary with the arguments to pass to `run_mcmc` as keyword arguments to set up the callback.
417 Examples
418 --------
419 >>> run_mcmc(..., **make_default_callback())
420 """
422 def asarray_or_none(val: None | Any) -> None | Array: 1ab
423 return None if val is None else jnp.asarray(val) 1ab
425 def callback(*, bart, callback_state, **kwargs): 1ab
426 print_state, sparse_state = callback_state 1ab
427 bart, _ = sparse_callback(callback_state=sparse_state, bart=bart, **kwargs) 1ab
428 print_callback(callback_state=print_state, bart=bart, **kwargs) 1ab
429 return bart, callback_state 1ab
430 # here I assume that the callbacks don't update their states
432 return dict( 1ab
433 callback=callback,
434 callback_state=(
435 PrintCallbackState(
436 asarray_or_none(dot_every), asarray_or_none(report_every)
437 ),
438 SparseCallbackState(asarray_or_none(sparse_on_at)),
439 ),
440 )
443class PrintCallbackState(Module): 1ab
444 """State for `print_callback`.
446 Parameters
447 ----------
448 dot_every
449 A dot is printed every `dot_every` MCMC iterations, `None` to disable.
450 report_every
451 A one line report is printed every `report_every` MCMC iterations,
452 `None` to disable.
453 """
455 dot_every: Int32[Array, ''] | None 1ab
456 report_every: Int32[Array, ''] | None 1ab
459def print_callback( 1ab
460 *,
461 bart: State,
462 burnin: Bool[Array, ''],
463 i_total: Int32[Array, ''],
464 n_burn: Int32[Array, ''],
465 n_save: Int32[Array, ''],
466 n_skip: Int32[Array, ''],
467 callback_state: PrintCallbackState,
468 **_,
469):
470 """Print a dot and/or a report periodically during the MCMC."""
471 if callback_state.dot_every is not None: 1ab
472 cond = (i_total + 1) % callback_state.dot_every == 0 1ab
473 lax.cond( 1ab
474 cond,
475 lambda: debug.callback(lambda: print('.', end='', flush=True)), # noqa: T201
476 # logging can't do in-line printing so I'll stick to print
477 lambda: None,
478 )
480 if callback_state.report_every is not None: 1ab
482 def print_report(): 1ab
483 debug.callback( 1ab
484 _print_report,
485 newline=callback_state.dot_every is not None,
486 burnin=burnin,
487 i_total=i_total,
488 n_iters=n_burn + n_save * n_skip,
489 grow_prop_count=bart.forest.grow_prop_count,
490 grow_acc_count=bart.forest.grow_acc_count,
491 prune_prop_count=bart.forest.prune_prop_count,
492 prune_acc_count=bart.forest.prune_acc_count,
493 prop_total=len(bart.forest.leaf_tree),
494 fill=grove.forest_fill(bart.forest.split_tree),
495 )
497 cond = (i_total + 1) % callback_state.report_every == 0 1ab
498 lax.cond(cond, print_report, lambda: None) 1ab
501def _convert_jax_arrays_in_args(func: Callable) -> Callable: 1ab
502 """Remove jax arrays from a function arguments.
504 Converts all `jax.Array` instances in the arguments to either Python scalars
505 or numpy arrays.
506 """
508 def convert_jax_arrays(pytree: PyTree) -> PyTree: 1ab
509 def convert_jax_arrays(val: Any) -> Any:
510 if not isinstance(val, jax.Array):
511 return val
512 elif val.shape:
513 return numpy.array(val)
514 else:
515 return val.item()
517 return tree.map(convert_jax_arrays, pytree)
519 @wraps(func) 1ab
520 def new_func(*args, **kw): 1ab
521 args = convert_jax_arrays(args)
522 kw = convert_jax_arrays(kw)
523 return func(*args, **kw)
525 return new_func 1ab
528@_convert_jax_arrays_in_args 1ab
529# convert all jax arrays in arguments because operations on them could lead to
530# deadlock with the main thread
531def _print_report( 1ab
532 *,
533 newline: bool,
534 burnin: bool,
535 i_total: int,
536 n_iters: int,
537 grow_prop_count: int,
538 grow_acc_count: int,
539 prune_prop_count: int,
540 prune_acc_count: int,
541 prop_total: int,
542 fill: float,
543):
544 """Print the report for `print_callback`."""
546 def acc_string(acc_count, prop_count):
547 if prop_count:
548 return f'{acc_count / prop_count:.0%}'
549 else:
550 return 'n/d'
552 grow_prop = grow_prop_count / prop_total
553 prune_prop = prune_prop_count / prop_total
554 grow_acc = acc_string(grow_acc_count, grow_prop_count)
555 prune_acc = acc_string(prune_acc_count, prune_prop_count)
557 prefix = '\n' if newline else ''
558 suffix = ' (burnin)' if burnin else ''
560 print( # noqa: T201, see print_callback for why not logging
561 f'{prefix}It {i_total + 1}/{n_iters} '
562 f'grow P={grow_prop:.0%} A={grow_acc}, '
563 f'prune P={prune_prop:.0%} A={prune_acc}, '
564 f'fill={fill:.0%}{suffix}'
565 )
568class SparseCallbackState(Module): 1ab
569 """State for `sparse_callback`.
571 Parameters
572 ----------
573 sparse_on_at
574 If specified, variable selection is activated starting from this
575 iteration. If `None`, variable selection is not used.
576 """
578 sparse_on_at: Int32[Array, ''] | None 1ab
581def sparse_callback( 1ab
582 *,
583 key: Key[Array, ''],
584 bart: State,
585 i_total: Int32[Array, ''],
586 callback_state: SparseCallbackState,
587 **_,
588):
589 """Perform variable selection, see `mcmcstep.step_sparse`."""
590 if callback_state.sparse_on_at is not None: 1ab
591 bart = lax.cond( 1ab
592 i_total < callback_state.sparse_on_at,
593 lambda: bart,
594 lambda: mcmcstep.step_sparse(key, bart),
595 )
596 return bart, callback_state 1ab
599class Trace(grove.TreeHeaps, Protocol): 1ab
600 """Protocol for a MCMC trace."""
602 offset: Float32[Array, ' trace_length'] 1ab
605class TreesTrace(Module): 1ab
606 """Implementation of `bartz.grove.TreeHeaps` for an MCMC trace."""
608 leaf_tree: Float32[Array, 'trace_length num_trees 2**d'] 1ab
609 var_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab
610 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'] 1ab
612 @classmethod 1ab
613 def from_dataclass(cls, obj: grove.TreeHeaps): 1ab
614 """Create a `TreesTrace` from any `bartz.grove.TreeHeaps`."""
615 return cls(**{f.name: getattr(obj, f.name) for f in fields(cls)}) 1ab
618@jax.jit 1ab
619def evaluate_trace( 1ab
620 trace: Trace, X: UInt[Array, 'p n']
621) -> Float32[Array, 'trace_length n']:
622 """
623 Compute predictions for all iterations of the BART MCMC.
625 Parameters
626 ----------
627 trace
628 A trace of the BART MCMC, as returned by `run_mcmc`.
629 X
630 The predictors matrix, with `p` predictors and `n` observations.
632 Returns
633 -------
634 The predictions for each iteration of the MCMC.
635 """
636 evaluate_trees = partial(grove.evaluate_forest, sum_trees=False) 1ab
637 evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0)) 1ab
638 trees = TreesTrace.from_dataclass(trace) 1ab
640 def loop(_, item): 1ab
641 offset, trees = item 1ab
642 values = evaluate_trees(X, trees) 1ab
643 return None, offset + jnp.sum(values, axis=0, dtype=jnp.float32) 1ab
645 _, y = lax.scan(loop, None, (trace.offset, trees)) 1ab
646 return y 1ab
649@partial(jax.jit, static_argnums=(0,)) 1ab
650def compute_varcount( 1ab
651 p: int, trace: grove.TreeHeaps
652) -> Int32[Array, 'trace_length {p}']:
653 """
654 Count how many times each predictor is used in each MCMC state.
656 Parameters
657 ----------
658 p
659 The number of predictors.
660 trace
661 A trace of the BART MCMC, as returned by `run_mcmc`.
663 Returns
664 -------
665 Histogram of predictor usage in each MCMC state.
666 """
667 vmapped_var_histogram = jax.vmap(grove.var_histogram, in_axes=(None, 0, 0)) 1ab
668 return vmapped_var_histogram(p, trace.var_tree, trace.split_tree) 1ab