MCMC loop

Functions that implement the full BART posterior MCMC loop.

The main entry point is run_mcmc.

class bartz.mcmcloop.BurninTrace(sigma2, grow_prop_count, grow_acc_count, prune_prop_count, prune_acc_count, log_likelihood, log_trans_prior)[source]

MCMC trace with only diagnostic values.

classmethod from_state(state)[source]

Create a single-item burn-in trace from a MCMC state.

Return type:

BurninTrace

class bartz.mcmcloop.MainTrace(sigma2, grow_prop_count, grow_acc_count, prune_prop_count, prune_acc_count, log_likelihood, log_trans_prior, leaf_tree, var_tree, split_tree, offset)[source]

MCMC trace with trees and diagnostic values.

classmethod from_state(state)[source]

Create a single-item main trace from a MCMC state.

Return type:

MainTrace

class bartz.mcmcloop.Callback(*args, **kwargs)[source]

Callback type for run_mcmc.

__call__(*, bart, burnin, i_total, i_skip, callback_state, n_burn, n_save, n_skip, i_outer, inner_loop_length)[source]

Do an arbitrary action after an iteration of the MCMC.

Parameters:
  • bart (State) – The MCMC state just after updating it.

  • burnin (Bool[Array, '']) – Whether the last iteration was in the burn-in phase.

  • i_total (Int32[Array, '']) – The index of the last MCMC iteration (0-based).

  • i_skip (Int32[Array, '']) – The number of MCMC updates from the last saved state. The initial state counts as saved, even if it’s not copied into the trace.

  • callback_state (PyTree[Any, 'T']) – The callback state, initially set to the argument passed to run_mcmc, afterwards to the value returned by the last invocation of the callback.

  • n_burn (Int32[Array, ''])

  • n_save (Int32[Array, ''])

  • n_skip (Int32[Array, '']) – The corresponding run_mcmc arguments as-is.

  • i_outer (Int32[Array, '']) – The index of the last outer loop iteration (0-based).

  • inner_loop_length (int) – The number of MCMC iterations in the inner loop.

Returns:

  • bart (State) – A possibly modified MCMC state. To avoid modifying the state, return the bart argument passed to the callback as-is.

  • callback_state (CallbackState) – The new state to be passed on the next callback invocation.

Notes

For convenience, the callback may return None, and the states won’t be updated.

bartz.mcmcloop.run_mcmc(key, bart, n_save, *, n_burn=0, n_skip=1, inner_loop_length=None, callback=None, callback_state=None, burnin_extractor=BurninTrace.from_state, main_extractor=MainTrace.from_state)[source]

Run the MCMC for the BART posterior.

Parameters:
  • key (Key[Array, '']) – A key for random number generation.

  • bart (State) – The initial MCMC state, as created and updated by the functions in bartz.mcmcstep. The MCMC loop uses buffer donation to avoid copies, so this variable is invalidated after running run_mcmc. Make a copy beforehand to use it again.

  • n_save (int) – The number of iterations to save.

  • n_burn (int, default: 0) – The number of initial iterations which are not saved.

  • n_skip (int, default: 1) – The number of iterations to skip between each saved iteration, plus 1. The effective burn-in is n_burn + n_skip - 1.

  • inner_loop_length (int | None, default: None) – The MCMC loop is split into an outer and an inner loop. The outer loop is in Python, while the inner loop is in JAX. inner_loop_length is the number of iterations of the inner loop to run for each iteration of the outer loop. If not specified, the outer loop will iterate just once, with all iterations done in a single inner loop run. The inner stride is unrelated to the stride used for saving the trace.

  • callback (Callback | None, default: None) – An arbitrary function run during the loop after updating the state. For the signature, see Callback. The callback is called under the jax jit, so the argument values are not available at the time the Python code is executed. Use the utilities in jax.debug to access the values at actual runtime. The callback may return new values for the MCMC state and the callback state.

  • callback_state (PyTree[Any, 'T'], default: None) – The initial custom state for the callback.

  • burnin_extractor (Callable[[State], PyTree], default: BurninTrace.from_state)

  • main_extractor (Callable[[State], PyTree], default: MainTrace.from_state) – Functions that extract the variables to be saved respectively only in the main trace and in both traces, given the MCMC state as argument. Must return a pytree, and must be vmappable.

Returns:

  • bart (State) – The final MCMC state.

  • burnin_trace (PyTree[Shaped[Array, ‘n_burn *’]]) – The trace of the burn-in phase. For the default layout, see BurninTrace.

  • main_trace (PyTree[Shaped[Array, ‘n_save *’]]) – The trace of the main phase. For the default layout, see MainTrace.

Notes

The number of MCMC updates is n_burn + n_skip * n_save. The traces do not include the initial state, and include the final state.

bartz.mcmcloop.pytree_at_set(dest, index, val)[source]

Map dest.at[index].set(val) over pytrees.

Return type:

PyTree

bartz.mcmcloop.make_print_callback(dot_every=1, report_every=100)[source]

Prepare a logging callback for run_mcmc.

The callback prints a dot on every iteration, and a longer report outer loop iteration.

Parameters:
  • dot_every (int | Integer[Array, ''] | None, default: 1) – A dot is printed every dot_every MCMC iterations, None to disable.

  • report_every (int | Integer[Array, ''] | None, default: 100) – A one line report is printed every report_every MCMC iterations, None to disable.

Returns:

dict[str, Any] – A dictionary with the arguments to pass to run_mcmc as keyword arguments to set up the callback.

Examples

>>> run_mcmc(..., **make_print_callback())
class bartz.mcmcloop.Trace(*args, **kwargs)[source]

Protocol for a MCMC trace.

class bartz.mcmcloop.TreesTrace(leaf_tree, var_tree, split_tree)[source]

Implementation of bartz.grove.TreeHeaps for an MCMC trace.

classmethod from_dataclass(obj)[source]

Create a TreesTrace from any bartz.grove.TreeHeaps.

bartz.mcmcloop.evaluate_trace(trace, X)[source]

Compute predictions for all iterations of the BART MCMC.

Parameters:
  • trace (Trace) – A trace of the BART MCMC, as returned by run_mcmc.

  • X (UInt[Array, 'p n']) – The predictors matrix, with p predictors and n observations.

Returns:

Float32[Array, 'trace_length n']The predictions for each iteration of the MCMC.

bartz.mcmcloop.compute_varcount(p, trace)[source]

Count how many times each predictor is used in each MCMC state.

Parameters:
  • p (int) – The number of predictors.

  • trace (TreeHeaps) – A trace of the BART MCMC, as returned by run_mcmc.

Returns:

Int32[Array, 'trace_length {p}']Histogram of predictor usage in each MCMC state.