MCMC setup and step¶
Functions that implement the BART posterior MCMC initialization and update step.
Functions that do MCMC steps operate by taking as input a bart state, and outputting a new state. The inputs are not modified.
The main entry points are:
- class bartz.mcmcstep.Forest(leaf_tree, var_tree, split_tree, affluence_tree, max_split, blocked_vars, p_nonterminal, p_propose_grow, leaf_indices, min_points_per_decision_node, min_points_per_leaf, resid_batch_size, count_batch_size, log_trans_prior, log_likelihood, grow_prop_count, prune_prop_count, grow_acc_count, prune_acc_count, sigma_mu2)[source]¶
Represents the MCMC state of a sum of trees.
- Parameters:
leaf_tree (
Float32[Array, 'num_trees 2**d']
) – The leaf values.var_tree (
UInt[Array, 'num_trees 2**(d-1)']
) – The decision axes.split_tree (
UInt[Array, 'num_trees 2**(d-1)']
) – The decision boundaries.affluence_tree (
Bool[Array, 'num_trees 2**(d-1)']
) – Marks leaves that can be grown.p_nonterminal (
Float32[Array, '2**d']
) – The prior probability of each node being nonterminal, conditional on its ancestors. Includes the nodes at maximum depth which should be set to 0.p_propose_grow (
Float32[Array, '2**(d-1)']
) – The unnormalized probability of picking a leaf for a grow proposal.leaf_indices (
UInt[Array, 'num_trees n']
) – The index of the leaf each datapoints falls into, for each tree.min_points_per_decision_node (
Int32[Array, '']
|None
) – The minimum number of data points in a decision node.min_points_per_leaf (
Int32[Array, '']
|None
) – The minimum number of data points in a leaf node.resid_batch_size (
int
|None
)count_batch_size (
int
|None
) – The data batch sizes for computing the sufficient statistics. IfNone
, they are computed with no batching.log_trans_prior (
Float32[Array, 'num_trees']
|None
) – The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.log_likelihood (
Float32[Array, 'num_trees']
|None
) – The log likelihood ratio.grow_prop_count (
Int32[Array, '']
)prune_prop_count (
Int32[Array, '']
) – The number of grow/prune proposals made during one full MCMC cycle.grow_acc_count (
Int32[Array, '']
)prune_acc_count (
Int32[Array, '']
) – The number of grow/prune moves accepted during one full MCMC cycle.sigma_mu2 (
Float32[Array, '']
) – The prior variance of a leaf, conditional on the tree structure.
- class bartz.mcmcstep.State(X, y, z, offset, resid, sigma2, prec_scale, sigma2_alpha, sigma2_beta, forest)[source]¶
Represents the MCMC state of BART.
- Parameters:
X (
UInt[Array, 'p n']
) – The predictors.max_split – The maximum split index for each predictor.
y (
Float32[Array, 'n']
|Bool[Array, 'n']
) – The response. If the data type isbool
, the model is binary regression.resid (
Float32[Array, 'n']
) – The residuals (y
orz
minus sum of trees).z (
None
|Float32[Array, 'n']
) – The latent variable for binary regression.None
in continuous regression.offset (
Float32[Array, '']
) – Constant shift added to the sum of trees.sigma2 (
Float32[Array, '']
|None
) – The error variance.None
in binary regression.prec_scale (
Float32[Array, 'n']
|None
) – The scale on the error precision, i.e.,1 / error_scale ** 2
.None
in binary regression.sigma2_alpha (
Float32[Array, '']
|None
)sigma2_beta (
Float32[Array, '']
|None
) – The shape and scale parameters of the inverse gamma prior on the noise variance.None
in binary regression.forest (
Forest
) – The sum of trees model.
- bartz.mcmcstep.init(*, X, y, offset=0.0, max_split, num_trees, p_nonterminal, sigma_mu2, sigma2_alpha=None, sigma2_beta=None, error_scale=None, min_points_per_decision_node=None, resid_batch_size='auto', count_batch_size='auto', save_ratios=False, filter_splitless_vars=True, min_points_per_leaf=None)[source]¶
Make a BART posterior sampling MCMC initial state.
- Parameters:
X (
UInt[Any, 'p n']
) – The predictors. Note this is trasposed compared to the usual convention.y (
Float32[Any, 'n']
|Bool[Any, 'n']
) – The response. If the data type isbool
, the regression model is binary regression with probit.offset (
float
|Float32[Any, '']
, default:0.0
) – Constant shift added to the sum of trees. 0 if not specified.max_split (
UInt[Any, 'p']
) – The maximum split index for each variable. All split ranges start at 1.num_trees (
int
) – The number of trees in the forest.p_nonterminal (
Float32[Any, 'd-1']
) – The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array.sigma_mu2 (
float
|Float32[Any, '']
) – The prior variance of a leaf, conditional on the tree structure. The prior variance of the sum of trees isnum_trees * sigma_mu2
. The prior mean of leaves is always zero.sigma2_alpha (
float
|Float32[Any, '']
|None
, default:None
)sigma2_beta (
float
|Float32[Any, '']
|None
, default:None
) – The shape and scale parameters of the inverse gamma prior on the error variance. Leave unspecified for binary regression.error_scale (
Float32[Any, 'n']
|None
, default:None
) – Each error is scaled by the corresponding factor inerror_scale
, so the error variance fory[i]
issigma2 * error_scale[i] ** 2
. Not supported for binary regression. If not specified, defaults to 1 for all points, but potentially skipping calculations.min_points_per_decision_node (
int
|Integer[Any, '']
|None
, default:None
) – The minimum number of data points in a decision node. 0 if not specified.resid_batch_size (
int
|None
|Literal
['auto'
], default:'auto'
)count_batch_size (
int
|None
|Literal
['auto'
], default:'auto'
) – The batch sizes, along datapoints, for summing the residuals and counting the number of datapoints in each leaf.None
for no batching. If ‘auto’, pick a value based on the device ofy
, or the default device.save_ratios (
bool
, default:False
) – Whether to save the Metropolis-Hastings ratios.filter_splitless_vars (
bool
, default:True
) – Whether to checkmax_split
for variables without available cutpoints. If any are found, they are put into a list of variables to exclude from the MCMC. IfFalse
, no check is performed, but the results may be wrong if any variable is blocked. The function is jax-traceable only if this is set toFalse
.min_points_per_leaf (
int
|Integer[Any, '']
|None
, default:None
) – The minimum number of datapoints in a leaf node. 0 if not specified. Unlikemin_points_per_decision_node
, this constraint is not taken into account in the Metropolis-Hastings ratio because it would be expensive to compute. Grow moves that would violate this constraint are vetoed. This parameter is independent ofmin_points_per_decision_node
and there is no check that they are coherent. It makes sense to setmin_points_per_decision_node >= 2 * min_points_per_leaf
.
- Returns:
State
– An initialized BART MCMC state.- Raises:
ValueError – If
y
is boolean and arguments unused in binary regression are set.
Notes
In decision nodes, the values in
X[i, :]
are compared to a cutpoint out of the range[1, 2, ..., max_split[i]]
. A point belongs to the left child iffX[i, j] < cutpoint
. Thus it makes sense forX[i, :]
to be integers in the range[0, 1, ..., max_split[i]]
.
- bartz.mcmcstep.step_trees(key, bart)[source]¶
Forest sampling step of BART MCMC.
- Parameters:
- Returns:
State
– The new BART mcmc state.
Notes
This function zeroes the proposal counters.
- class bartz.mcmcstep.Moves(allowed, grow, num_growable, node, left, right, partial_ratio, log_trans_prior_ratio, grow_var, grow_split, var_tree, affluence_tree, logu, acc, to_prune)[source]¶
Moves proposed to modify each tree.
- Parameters:
allowed (
Bool[Array, 'num_trees']
) – Whether there is a possible move. IfFalse
, the other values may not make sense. The only case in which a move is marked as allowed but is then vetoed is if it does not satisfymin_points_per_leaf
, which for efficiency is implemented post-hoc without changing the rest of the MCMC logic.grow (
Bool[Array, 'num_trees']
) – Whether the move is a grow move or a prune move.num_growable (
UInt[Array, 'num_trees']
) – The number of growable leaves in the original tree.node (
UInt[Array, 'num_trees']
) – The index of the leaf to grow or node to prune.left (
UInt[Array, 'num_trees']
)right (
UInt[Array, 'num_trees']
) – The indices of the children of ‘node’.partial_ratio (
Float32[Array, 'num_trees']
|None
) – A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio, the probability of proposing the prune move, and the probability that the children of the modified node are terminal. If the move is PRUNE, the ratio is inverted.None
oncelog_trans_prior_ratio
has been computed.log_trans_prior_ratio (
None
|Float32[Array, 'num_trees']
) – The logarithm of the product of the transition and prior terms of the Metropolis-Hastings ratio for the acceptance of the proposed move.None
if not yet computed. If PRUNE, the log-ratio is negated.grow_var (
UInt[Array, 'num_trees']
) – The decision axes of the new rules.grow_split (
UInt[Array, 'num_trees']
) – The decision boundaries of the new rules.var_tree (
UInt[Array, 'num_trees 2**(d-1)']
) – The updated decision axes of the trees, valid whatever move.affluence_tree (
Bool[Array, 'num_trees 2**(d-1)']
) – A partially updatedaffluence_tree
, marking non-leaf nodes that would become leaves if the move was accepted. This mark initially (out ofpropose_moves
) takes into account if there would be available decision rules to grow the leaf, and whether there are enough datapoints in the node is marked inaccept_moves_parallel_stage
.logu (
Float32[Array, 'num_trees']
) – The logarithm of a uniform (0, 1] random variable to be used to accept the move. It’s in (-oo, 0].acc (
None
|Bool[Array, 'num_trees']
) – Whether the move was accepted.None
if not yet computed.to_prune (
None
|Bool[Array, 'num_trees']
) – Whether the final operation to apply the move is pruning. This indicates an accepted prune move or a rejected grow move.None
if not yet computed.
- bartz.mcmcstep.propose_moves(key, forest)[source]¶
Propose moves for all the trees.
There are two types of moves: GROW (convert a leaf to a decision node and add two leaves beneath it) and PRUNE (convert the parent of two leaves to a leaf, deleting its children).
- class bartz.mcmcstep.GrowMoves(allowed, num_growable, node, var, split, partial_ratio, var_tree, affluence_tree)[source]¶
Represent a proposed grow move for each tree.
- Parameters:
allowed (
Bool[Array, 'num_trees']
) – Whether the move is allowed for proposal.num_growable (
UInt[Array, 'num_trees']
) – The number of leaves that can be proposed for grow.node (
UInt[Array, 'num_trees']
) – The index of the leaf to grow.2 ** d
if there are no growable leaves.var (
UInt[Array, 'num_trees']
)split (
UInt[Array, 'num_trees']
) – The decision axis and boundary of the new rule.partial_ratio (
Float32[Array, 'num_trees']
) – A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio and the probability of proposing the prune move.var_tree (
UInt[Array, 'num_trees 2**(d-1)']
) – The updated decision axes of the tree.affluence_tree (
Bool[Array, 'num_trees 2**(d-1)']
) – A partially updatedaffluence_tree
that marks each new leaf that would be produced asTrue
if it would have available decision rules.
- bartz.mcmcstep.propose_grow_moves(key, var_tree, split_tree, affluence_tree, max_split, blocked_vars, p_nonterminal, p_propose_grow)[source]¶
Propose a GROW move for each tree.
A GROW move picks a leaf node and converts it to a non-terminal node with two leaf children.
- Parameters:
key (
Key[Array, 'num_trees']
) – A jax random key.var_tree (
UInt[Array, 'num_trees 2**(d-1)']
) – The splitting axes of the tree.split_tree (
UInt[Array, 'num_trees 2**(d-1)']
) – The splitting points of the tree.affluence_tree (
Bool[Array, 'num_trees 2**(d-1)']
) – Whether each leaf has enough points to be grown.max_split (
UInt[Array, 'p']
) – The maximum split index for each variable.blocked_vars (
Int32[Array, 'k']
|None
) – The indices of the variables that have no available cutpoints.p_nonterminal (
Float32[Array, '2**d']
) – The a priori probability of a node to be nonterminal conditional on the ancestors, including at the maximum depth where it should be zero.p_propose_grow (
Float32[Array, '2**(d-1)']
) – The unnormalized probability of choosing a leaf to grow.
- Returns:
GrowMoves
– An object representing the proposed move.
Notes
The move is not proposed if each leaf is already at maximum depth, or has less datapoints than the requested threshold
min_points_per_decision_node
, or it does not have any available decision rules given its ancestors. This is marked by settingallowed
toFalse
andnum_growable
to 0.
- bartz.mcmcstep.choose_leaf(key, split_tree, affluence_tree, p_propose_grow)[source]¶
Choose a leaf node to grow in a tree.
- Parameters:
key (
Key[Array, '']
) – A jax random key.split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.affluence_tree (
Bool[Array, '2**(d-1)']
) – Whether a leaf has enough points that it could be split into two leaves satisfying themin_points_per_leaf
requirement.p_propose_grow (
Float32[Array, '2**(d-1)']
) – The unnormalized probability of choosing a leaf to grow.
- Returns:
leaf_to_grow (Int32[Array, ‘’]) – The index of the leaf to grow. If
num_growable == 0
, return2 ** d
.num_growable (Int32[Array, ‘’]) – The number of leaf nodes that can be grown, i.e., are nonterminal and have at least twice
min_points_per_leaf
.prob_choose (Float32[Array, ‘’]) – The (normalized) probability that this function had to choose that specific leaf, given the arguments.
num_prunable (Int32[Array, ‘’]) – The number of leaf parents that could be pruned, after converting the selected leaf to a non-terminal node.
- bartz.mcmcstep.growable_leaves(split_tree, affluence_tree)[source]¶
Return a mask indicating the leaf nodes that can be proposed for growth.
The condition is that a leaf is not at the bottom level, has available decision rules given its ancestors, and has at least
min_points_per_decision_node
points.- Parameters:
split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.affluence_tree (
Bool[Array, '2**(d-1)']
) – Marks leaves that can be grown.
- Returns:
Bool[Array, '2**(d-1)']
– The mask indicating the leaf nodes that can be proposed to grow.
Notes
This function needs
split_tree
and not justaffluence_tree
becauseaffluence_tree
can be “dirty”, i.e., mark unused nodes asTrue
.
- bartz.mcmcstep.categorical(key, distr)[source]¶
Return a random integer from an arbitrary distribution.
- Parameters:
key (
Key[Array, '']
) – A jax random key.distr (
Float32[Array, 'n']
) – An unnormalized probability distribution.
- Returns:
u (Int32[Array, ‘’]) – A random integer in the range
[0, n)
. If all probabilities are zero, returnn
.norm (Float32[Array, ‘’]) – The sum of
distr
.
- bartz.mcmcstep.choose_variable(key, var_tree, split_tree, max_split, leaf_index, blocked_vars)[source]¶
Choose a variable to split on for a new non-terminal node.
- Parameters:
key (
Key[Array, '']
) – A jax random key.var_tree (
UInt[Array, '2**(d-1)']
) – The variable indices of the tree.split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.max_split (
UInt[Array, 'p']
) – The maximum split index for each variable.leaf_index (
Int32[Array, '']
) – The index of the leaf to grow.blocked_vars (
Int32[Array, 'k']
|None
) – The indices of the variables that have no available cutpoints. IfNone
, all variables are assumed unblocked.
- Returns:
var (Int32[Array, ‘’]) – The index of the variable to split on.
num_available_var (Int32[Array, ‘’]) – The number of variables with available decision rules
var
was chosen from.
- bartz.mcmcstep.fully_used_variables(var_tree, split_tree, max_split, leaf_index)[source]¶
Find variables in the ancestors of a node that have an empty split range.
- Parameters:
var_tree (
UInt[Array, '2**(d-1)']
) – The variable indices of the tree.split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.max_split (
UInt[Array, 'p']
) – The maximum split index for each variable.leaf_index (
Int32[Array, '']
) – The index of the node, assumed to be valid forvar_tree
.
- Returns:
UInt[Array, 'd-2']
– The indices of the variables that have an empty split range.
Notes
The number of unused variables is not known in advance. Unused values in the array are filled with
p
. The fill values are not guaranteed to be placed in any particular order, and variables may appear more than once.
- bartz.mcmcstep.ancestor_variables(var_tree, max_split, node_index)[source]¶
Return the list of variables in the ancestors of a node.
- Parameters:
var_tree (
UInt[Array, '2**(d-1)']
) – The variable indices of the tree.max_split (
UInt[Array, 'p']
) – The maximum split index for each variable. Used only to getp
.node_index (
Int32[Array, '']
) – The index of the node, assumed to be valid forvar_tree
.
- Returns:
UInt[Array, 'd-2']
– The variable indices of the ancestors of the node.
Notes
The ancestors are the nodes going from the root to the parent of the node. The number of ancestors is not known at tracing time; unused spots in the output array are filled with
p
.
- bartz.mcmcstep.split_range(var_tree, split_tree, max_split, node_index, ref_var)[source]¶
Return the range of allowed splits for a variable at a given node.
- Parameters:
var_tree (
UInt[Array, '2**(d-1)']
) – The variable indices of the tree.split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.max_split (
UInt[Array, 'p']
) – The maximum split index for each variable.node_index (
Int32[Array, '']
) – The index of the node, assumed to be valid forvar_tree
.ref_var (
Int32[Array, '']
) – The variable for which to measure the split range.
- Returns:
tuple
[Int32[Array, '']
,Int32[Array, '']
] – The range of allowed splits as [l, r). Ifref_var
is out of bounds, l=r=1.
- bartz.mcmcstep.randint_exclude(key, sup, exclude)[source]¶
Return a random integer in a range, excluding some values.
- Parameters:
key (
Key[Array, '']
) – A jax random key.sup (
int
|Integer[Array, '']
) – The exclusive upper bound of the range.exclude (
Integer[Array, 'n']
) – The values to exclude from the range. Values greater than or equal tosup
are ignored. Values can appear more than once.
- Returns:
u (Int32[Array, ‘’]) – A random integer
u
in the range[0, sup)
such thatu not in exclude
.num_allowed (Int32[Array, ‘’]) – The number of integers in the range that were not excluded.
Notes
If all values in the range are excluded, return
sup
.
- bartz.mcmcstep.choose_split(key, var, var_tree, split_tree, max_split, leaf_index)[source]¶
Choose a split point for a new non-terminal node.
- Parameters:
key (
Key[Array, '']
) – A jax random key.var (
Int32[Array, '']
) – The variable to split on.var_tree (
UInt[Array, '2**(d-1)']
) – The splitting axes of the tree. Does not need to already containvar
atleaf_index
.split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.max_split (
UInt[Array, 'p']
) – The maximum split index for each variable.leaf_index (
Int32[Array, '']
) – The index of the leaf to grow.
- Returns:
split (Int32[Array, ‘’]) – The cutpoint.
l (Int32[Array, ‘’])
r (Int32[Array, ‘’]) – The integer range
split
was drawn from is [l, r).
Notes
If
var
is out of bounds, or if the available split range on that variable is empty, return 0.
- bartz.mcmcstep.compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)[source]¶
Compute the product of the transition and prior ratios of a grow move.
- Parameters:
prob_choose (
Float32[Array, '']
) – The probability that the leaf had to be chosen amongst the growable leaves.num_prunable (
Int32[Array, '']
) – The number of leaf parents that could be pruned, after converting the leaf to be grown to a non-terminal node.p_nonterminal (
Float32[Array, '2**d']
) – The a priori probability of each node being nonterminal conditional on its ancestors.leaf_to_grow (
Int32[Array, '']
) – The index of the leaf to grow.
- Returns:
Float32[Array, '']
– The partial transition ratio times the prior ratio.
Notes
The transition ratio is P(new tree => old tree) / P(old tree => new tree). The “partial” transition ratio returned is missing the factor P(propose prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The “partial” prior ratio is missing the factor P(children are leaves).
- class bartz.mcmcstep.PruneMoves(allowed, node, partial_ratio, affluence_tree)[source]¶
Represent a proposed prune move for each tree.
- Parameters:
allowed (
Bool[Array, 'num_trees']
) – Whether the move is possible.node (
UInt[Array, 'num_trees']
) – The index of the node to prune.2 ** d
if no node can be pruned.partial_ratio (
Float32[Array, 'num_trees']
) – A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio, the probability of proposing the prune move, and the prior probability that the children of the node to prune are leaves. This ratio is inverted, and is meant to be inverted back inaccept_move_and_sample_leaves
.
- bartz.mcmcstep.propose_prune_moves(key, split_tree, affluence_tree, p_nonterminal, p_propose_grow)[source]¶
Tree structure prune move proposal of BART MCMC.
- Parameters:
key (
Key[Array, '']
) – A jax random key.split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.affluence_tree (
Bool[Array, '2**(d-1)']
) – Whether each leaf can be grown.p_nonterminal (
Float32[Array, '2**d']
) – The a priori probability of a node to be nonterminal conditional on the ancestors, including at the maximum depth where it should be zero.p_propose_grow (
Float32[Array, '2**(d-1)']
) – The unnormalized probability of choosing a leaf to grow.
- Returns:
PruneMoves
– An object representing the proposed moves.
- bartz.mcmcstep.choose_leaf_parent(key, split_tree, affluence_tree, p_propose_grow)[source]¶
Pick a non-terminal node with leaf children to prune in a tree.
- Parameters:
key (
Key[Array, '']
) – A jax random key.split_tree (
UInt[Array, '2**(d-1)']
) – The splitting points of the tree.affluence_tree (
Bool[Array, '2**(d-1)']
) – Whether a leaf has enough points to be grown.p_propose_grow (
Float32[Array, '2**(d-1)']
) – The unnormalized probability of choosing a leaf to grow.
- Returns:
node_to_prune (Int32[Array, ‘’]) – The index of the node to prune. If
num_prunable == 0
, return2 ** d
.num_prunable (Int32[Array, ‘’]) – The number of leaf parents that could be pruned.
prob_choose (Float32[Array, ‘’]) – The (normalized) probability that
choose_leaf
would chosenode_to_prune
as leaf to grow, if passed the tree wherenode_to_prune
had been pruned.affluence_tree (Bool[Array, ‘num_trees 2**(d-1)’]) – A partially updated
affluence_tree
, marking the node to prune as growable.
- bartz.mcmcstep.randint_masked(key, mask)[source]¶
Return a random integer in a range, including only some values.
- Parameters:
key (
Key[Array, '']
) – A jax random key.mask (
Bool[Array, 'n']
) – The mask indicating the allowed values.
- Returns:
Int32[Array, '']
– A random integer in the range[0, n)
such thatmask[u] == True
.
Notes
If all values in the mask are
False
, returnn
.
- bartz.mcmcstep.accept_moves_and_sample_leaves(key, bart, moves)[source]¶
Accept or reject the proposed moves and sample the new leaf values.
- Parameters:
key (
Key[Array, '']
) – A jax random key.bart (
State
) – A valid BART mcmc state.moves (
Moves
) – The proposed moves, seepropose_moves
.
- Returns:
State
– A new (valid) BART mcmc state.
- class bartz.mcmcstep.Counts(left, right, total)[source]¶
Number of datapoints in the nodes involved in proposed moves for each tree.
- Parameters:
left (
UInt[Array, 'num_trees']
) – Number of datapoints in the left child.right (
UInt[Array, 'num_trees']
) – Number of datapoints in the right child.total (
UInt[Array, 'num_trees']
) – Number of datapoints in the parent (= left + right
).
- class bartz.mcmcstep.Precs(left, right, total)[source]¶
Likelihood precision scale in the nodes involved in proposed moves for each tree.
The “likelihood precision scale” of a tree node is the sum of the inverse squared error scales of the datapoints selected by the node.
- Parameters:
left (
Float32[Array, 'num_trees']
) – Likelihood precision scale in the left child.right (
Float32[Array, 'num_trees']
) – Likelihood precision scale in the right child.total (
Float32[Array, 'num_trees']
) – Likelihood precision scale in the parent (= left + right
).
- class bartz.mcmcstep.PreLkV(sigma2_left, sigma2_right, sigma2_total, sqrt_term)[source]¶
Non-sequential terms of the likelihood ratio for each tree.
These terms can be computed in parallel across trees.
- Parameters:
sigma2_left (
Float32[Array, 'num_trees']
) – The noise variance in the left child of the leaves grown or pruned by the moves.sigma2_right (
Float32[Array, 'num_trees']
) – The noise variance in the right child of the leaves grown or pruned by the moves.sigma2_total (
Float32[Array, 'num_trees']
) – The noise variance in the total of the leaves grown or pruned by the moves.sqrt_term (
Float32[Array, 'num_trees']
) – The logarithm of the square root term of the likelihood ratio.
- class bartz.mcmcstep.PreLk(exp_factor)[source]¶
Non-sequential terms of the likelihood ratio shared by all trees.
- Parameters:
exp_factor (
Float32[Array, '']
) – The factor to multiply the likelihood ratio by, shared by all trees.
- class bartz.mcmcstep.PreLf(mean_factor, centered_leaves)[source]¶
Pre-computed terms used to sample leaves from their posterior.
These terms can be computed in parallel across trees.
- Parameters:
mean_factor (
Float32[Array, 'num_trees 2**d']
) – The factor to be multiplied by the sum of the scaled residuals to obtain the posterior mean.centered_leaves (
Float32[Array, 'num_trees 2**d']
) – The mean-zero normal values to be added to the posterior mean to obtain the posterior leaf samples.
- class bartz.mcmcstep.ParallelStageOut(bart, moves, prec_trees, move_precs, prelkv, prelk, prelf)[source]¶
The output of
accept_moves_parallel_stage
.- Parameters:
bart (
State
) – A partially updated BART mcmc state.moves (
Moves
) – The proposed moves, withpartial_ratio
set toNone
andlog_trans_prior_ratio
set to its final value.prec_trees (
Float32[Array, 'num_trees 2**d']
|Int32[Array, 'num_trees 2**d']
) – The likelihood precision scale in each potential or actual leaf node. If there is no precision scale, this is the number of points in each leaf.move_counts – The counts of the number of points in the the nodes modified by the moves. If
bart.min_points_per_leaf
is not set andbart.prec_scale
is set, they are not computed.move_precs (
Precs
|Counts
) – The likelihood precision scale in each node modified by the moves. Ifbart.prec_scale
is not set, this is set tomove_counts
.prelkv (
PreLkV
)prelk (
PreLk
)prelf (
PreLf
) – Objects with pre-computed terms of the likelihood ratios and leaf samples.
- bartz.mcmcstep.accept_moves_parallel_stage(key, bart, moves)[source]¶
Pre-compute quantities used to accept moves, in parallel across trees.
- Parameters:
key (jax.dtypes.prng_key array) – A jax random key.
bart (dict) – A BART mcmc state.
moves (dict) – The proposed moves, see
propose_moves
.
- Returns:
ParallelStageOut
– An object with all that could be done in parallel.
- bartz.mcmcstep.apply_grow_to_indices(moves, leaf_indices, X)[source]¶
Update the leaf indices to apply a grow move.
- Parameters:
moves (
Moves
) – The proposed moves, seepropose_moves
.leaf_indices (
UInt[Array, 'num_trees n']
) – The index of the leaf each datapoint falls into.X (
UInt[Array, 'p n']
) – The predictors matrix.
- Returns:
UInt[Array, 'num_trees n']
– The updated leaf indices.
- bartz.mcmcstep.compute_count_trees(leaf_indices, moves, batch_size)[source]¶
Count the number of datapoints in each leaf.
- Parameters:
leaf_indices (
UInt[Array, 'num_trees n']
) – The index of the leaf each datapoint falls into, with the deeper version of the tree (post-GROW, pre-PRUNE).moves (
Moves
) – The proposed moves, seepropose_moves
.batch_size (
int
|None
) – The data batch size to use for the summation.
- Returns:
count_trees (Int32[Array, ‘num_trees 2**d’]) – The number of points in each potential or actual leaf node.
counts (Counts) – The counts of the number of points in the leaves grown or pruned by the moves.
- bartz.mcmcstep.count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)[source]¶
Count the number of datapoints in each leaf.
- Parameters:
leaf_indices (
UInt[Array, 'num_trees n']
) – The index of the leaf each datapoint falls into.tree_size (
int
) – The size of the leaf tree array (2 ** d).batch_size (
int
|None
) – The data batch size to use for the summation.
- Returns:
Int32[Array, 'num_trees 2**(d-1)']
– The number of points in each leaf node.
- bartz.mcmcstep.compute_prec_trees(prec_scale, leaf_indices, moves, batch_size)[source]¶
Compute the likelihood precision scale in each leaf.
- Parameters:
prec_scale (
Float32[Array, 'n']
) – The scale of the precision of the error on each datapoint.leaf_indices (
UInt[Array, 'num_trees n']
) – The index of the leaf each datapoint falls into, with the deeper version of the tree (post-GROW, pre-PRUNE).moves (
Moves
) – The proposed moves, seepropose_moves
.batch_size (
int
|None
) – The data batch size to use for the summation.
- Returns:
prec_trees (Float32[Array, ‘num_trees 2**d’]) – The likelihood precision scale in each potential or actual leaf node.
precs (Precs) – The likelihood precision scale in the nodes involved in the moves.
- bartz.mcmcstep.prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size)[source]¶
Compute the likelihood precision scale in each leaf.
- Parameters:
prec_scale (
Float32[Array, 'n']
) – The scale of the precision of the error on each datapoint.leaf_indices (
UInt[Array, 'num_trees n']
) – The index of the leaf each datapoint falls into.tree_size (
int
) – The size of the leaf tree array (2 ** d).batch_size (
int
|None
) – The data batch size to use for the summation.
- Returns:
Float32[Array, 'num_trees {tree_size}']
– The likelihood precision scale in each leaf node.
- bartz.mcmcstep.complete_ratio(moves, p_nonterminal)[source]¶
Complete non-likelihood MH ratio calculation.
This function adds the probability of choosing a prune move over the grow move in the inverse transition, and the a priori probability that the children nodes are leaves.
- Parameters:
moves (
Moves
) – The proposed moves. Must have already been updated to keep into account the thresholds on the number of datapoints per node, this happens inaccept_moves_parallel_stage
.p_nonterminal (
Float32[Array, '2**d']
) – The a priori probability of each node being nonterminal conditional on its ancestors, including at the maximum depth where it should be zero.
- Returns:
Moves
– The updated moves, withpartial_ratio=None
andlog_trans_prior_ratio
set.
- bartz.mcmcstep.adapt_leaf_trees_to_grow_indices(leaf_trees, moves)[source]¶
Modify leaves such that post-grow indices work on the original tree.
The value of the leaf to grow is copied to what would be its children if the grow move was accepted.
- Parameters:
leaf_trees (
Float32[Array, 'num_trees 2**d']
) – The leaf values.moves (
Moves
) – The proposed moves, seepropose_moves
.
- Returns:
Float32[Array, 'num_trees 2**d']
– The modified leaf values.
- bartz.mcmcstep.precompute_likelihood_terms(sigma2, sigma_mu2, move_precs)[source]¶
Pre-compute terms used in the likelihood ratio of the acceptance step.
- Parameters:
sigma2 (
Float32[Array, '']
) – The error variance, or the global error variance factor isprec_scale
is set.sigma_mu2 (
Float32[Array, '']
) – The prior variance of each leaf.move_precs (
Precs
|Counts
) – The likelihood precision scale in the leaves grown or pruned by the moves, under keys ‘left’, ‘right’, and ‘total’ (left + right).
- Returns:
prelkv (PreLkV) – Dictionary with pre-computed terms of the likelihood ratio, one per tree.
prelk (PreLk) – Dictionary with pre-computed terms of the likelihood ratio, shared by all trees.
- bartz.mcmcstep.precompute_leaf_terms(key, prec_trees, sigma2, sigma_mu2)[source]¶
Pre-compute terms used to sample leaves from their posterior.
- Parameters:
key (
Key[Array, '']
) – A jax random key.prec_trees (
Float32[Array, 'num_trees 2**d']
) – The likelihood precision scale in each potential or actual leaf node.sigma2 (
Float32[Array, '']
) – The error variance, or the global error variance factor ifprec_scale
is set.sigma_mu2 (
Float32[Array, '']
) – The prior variance of each leaf.
- Returns:
PreLf
– Pre-computed terms for leaf sampling.
- bartz.mcmcstep.accept_moves_sequential_stage(pso)[source]¶
Accept/reject the moves one tree at a time.
This is the most performance-sensitive function because it contains all and only the parts of the algorithm that can not be parallelized across trees.
- Parameters:
pso (
ParallelStageOut
) – The output ofaccept_moves_parallel_stage
.- Returns:
bart (State) – A partially updated BART mcmc state.
moves (Moves) – The accepted/rejected moves, with
acc
andto_prune
set.
- class bartz.mcmcstep.SeqStageInAllTrees(X, resid_batch_size, prec_scale, save_ratios, prelk)[source]¶
The inputs to
accept_move_and_sample_leaves
that are shared by all trees.- Parameters:
X (
UInt[Array, 'p n']
) – The predictors.resid_batch_size (
int
|None
) – The batch size for computing the sum of residuals in each leaf.prec_scale (
Float32[Array, 'n']
|None
) – The scale of the precision of the error on each datapoint. If None, it is assumed to be 1.save_ratios (
bool
) – Whether to save the acceptance ratios.prelk (
PreLk
) – The pre-computed terms of the likelihood ratio which are shared across trees.
- class bartz.mcmcstep.SeqStageInPerTree(leaf_tree, prec_tree, move, move_precs, leaf_indices, prelkv, prelf)[source]¶
The inputs to
accept_move_and_sample_leaves
that are separate for each tree.- Parameters:
leaf_tree (
Float32[Array, '2**d']
) – The leaf values of the tree.prec_tree (
Float32[Array, '2**d']
) – The likelihood precision scale in each potential or actual leaf node.move (
Moves
) – The proposed move, seepropose_moves
.move_precs (
Precs
|Counts
) – The likelihood precision scale in each node modified by the moves.leaf_indices (
UInt[Array, 'n']
) – The leaf indices for the largest version of the tree compatible with the move.prelkv (
PreLkV
)prelf (
PreLf
) – The pre-computed terms of the likelihood ratio and leaf sampling which are specific to the tree.
- bartz.mcmcstep.accept_move_and_sample_leaves(resid, at, pt)[source]¶
Accept or reject a proposed move and sample the new leaf values.
- Parameters:
resid (
Float32[Array, 'n']
) – The residuals (data minus forest value).at (
SeqStageInAllTrees
) – The inputs that are the same for all trees.pt (
SeqStageInPerTree
) – The inputs that are separate for each tree.
- Returns:
resid (Float32[Array, ‘n’]) – The updated residuals (data minus forest value).
leaf_tree (Float32[Array, ‘2**d’]) – The new leaf values of the tree.
acc (Bool[Array, ‘’]) – Whether the move was accepted.
to_prune (Bool[Array, ‘’]) – Whether, to reflect the acceptance status of the move, the state should be updated by pruning the leaves involved in the move.
log_lk_ratio (Float32[Array, ‘’] | None) – The logarithm of the likelihood ratio for the move.
None
if not to be saved.
- bartz.mcmcstep.sum_resid(scaled_resid, leaf_indices, tree_size, batch_size)[source]¶
Sum the residuals in each leaf.
- Parameters:
scaled_resid (
Float32[Array, 'n']
) – The residuals (data minus forest value) multiplied by the error precision scale.leaf_indices (
UInt[Array, 'n']
) – The leaf indices of the tree (in which leaf each data point falls into).tree_size (
int
) – The size of the tree array (2 ** d).batch_size (
int
|None
) – The data batch size for the aggregation. Batching increases numerical accuracy and parallelism.
- Returns:
Float32[Array, '{tree_size}']
– The sum of the residuals at data points in each leaf.
- bartz.mcmcstep.compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk)[source]¶
Compute the likelihood ratio of a grow move.
- Parameters:
total_resid (
Float32[Array, '']
)left_resid (
Float32[Array, '']
)right_resid (
Float32[Array, '']
) – The sum of the residuals (scaled by error precision scale) of the datapoints falling in the nodes involved in the moves.prelkv (
PreLkV
)prelk (
PreLk
) – The pre-computed terms of the likelihood ratio, seeprecompute_likelihood_terms
.
- Returns:
Float32[Array, '']
– The likelihood ratio P(data | new tree) / P(data | old tree).
- bartz.mcmcstep.accept_moves_final_stage(bart, moves)[source]¶
Post-process the mcmc state after accepting/rejecting the moves.
This function is separate from
accept_moves_sequential_stage
to signal it can work in parallel across trees.- Parameters:
bart (
State
) – A partially updated BART mcmc state.moves (
Moves
) – The proposed moves (seepropose_moves
) as updated byaccept_moves_sequential_stage
.
- Returns:
State
– The fully updated BART mcmc state.
- bartz.mcmcstep.apply_moves_to_leaf_indices(leaf_indices, moves)[source]¶
Update the leaf indices to match the accepted move.
- Parameters:
leaf_indices (
UInt[Array, 'num_trees n']
) – The index of the leaf each datapoint falls into, if the grow move was accepted.moves (
Moves
) – The proposed moves (seepropose_moves
), as updated byaccept_moves_sequential_stage
.
- Returns:
UInt[Array, 'num_trees n']
– The updated leaf indices.
- bartz.mcmcstep.apply_moves_to_split_trees(split_tree, moves)[source]¶
Update the split trees to match the accepted move.
- Parameters:
split_tree (
UInt[Array, 'num_trees 2**(d-1)']
) – The cutpoints of the decision nodes in the initial trees.moves (
Moves
) – The proposed moves (seepropose_moves
), as updated byaccept_moves_sequential_stage
.
- Returns:
UInt[Array, 'num_trees 2**(d-1)']
– The updated split trees.