Debugging

Debugging utilities. The entry point is the class debug_gbart.

bartz.debug.format_tree(tree, *, print_all=False)[source]

Convert a tree to a human-readable string.

Parameters:
  • tree (TreeHeaps) – A single tree to format.

  • print_all (bool, default: False) – If True, also print the contents of unused node slots in the arrays.

Returns:

strA string representation of the tree.

bartz.debug.tree_actual_depth(split_tree)[source]

Measure the depth of the tree.

Parameters:

split_tree (UInt[Array, '2**(d-1)']) – The cutpoints of the decision rules.

Returns:

Int32[Array, '']The depth of the deepest leaf in the tree. The root is at depth 0.

bartz.debug.forest_depth_distr(split_tree)[source]

Histogram the depths of a set of trees.

Parameters:

split_tree (UInt[Array, 'num_trees 2**(d-1)']) – The cutpoints of the decision rules of the trees.

Returns:

Int32[Array, 'd']An integer vector where the i-th element counts how many trees have depth i.

bartz.debug.trace_depth_distr(split_tree)[source]

Histogram the depths of a sequence of sets of trees.

Parameters:

split_tree (UInt[Array, 'trace_length num_trees 2**(d-1)']) – The cutpoints of the decision rules of the trees.

Returns:

Int32[Array, 'trace_length d']A matrix where element (t,i) counts how many trees have depth i in set t.

bartz.debug.points_per_decision_node_distr(var_tree, split_tree, X)[source]

Histogram points-per-node counts.

Count how many parent-of-leaf nodes in a tree select each possible amount of points.

Parameters:
  • var_tree (UInt[Array, '2**(d-1)']) – The variables of the decision rules.

  • split_tree (UInt[Array, '2**(d-1)']) – The cutpoints of the decision rules.

  • X (UInt[Array, 'p n']) – The set of points to count.

Returns:

Int32[Array, 'n+1']A vector where the i-th element counts how many next-to-leaf nodes have i points.

bartz.debug.forest_points_per_decision_node_distr(trees, X)[source]

Histogram points-per-node counts for a set of trees.

Count how many parent-of-leaf nodes in a set of trees select each possible amount of points.

Parameters:
  • trees (TreeHeaps) – The set of trees. The variables must have broadcast shape (num_trees,).

  • X (UInt[Array, 'p n']) – The set of points to count.

Returns:

Int32[Array, 'n+1']A vector where the i-th element counts how many next-to-leaf nodes have i points.

bartz.debug.trace_points_per_decision_node_distr(trace, X)[source]

Separately histogram points-per-node counts over a sequence of sets of trees.

For each set of trees, count how many parent-of-leaf nodes select each possible amount of points.

Parameters:
  • trace (TreeHeaps) – The sequence of sets of trees. The variables must have broadcast shape (trace_length, num_trees).

  • X (UInt[Array, 'p n']) – The set of points to count.

Returns:

Int32[Array, 'trace_length n+1']A matrix where element (t,i) counts how many next-to-leaf nodes have i points in set t.

bartz.debug.points_per_leaf_distr(var_tree, split_tree, X)[source]

Histogram points-per-leaf counts in a tree.

Count how many leaves in a tree select each possible amount of points.

Parameters:
  • var_tree (UInt[Array, '2**(d-1)']) – The variables of the decision rules.

  • split_tree (UInt[Array, '2**(d-1)']) – The cutpoints of the decision rules.

  • X (UInt[Array, 'p n']) – The set of points to count.

Returns:

Int32[Array, 'n+1']A vector where the i-th element counts how many leaves have i points.

bartz.debug.forest_points_per_leaf_distr(trees, X)[source]

Histogram points-per-leaf counts over a set of trees.

Count how many leaves in a set of trees select each possible amount of points.

Parameters:
  • trees (TreeHeaps) – The set of trees. The variables must have broadcast shape (num_trees,).

  • X (UInt[Array, 'p n']) – The set of points to count.

Returns:

Int32[Array, 'n+1']A vector where the i-th element counts how many leaves have i points.

bartz.debug.trace_points_per_leaf_distr(trace, X)[source]

Separately histogram points-per-leaf counts over a sequence of sets of trees.

For each set of trees, count how many leaves select each possible amount of points.

Parameters:
  • trace (TreeHeaps) – The sequence of sets of trees. The variables must have broadcast shape (trace_length, num_trees).

  • X (UInt[Array, 'p n']) – The set of points to count.

Returns:

Int32[Array, 'trace_length n+1']A matrix where element (t,i) counts how many leaves have i points in set t.

bartz.debug.check(func)[source]

Add a function to a list of functions used to check trees.

Use to decorate functions that check whether a tree is valid in some way. These functions are invoked automatically by check_tree, check_trace and debug_gbart.

Parameters:

func (Callable[[TreeHeaps, UInt[Array, 'p']], bool | Bool[Array, '']]) – The function to add to the list. It must accept a TreeHeaps and a max_split argument, and return a boolean scalar that indicates if the tree is ok.

Returns:

Callable[[TreeHeaps, UInt[Array, 'p']], bool | Bool[Array, '']]The function unchanged.

bartz.debug.check_types(tree, max_split)[source]

Check that integer types are as small as possible and coherent.

Return type:

bool

bartz.debug.check_sizes(tree, max_split)[source]

Check that array sizes are coherent.

Return type:

bool

bartz.debug.check_unused_node(tree, max_split)[source]

Check that the unused node slot at index 0 is not dirty.

Return type:

Bool[Array, '']

bartz.debug.check_leaf_values(tree, max_split)[source]

Check that all leaf values are not inf of nan.

Return type:

Bool[Array, '']

bartz.debug.check_stray_nodes(tree, max_split)[source]

Check if there is any marked-non-leaf node with a marked-leaf parent.

Return type:

Bool[Array, '']

bartz.debug.check_rule_consistency(tree, max_split)[source]

Check that decision rules define proper subsets of ancestor rules.

Return type:

bool | Bool[Array, '']

bartz.debug.check_num_nodes(tree, max_split)[source]

Check that #leaves = 1 + #(internal nodes).

Return type:

Bool[Array, '']

bartz.debug.check_var_in_bounds(tree, max_split)[source]

Check that variables are in [0, max_split.size).

Return type:

Bool[Array, '']

bartz.debug.check_split_in_bounds(tree, max_split)[source]

Check that splits are in [0, max_split[var]].

Return type:

Bool[Array, '']

bartz.debug.check_tree(tree, max_split)[source]

Check the validity of a tree.

Use describe_error to parse the error code returned by this function.

Parameters:
  • tree (TreeHeaps) – The tree to check.

  • max_split (UInt[Array, 'p']) – The maximum split value for each variable.

Returns:

UInt[Array, '']An integer where each bit indicates whether a check failed.

bartz.debug.describe_error(error)[source]

Describe the error code returned by check_tree.

Parameters:

error (int | Integer[Array, '']) – The error code returned by check_tree.

Returns:

list[str]A list of the function names that implement the failed checks.

bartz.debug.check_trace(trace, max_split)[source]

Check the validity of a sequence of sets of trees.

Use describe_error to parse the error codes returned by this function.

Parameters:
  • trace (TreeHeaps) – The sequence of sets of trees to check. The tree arrays must have broadcast shape (trace_length, num_trees). This object can have additional attributes beyond the tree arrays, they are ignored.

  • max_split (UInt[Array, 'p']) – The maximum split value for each variable.

Returns:

UInt[Array, 'trace_length num_trees']A matrix of error codes for each tree.

class bartz.debug.BARTTraceMeta(ndpost, ntree, numcut, heap_size)[source]

Metadata of R BART tree traces.

Parameters:
  • ndpost (int) – The number of posterior draws.

  • ntree (int) – The number of trees in the model.

  • numcut (UInt[Array, 'p']) – The maximum split value for each variable.

  • heap_size (int) – The size of the heap required to store the trees.

bartz.debug.scan_BART_trees(trees)[source]

Scan an R BART tree trace checking for errors and parsing metadata.

Parameters:

trees (str) – The string representation of a trace of trees of the R BART package. Can be accessed from mc_gbart(...).treedraws['trees'].

Returns:

BARTTraceMetaAn object containing the metadata.

Raises:

ValueError – If the string is malformed or contains leftover characters.

class bartz.debug.TraceWithOffset(leaf_tree, var_tree, split_tree, offset)[source]

Implementation of bartz.mcmcloop.Trace.

classmethod from_trees_trace(trees, offset)[source]

Create a TraceWithOffset from a TreeHeaps.

Return type:

TraceWithOffset

bartz.debug.trees_BART_to_bartz(trees, *, min_maxdepth=0, offset=None)[source]

Convert trees from the R BART format to the bartz format.

Parameters:
  • trees (str) – The string representation of a trace of trees of the R BART package. Can be accessed from mc_gbart(...).treedraws['trees'].

  • min_maxdepth (int, default: 0) – The maximum tree depth of the output will be set to the maximum observed depth in the input trees. Use this parameter to require at least this maximum depth in the output format.

  • offset (float | Float[Any, ''] | None, default: None) – The trace returned by bartz.mcmcloop.run_mcmc contains an offset to be summed to the sum of trees. To match that behavior, this function returns an offset as well, zero by default. Set with this parameter otherwise.

Returns:

  • trace (TraceWithOffset) – A representation of the trees compatible with the trace returned by bartz.mcmcloop.run_mcmc.

  • meta (BARTTraceMeta) – The metadata of the trace, containing the number of iterations, trees, and the maximum split value.

class bartz.debug.SamplePriorStack(nonterminal, lower, upper, var, split)[source]

Represent the manually managed stack used in sample_prior.

Each level of the stack represents a recursion into a child node in a binary tree of maximum depth d.

Parameters:
  • nonterminal (Bool[Array, 'd-1']) – Whether the node is valid or the recursion is into unused node slots.

  • lower (UInt[Array, 'd-1 p'])

  • upper (UInt[Array, 'd-1 p']) – The available cutpoints along var are in the integer range [1 + lower[var], 1 + upper[var]).

  • var (UInt[Array, 'd-1'])

  • split (UInt[Array, 'd-1']) – The variable and cutpoint of a decision node.

classmethod initial(p_nonterminal, max_split)[source]

Initialize the stack.

Parameters:
  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

Returns:

SamplePriorStack – A SamplePriorStack initialized to start the recursion.

class bartz.debug.SamplePriorTrees(leaf_tree, var_tree, split_tree)[source]

Object holding the trees generated by sample_prior.

Parameters:
  • leaf_tree (Float32[Array, '* 2**d'])

  • var_tree (UInt[Array, '* 2**(d-1)'])

  • split_tree (UInt[Array, '* 2**(d-1)']) – The arrays representing the trees, see bartz.grove.

classmethod initial(key, sigma_mu, p_nonterminal, max_split)[source]

Initialize the trees.

The leaves are already correct and do not need to be changed.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

Returns:

SamplePriorTreesTrees initialized with random leaves and stub tree structures.

class bartz.debug.SamplePriorCarry(key, stack, trees)[source]

Object holding values carried along the recursion in sample_prior.

Parameters:
  • key (Key[Array, '']) – A jax random key used to sample decision rules.

  • stack (SamplePriorStack) – The stack used to manage the recursion.

  • trees (SamplePriorTrees) – The output arrays.

classmethod initial(key, sigma_mu, p_nonterminal, max_split)[source]

Initialize the carry object.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

Returns:

SamplePriorCarry – A SamplePriorCarry initialized to start the recursion.

class bartz.debug.SamplePriorX(node, depth, next_depth)[source]

Object representing the recursion scan in sample_prior.

The sequence of nodes to visit is pre-computed recursively once, unrolling the recursion schedule.

Parameters:
  • node (Int32[Array, '2**(d-1)-1']) – The heap index of the node to visit.

  • depth (Int32[Array, '2**(d-1)-1']) – The depth of the node.

  • next_depth (Int32[Array, '2**(d-1)-1']) – The depth of the next node to visit, either the left child or the right sibling of the node or of an ancestor.

classmethod initial(p_nonterminal)[source]

Initialize the sequence of nodes to visit.

Parameters:

p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.

Returns:

SamplePriorX – A SamplePriorX initialized with the sequence of nodes to visit.

bartz.debug.sample_prior_onetree(key, max_split, p_nonterminal, sigma_mu)[source]

Sample a tree from the BART prior.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • max_split (UInt[Array, 'p']) – The maximum split value for each variable.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

Returns:

SamplePriorTreesAn object containing a generated tree.

bartz.debug.sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu)[source]

Sample a set of independent trees from the BART prior.

Parameters:
  • keys (Key[Array, 'num_trees']) – A sequence of jax random keys, one for each tree. This determined the number of trees sampled.

  • max_split (UInt[Array, 'p']) – The maximum split value for each variable.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

Returns:

SamplePriorTreesAn object containing the generated trees.

bartz.debug.sample_prior(key, trace_length, num_trees, max_split, p_nonterminal, sigma_mu)[source]

Sample independent trees from the BART prior.

Parameters:
  • key (Key[Array, '']) – A jax random key.

  • trace_length (int) – The number of iterations.

  • num_trees (int) – The number of trees for each iteration.

  • max_split (UInt[Array, 'p']) – The number of cutpoints along each variable.

  • p_nonterminal (Float32[Array, 'd-1']) – The prior probability of a node being non-terminal conditional on its ancestors and on having available decision rules, at each depth. This determines the maximum depth of the trees.

  • sigma_mu (Float32[Array, '']) – The prior standard deviation of each leaf.

Returns:

SamplePriorTreesAn object containing the generated trees, with batch shape (trace_length, num_trees).

class bartz.debug.debug_gbart(*args, check_trees=True, **kw)[source]

A subclass of gbart that adds debugging functionality.

Parameters:
  • *args – Passed to gbart.

  • check_trees (bool, default: True) – If True, check all trees with check_trace after running the MCMC, and assert that they are all valid. Set to False to allow jax tracing.

  • **kw – Passed to gbart.

show_tree(i_sample, i_tree, print_all=False)[source]

Print a single tree in human-readable format.

Parameters:
  • i_sample (int) – The index of the posterior sample.

  • i_tree (int) – The index of the tree in the sample.

  • print_all (bool, default: False) – If True, also print the content of unused node slots.

sigma_harmonic_mean(prior=False)[source]

Return the harmonic mean of the error variance.

Parameters:

prior (bool, default: False) – If True, use the prior distribution, otherwise use the full conditional at the last MCMC iteration.

Returns:

Float32[Array, '']The harmonic mean 1/E[1/sigma^2] in the selected distribution.

compare_resid()[source]

Re-compute residuals to compare them with the updated ones.

Returns:

  • resid1 (Float32[Array, ‘n’]) – The final state of the residuals updated during the MCMC.

  • resid2 (Float32[Array, ‘n’]) – The residuals computed from the final state of the trees.

avg_acc()[source]

Compute the average acceptance rates of tree moves.

Returns:

  • acc_grow (Float32[Array, ‘’]) – The average acceptance rate of grow moves.

  • acc_prune (Float32[Array, ‘’]) – The average acceptance rate of prune moves.

avg_prop()[source]

Compute the average proposal rate of grow and prune moves.

Returns:

  • prop_grow (Float32[Array, ‘’]) – The fraction of times grow was proposed instead of prune.

  • prop_prune (Float32[Array, ‘’]) – The fraction of times prune was proposed instead of grow.

Notes

This function does not take into account cases where no move was proposed.

avg_move()[source]

Compute the move rate.

Returns:

  • rate_grow (Float32[Array, ‘’]) – The fraction of times a grow move was proposed and accepted.

  • rate_prune (Float32[Array, ‘’]) – The fraction of times a prune move was proposed and accepted.

depth_distr()[source]

Histogram of tree depths for each state of the trees.

Returns:

Float32[Array, 'trace_length d']A matrix where each row contains a histogram of tree depths.

points_per_decision_node_distr()[source]

Histogram of number of points belonging to parent-of-leaf nodes.

Returns:

Float32[Array, 'trace_length n+1']A matrix where each row contains a histogram of number of points.

points_per_leaf_distr()[source]

Histogram of number of points belonging to leaves.

Returns:

Float32[Array, 'trace_length n+1']A matrix where each row contains a histogram of number of points.

check_trees()[source]

Apply check_trace to all the tree draws.

Return type:

UInt[Array, 'trace_length ntree']

tree_goes_bad()[source]

Find iterations where a tree becomes invalid.

Returns:

Bool[Array, 'trace_length ntree'] – A where (i,j) is True if tree j is invalid at iteration i but not i-1.