MCMC setup and step¶
Functions that implement the BART posterior MCMC initialization and update step.
Functions that do MCMC steps operate by taking as input a bart state, and outputting a new dictionary with the new state. The input dict/arrays are not modified.
In general, integer types are chosen to be the minimal types that contain the range of possible values.
- bartz.mcmcstep.init(*, X, y, max_split, num_trees, p_nonterminal, sigma2_alpha, sigma2_beta, small_float=jnp.float32, large_float=jnp.float32, min_points_per_leaf=None, resid_batch_size='auto', count_batch_size='auto', save_ratios=False)[source]¶
Make a BART posterior sampling MCMC initial state.
- Parameters:
- Xint array (p, n)
The predictors. Note this is trasposed compared to the usual convention.
- yfloat array (n,)
The response.
- max_splitint array (p,)
The maximum split index for each variable. All split ranges start at 1.
- num_treesint
The number of trees in the forest.
- p_nonterminalfloat array (d - 1,)
The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array.
- sigma2_alphafloat
The shape parameter of the inverse gamma prior on the noise variance.
- sigma2_betafloat
The scale parameter of the inverse gamma prior on the noise variance.
- small_floatdtype, default float32
The dtype for large arrays used in the algorithm.
- large_floatdtype, default float32
The dtype for scalars, small arrays, and arrays which require accuracy.
- min_points_per_leafint, optional
The minimum number of data points in a leaf node. 0 if not specified.
- resid_batch_size, count_batch_sizesint, None, str, default ‘auto’
The batch sizes, along datapoints, for summing the residuals and counting the number of datapoints in each leaf.
None
for no batching. If ‘auto’, pick a value based on the device ofy
, or the default device.- save_ratiosbool, default False
Whether to save the Metropolis-Hastings ratios.
- Returns:
- bartdict
A dictionary with array values, representing a BART mcmc state. The keys are:
- ‘leaf_trees’small_float array (num_trees, 2 ** d)
The leaf values.
- ‘var_trees’int array (num_trees, 2 ** (d - 1))
The decision axes.
- ‘split_trees’int array (num_trees, 2 ** (d - 1))
The decision boundaries.
- ‘resid’large_float array (n,)
The residuals (data minus forest value). Large float to avoid roundoff.
- ‘sigma2’large_float
The noise variance.
- ‘grow_prop_count’, ‘prune_prop_count’int
The number of grow/prune proposals made during one full MCMC cycle.
- ‘grow_acc_count’, ‘prune_acc_count’int
The number of grow/prune moves accepted during one full MCMC cycle.
- ‘p_nonterminal’large_float array (d,)
The probability of a nonterminal node at each depth, padded with a zero.
- ‘p_propose_grow’large_float array (2 ** (d - 1),)
The unnormalized probability of picking a leaf for a grow proposal.
- ‘sigma2_alpha’large_float
The shape parameter of the inverse gamma prior on the noise variance.
- ‘sigma2_beta’large_float
The scale parameter of the inverse gamma prior on the noise variance.
- ‘max_split’int array (p,)
The maximum split index for each variable.
- ‘y’small_float array (n,)
The response.
- ‘X’int array (p, n)
The predictors.
- ‘leaf_indices’int array (num_trees, n)
The index of the leaf each datapoints falls into, for each tree.
- ‘min_points_per_leaf’int or None
The minimum number of data points in a leaf node.
- ‘affluence_trees’bool array (num_trees, 2 ** (d - 1)) or None
Whether a non-bottom leaf nodes contains twice
min_points_per_leaf
datapoints. Ifmin_points_per_leaf
is not specified, this is None.- ‘opt’LeafDict
A dictionary with config values:
- ‘small_float’dtype
The dtype for large arrays used in the algorithm.
- ‘large_float’dtype
The dtype for scalars, small arrays, and arrays which require accuracy.
- ‘require_min_points’bool
Whether the
min_points_per_leaf
parameter is specified.- ‘resid_batch_size’, ‘count_batch_size’int or None
The data batch sizes for computing the sufficient statistics.
- ‘ratios’dict, optional
If
save_ratios
is True, this field is present. It has the fields:- ‘log_trans_prior’large_float array (num_trees,)
The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.
- ‘log_likelihood’large_float array (num_trees,)
The log likelihood ratio.
- bartz.mcmcstep.step(bart, key)[source]¶
Perform one full MCMC step on a BART state.
- Parameters:
- bartdict
A BART mcmc state, as created by
init
.- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- bartdict
The new BART mcmc state.
- bartz.mcmcstep.sample_trees(bart, key)[source]¶
Forest sampling step of BART MCMC.
- Parameters:
- bartdict
A BART mcmc state, as created by
init
.- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- bartdict
The new BART mcmc state.
Notes
This function zeroes the proposal counters.
- bartz.mcmcstep.sample_moves(bart, key)[source]¶
Propose moves for all the trees.
- Parameters:
- bartdict
BART mcmc state.
- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- movesdict
A dictionary with fields:
- ‘allowed’bool array (num_trees,)
Whether the move is possible.
- ‘grow’bool array (num_trees,)
Whether the move is a grow move or a prune move.
- ‘num_growable’int array (num_trees,)
The number of growable leaves in the original tree.
- ‘node’int array (num_trees,)
The index of the leaf to grow or node to prune.
- ‘left’, ‘right’int array (num_trees,)
The indices of the children of ‘node’.
- ‘partial_ratio’float array (num_trees,)
A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio and the probability of proposing the prune move. If the move is Prune, the ratio is inverted.
- ‘grow_var’int array (num_trees,)
The decision axes of the new rules.
- ‘grow_split’int array (num_trees,)
The decision boundaries of the new rules.
- ‘var_trees’int array (num_trees, 2 ** (d - 1))
The updated decision axes of the trees, valid whatever move.
- ‘logu’float array (num_trees,)
The logarithm of a uniform (0, 1] random variable to be used to accept the move. It’s in (-oo, 0].
- bartz.mcmcstep.grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key)[source]¶
Tree structure grow move proposal of BART MCMC.
This moves picks a leaf node and converts it to a non-terminal node with two leaf children. The move is not possible if all the leaves are already at maximum depth.
- Parameters:
- var_treearray (2 ** (d - 1),)
The variable indices of the tree.
- split_treearray (2 ** (d - 1),)
The splitting points of the tree.
- affluence_treebool array (2 ** (d - 1),) or None
Whether a leaf has enough points to be grown.
- max_splitarray (p,)
The maximum split index for each variable.
- p_nonterminalarray (d,)
The probability of a nonterminal node at each depth.
- p_propose_growarray (2 ** (d - 1),)
The unnormalized probability of choosing a leaf to grow.
- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- grow_movedict
A dictionary with fields:
- ‘num_growable’int
The number of growable leaves.
- ‘node’int
The index of the leaf to grow.
2 ** d
if there are no growable leaves.- ‘var’, ‘split’int
The decision axis and boundary of the new rule.
- ‘partial_ratio’float
A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio and the probability of proposing the prune move.
- ‘var_tree’array (2 ** (d - 1),)
The updated decision axes of the tree.
- bartz.mcmcstep.choose_leaf(split_tree, affluence_tree, p_propose_grow, key)[source]¶
Choose a leaf node to grow in a tree.
- Parameters:
- split_treearray (2 ** (d - 1),)
The splitting points of the tree.
- affluence_treebool array (2 ** (d - 1),) or None
Whether a leaf has enough points to be grown.
- p_propose_growarray (2 ** (d - 1),)
The unnormalized probability of choosing a leaf to grow.
- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- leaf_to_growint
The index of the leaf to grow. If
num_growable == 0
, return2 ** d
.- num_growableint
The number of leaf nodes that can be grown.
- prob_choosefloat
The normalized probability of choosing the selected leaf.
- num_prunableint
The number of leaf parents that could be pruned, after converting the selected leaf to a non-terminal node.
- bartz.mcmcstep.growable_leaves(split_tree, affluence_tree)[source]¶
Return a mask indicating the leaf nodes that can be proposed for growth.
- Parameters:
- split_treearray (2 ** (d - 1),)
The splitting points of the tree.
- affluence_treebool array (2 ** (d - 1),) or None
Whether a leaf has enough points to be grown.
- Returns:
- is_growablebool array (2 ** (d - 1),)
The mask indicating the leaf nodes that can be proposed to grow, i.e., that are not at the bottom level and have at least two times the number of minimum points per leaf.
- bartz.mcmcstep.categorical(key, distr)[source]¶
Return a random integer from an arbitrary distribution.
- Parameters:
- keyjax.dtypes.prng_key array
A jax random key.
- distrfloat array (n,)
An unnormalized probability distribution.
- Returns:
- uint
A random integer in the range
[0, n)
. If all probabilities are zero, returnn
.
- bartz.mcmcstep.choose_variable(var_tree, split_tree, max_split, leaf_index, key)[source]¶
Choose a variable to split on for a new non-terminal node.
- Parameters:
- var_treeint array (2 ** (d - 1),)
The variable indices of the tree.
- split_treeint array (2 ** (d - 1),)
The splitting points of the tree.
- max_splitint array (p,)
The maximum split index for each variable.
- leaf_indexint
The index of the leaf to grow.
- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- varint
The index of the variable to split on.
Notes
The variable is chosen among the variables that have a non-empty range of allowed splits. If no variable has a non-empty range, return
p
.
- bartz.mcmcstep.fully_used_variables(var_tree, split_tree, max_split, leaf_index)[source]¶
Return a list of variables that have an empty split range at a given node.
- Parameters:
- var_treeint array (2 ** (d - 1),)
The variable indices of the tree.
- split_treeint array (2 ** (d - 1),)
The splitting points of the tree.
- max_splitint array (p,)
The maximum split index for each variable.
- leaf_indexint
The index of the node, assumed to be valid for
var_tree
.
- Returns:
- var_to_ignoreint array (d - 2,)
The indices of the variables that have an empty split range. Since the number of such variables is not fixed, unused values in the array are filled with
p
. The fill values are not guaranteed to be placed in any particular order. Variables may appear more than once.
- bartz.mcmcstep.ancestor_variables(var_tree, max_split, node_index)[source]¶
Return the list of variables in the ancestors of a node.
- Parameters:
- var_treeint array (2 ** (d - 1),)
The variable indices of the tree.
- max_splitint array (p,)
The maximum split index for each variable. Used only to get
p
.- node_indexint
The index of the node, assumed to be valid for
var_tree
.
- Returns:
- ancestor_varsint array (d - 2,)
The variable indices of the ancestors of the node, from the root to the parent. Unused spots are filled with
p
.
- bartz.mcmcstep.split_range(var_tree, split_tree, max_split, node_index, ref_var)[source]¶
Return the range of allowed splits for a variable at a given node.
- Parameters:
- var_treeint array (2 ** (d - 1),)
The variable indices of the tree.
- split_treeint array (2 ** (d - 1),)
The splitting points of the tree.
- max_splitint array (p,)
The maximum split index for each variable.
- node_indexint
The index of the node, assumed to be valid for
var_tree
.- ref_varint
The variable for which to measure the split range.
- Returns:
- l, rint
The range of allowed splits is [l, r).
- bartz.mcmcstep.randint_exclude(key, sup, exclude)[source]¶
Return a random integer in a range, excluding some values.
- Parameters:
- keyjax.dtypes.prng_key array
A jax random key.
- supint
The exclusive upper bound of the range.
- excludeint array (n,)
The values to exclude from the range. Values greater than or equal to
sup
are ignored. Values can appear more than once.
- Returns:
- uint
A random integer in the range
[0, sup)
, and which satisfiesu not in exclude
. If all values in the range are excluded, returnsup
.
- bartz.mcmcstep.choose_split(var_tree, split_tree, max_split, leaf_index, key)[source]¶
Choose a split point for a new non-terminal node.
- Parameters:
- var_treeint array (2 ** (d - 1),)
The variable indices of the tree.
- split_treeint array (2 ** (d - 1),)
The splitting points of the tree.
- max_splitint array (p,)
The maximum split index for each variable.
- leaf_indexint
The index of the leaf to grow. It is assumed that
var_tree
already contains the target variable at this index.- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- splitint
The split point.
- bartz.mcmcstep.compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)[source]¶
Compute the product of the transition and prior ratios of a grow move.
- Parameters:
- num_growableint
The number of leaf nodes that can be grown.
- num_prunableint
The number of leaf parents that could be pruned, after converting the leaf to be grown to a non-terminal node.
- p_nonterminalarray (d,)
The probability of a nonterminal node at each depth.
- leaf_to_growint
The index of the leaf to grow.
- Returns:
- ratiofloat
The transition ratio P(new tree -> old tree) / P(old tree -> new tree) times the prior ratio P(new tree) / P(old tree), but the transition ratio is missing the factor P(propose prune) in the numerator.
- bartz.mcmcstep.prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key)[source]¶
Tree structure prune move proposal of BART MCMC.
- Parameters:
- var_treeint array (2 ** (d - 1),)
The variable indices of the tree.
- split_treeint array (2 ** (d - 1),)
The splitting points of the tree.
- affluence_treebool array (2 ** (d - 1),) or None
Whether a leaf has enough points to be grown.
- max_splitint array (p,)
The maximum split index for each variable.
- p_nonterminalfloat array (d,)
The probability of a nonterminal node at each depth.
- p_propose_growfloat array (2 ** (d - 1),)
The unnormalized probability of choosing a leaf to grow.
- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- prune_movedict
A dictionary with fields:
- ‘allowed’bool
Whether the move is possible.
- ‘node’int
The index of the node to prune.
2 ** d
if no node can be pruned.- ‘partial_ratio’float
A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio and the probability of proposing the prune move. This ratio is inverted.
- bartz.mcmcstep.choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)[source]¶
Pick a non-terminal node with leaf children to prune in a tree.
- Parameters:
- split_treearray (2 ** (d - 1),)
The splitting points of the tree.
- affluence_treebool array (2 ** (d - 1),) or None
Whether a leaf has enough points to be grown.
- p_propose_growarray (2 ** (d - 1),)
The unnormalized probability of choosing a leaf to grow.
- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- node_to_pruneint
The index of the node to prune. If
num_prunable == 0
, return2 ** d
.- num_prunableint
The number of leaf parents that could be pruned.
- prob_choosefloat
The normalized probability of choosing the node to prune for growth.
- bartz.mcmcstep.randint_masked(key, mask)[source]¶
Return a random integer in a range, including only some values.
- Parameters:
- keyjax.dtypes.prng_key array
A jax random key.
- maskbool array (n,)
The mask indicating the allowed values.
- Returns:
- uint
A random integer in the range
[0, n)
, and which satisfiesmask[u] == True
. If all values in the mask areFalse
, returnn
.
- bartz.mcmcstep.accept_moves_and_sample_leaves(bart, moves, key)[source]¶
Accept or reject the proposed moves and sample the new leaf values.
- Parameters:
- bartdict
A BART mcmc state.
- movesdict
The proposed moves, see
sample_moves
.- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- bartdict
The new BART mcmc state.
- bartz.mcmcstep.accept_moves_parallel_stage(bart, moves, key)[source]¶
Pre-computes quantities used to accept moves, in parallel across trees.
- Parameters:
- bartdict
A BART mcmc state.
- movesdict
The proposed moves, see
sample_moves
.- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- bartdict
A partially updated BART mcmc state.
- movesdict
The proposed moves, with the field ‘partial_ratio’ replaced by ‘log_trans_prior_ratio’.
- count_treesarray (num_trees, 2 ** d)
The number of points in each potential or actual leaf node.
- move_countsdict
The counts of the number of points in the the nodes modified by the moves.
- prelkv, prelk, prelfdict
Dictionary with pre-computed terms of the likelihood ratios and leaf samples.
- bartz.mcmcstep.apply_grow_to_indices(moves, leaf_indices, X)[source]¶
Update the leaf indices to apply a grow move.
- Parameters:
- movesdict
The proposed moves, see
sample_moves
.- leaf_indicesarray (num_trees, n)
The index of the leaf each datapoint falls into.
- Xarray (p, n)
The predictors matrix.
- Returns:
- grow_leaf_indicesarray (num_trees, n)
The updated leaf indices.
- bartz.mcmcstep.compute_count_trees(leaf_indices, moves, batch_size)[source]¶
Count the number of datapoints in each leaf.
- Parameters:
- grow_leaf_indicesint array (num_trees, n)
The index of the leaf each datapoint falls into, if the grow move is accepted.
- movesdict
The proposed moves, see
sample_moves
.- batch_sizeint or None
The data batch size to use for the summation.
- Returns:
- count_treesint array (num_trees, 2 ** (d - 1))
The number of points in each potential or actual leaf node.
- countsdict
The counts of the number of points in the the nodes modified by the moves, organized as two dictionaries ‘grow’ and ‘prune’, with subfields ‘left’, ‘right’, and ‘total’.
- bartz.mcmcstep.count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)[source]¶
Count the number of datapoints in each leaf.
- Parameters:
- leaf_indicesint array (num_trees, n)
The index of the leaf each datapoint falls into.
- tree_sizeint
The size of the leaf tree array (2 ** d).
- batch_sizeint or None
The data batch size to use for the summation.
- Returns:
- count_treesint array (num_trees, 2 ** (d - 1))
The number of points in each leaf node.
- bartz.mcmcstep.complete_ratio(moves, move_counts, min_points_per_leaf)[source]¶
Complete non-likelihood MH ratio calculation.
This functions adds the probability of choosing the prune move.
- Parameters:
- movesdict
The proposed moves, see
sample_moves
.- move_countsdict
The counts of the number of points in the the nodes modified by the moves.
- min_points_per_leafint or None
The minimum number of data points in a leaf node.
- Returns:
- movesdict
The updated moves, with the field ‘partial_ratio’ replaced by ‘log_trans_prior_ratio’.
- bartz.mcmcstep.compute_p_prune(moves, left_count, right_count, min_points_per_leaf)[source]¶
Compute the probability of proposing a prune move.
- Parameters:
- movesdict
The proposed moves, see
sample_moves
.- left_count, right_countint
The number of datapoints in the proposed children of the leaf to grow.
- min_points_per_leafint or None
The minimum number of data points in a leaf node.
- Returns:
- p_prunefloat
The probability of proposing a prune move. If grow: after accepting the grow move, if prune: right away.
- bartz.mcmcstep.adapt_leaf_trees_to_grow_indices(leaf_trees, moves)[source]¶
Modify leaf values such that the indices of the grow moves work on the original tree.
- Parameters:
- leaf_treesfloat array (num_trees, 2 ** d)
The leaf values.
- movesdict
The proposed moves, see
sample_moves
.
- Returns:
- leaf_treesfloat array (num_trees, 2 ** d)
The modified leaf values. The value of the leaf to grow is copied to what would be its children if the grow move was accepted.
- bartz.mcmcstep.precompute_likelihood_terms(count_trees, sigma2, move_counts)[source]¶
Pre-compute terms used in the likelihood ratio of the acceptance step.
- Parameters:
- count_treesarray (num_trees, 2 ** d)
The number of points in each potential or actual leaf node.
- sigma2float
The noise variance.
- move_countsdict
The counts of the number of points in the the nodes modified by the moves.
- Returns:
- prelkvdict
Dictionary with pre-computed terms of the likelihood ratio, one per tree.
- prelkdict
Dictionary with pre-computed terms of the likelihood ratio, shared by all trees.
- bartz.mcmcstep.precompute_leaf_terms(count_trees, sigma2, key)[source]¶
Pre-compute terms used to sample leaves from their posterior.
- Parameters:
- count_treesarray (num_trees, 2 ** d)
The number of points in each potential or actual leaf node.
- sigma2float
The noise variance.
- keyjax.dtypes.prng_key array
A jax random key.
- Returns:
- prelfdict
Dictionary with pre-computed terms of the leaf sampling, with fields:
- ‘mean_factor’float array (num_trees, 2 ** d)
The factor to be multiplied by the sum of residuals to obtain the posterior mean.
- ‘centered_leaves’float array (num_trees, 2 ** d)
The mean-zero normal values to be added to the posterior mean to obtain the posterior leaf samples.
- bartz.mcmcstep.accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf)[source]¶
The part of accepting the moves that has to be done one tree at a time.
- Parameters:
- bartdict
A partially updated BART mcmc state.
- count_treesarray (num_trees, 2 ** d)
The number of points in each potential or actual leaf node.
- movesdict
The proposed moves, see
sample_moves
.- move_countsdict
The counts of the number of points in the the nodes modified by the moves.
- prelkv, prelk, prelfdict
Dictionaries with pre-computed terms of the likelihood ratios and leaf samples.
- Returns:
- bartdict
A partially updated BART mcmc state.
- movesdict
The proposed moves, with these additional fields:
- ‘acc’bool array (num_trees,)
Whether the move was accepted.
- ‘to_prune’bool array (num_trees,)
Whether, to reflect the acceptance status of the move, the state should be updated by pruning the leaves involved in the move.
- bartz.mcmcstep.accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf)[source]¶
Accept or reject a proposed move and sample the new leaf values.
- Parameters:
- Xint array (p, n)
The predictors.
- ntreeint
The number of trees in the forest.
- resid_batch_sizeint, None
The batch size for computing the sum of residuals in each leaf.
- residfloat array (n,)
The residuals (data minus forest value).
- min_points_per_leafint or None
The minimum number of data points in a leaf node.
- save_ratiosbool
Whether to save the acceptance ratios.
- prelkdict
The pre-computed terms of the likelihood ratio which are shared across trees.
- leaf_treefloat array (2 ** d,)
The leaf values of the tree.
- count_treeint array (2 ** d,)
The number of datapoints in each leaf.
- movedict
The proposed move, see
sample_moves
.- leaf_indicesint array (n,)
The leaf indices for the largest version of the tree compatible with the move.
- prelkv, prelfdict
The pre-computed terms of the likelihood ratio and leaf sampling which are specific to the tree.
- Returns:
- residfloat array (n,)
The updated residuals (data minus forest value).
- leaf_treefloat array (2 ** d,)
The new leaf values of the tree.
- accbool
Whether the move was accepted.
- to_prunebool
Whether, to reflect the acceptance status of the move, the state should be updated by pruning the leaves involved in the move.
- ratiosdict
The acceptance ratios for the moves. Empty if not to be saved.
- bartz.mcmcstep.sum_resid(resid, leaf_indices, tree_size, batch_size)[source]¶
Sum the residuals in each leaf.
- Parameters:
- residfloat array (n,)
The residuals (data minus forest value).
- leaf_indicesint array (n,)
The leaf indices of the tree (in which leaf each data point falls into).
- tree_sizeint
The size of the tree array (2 ** d).
- batch_sizeint, None
The data batch size for the aggregation. Batching increases numerical accuracy and parallelism.
- Returns:
- resid_treefloat array (2 ** d,)
The sum of the residuals at data points in each leaf.
- bartz.mcmcstep.compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk)[source]¶
Compute the likelihood ratio of a grow move.
- Parameters:
- total_residfloat
The sum of the residuals in the leaf to grow.
- left_resid, right_residfloat
The sum of the residuals in the left/right child of the leaf to grow.
- prelkv, prelkdict
The pre-computed terms of the likelihood ratio, see
precompute_likelihood_terms
.
- Returns:
- ratiofloat
The likelihood ratio P(data | new tree) / P(data | old tree).
- bartz.mcmcstep.accept_moves_final_stage(bart, moves)[source]¶
The final part of accepting the moves, in parallel across trees.
- Parameters:
- bartdict
A partially updated BART mcmc state.
- countsdict
The indicators of proposals and acceptances for grow and prune moves.
- movesdict
The proposed moves (see
sample_moves
) as updated byaccept_moves_sequential_stage
.
- Returns:
- bartdict
The fully updated BART mcmc state.
- bartz.mcmcstep.apply_moves_to_leaf_indices(leaf_indices, moves)[source]¶
Vectorized version of apply_moves_to_leaf_indices. Takes similar arguments as apply_moves_to_leaf_indices but with additional array axes over which apply_moves_to_leaf_indices is mapped.
Original documentation:
Update the leaf indices to match the accepted move.
- Parameters:
- leaf_indicesint array (num_trees, n)
The index of the leaf each datapoint falls into, if the grow move was accepted.
- movesdict
The proposed moves (see
sample_moves
), as updated byaccept_moves_sequential_stage
.
- Returns:
- leaf_indicesint array (num_trees, n)
The updated leaf indices.
- bartz.mcmcstep.apply_moves_to_split_trees(split_trees, moves)[source]¶
Vectorized version of apply_moves_to_split_trees. Takes similar arguments as apply_moves_to_split_trees but with additional array axes over which apply_moves_to_split_trees is mapped.
Original documentation:
Update the split trees to match the accepted move.
- Parameters:
- split_treesint array (num_trees, 2 ** (d - 1))
The cutpoints of the decision nodes in the initial trees.
- movesdict
The proposed moves (see
sample_moves
), as updated byaccept_moves_sequential_stage
.
- Returns:
- split_treesint array (num_trees, 2 ** (d - 1))
The updated split trees.