MCMC loop¶
Functions that implement the full BART posterior MCMC loop.
- bartz.mcmcloop.default_onlymain_extractor(state)[source]¶
Extract variables for the main trace, to be used in
run_mcmc
.- Return type:
dict
[str
,Real[Array, 'samples *']
]
- bartz.mcmcloop.default_both_extractor(state)[source]¶
Extract variables for main & burn-in traces, to be used in
run_mcmc
.- Return type:
dict
[str
,Real[Array, 'samples *']
|None
]
- bartz.mcmcloop.run_mcmc(key, bart, n_save, *, n_burn=0, n_skip=1, inner_loop_length=None, allow_overflow=False, inner_callback=None, outer_callback=None, callback_state=None, onlymain_extractor=default_onlymain_extractor, both_extractor=default_both_extractor)[source]¶
Run the MCMC for the BART posterior.
- Parameters:
key (jax.dtypes.prng_key array) – A key for random number generation.
bart (dict) – The initial MCMC state, as created and updated by the functions in
bartz.mcmcstep
. The MCMC loop uses buffer donation to avoid copies, so this variable is invalidated after runningrun_mcmc
. Make a copy beforehand to use it again.n_save (int) – The number of iterations to save.
n_burn (int, default 0) – The number of initial iterations which are not saved.
n_skip (int, default 1) – The number of iterations to skip between each saved iteration, plus 1. The effective burn-in is
n_burn + n_skip - 1
.inner_loop_length (int, optional) – The MCMC loop is split into an outer and an inner loop. The outer loop is in Python, while the inner loop is in JAX.
inner_loop_length
is the number of iterations of the inner loop to run for each iteration of the outer loop. If not specified, the outer loop will iterate just once, with all iterations done in a single inner loop run. The inner stride is unrelated to the stride used for saving the trace.allow_overflow (bool, default False) – If
False
,inner_loop_length
must be a divisor of the total number of iterationsn_burn + n_skip * n_save
. IfTrue
andinner_loop_length
is not a divisor, some of the MCMC iterations in the last outer loop iteration will not be saved to the trace.inner_callback (callable, optional)
outer_callback (callable, optional) –
Arbitrary functions run during the loop after updating the state.
inner_callback
is called after each update, whileouter_callback
is called after completing an inner loop. The callbacks are invoked with the following arguments, passed by keyword:- bartdict
The MCMC state just after updating it.
- burninbool
Whether the last iteration was in the burn-in phase.
- overflowbool
Whether the last iteration was in the overflow phase (iterations not saved due to
inner_loop_length
not being a divisor of the total number of iterations).- i_totalint
The index of the last MCMC iteration (0-based).
- i_skipint
The number of MCMC updates from the last saved state. The initial state counts as saved, even if it’s not copied into the trace.
- callback_statejax pytree
The callback state, initially set to the argument passed to
run_mcmc
, afterwards to the value returned by the last invocation ofinner_callback
orouter_callback
.- n_burn, n_save, n_skipint
The corresponding arguments as-is.
- i_outerint
The index of the last outer loop iteration (0-based).
- inner_loop_lengthint
The number of MCMC iterations in the inner loop.
inner_callback
is called under the jax jit, so the argument values are not available at the time the Python code is executed. Use the utilities injax.debug
to access the values at actual runtime.The callbacks must return two values:
- bartdict
A possibly modified MCMC state. To avoid modifying the state, return the
bart
argument passed to the callback as-is.- callback_statejax pytree
The new state to be passed on the next callback invocation.
For convenience, if a callback returns
None
, the states are not updated.callback_state (jax pytree, optional) – The initial state for the callbacks.
onlymain_extractor (callable, optional)
both_extractor (callable, optional) – Functions that extract the variables to be saved respectively only in the main trace and in both traces, given the MCMC state as argument. Must return a pytree, and must be vmappable.
- Returns:
bart (dict) – The final MCMC state.
burnin_trace (dict of (n_burn, …) arrays) – The trace of the burn-in phase, containing the following subset of fields from the
bart
dictionary, with an additional head index that runs over MCMC iterations: ‘sigma2’, ‘grow_prop_count’, ‘grow_acc_count’, ‘prune_prop_count’, ‘prune_acc_count’ (or if specified the fields intracevars_both
).main_trace (dict of (n_save, …) arrays) – The trace of the main phase, containing the following subset of fields from the
bart
dictionary, with an additional head index that runs over MCMC iterations: ‘leaf_trees’, ‘var_trees’, ‘split_trees’ (or if specified the fields intracevars_onlymain
), plus the fields inburnin_trace
.
- Raises:
ValueError – If
inner_loop_length
is not a divisor of the total number of iterations andallow_overflow
isFalse
.
Notes
The number of MCMC updates is
n_burn + n_skip * n_save
. The traces do not include the initial state, and include the final state.
- bartz.mcmcloop.make_print_callbacks(dot_every_inner=1, report_every_outer=1)[source]¶
Prepare logging callbacks for
run_mcmc
.Prepare callbacks which print a dot on every iteration, and a longer report outer loop iteration.
- Parameters:
dot_every_inner (int, default 1) – A dot is printed every
dot_every_inner
MCMC iterations.report_every_outer (int, default 1) – A report is printed every
report_every_outer
outer loop iterations.
- Returns:
kwargs (dict) – A dictionary with the arguments to pass to
run_mcmc
as keyword arguments to set up the callbacks.
Examples
>>> run_mcmc(..., **make_print_callbacks())
- bartz.mcmcloop.evaluate_trace(trace, X)[source]¶
Compute predictions for all iterations of the BART MCMC.
- Parameters:
trace (dict) – A trace of the BART MCMC, as returned by
run_mcmc
.X (array (p, n)) – The predictors matrix, with
p
predictors andn
observations.
- Returns:
y (array (n_trace, n)) – The predictions for each iteration of the MCMC.