MCMC loop

Functions that implement the full BART posterior MCMC loop.

bartz.mcmcloop.run_mcmc(bart, n_burn, n_save, n_skip, callback, key)[source]

Run the MCMC for the BART posterior.

Parameters:
bartdict

The initial MCMC state, as created and updated by the functions in bartz.mcmcstep.

n_burnint

The number of initial iterations which are not saved.

n_saveint

The number of iterations to save.

n_skipint

The number of iterations to skip between each saved iteration, plus 1.

callbackcallable

An arbitrary function run at each iteration, called with the following arguments, passed by keyword:

bartdict

The current MCMC state.

burninbool

Whether the last iteration was in the burn-in phase.

i_totalint

The index of the last iteration (0-based).

i_skipint

The index of the last iteration, starting from the last saved iteration.

n_burn, n_save, n_skipint

The corresponding arguments as-is.

Since this function is called under the jax jit, the values are not available at the time the Python code is executed. Use the utilities in jax.debug to access the values at actual runtime.

keyjax.dtypes.prng_key array

The key for random number generation.

Returns:
bartdict

The final MCMC state.

burnin_tracedict

The trace of the burn-in phase, containing the following subset of fields from the bart dictionary, with an additional head index that runs over MCMC iterations: ‘sigma2’, ‘grow_prop_count’, ‘grow_acc_count’, ‘prune_prop_count’, ‘prune_acc_count’.

main_tracedict

The trace of the main phase, containing the following subset of fields from the bart dictionary, with an additional head index that runs over MCMC iterations: ‘leaf_trees’, ‘var_trees’, ‘split_trees’, plus the fields in burnin_trace.

bartz.mcmcloop.make_simple_print_callback(printevery)[source]

Create a logging callback function for MCMC iterations.

Parameters:
printeveryint

The number of iterations between each log.

Returns:
callbackcallable

A function in the format required by run_mcmc.

bartz.mcmcloop.evaluate_trace(trace, X)[source]

Compute predictions for all iterations of the BART MCMC.

Parameters:
tracedict

A trace of the BART MCMC, as returned by run_mcmc.

Xarray (p, n)

The predictors matrix, with p predictors and n observations.

Returns:
yarray (n_trace, n)

The predictions for each iteration of the MCMC.