MCMC loop

Functions that implement the full BART posterior MCMC loop.

bartz.mcmcloop.default_onlymain_extractor(state)[source]

Extract variables for the main trace, to be used in run_mcmc.

Return type:

dict[str, Real[Array, 'samples *']]

bartz.mcmcloop.default_both_extractor(state)[source]

Extract variables for main & burn-in traces, to be used in run_mcmc.

Return type:

dict[str, Real[Array, 'samples *'] | None]

bartz.mcmcloop.run_mcmc(key, bart, n_save, *, n_burn=0, n_skip=1, inner_loop_length=None, allow_overflow=False, inner_callback=None, outer_callback=None, callback_state=None, onlymain_extractor=default_onlymain_extractor, both_extractor=default_both_extractor)[source]

Run the MCMC for the BART posterior.

Parameters:
  • key (jax.dtypes.prng_key array) – A key for random number generation.

  • bart (dict) – The initial MCMC state, as created and updated by the functions in bartz.mcmcstep. The MCMC loop uses buffer donation to avoid copies, so this variable is invalidated after running run_mcmc. Make a copy beforehand to use it again.

  • n_save (int) – The number of iterations to save.

  • n_burn (int, default 0) – The number of initial iterations which are not saved.

  • n_skip (int, default 1) – The number of iterations to skip between each saved iteration, plus 1. The effective burn-in is n_burn + n_skip - 1.

  • inner_loop_length (int, optional) – The MCMC loop is split into an outer and an inner loop. The outer loop is in Python, while the inner loop is in JAX. inner_loop_length is the number of iterations of the inner loop to run for each iteration of the outer loop. If not specified, the outer loop will iterate just once, with all iterations done in a single inner loop run. The inner stride is unrelated to the stride used for saving the trace.

  • allow_overflow (bool, default False) – If False, inner_loop_length must be a divisor of the total number of iterations n_burn + n_skip * n_save. If True and inner_loop_length is not a divisor, some of the MCMC iterations in the last outer loop iteration will not be saved to the trace.

  • inner_callback (callable, optional)

  • outer_callback (callable, optional) –

    Arbitrary functions run during the loop after updating the state. inner_callback is called after each update, while outer_callback is called after completing an inner loop. The callbacks are invoked with the following arguments, passed by keyword:

    bartdict

    The MCMC state just after updating it.

    burninbool

    Whether the last iteration was in the burn-in phase.

    overflowbool

    Whether the last iteration was in the overflow phase (iterations not saved due to inner_loop_length not being a divisor of the total number of iterations).

    i_totalint

    The index of the last MCMC iteration (0-based).

    i_skipint

    The number of MCMC updates from the last saved state. The initial state counts as saved, even if it’s not copied into the trace.

    callback_statejax pytree

    The callback state, initially set to the argument passed to run_mcmc, afterwards to the value returned by the last invocation of inner_callback or outer_callback.

    n_burn, n_save, n_skipint

    The corresponding arguments as-is.

    i_outerint

    The index of the last outer loop iteration (0-based).

    inner_loop_lengthint

    The number of MCMC iterations in the inner loop.

    inner_callback is called under the jax jit, so the argument values are not available at the time the Python code is executed. Use the utilities in jax.debug to access the values at actual runtime.

    The callbacks must return two values:

    bartdict

    A possibly modified MCMC state. To avoid modifying the state, return the bart argument passed to the callback as-is.

    callback_statejax pytree

    The new state to be passed on the next callback invocation.

    For convenience, if a callback returns None, the states are not updated.

  • callback_state (jax pytree, optional) – The initial state for the callbacks.

  • onlymain_extractor (callable, optional)

  • both_extractor (callable, optional) – Functions that extract the variables to be saved respectively only in the main trace and in both traces, given the MCMC state as argument. Must return a pytree, and must be vmappable.

Returns:

  • bart (dict) – The final MCMC state.

  • burnin_trace (dict of (n_burn, …) arrays) – The trace of the burn-in phase, containing the following subset of fields from the bart dictionary, with an additional head index that runs over MCMC iterations: ‘sigma2’, ‘grow_prop_count’, ‘grow_acc_count’, ‘prune_prop_count’, ‘prune_acc_count’ (or if specified the fields in tracevars_both).

  • main_trace (dict of (n_save, …) arrays) – The trace of the main phase, containing the following subset of fields from the bart dictionary, with an additional head index that runs over MCMC iterations: ‘leaf_trees’, ‘var_trees’, ‘split_trees’ (or if specified the fields in tracevars_onlymain), plus the fields in burnin_trace.

Raises:

ValueError – If inner_loop_length is not a divisor of the total number of iterations and allow_overflow is False.

Notes

The number of MCMC updates is n_burn + n_skip * n_save. The traces do not include the initial state, and include the final state.

bartz.mcmcloop.make_print_callbacks(dot_every_inner=1, report_every_outer=1)[source]

Prepare logging callbacks for run_mcmc.

Prepare callbacks which print a dot on every iteration, and a longer report outer loop iteration.

Parameters:
  • dot_every_inner (int, default 1) – A dot is printed every dot_every_inner MCMC iterations.

  • report_every_outer (int, default 1) – A report is printed every report_every_outer outer loop iterations.

Returns:

kwargs (dict) – A dictionary with the arguments to pass to run_mcmc as keyword arguments to set up the callbacks.

Examples

>>> run_mcmc(..., **make_print_callbacks())
bartz.mcmcloop.evaluate_trace(trace, X)[source]

Compute predictions for all iterations of the BART MCMC.

Parameters:
  • trace (dict) – A trace of the BART MCMC, as returned by run_mcmc.

  • X (array (p, n)) – The predictors matrix, with p predictors and n observations.

Returns:

y (array (n_trace, n)) – The predictions for each iteration of the MCMC.