MCMC loop¶
Functions that implement the full BART posterior MCMC loop.
- bartz.mcmcloop.run_mcmc(bart, n_burn, n_save, n_skip, callback, key)[source]¶
Run the MCMC for the BART posterior.
- Parameters:
- bartdict
The initial MCMC state, as created and updated by the functions in
bartz.mcmcstep
.- n_burnint
The number of initial iterations which are not saved.
- n_saveint
The number of iterations to save.
- n_skipint
The number of iterations to skip between each saved iteration, plus 1.
- callbackcallable
An arbitrary function run at each iteration, called with the following arguments, passed by keyword:
- bartdict
The current MCMC state.
- burninbool
Whether the last iteration was in the burn-in phase.
- i_totalint
The index of the last iteration (0-based).
- i_skipint
The index of the last iteration, starting from the last saved iteration.
- n_burn, n_save, n_skipint
The corresponding arguments as-is.
Since this function is called under the jax jit, the values are not available at the time the Python code is executed. Use the utilities in
jax.debug
to access the values at actual runtime.- keyjax.dtypes.prng_key array
The key for random number generation.
- Returns:
- bartdict
The final MCMC state.
- burnin_tracedict
The trace of the burn-in phase, containing the following subset of fields from the
bart
dictionary, with an additional head index that runs over MCMC iterations: ‘sigma2’, ‘grow_prop_count’, ‘grow_acc_count’, ‘prune_prop_count’, ‘prune_acc_count’.- main_tracedict
The trace of the main phase, containing the following subset of fields from the
bart
dictionary, with an additional head index that runs over MCMC iterations: ‘leaf_trees’, ‘var_trees’, ‘split_trees’, plus the fields inburnin_trace
.
- bartz.mcmcloop.make_simple_print_callback(printevery)[source]¶
Create a logging callback function for MCMC iterations.
- Parameters:
- printeveryint
The number of iterations between each log.
- Returns:
- callbackcallable
A function in the format required by
run_mcmc
.
- bartz.mcmcloop.evaluate_trace(trace, X)[source]¶
Compute predictions for all iterations of the BART MCMC.
- Parameters:
- tracedict
A trace of the BART MCMC, as returned by
run_mcmc
.- Xarray (p, n)
The predictors matrix, with
p
predictors andn
observations.
- Returns:
- yarray (n_trace, n)
The predictions for each iteration of the MCMC.