MCMC loop¶
Functions that implement the full BART posterior MCMC loop.
The main entry point is run_mcmc
.
- class bartz.mcmcloop.BurninTrace(sigma2, grow_prop_count, grow_acc_count, prune_prop_count, prune_acc_count, log_likelihood, log_trans_prior)[source]¶
MCMC trace with only diagnostic values.
- class bartz.mcmcloop.MainTrace(sigma2, grow_prop_count, grow_acc_count, prune_prop_count, prune_acc_count, log_likelihood, log_trans_prior, leaf_tree, var_tree, split_tree, offset)[source]¶
MCMC trace with trees and diagnostic values.
- class bartz.mcmcloop.Callback(*args, **kwargs)[source]¶
Callback type for
run_mcmc
.- __call__(*, bart, burnin, i_total, i_skip, callback_state, n_burn, n_save, n_skip, i_outer, inner_loop_length)[source]¶
Do an arbitrary action after an iteration of the MCMC.
- Parameters:
bart (
State
) – The MCMC state just after updating it.burnin (
Bool[Array, '']
) – Whether the last iteration was in the burn-in phase.i_total (
Int32[Array, '']
) – The index of the last MCMC iteration (0-based).i_skip (
Int32[Array, '']
) – The number of MCMC updates from the last saved state. The initial state counts as saved, even if it’s not copied into the trace.callback_state (
PyTree[Any, 'T']
) – The callback state, initially set to the argument passed torun_mcmc
, afterwards to the value returned by the last invocation of the callback.n_burn (
Int32[Array, '']
)n_save (
Int32[Array, '']
)n_skip (
Int32[Array, '']
) – The correspondingrun_mcmc
arguments as-is.i_outer (
Int32[Array, '']
) – The index of the last outer loop iteration (0-based).inner_loop_length (
int
) – The number of MCMC iterations in the inner loop.
- Returns:
bart (State) – A possibly modified MCMC state. To avoid modifying the state, return the
bart
argument passed to the callback as-is.callback_state (CallbackState) – The new state to be passed on the next callback invocation.
Notes
For convenience, the callback may return
None
, and the states won’t be updated.
- bartz.mcmcloop.run_mcmc(key, bart, n_save, *, n_burn=0, n_skip=1, inner_loop_length=None, callback=None, callback_state=None, burnin_extractor=BurninTrace.from_state, main_extractor=MainTrace.from_state)[source]¶
Run the MCMC for the BART posterior.
- Parameters:
key (
Key[Array, '']
) – A key for random number generation.bart (
State
) – The initial MCMC state, as created and updated by the functions inbartz.mcmcstep
. The MCMC loop uses buffer donation to avoid copies, so this variable is invalidated after runningrun_mcmc
. Make a copy beforehand to use it again.n_save (
int
) – The number of iterations to save.n_burn (
int
, default:0
) – The number of initial iterations which are not saved.n_skip (
int
, default:1
) – The number of iterations to skip between each saved iteration, plus 1. The effective burn-in isn_burn + n_skip - 1
.inner_loop_length (
int
|None
, default:None
) – The MCMC loop is split into an outer and an inner loop. The outer loop is in Python, while the inner loop is in JAX.inner_loop_length
is the number of iterations of the inner loop to run for each iteration of the outer loop. If not specified, the outer loop will iterate just once, with all iterations done in a single inner loop run. The inner stride is unrelated to the stride used for saving the trace.callback (
Callback
|None
, default:None
) – An arbitrary function run during the loop after updating the state. For the signature, seeCallback
. The callback is called under the jax jit, so the argument values are not available at the time the Python code is executed. Use the utilities injax.debug
to access the values at actual runtime. The callback may return new values for the MCMC state and the callback state.callback_state (
PyTree[Any, 'T']
, default:None
) – The initial custom state for the callback.burnin_extractor (
Callable
[[State
],PyTree
], default:BurninTrace.from_state
)main_extractor (
Callable
[[State
],PyTree
], default:MainTrace.from_state
) – Functions that extract the variables to be saved respectively only in the main trace and in both traces, given the MCMC state as argument. Must return a pytree, and must be vmappable.
- Returns:
bart (State) – The final MCMC state.
burnin_trace (PyTree[Shaped[Array, ‘n_burn *’]]) – The trace of the burn-in phase. For the default layout, see
BurninTrace
.main_trace (PyTree[Shaped[Array, ‘n_save *’]]) – The trace of the main phase. For the default layout, see
MainTrace
.
Notes
The number of MCMC updates is
n_burn + n_skip * n_save
. The traces do not include the initial state, and include the final state.
- bartz.mcmcloop.pytree_at_set(dest, index, val)[source]¶
Map
dest.at[index].set(val)
over pytrees.- Return type:
PyTree
- bartz.mcmcloop.make_print_callback(dot_every=1, report_every=100)[source]¶
Prepare a logging callback for
run_mcmc
.The callback prints a dot on every iteration, and a longer report outer loop iteration.
- Parameters:
dot_every (
int
|Integer[Array, '']
|None
, default:1
) – A dot is printed everydot_every
MCMC iterations,None
to disable.report_every (
int
|Integer[Array, '']
|None
, default:100
) – A one line report is printed everyreport_every
MCMC iterations,None
to disable.
- Returns:
dict
[str
,Any
] – A dictionary with the arguments to pass torun_mcmc
as keyword arguments to set up the callback.
Examples
>>> run_mcmc(..., **make_print_callback())
- class bartz.mcmcloop.TreesTrace(leaf_tree, var_tree, split_tree)[source]¶
Implementation of
bartz.grove.TreeHeaps
for an MCMC trace.- classmethod from_dataclass(obj)[source]¶
Create a
TreesTrace
from anybartz.grove.TreeHeaps
.
- bartz.mcmcloop.evaluate_trace(trace, X)[source]¶
Compute predictions for all iterations of the BART MCMC.