MCMC setup and step

Functions that implement the BART posterior MCMC initialization and update step.

Functions that do MCMC steps operate by taking as input a bart state, and outputting a new dictionary with the new state. The input dict/arrays are not modified.

In general, integer types are chosen to be the minimal types that contain the range of possible values.

bartz.mcmcstep.init(*, X, y, max_split, num_trees, p_nonterminal, sigma2_alpha, sigma2_beta, small_float=jnp.float32, large_float=jnp.float32, min_points_per_leaf=None, resid_batch_size='auto', count_batch_size='auto', save_ratios=False)[source]

Make a BART posterior sampling MCMC initial state.

Parameters:
Xint array (p, n)

The predictors. Note this is trasposed compared to the usual convention.

yfloat array (n,)

The response.

max_splitint array (p,)

The maximum split index for each variable. All split ranges start at 1.

num_treesint

The number of trees in the forest.

p_nonterminalfloat array (d - 1,)

The probability of a nonterminal node at each depth. The maximum depth of trees is fixed by the length of this array.

sigma2_alphafloat

The shape parameter of the inverse gamma prior on the noise variance.

sigma2_betafloat

The scale parameter of the inverse gamma prior on the noise variance.

small_floatdtype, default float32

The dtype for large arrays used in the algorithm.

large_floatdtype, default float32

The dtype for scalars, small arrays, and arrays which require accuracy.

min_points_per_leafint, optional

The minimum number of data points in a leaf node. 0 if not specified.

resid_batch_size, count_batch_sizesint, None, str, default ‘auto’

The batch sizes, along datapoints, for summing the residuals and counting the number of datapoints in each leaf. None for no batching. If ‘auto’, pick a value based on the device of y, or the default device.

save_ratiosbool, default False

Whether to save the Metropolis-Hastings ratios.

Returns:
bartdict

A dictionary with array values, representing a BART mcmc state. The keys are:

‘leaf_trees’small_float array (num_trees, 2 ** d)

The leaf values.

‘var_trees’int array (num_trees, 2 ** (d - 1))

The decision axes.

‘split_trees’int array (num_trees, 2 ** (d - 1))

The decision boundaries.

‘resid’large_float array (n,)

The residuals (data minus forest value). Large float to avoid roundoff.

‘sigma2’large_float

The noise variance.

‘grow_prop_count’, ‘prune_prop_count’int

The number of grow/prune proposals made during one full MCMC cycle.

‘grow_acc_count’, ‘prune_acc_count’int

The number of grow/prune moves accepted during one full MCMC cycle.

‘p_nonterminal’large_float array (d,)

The probability of a nonterminal node at each depth, padded with a zero.

‘p_propose_grow’large_float array (2 ** (d - 1),)

The unnormalized probability of picking a leaf for a grow proposal.

‘sigma2_alpha’large_float

The shape parameter of the inverse gamma prior on the noise variance.

‘sigma2_beta’large_float

The scale parameter of the inverse gamma prior on the noise variance.

‘max_split’int array (p,)

The maximum split index for each variable.

‘y’small_float array (n,)

The response.

‘X’int array (p, n)

The predictors.

‘leaf_indices’int array (num_trees, n)

The index of the leaf each datapoints falls into, for each tree.

‘min_points_per_leaf’int or None

The minimum number of data points in a leaf node.

‘affluence_trees’bool array (num_trees, 2 ** (d - 1)) or None

Whether a non-bottom leaf nodes contains twice min_points_per_leaf datapoints. If min_points_per_leaf is not specified, this is None.

‘opt’LeafDict

A dictionary with config values:

‘small_float’dtype

The dtype for large arrays used in the algorithm.

‘large_float’dtype

The dtype for scalars, small arrays, and arrays which require accuracy.

‘require_min_points’bool

Whether the min_points_per_leaf parameter is specified.

‘resid_batch_size’, ‘count_batch_size’int or None

The data batch sizes for computing the sufficient statistics.

‘ratios’dict, optional

If save_ratios is True, this field is present. It has the fields:

‘log_trans_prior’large_float array (num_trees,)

The log transition and prior Metropolis-Hastings ratio for the proposed move on each tree.

‘log_likelihood’large_float array (num_trees,)

The log likelihood ratio.

bartz.mcmcstep.step(bart, key)[source]

Perform one full MCMC step on a BART state.

Parameters:
bartdict

A BART mcmc state, as created by init.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
bartdict

The new BART mcmc state.

bartz.mcmcstep.sample_trees(bart, key)[source]

Forest sampling step of BART MCMC.

Parameters:
bartdict

A BART mcmc state, as created by init.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
bartdict

The new BART mcmc state.

Notes

This function zeroes the proposal counters.

bartz.mcmcstep.sample_moves(bart, key)[source]

Propose moves for all the trees.

Parameters:
bartdict

BART mcmc state.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
movesdict

A dictionary with fields:

‘allowed’bool array (num_trees,)

Whether the move is possible.

‘grow’bool array (num_trees,)

Whether the move is a grow move or a prune move.

‘num_growable’int array (num_trees,)

The number of growable leaves in the original tree.

‘node’int array (num_trees,)

The index of the leaf to grow or node to prune.

‘left’, ‘right’int array (num_trees,)

The indices of the children of ‘node’.

‘partial_ratio’float array (num_trees,)

A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio and the probability of proposing the prune move. If the move is Prune, the ratio is inverted.

‘grow_var’int array (num_trees,)

The decision axes of the new rules.

‘grow_split’int array (num_trees,)

The decision boundaries of the new rules.

‘var_trees’int array (num_trees, 2 ** (d - 1))

The updated decision axes of the trees, valid whatever move.

‘logu’float array (num_trees,)

The logarithm of a uniform (0, 1] random variable to be used to accept the move. It’s in (-oo, 0].

bartz.mcmcstep.grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key)[source]

Tree structure grow move proposal of BART MCMC.

This moves picks a leaf node and converts it to a non-terminal node with two leaf children. The move is not possible if all the leaves are already at maximum depth.

Parameters:
var_treearray (2 ** (d - 1),)

The variable indices of the tree.

split_treearray (2 ** (d - 1),)

The splitting points of the tree.

affluence_treebool array (2 ** (d - 1),) or None

Whether a leaf has enough points to be grown.

max_splitarray (p,)

The maximum split index for each variable.

p_nonterminalarray (d,)

The probability of a nonterminal node at each depth.

p_propose_growarray (2 ** (d - 1),)

The unnormalized probability of choosing a leaf to grow.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
grow_movedict

A dictionary with fields:

‘num_growable’int

The number of growable leaves.

‘node’int

The index of the leaf to grow. 2 ** d if there are no growable leaves.

‘var’, ‘split’int

The decision axis and boundary of the new rule.

‘partial_ratio’float

A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio and the probability of proposing the prune move.

‘var_tree’array (2 ** (d - 1),)

The updated decision axes of the tree.

bartz.mcmcstep.choose_leaf(split_tree, affluence_tree, p_propose_grow, key)[source]

Choose a leaf node to grow in a tree.

Parameters:
split_treearray (2 ** (d - 1),)

The splitting points of the tree.

affluence_treebool array (2 ** (d - 1),) or None

Whether a leaf has enough points to be grown.

p_propose_growarray (2 ** (d - 1),)

The unnormalized probability of choosing a leaf to grow.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
leaf_to_growint

The index of the leaf to grow. If num_growable == 0, return 2 ** d.

num_growableint

The number of leaf nodes that can be grown.

prob_choosefloat

The normalized probability of choosing the selected leaf.

num_prunableint

The number of leaf parents that could be pruned, after converting the selected leaf to a non-terminal node.

bartz.mcmcstep.growable_leaves(split_tree, affluence_tree)[source]

Return a mask indicating the leaf nodes that can be proposed for growth.

Parameters:
split_treearray (2 ** (d - 1),)

The splitting points of the tree.

affluence_treebool array (2 ** (d - 1),) or None

Whether a leaf has enough points to be grown.

Returns:
is_growablebool array (2 ** (d - 1),)

The mask indicating the leaf nodes that can be proposed to grow, i.e., that are not at the bottom level and have at least two times the number of minimum points per leaf.

bartz.mcmcstep.categorical(key, distr)[source]

Return a random integer from an arbitrary distribution.

Parameters:
keyjax.dtypes.prng_key array

A jax random key.

distrfloat array (n,)

An unnormalized probability distribution.

Returns:
uint

A random integer in the range [0, n). If all probabilities are zero, return n.

bartz.mcmcstep.choose_variable(var_tree, split_tree, max_split, leaf_index, key)[source]

Choose a variable to split on for a new non-terminal node.

Parameters:
var_treeint array (2 ** (d - 1),)

The variable indices of the tree.

split_treeint array (2 ** (d - 1),)

The splitting points of the tree.

max_splitint array (p,)

The maximum split index for each variable.

leaf_indexint

The index of the leaf to grow.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
varint

The index of the variable to split on.

Notes

The variable is chosen among the variables that have a non-empty range of allowed splits. If no variable has a non-empty range, return p.

bartz.mcmcstep.fully_used_variables(var_tree, split_tree, max_split, leaf_index)[source]

Return a list of variables that have an empty split range at a given node.

Parameters:
var_treeint array (2 ** (d - 1),)

The variable indices of the tree.

split_treeint array (2 ** (d - 1),)

The splitting points of the tree.

max_splitint array (p,)

The maximum split index for each variable.

leaf_indexint

The index of the node, assumed to be valid for var_tree.

Returns:
var_to_ignoreint array (d - 2,)

The indices of the variables that have an empty split range. Since the number of such variables is not fixed, unused values in the array are filled with p. The fill values are not guaranteed to be placed in any particular order. Variables may appear more than once.

bartz.mcmcstep.ancestor_variables(var_tree, max_split, node_index)[source]

Return the list of variables in the ancestors of a node.

Parameters:
var_treeint array (2 ** (d - 1),)

The variable indices of the tree.

max_splitint array (p,)

The maximum split index for each variable. Used only to get p.

node_indexint

The index of the node, assumed to be valid for var_tree.

Returns:
ancestor_varsint array (d - 2,)

The variable indices of the ancestors of the node, from the root to the parent. Unused spots are filled with p.

bartz.mcmcstep.split_range(var_tree, split_tree, max_split, node_index, ref_var)[source]

Return the range of allowed splits for a variable at a given node.

Parameters:
var_treeint array (2 ** (d - 1),)

The variable indices of the tree.

split_treeint array (2 ** (d - 1),)

The splitting points of the tree.

max_splitint array (p,)

The maximum split index for each variable.

node_indexint

The index of the node, assumed to be valid for var_tree.

ref_varint

The variable for which to measure the split range.

Returns:
l, rint

The range of allowed splits is [l, r).

bartz.mcmcstep.randint_exclude(key, sup, exclude)[source]

Return a random integer in a range, excluding some values.

Parameters:
keyjax.dtypes.prng_key array

A jax random key.

supint

The exclusive upper bound of the range.

excludeint array (n,)

The values to exclude from the range. Values greater than or equal to sup are ignored. Values can appear more than once.

Returns:
uint

A random integer in the range [0, sup), and which satisfies u not in exclude. If all values in the range are excluded, return sup.

bartz.mcmcstep.choose_split(var_tree, split_tree, max_split, leaf_index, key)[source]

Choose a split point for a new non-terminal node.

Parameters:
var_treeint array (2 ** (d - 1),)

The variable indices of the tree.

split_treeint array (2 ** (d - 1),)

The splitting points of the tree.

max_splitint array (p,)

The maximum split index for each variable.

leaf_indexint

The index of the leaf to grow. It is assumed that var_tree already contains the target variable at this index.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
splitint

The split point.

bartz.mcmcstep.compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow)[source]

Compute the product of the transition and prior ratios of a grow move.

Parameters:
num_growableint

The number of leaf nodes that can be grown.

num_prunableint

The number of leaf parents that could be pruned, after converting the leaf to be grown to a non-terminal node.

p_nonterminalarray (d,)

The probability of a nonterminal node at each depth.

leaf_to_growint

The index of the leaf to grow.

Returns:
ratiofloat

The transition ratio P(new tree -> old tree) / P(old tree -> new tree) times the prior ratio P(new tree) / P(old tree), but the transition ratio is missing the factor P(propose prune) in the numerator.

bartz.mcmcstep.prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key)[source]

Tree structure prune move proposal of BART MCMC.

Parameters:
var_treeint array (2 ** (d - 1),)

The variable indices of the tree.

split_treeint array (2 ** (d - 1),)

The splitting points of the tree.

affluence_treebool array (2 ** (d - 1),) or None

Whether a leaf has enough points to be grown.

max_splitint array (p,)

The maximum split index for each variable.

p_nonterminalfloat array (d,)

The probability of a nonterminal node at each depth.

p_propose_growfloat array (2 ** (d - 1),)

The unnormalized probability of choosing a leaf to grow.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
prune_movedict

A dictionary with fields:

‘allowed’bool

Whether the move is possible.

‘node’int

The index of the node to prune. 2 ** d if no node can be pruned.

‘partial_ratio’float

A factor of the Metropolis-Hastings ratio of the move. It lacks the likelihood ratio and the probability of proposing the prune move. This ratio is inverted.

bartz.mcmcstep.choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key)[source]

Pick a non-terminal node with leaf children to prune in a tree.

Parameters:
split_treearray (2 ** (d - 1),)

The splitting points of the tree.

affluence_treebool array (2 ** (d - 1),) or None

Whether a leaf has enough points to be grown.

p_propose_growarray (2 ** (d - 1),)

The unnormalized probability of choosing a leaf to grow.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
node_to_pruneint

The index of the node to prune. If num_prunable == 0, return 2 ** d.

num_prunableint

The number of leaf parents that could be pruned.

prob_choosefloat

The normalized probability of choosing the node to prune for growth.

bartz.mcmcstep.randint_masked(key, mask)[source]

Return a random integer in a range, including only some values.

Parameters:
keyjax.dtypes.prng_key array

A jax random key.

maskbool array (n,)

The mask indicating the allowed values.

Returns:
uint

A random integer in the range [0, n), and which satisfies mask[u] == True. If all values in the mask are False, return n.

bartz.mcmcstep.accept_moves_and_sample_leaves(bart, moves, key)[source]

Accept or reject the proposed moves and sample the new leaf values.

Parameters:
bartdict

A BART mcmc state.

movesdict

The proposed moves, see sample_moves.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
bartdict

The new BART mcmc state.

bartz.mcmcstep.accept_moves_parallel_stage(bart, moves, key)[source]

Pre-computes quantities used to accept moves, in parallel across trees.

Parameters:
bartdict

A BART mcmc state.

movesdict

The proposed moves, see sample_moves.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
bartdict

A partially updated BART mcmc state.

movesdict

The proposed moves, with the field ‘partial_ratio’ replaced by ‘log_trans_prior_ratio’.

count_treesarray (num_trees, 2 ** d)

The number of points in each potential or actual leaf node.

move_countsdict

The counts of the number of points in the the nodes modified by the moves.

prelkv, prelk, prelfdict

Dictionary with pre-computed terms of the likelihood ratios and leaf samples.

bartz.mcmcstep.apply_grow_to_indices(moves, leaf_indices, X)[source]

Update the leaf indices to apply a grow move.

Parameters:
movesdict

The proposed moves, see sample_moves.

leaf_indicesarray (num_trees, n)

The index of the leaf each datapoint falls into.

Xarray (p, n)

The predictors matrix.

Returns:
grow_leaf_indicesarray (num_trees, n)

The updated leaf indices.

bartz.mcmcstep.compute_count_trees(leaf_indices, moves, batch_size)[source]

Count the number of datapoints in each leaf.

Parameters:
grow_leaf_indicesint array (num_trees, n)

The index of the leaf each datapoint falls into, if the grow move is accepted.

movesdict

The proposed moves, see sample_moves.

batch_sizeint or None

The data batch size to use for the summation.

Returns:
count_treesint array (num_trees, 2 ** (d - 1))

The number of points in each potential or actual leaf node.

countsdict

The counts of the number of points in the the nodes modified by the moves, organized as two dictionaries ‘grow’ and ‘prune’, with subfields ‘left’, ‘right’, and ‘total’.

bartz.mcmcstep.count_datapoints_per_leaf(leaf_indices, tree_size, batch_size)[source]

Count the number of datapoints in each leaf.

Parameters:
leaf_indicesint array (num_trees, n)

The index of the leaf each datapoint falls into.

tree_sizeint

The size of the leaf tree array (2 ** d).

batch_sizeint or None

The data batch size to use for the summation.

Returns:
count_treesint array (num_trees, 2 ** (d - 1))

The number of points in each leaf node.

bartz.mcmcstep.complete_ratio(moves, move_counts, min_points_per_leaf)[source]

Complete non-likelihood MH ratio calculation.

This functions adds the probability of choosing the prune move.

Parameters:
movesdict

The proposed moves, see sample_moves.

move_countsdict

The counts of the number of points in the the nodes modified by the moves.

min_points_per_leafint or None

The minimum number of data points in a leaf node.

Returns:
movesdict

The updated moves, with the field ‘partial_ratio’ replaced by ‘log_trans_prior_ratio’.

bartz.mcmcstep.compute_p_prune(moves, left_count, right_count, min_points_per_leaf)[source]

Compute the probability of proposing a prune move.

Parameters:
movesdict

The proposed moves, see sample_moves.

left_count, right_countint

The number of datapoints in the proposed children of the leaf to grow.

min_points_per_leafint or None

The minimum number of data points in a leaf node.

Returns:
p_prunefloat

The probability of proposing a prune move. If grow: after accepting the grow move, if prune: right away.

bartz.mcmcstep.adapt_leaf_trees_to_grow_indices(leaf_trees, moves)[source]

Modify leaf values such that the indices of the grow moves work on the original tree.

Parameters:
leaf_treesfloat array (num_trees, 2 ** d)

The leaf values.

movesdict

The proposed moves, see sample_moves.

Returns:
leaf_treesfloat array (num_trees, 2 ** d)

The modified leaf values. The value of the leaf to grow is copied to what would be its children if the grow move was accepted.

bartz.mcmcstep.precompute_likelihood_terms(count_trees, sigma2, move_counts)[source]

Pre-compute terms used in the likelihood ratio of the acceptance step.

Parameters:
count_treesarray (num_trees, 2 ** d)

The number of points in each potential or actual leaf node.

sigma2float

The noise variance.

move_countsdict

The counts of the number of points in the the nodes modified by the moves.

Returns:
prelkvdict

Dictionary with pre-computed terms of the likelihood ratio, one per tree.

prelkdict

Dictionary with pre-computed terms of the likelihood ratio, shared by all trees.

bartz.mcmcstep.precompute_leaf_terms(count_trees, sigma2, key)[source]

Pre-compute terms used to sample leaves from their posterior.

Parameters:
count_treesarray (num_trees, 2 ** d)

The number of points in each potential or actual leaf node.

sigma2float

The noise variance.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
prelfdict

Dictionary with pre-computed terms of the leaf sampling, with fields:

‘mean_factor’float array (num_trees, 2 ** d)

The factor to be multiplied by the sum of residuals to obtain the posterior mean.

‘centered_leaves’float array (num_trees, 2 ** d)

The mean-zero normal values to be added to the posterior mean to obtain the posterior leaf samples.

bartz.mcmcstep.accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf)[source]

The part of accepting the moves that has to be done one tree at a time.

Parameters:
bartdict

A partially updated BART mcmc state.

count_treesarray (num_trees, 2 ** d)

The number of points in each potential or actual leaf node.

movesdict

The proposed moves, see sample_moves.

move_countsdict

The counts of the number of points in the the nodes modified by the moves.

prelkv, prelk, prelfdict

Dictionaries with pre-computed terms of the likelihood ratios and leaf samples.

Returns:
bartdict

A partially updated BART mcmc state.

movesdict

The proposed moves, with these additional fields:

‘acc’bool array (num_trees,)

Whether the move was accepted.

‘to_prune’bool array (num_trees,)

Whether, to reflect the acceptance status of the move, the state should be updated by pruning the leaves involved in the move.

bartz.mcmcstep.accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf)[source]

Accept or reject a proposed move and sample the new leaf values.

Parameters:
Xint array (p, n)

The predictors.

ntreeint

The number of trees in the forest.

resid_batch_sizeint, None

The batch size for computing the sum of residuals in each leaf.

residfloat array (n,)

The residuals (data minus forest value).

min_points_per_leafint or None

The minimum number of data points in a leaf node.

save_ratiosbool

Whether to save the acceptance ratios.

prelkdict

The pre-computed terms of the likelihood ratio which are shared across trees.

leaf_treefloat array (2 ** d,)

The leaf values of the tree.

count_treeint array (2 ** d,)

The number of datapoints in each leaf.

movedict

The proposed move, see sample_moves.

leaf_indicesint array (n,)

The leaf indices for the largest version of the tree compatible with the move.

prelkv, prelfdict

The pre-computed terms of the likelihood ratio and leaf sampling which are specific to the tree.

Returns:
residfloat array (n,)

The updated residuals (data minus forest value).

leaf_treefloat array (2 ** d,)

The new leaf values of the tree.

accbool

Whether the move was accepted.

to_prunebool

Whether, to reflect the acceptance status of the move, the state should be updated by pruning the leaves involved in the move.

ratiosdict

The acceptance ratios for the moves. Empty if not to be saved.

bartz.mcmcstep.sum_resid(resid, leaf_indices, tree_size, batch_size)[source]

Sum the residuals in each leaf.

Parameters:
residfloat array (n,)

The residuals (data minus forest value).

leaf_indicesint array (n,)

The leaf indices of the tree (in which leaf each data point falls into).

tree_sizeint

The size of the tree array (2 ** d).

batch_sizeint, None

The data batch size for the aggregation. Batching increases numerical accuracy and parallelism.

Returns:
resid_treefloat array (2 ** d,)

The sum of the residuals at data points in each leaf.

bartz.mcmcstep.compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk)[source]

Compute the likelihood ratio of a grow move.

Parameters:
total_residfloat

The sum of the residuals in the leaf to grow.

left_resid, right_residfloat

The sum of the residuals in the left/right child of the leaf to grow.

prelkv, prelkdict

The pre-computed terms of the likelihood ratio, see precompute_likelihood_terms.

Returns:
ratiofloat

The likelihood ratio P(data | new tree) / P(data | old tree).

bartz.mcmcstep.accept_moves_final_stage(bart, moves)[source]

The final part of accepting the moves, in parallel across trees.

Parameters:
bartdict

A partially updated BART mcmc state.

countsdict

The indicators of proposals and acceptances for grow and prune moves.

movesdict

The proposed moves (see sample_moves) as updated by accept_moves_sequential_stage.

Returns:
bartdict

The fully updated BART mcmc state.

bartz.mcmcstep.apply_moves_to_leaf_indices(leaf_indices, moves)[source]

Vectorized version of apply_moves_to_leaf_indices. Takes similar arguments as apply_moves_to_leaf_indices but with additional array axes over which apply_moves_to_leaf_indices is mapped.

Original documentation:

Update the leaf indices to match the accepted move.

Parameters:
leaf_indicesint array (num_trees, n)

The index of the leaf each datapoint falls into, if the grow move was accepted.

movesdict

The proposed moves (see sample_moves), as updated by accept_moves_sequential_stage.

Returns:
leaf_indicesint array (num_trees, n)

The updated leaf indices.

bartz.mcmcstep.apply_moves_to_split_trees(split_trees, moves)[source]

Vectorized version of apply_moves_to_split_trees. Takes similar arguments as apply_moves_to_split_trees but with additional array axes over which apply_moves_to_split_trees is mapped.

Original documentation:

Update the split trees to match the accepted move.

Parameters:
split_treesint array (num_trees, 2 ** (d - 1))

The cutpoints of the decision nodes in the initial trees.

movesdict

The proposed moves (see sample_moves), as updated by accept_moves_sequential_stage.

Returns:
split_treesint array (num_trees, 2 ** (d - 1))

The updated split trees.

bartz.mcmcstep.sample_sigma(bart, key)[source]

Noise variance sampling step of BART MCMC.

Parameters:
bartdict

A BART mcmc state, as created by init.

keyjax.dtypes.prng_key array

A jax random key.

Returns:
bartdict

The new BART mcmc state.