Coverage for src/bartz/debug.py: 86%
427 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
1# bartz/src/bartz/debug.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Debugging utilities. The entry point is the class `debug_gbart`."""
27from collections.abc import Callable 1ab
28from dataclasses import replace 1ab
29from functools import partial 1ab
30from math import ceil, log2 1ab
31from re import fullmatch 1ab
33import numpy 1ab
34from equinox import Module, field 1ab
35from jax import jit, lax, random, vmap 1ab
36from jax import numpy as jnp 1ab
37from jax.tree_util import tree_map 1ab
38from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt 1ab
40from bartz.BART import FloatLike, gbart 1ab
41from bartz.grove import ( 1ab
42 TreeHeaps,
43 evaluate_forest,
44 is_actual_leaf,
45 is_leaves_parent,
46 traverse_tree,
47 tree_depth,
48 tree_depths,
49)
50from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc 1ab
51from bartz.jaxext import split as split_key 1ab
52from bartz.mcmcloop import TreesTrace 1ab
53from bartz.mcmcstep import randint_masked 1ab
56def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: 1ab
57 """Convert a tree to a human-readable string.
59 Parameters
60 ----------
61 tree
62 A single tree to format.
63 print_all
64 If `True`, also print the contents of unused node slots in the arrays.
66 Returns
67 -------
68 A string representation of the tree.
69 """
70 tee = '├──' 1ab
71 corner = '└──' 1ab
72 join = '│ ' 1ab
73 space = ' ' 1ab
74 down = '┐' 1ab
75 bottom = '╢' # '┨' # 1ab
77 def traverse_tree( 1ab
78 lines: list[str],
79 index: int,
80 depth: int,
81 indent: str,
82 first_indent: str,
83 next_indent: str,
84 unused: bool,
85 ):
86 if index >= len(tree.leaf_tree): 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true1ab
87 return
89 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 1ab
90 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 1ab
92 is_leaf = split == 0 1ab
93 left_child = 2 * index 1ab
94 right_child = 2 * index + 1 1ab
96 if print_all: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true1ab
97 if unused:
98 category = 'unused'
99 elif is_leaf:
100 category = 'leaf'
101 else:
102 category = 'decision'
103 node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})'
104 else:
105 assert not unused 1ab
106 if is_leaf: 1ab
107 node_str = f'{tree.leaf_tree[index]:#.2g}' 1ab
108 else:
109 node_str = f'x{var} < {split}' 1ab
111 if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): 1ab
112 link = down 1ab
113 elif not print_all and left_child >= len(tree.leaf_tree): 1ab
114 link = bottom 1ab
115 else:
116 link = ' ' 1ab
118 max_number = len(tree.leaf_tree) - 1 1ab
119 ndigits = len(str(max_number)) 1ab
120 number = str(index).rjust(ndigits) 1ab
122 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 1ab
124 indent += next_indent 1ab
125 unused = unused or is_leaf 1ab
127 if unused and not print_all: 1ab
128 return 1ab
130 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 1ab
131 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 1ab
133 lines = [] 1ab
134 traverse_tree(lines, 1, 0, '', '', '', False) 1ab
135 return '\n'.join(lines) 1ab
138def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']: 1ab
139 """Measure the depth of the tree.
141 Parameters
142 ----------
143 split_tree
144 The cutpoints of the decision rules.
146 Returns
147 -------
148 The depth of the deepest leaf in the tree. The root is at depth 0.
149 """
150 # this could be done just with split_tree != 0
151 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab
152 depth = tree_depths(is_leaf.size) 1ab
153 depth = jnp.where(is_leaf, depth, 0) 1ab
154 return jnp.max(depth) 1ab
157def forest_depth_distr( 1ab
158 split_tree: UInt[Array, 'num_trees 2**(d-1)'],
159) -> Int32[Array, ' d']:
160 """Histogram the depths of a set of trees.
162 Parameters
163 ----------
164 split_tree
165 The cutpoints of the decision rules of the trees.
167 Returns
168 -------
169 An integer vector where the i-th element counts how many trees have depth i.
170 """
171 depth = tree_depth(split_tree) + 1 1ab
172 depths = vmap(tree_actual_depth)(split_tree) 1ab
173 return jnp.bincount(depths, length=depth) 1ab
176@jit 1ab
177def trace_depth_distr( 1ab
178 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'],
179) -> Int32[Array, 'trace_length d']:
180 """Histogram the depths of a sequence of sets of trees.
182 Parameters
183 ----------
184 split_tree
185 The cutpoints of the decision rules of the trees.
187 Returns
188 -------
189 A matrix where element (t,i) counts how many trees have depth i in set t.
190 """
191 return vmap(forest_depth_distr)(split_tree) 1ab
194def points_per_decision_node_distr( 1ab
195 var_tree: UInt[Array, ' 2**(d-1)'],
196 split_tree: UInt[Array, ' 2**(d-1)'],
197 X: UInt[Array, 'p n'],
198) -> Int32[Array, ' n+1']:
199 """Histogram points-per-node counts.
201 Count how many parent-of-leaf nodes in a tree select each possible amount
202 of points.
204 Parameters
205 ----------
206 var_tree
207 The variables of the decision rules.
208 split_tree
209 The cutpoints of the decision rules.
210 X
211 The set of points to count.
213 Returns
214 -------
215 A vector where the i-th element counts how many next-to-leaf nodes have i points.
216 """
217 traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None)) 1ab
218 indices = traverse_tree_X(X, var_tree, split_tree) 1ab
219 indices >>= 1 1ab
220 count_tree = jnp.zeros(split_tree.size, int).at[indices].add(1).at[0].set(0) 1ab
221 is_parent = is_leaves_parent(split_tree) 1ab
222 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_parent) 1ab
225def forest_points_per_decision_node_distr( 1ab
226 trees: TreeHeaps, X: UInt[Array, 'p n']
227) -> Int32[Array, ' n+1']:
228 """Histogram points-per-node counts for a set of trees.
230 Count how many parent-of-leaf nodes in a set of trees select each possible
231 amount of points.
233 Parameters
234 ----------
235 trees
236 The set of trees. The variables must have broadcast shape (num_trees,).
237 X
238 The set of points to count.
240 Returns
241 -------
242 A vector where the i-th element counts how many next-to-leaf nodes have i points.
243 """
244 distr = jnp.zeros(X.shape[1] + 1, int) 1ab
246 def loop(distr, heaps: tuple[Array, Array]): 1ab
247 return distr + points_per_decision_node_distr(*heaps, X), None 1ab
249 distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree)) 1ab
250 return distr 1ab
253@jit 1ab
254def trace_points_per_decision_node_distr( 1ab
255 trace: TreeHeaps, X: UInt[Array, 'p n']
256) -> Int32[Array, 'trace_length n+1']:
257 """Separately histogram points-per-node counts over a sequence of sets of trees.
259 For each set of trees, count how many parent-of-leaf nodes select each
260 possible amount of points.
262 Parameters
263 ----------
264 trace
265 The sequence of sets of trees. The variables must have broadcast shape
266 (trace_length, num_trees).
267 X
268 The set of points to count.
270 Returns
271 -------
272 A matrix where element (t,i) counts how many next-to-leaf nodes have i points in set t.
273 """
275 def loop(_, trace): 1ab
276 return None, forest_points_per_decision_node_distr(trace, X) 1ab
278 _, distr = lax.scan(loop, None, trace) 1ab
279 return distr 1ab
282def points_per_leaf_distr( 1ab
283 var_tree: UInt[Array, ' 2**(d-1)'],
284 split_tree: UInt[Array, ' 2**(d-1)'],
285 X: UInt[Array, 'p n'],
286) -> Int32[Array, ' n+1']:
287 """Histogram points-per-leaf counts in a tree.
289 Count how many leaves in a tree select each possible amount of points.
291 Parameters
292 ----------
293 var_tree
294 The variables of the decision rules.
295 split_tree
296 The cutpoints of the decision rules.
297 X
298 The set of points to count.
300 Returns
301 -------
302 A vector where the i-th element counts how many leaves have i points.
303 """
304 traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None)) 1ab
305 indices = traverse_tree_X(X, var_tree, split_tree) 1ab
306 count_tree = jnp.zeros(2 * split_tree.size, int).at[indices].add(1) 1ab
307 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab
308 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_leaf) 1ab
311def forest_points_per_leaf_distr( 1ab
312 trees: TreeHeaps, X: UInt[Array, 'p n']
313) -> Int32[Array, ' n+1']:
314 """Histogram points-per-leaf counts over a set of trees.
316 Count how many leaves in a set of trees select each possible amount of points.
318 Parameters
319 ----------
320 trees
321 The set of trees. The variables must have broadcast shape (num_trees,).
322 X
323 The set of points to count.
325 Returns
326 -------
327 A vector where the i-th element counts how many leaves have i points.
328 """
329 distr = jnp.zeros(X.shape[1] + 1, int) 1ab
331 def loop(distr, heaps: tuple[Array, Array]): 1ab
332 return distr + points_per_leaf_distr(*heaps, X), None 1ab
334 distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree)) 1ab
335 return distr 1ab
338@jit 1ab
339def trace_points_per_leaf_distr( 1ab
340 trace: TreeHeaps, X: UInt[Array, 'p n']
341) -> Int32[Array, 'trace_length n+1']:
342 """Separately histogram points-per-leaf counts over a sequence of sets of trees.
344 For each set of trees, count how many leaves select each possible amount of
345 points.
347 Parameters
348 ----------
349 trace
350 The sequence of sets of trees. The variables must have broadcast shape
351 (trace_length, num_trees).
352 X
353 The set of points to count.
355 Returns
356 -------
357 A matrix where element (t,i) counts how many leaves have i points in set t.
358 """
360 def loop(_, trace): 1ab
361 return None, forest_points_per_leaf_distr(trace, X) 1ab
363 _, distr = lax.scan(loop, None, trace) 1ab
364 return distr 1ab
367check_functions = [] 1ab
370CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']] 1ab
373def check(func: CheckFunc) -> CheckFunc: 1ab
374 """Add a function to a list of functions used to check trees.
376 Use to decorate functions that check whether a tree is valid in some way.
377 These functions are invoked automatically by `check_tree`, `check_trace` and
378 `debug_gbart`.
380 Parameters
381 ----------
382 func
383 The function to add to the list. It must accept a `TreeHeaps` and a
384 `max_split` argument, and return a boolean scalar that indicates if the
385 tree is ok.
387 Returns
388 -------
389 The function unchanged.
390 """
391 check_functions.append(func) 1ab
392 return func 1ab
395@check 1ab
396def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: 1ab
397 """Check that integer types are as small as possible and coherent."""
398 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab
399 expected_split_dtype = max_split.dtype 1ab
400 return ( 1ab
401 tree.var_tree.dtype == expected_var_dtype
402 and tree.split_tree.dtype == expected_split_dtype
403 )
406@check 1ab
407def check_sizes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: # noqa: ARG001 1ab
408 """Check that array sizes are coherent."""
409 return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size 1ab
412@check 1ab
413def check_unused_node(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab
414 """Check that the unused node slot at index 0 is not dirty."""
415 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 1ab
418@check 1ab
419def check_leaf_values(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab
420 """Check that all leaf values are not inf of nan."""
421 return jnp.all(jnp.isfinite(tree.leaf_tree)) 1ab
424@check 1ab
425def check_stray_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab
426 """Check if there is any marked-non-leaf node with a marked-leaf parent."""
427 index = jnp.arange( 1ab
428 2 * tree.split_tree.size,
429 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
430 )
431 parent_index = index >> 1 1ab
432 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 1ab
433 parent_is_leaf = tree.split_tree[parent_index] == 0 1ab
434 stray = is_not_leaf & parent_is_leaf 1ab
435 stray = stray.at[1].set(False) 1ab
436 return ~jnp.any(stray) 1ab
439@check 1ab
440def check_rule_consistency( 1ab
441 tree: TreeHeaps, max_split: UInt[Array, ' p']
442) -> bool | Bool[Array, '']:
443 """Check that decision rules define proper subsets of ancestor rules."""
444 if tree.var_tree.size < 4: 444 ↛ 445line 444 didn't jump to line 445 because the condition on line 444 was never true1ab
445 return True
447 # initial boundaries of decision rules. use extreme integers instead of 0,
448 # max_split to avoid checking if there is something out of bounds.
449 small = jnp.iinfo(jnp.int32).min 1ab
450 large = jnp.iinfo(jnp.int32).max 1ab
451 lower = jnp.full(max_split.size, small, jnp.int32) 1ab
452 upper = jnp.full(max_split.size, large, jnp.int32) 1ab
453 # specify the type explicitly, otherwise they are weakly types and get
454 # implicitly converted to split.dtype (typically uint8) in the expressions
456 def _check_recursive(node, lower, upper): 1ab
457 # read decision rule
458 var = tree.var_tree[node] 1ab
459 split = tree.split_tree[node] 1ab
461 # get rule boundaries from ancestors. use fill value in case var is
462 # out of bounds, we don't want to check out of bounds in this function
463 lower_var = lower.at[var].get(mode='fill', fill_value=small) 1ab
464 upper_var = upper.at[var].get(mode='fill', fill_value=large) 1ab
466 # check rule is in bounds
467 bad = jnp.where(split, (split <= lower_var) | (split >= upper_var), False) 1ab
469 # recurse
470 if node < tree.var_tree.size // 2: 1ab
471 bad |= _check_recursive( 1ab
472 2 * node,
473 lower,
474 upper.at[jnp.where(split, var, max_split.size)].set(split),
475 )
476 bad |= _check_recursive( 1ab
477 2 * node + 1,
478 lower.at[jnp.where(split, var, max_split.size)].set(split),
479 upper,
480 )
481 return bad 1ab
483 return ~_check_recursive(1, lower, upper) 1ab
486@check 1ab
487def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab
488 """Check that #leaves = 1 + #(internal nodes)."""
489 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 1ab
490 num_leaves = jnp.count_nonzero(is_leaf) 1ab
491 num_internal = jnp.count_nonzero(tree.split_tree) 1ab
492 return num_leaves == num_internal + 1 1ab
495@check 1ab
496def check_var_in_bounds( 1ab
497 tree: TreeHeaps, max_split: UInt[Array, ' p']
498) -> Bool[Array, '']:
499 """Check that variables are in [0, max_split.size)."""
500 decision_node = tree.split_tree.astype(bool) 1ab
501 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 1ab
502 return jnp.all(in_bounds | ~decision_node) 1ab
505@check 1ab
506def check_split_in_bounds( 1ab
507 tree: TreeHeaps, max_split: UInt[Array, ' p']
508) -> Bool[Array, '']:
509 """Check that splits are in [0, max_split[var]]."""
510 max_split_var = ( 1ab
511 max_split.astype(jnp.int32)
512 .at[tree.var_tree]
513 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
514 )
515 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 1ab
518def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: 1ab
519 """Check the validity of a tree.
521 Use `describe_error` to parse the error code returned by this function.
523 Parameters
524 ----------
525 tree
526 The tree to check.
527 max_split
528 The maximum split value for each variable.
530 Returns
531 -------
532 An integer where each bit indicates whether a check failed.
533 """
534 error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1) 1ab
535 error = error_type(0) 1ab
536 for i, func in enumerate(check_functions): 1ab
537 ok = func(tree, max_split) 1ab
538 ok = jnp.bool_(ok) 1ab
539 bit = (~ok) << i 1ab
540 error |= bit 1ab
541 return error 1ab
544def describe_error(error: int | Integer[Array, '']) -> list[str]: 1ab
545 """Describe the error code returned by `check_tree`.
547 Parameters
548 ----------
549 error
550 The error code returned by `check_tree`.
552 Returns
553 -------
554 A list of the function names that implement the failed checks.
555 """
556 return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
559@jit 1ab
560@partial(vmap_nodoc, in_axes=(0, None)) 1ab
561def check_trace( 1ab
562 trace: TreeHeaps, max_split: UInt[Array, ' p']
563) -> UInt[Array, 'trace_length num_trees']:
564 """Check the validity of a sequence of sets of trees.
566 Use `describe_error` to parse the error codes returned by this function.
568 Parameters
569 ----------
570 trace
571 The sequence of sets of trees to check. The tree arrays must have
572 broadcast shape (trace_length, num_trees). This object can have
573 additional attributes beyond the tree arrays, they are ignored.
574 max_split
575 The maximum split value for each variable.
577 Returns
578 -------
579 A matrix of error codes for each tree.
580 """
581 trees = TreesTrace.from_dataclass(trace) 1ab
582 check_forest = vmap(check_tree, in_axes=(0, None)) 1ab
583 return check_forest(trees, max_split) 1ab
586def _get_next_line(s: str, i: int) -> tuple[str, int]: 1ab
587 """Get the next line from a string and the new index."""
588 i_new = s.find('\n', i) 1ab
589 if i_new == -1: 589 ↛ 590line 589 didn't jump to line 590 because the condition on line 589 was never true1ab
590 return s[i:], len(s)
591 return s[i:i_new], i_new + 1 1ab
594class BARTTraceMeta(Module): 1ab
595 """Metadata of R BART tree traces.
597 Parameters
598 ----------
599 ndpost
600 The number of posterior draws.
601 ntree
602 The number of trees in the model.
603 numcut
604 The maximum split value for each variable.
605 heap_size
606 The size of the heap required to store the trees.
607 """
609 ndpost: int = field(static=True) 1ab
610 ntree: int = field(static=True) 1ab
611 numcut: UInt[Array, ' p'] 1ab
612 heap_size: int = field(static=True) 1ab
615def scan_BART_trees(trees: str) -> BARTTraceMeta: 1ab
616 """Scan an R BART tree trace checking for errors and parsing metadata.
618 Parameters
619 ----------
620 trees
621 The string representation of a trace of trees of the R BART package.
622 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
624 Returns
625 -------
626 An object containing the metadata.
628 Raises
629 ------
630 ValueError
631 If the string is malformed or contains leftover characters.
632 """
633 # parse first line
634 line, i_char = _get_next_line(trees, 0) 1ab
635 i_line = 1 1ab
636 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 1ab
637 if match is None: 637 ↛ 638line 637 didn't jump to line 638 because the condition on line 637 was never true1ab
638 msg = f'Malformed header at {i_line=}'
639 raise ValueError(msg)
640 ndpost, ntree, p = map(int, match.groups()) 1ab
642 # initial values for maxima
643 max_heap_index = 0 1ab
644 numcut = numpy.zeros(p, int) 1ab
646 # cycle over iterations and trees
647 for i_iter in range(ndpost): 1ab
648 for i_tree in range(ntree): 1ab
649 # parse first line of tree definition
650 line, i_char = _get_next_line(trees, i_char) 1ab
651 i_line += 1 1ab
652 match = fullmatch(r'(\d+)', line) 1ab
653 if match is None: 653 ↛ 654line 653 didn't jump to line 654 because the condition on line 653 was never true1ab
654 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
655 raise ValueError(msg)
656 num_nodes = int(line) 1ab
658 # cycle over nodes
659 for i_node in range(num_nodes): 1ab
660 # parse node definition
661 line, i_char = _get_next_line(trees, i_char) 1ab
662 i_line += 1 1ab
663 match = fullmatch( 1ab
664 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
665 )
666 if match is None: 666 ↛ 667line 666 didn't jump to line 667 because the condition on line 666 was never true1ab
667 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
668 raise ValueError(msg)
669 i_heap = int(match.group(1)) 1ab
670 var = int(match.group(2)) 1ab
671 split = int(match.group(3)) 1ab
673 # update maxima
674 numcut[var] = max(numcut[var], split) 1ab
675 max_heap_index = max(max_heap_index, i_heap) 1ab
677 assert i_char <= len(trees) 1ab
678 if i_char < len(trees): 678 ↛ 679line 678 didn't jump to line 679 because the condition on line 678 was never true1ab
679 msg = f'Leftover {len(trees) - i_char} characters in string'
680 raise ValueError(msg)
682 # determine minimal integer type for numcut
683 numcut += 1 # because BART is 0-based 1ab
684 split_dtype = minimal_unsigned_dtype(numcut.max()) 1ab
685 numcut = jnp.array(numcut.astype(split_dtype)) 1ab
687 # determine minimum heap size to store the trees
688 heap_size = 2 ** ceil(log2(max_heap_index + 1)) 1ab
690 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 1ab
693class TraceWithOffset(Module): 1ab
694 """Implementation of `bartz.mcmcloop.Trace`."""
696 leaf_tree: Float32[Array, 'ndpost ntree 2**d'] 1ab
697 var_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 1ab
698 split_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 1ab
699 offset: Float32[Array, ' ndpost'] 1ab
701 @classmethod 1ab
702 def from_trees_trace( 1ab
703 cls, trees: TreeHeaps, offset: Float32[Array, '']
704 ) -> 'TraceWithOffset':
705 """Create a `TraceWithOffset` from a `TreeHeaps`."""
706 ndpost, _, _ = trees.leaf_tree.shape 1ab
707 return cls( 1ab
708 leaf_tree=trees.leaf_tree,
709 var_tree=trees.var_tree,
710 split_tree=trees.split_tree,
711 offset=jnp.full(ndpost, offset),
712 )
715def trees_BART_to_bartz( 1ab
716 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
717) -> tuple[TraceWithOffset, BARTTraceMeta]:
718 """Convert trees from the R BART format to the bartz format.
720 Parameters
721 ----------
722 trees
723 The string representation of a trace of trees of the R BART package.
724 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
725 min_maxdepth
726 The maximum tree depth of the output will be set to the maximum
727 observed depth in the input trees. Use this parameter to require at
728 least this maximum depth in the output format.
729 offset
730 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
731 summed to the sum of trees. To match that behavior, this function
732 returns an offset as well, zero by default. Set with this parameter
733 otherwise.
735 Returns
736 -------
737 trace : TraceWithOffset
738 A representation of the trees compatible with the trace returned by
739 `bartz.mcmcloop.run_mcmc`.
740 meta : BARTTraceMeta
741 The metadata of the trace, containing the number of iterations, trees,
742 and the maximum split value.
743 """
744 # scan all the string checking for errors and determining sizes
745 meta = scan_BART_trees(trees) 1ab
747 # skip first line
748 _, i_char = _get_next_line(trees, 0) 1ab
750 heap_size = max(meta.heap_size, 2**min_maxdepth) 1ab
751 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 1ab
752 var_trees = numpy.zeros( 1ab
753 (meta.ndpost, meta.ntree, heap_size // 2),
754 dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
755 )
756 split_trees = numpy.zeros( 1ab
757 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
758 )
760 # cycle over iterations and trees
761 for i_iter in range(meta.ndpost): 1ab
762 for i_tree in range(meta.ntree): 1ab
763 # parse first line of tree definition
764 line, i_char = _get_next_line(trees, i_char) 1ab
765 num_nodes = int(line) 1ab
767 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 1ab
769 # cycle over nodes
770 for _ in range(num_nodes): 1ab
771 # parse node definition
772 line, i_char = _get_next_line(trees, i_char) 1ab
773 values = line.split() 1ab
774 i_heap = int(values[0]) 1ab
775 var = int(values[1]) 1ab
776 split = int(values[2]) 1ab
777 leaf = float(values[3]) 1ab
779 # update values
780 leaf_trees[i_iter, i_tree, i_heap] = leaf 1ab
781 is_internal[i_heap // 2] = True 1ab
782 if i_heap < heap_size // 2: 1ab
783 var_trees[i_iter, i_tree, i_heap] = var 1ab
784 split_trees[i_iter, i_tree, i_heap] = split + 1 1ab
786 is_internal[0] = False 1ab
787 split_trees[i_iter, i_tree, ~is_internal] = 0 1ab
789 return TraceWithOffset( 1ab
790 leaf_tree=jnp.array(leaf_trees),
791 var_tree=jnp.array(var_trees),
792 split_tree=jnp.array(split_trees),
793 offset=jnp.zeros(meta.ndpost)
794 if offset is None
795 else jnp.full(meta.ndpost, offset),
796 ), meta
799class SamplePriorStack(Module): 1ab
800 """Represent the manually managed stack used in `sample_prior`.
802 Each level of the stack represents a recursion into a child node in a
803 binary tree of maximum depth `d`.
805 Parameters
806 ----------
807 nonterminal
808 Whether the node is valid or the recursion is into unused node slots.
809 lower
810 upper
811 The available cutpoints along ``var`` are in the integer range
812 ``[1 + lower[var], 1 + upper[var])``.
813 var
814 split
815 The variable and cutpoint of a decision node.
816 """
818 nonterminal: Bool[Array, ' d-1'] 1ab
819 lower: UInt[Array, 'd-1 p'] 1ab
820 upper: UInt[Array, 'd-1 p'] 1ab
821 var: UInt[Array, ' d-1'] 1ab
822 split: UInt[Array, ' d-1'] 1ab
824 @classmethod 1ab
825 def initial( 1ab
826 cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p']
827 ) -> 'SamplePriorStack':
828 """Initialize the stack.
830 Parameters
831 ----------
832 p_nonterminal
833 The prior probability of a node being non-terminal conditional on
834 its ancestors and on having available decision rules, at each depth.
835 max_split
836 The number of cutpoints along each variable.
838 Returns
839 -------
840 A `SamplePriorStack` initialized to start the recursion.
841 """
842 var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab
843 return cls( 1ab
844 nonterminal=jnp.ones(p_nonterminal.size, bool),
845 lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype),
846 upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)),
847 var=jnp.zeros(p_nonterminal.size, var_dtype),
848 split=jnp.zeros(p_nonterminal.size, max_split.dtype),
849 )
852class SamplePriorTrees(Module): 1ab
853 """Object holding the trees generated by `sample_prior`.
855 Parameters
856 ----------
857 leaf_tree
858 var_tree
859 split_tree
860 The arrays representing the trees, see `bartz.grove`.
861 """
863 leaf_tree: Float32[Array, '* 2**d'] 1ab
864 var_tree: UInt[Array, '* 2**(d-1)'] 1ab
865 split_tree: UInt[Array, '* 2**(d-1)'] 1ab
867 @classmethod 1ab
868 def initial( 1ab
869 cls,
870 key: Key[Array, ''],
871 sigma_mu: Float32[Array, ''],
872 p_nonterminal: Float32[Array, ' d-1'],
873 max_split: UInt[Array, ' p'],
874 ) -> 'SamplePriorTrees':
875 """Initialize the trees.
877 The leaves are already correct and do not need to be changed.
879 Parameters
880 ----------
881 key
882 A jax random key.
883 sigma_mu
884 The prior standard deviation of each leaf.
885 p_nonterminal
886 The prior probability of a node being non-terminal conditional on
887 its ancestors and on having available decision rules, at each depth.
888 max_split
889 The number of cutpoints along each variable.
891 Returns
892 -------
893 Trees initialized with random leaves and stub tree structures.
894 """
895 heap_size = 2 ** (p_nonterminal.size + 1) 1ab
896 return cls( 1ab
897 leaf_tree=sigma_mu * random.normal(key, (heap_size,)),
898 var_tree=jnp.zeros(
899 heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1)
900 ),
901 split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype),
902 )
905class SamplePriorCarry(Module): 1ab
906 """Object holding values carried along the recursion in `sample_prior`.
908 Parameters
909 ----------
910 key
911 A jax random key used to sample decision rules.
912 stack
913 The stack used to manage the recursion.
914 trees
915 The output arrays.
916 """
918 key: Key[Array, ''] 1ab
919 stack: SamplePriorStack 1ab
920 trees: SamplePriorTrees 1ab
922 @classmethod 1ab
923 def initial( 1ab
924 cls,
925 key: Key[Array, ''],
926 sigma_mu: Float32[Array, ''],
927 p_nonterminal: Float32[Array, ' d-1'],
928 max_split: UInt[Array, ' p'],
929 ) -> 'SamplePriorCarry':
930 """Initialize the carry object.
932 Parameters
933 ----------
934 key
935 A jax random key.
936 sigma_mu
937 The prior standard deviation of each leaf.
938 p_nonterminal
939 The prior probability of a node being non-terminal conditional on
940 its ancestors and on having available decision rules, at each depth.
941 max_split
942 The number of cutpoints along each variable.
944 Returns
945 -------
946 A `SamplePriorCarry` initialized to start the recursion.
947 """
948 keys = split_key(key) 1ab
949 return cls( 1ab
950 keys.pop(),
951 SamplePriorStack.initial(p_nonterminal, max_split),
952 SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split),
953 )
956class SamplePriorX(Module): 1ab
957 """Object representing the recursion scan in `sample_prior`.
959 The sequence of nodes to visit is pre-computed recursively once, unrolling
960 the recursion schedule.
962 Parameters
963 ----------
964 node
965 The heap index of the node to visit.
966 depth
967 The depth of the node.
968 next_depth
969 The depth of the next node to visit, either the left child or the right
970 sibling of the node or of an ancestor.
971 """
973 node: Int32[Array, ' 2**(d-1)-1'] 1ab
974 depth: Int32[Array, ' 2**(d-1)-1'] 1ab
975 next_depth: Int32[Array, ' 2**(d-1)-1'] 1ab
977 @classmethod 1ab
978 def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX': 1ab
979 """Initialize the sequence of nodes to visit.
981 Parameters
982 ----------
983 p_nonterminal
984 The prior probability of a node being non-terminal conditional on
985 its ancestors and on having available decision rules, at each depth.
987 Returns
988 -------
989 A `SamplePriorX` initialized with the sequence of nodes to visit.
990 """
991 seq = cls._sequence(p_nonterminal.size) 1ab
992 assert len(seq) == 2**p_nonterminal.size - 1 1ab
993 node = [node for node, depth in seq] 1ab
994 depth = [depth for node, depth in seq] 1ab
995 next_depth = depth[1:] + [p_nonterminal.size] 1ab
996 return cls( 1ab
997 node=jnp.array(node),
998 depth=jnp.array(depth),
999 next_depth=jnp.array(next_depth),
1000 )
1002 @classmethod 1ab
1003 def _sequence( 1ab
1004 cls, max_depth: int, depth: int = 0, node: int = 1
1005 ) -> tuple[tuple[int, int], ...]:
1006 """Recursively generate a sequence [(node, depth), ...]."""
1007 if depth < max_depth: 1ab
1008 out = ((node, depth),) 1ab
1009 out += cls._sequence(max_depth, depth + 1, 2 * node) 1ab
1010 out += cls._sequence(max_depth, depth + 1, 2 * node + 1) 1ab
1011 return out 1ab
1012 return () 1ab
1015def sample_prior_onetree( 1ab
1016 key: Key[Array, ''],
1017 max_split: UInt[Array, ' p'],
1018 p_nonterminal: Float32[Array, ' d-1'],
1019 sigma_mu: Float32[Array, ''],
1020) -> SamplePriorTrees:
1021 """Sample a tree from the BART prior.
1023 Parameters
1024 ----------
1025 key
1026 A jax random key.
1027 max_split
1028 The maximum split value for each variable.
1029 p_nonterminal
1030 The prior probability of a node being non-terminal conditional on
1031 its ancestors and on having available decision rules, at each depth.
1032 sigma_mu
1033 The prior standard deviation of each leaf.
1035 Returns
1036 -------
1037 An object containing a generated tree.
1038 """
1039 carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split) 1ab
1040 xs = SamplePriorX.initial(p_nonterminal) 1ab
1042 def loop(carry: SamplePriorCarry, x: SamplePriorX): 1ab
1043 keys = split_key(carry.key, 4) 1ab
1045 # get variables at current stack level
1046 stack = carry.stack 1ab
1047 nonterminal = stack.nonterminal[x.depth] 1ab
1048 lower = stack.lower[x.depth, :] 1ab
1049 upper = stack.upper[x.depth, :] 1ab
1051 # sample a random decision rule
1052 available: Bool[Array, ' p'] = lower < upper 1ab
1053 allowed = jnp.any(available) 1ab
1054 var = randint_masked(keys.pop(), available) 1ab
1055 split = 1 + random.randint(keys.pop(), (), lower[var], upper[var]) 1ab
1057 # cast to shorter integer types
1058 var = var.astype(carry.trees.var_tree.dtype) 1ab
1059 split = split.astype(carry.trees.split_tree.dtype) 1ab
1061 # decide whether to try to grow the node if it is growable
1062 pnt = p_nonterminal[x.depth] 1ab
1063 try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt) 1ab
1064 nonterminal &= try_nonterminal & allowed 1ab
1066 # update trees
1067 trees = carry.trees 1ab
1068 trees = replace( 1ab
1069 trees,
1070 var_tree=trees.var_tree.at[x.node].set(var),
1071 split_tree=trees.split_tree.at[x.node].set(
1072 jnp.where(nonterminal, split, 0)
1073 ),
1074 )
1076 def write_push_stack() -> SamplePriorStack: 1ab
1077 """Update the stack to go to the left child."""
1078 return replace( 1ab
1079 stack,
1080 nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal),
1081 lower=stack.lower.at[x.next_depth, :].set(lower),
1082 upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)),
1083 var=stack.var.at[x.depth].set(var),
1084 split=stack.split.at[x.depth].set(split),
1085 )
1087 def pop_push_stack() -> SamplePriorStack: 1ab
1088 """Update the stack to go to the right sibling, possibly at lower depth."""
1089 var = stack.var[x.next_depth - 1] 1ab
1090 split = stack.split[x.next_depth - 1] 1ab
1091 lower = stack.lower[x.next_depth - 1, :] 1ab
1092 upper = stack.upper[x.next_depth - 1, :] 1ab
1093 return replace( 1ab
1094 stack,
1095 lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)),
1096 upper=stack.upper.at[x.next_depth, :].set(upper),
1097 )
1099 # update stack
1100 stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack) 1ab
1102 # update carry
1103 carry = replace(carry, key=keys.pop(), stack=stack, trees=trees) 1ab
1104 return carry, None 1ab
1106 carry, _ = lax.scan(loop, carry, xs) 1ab
1107 return carry.trees 1ab
1110@partial(vmap_nodoc, in_axes=(0, None, None, None)) 1ab
1111def sample_prior_forest( 1ab
1112 keys: Key[Array, ' num_trees'],
1113 max_split: UInt[Array, ' p'],
1114 p_nonterminal: Float32[Array, ' d-1'],
1115 sigma_mu: Float32[Array, ''],
1116) -> SamplePriorTrees:
1117 """Sample a set of independent trees from the BART prior.
1119 Parameters
1120 ----------
1121 keys
1122 A sequence of jax random keys, one for each tree. This determined the
1123 number of trees sampled.
1124 max_split
1125 The maximum split value for each variable.
1126 p_nonterminal
1127 The prior probability of a node being non-terminal conditional on
1128 its ancestors and on having available decision rules, at each depth.
1129 sigma_mu
1130 The prior standard deviation of each leaf.
1132 Returns
1133 -------
1134 An object containing the generated trees.
1135 """
1136 return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu) 1ab
1139@partial(jit, static_argnums=(1, 2)) 1ab
1140def sample_prior( 1ab
1141 key: Key[Array, ''],
1142 trace_length: int,
1143 num_trees: int,
1144 max_split: UInt[Array, ' p'],
1145 p_nonterminal: Float32[Array, ' d-1'],
1146 sigma_mu: Float32[Array, ''],
1147) -> SamplePriorTrees:
1148 """Sample independent trees from the BART prior.
1150 Parameters
1151 ----------
1152 key
1153 A jax random key.
1154 trace_length
1155 The number of iterations.
1156 num_trees
1157 The number of trees for each iteration.
1158 max_split
1159 The number of cutpoints along each variable.
1160 p_nonterminal
1161 The prior probability of a node being non-terminal conditional on
1162 its ancestors and on having available decision rules, at each depth.
1163 This determines the maximum depth of the trees.
1164 sigma_mu
1165 The prior standard deviation of each leaf.
1167 Returns
1168 -------
1169 An object containing the generated trees, with batch shape (trace_length, num_trees).
1170 """
1171 keys = random.split(key, trace_length * num_trees) 1ab
1172 trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu) 1ab
1173 return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees) 1ab
1176class debug_gbart(gbart): 1ab
1177 """A subclass of `gbart` that adds debugging functionality.
1179 Parameters
1180 ----------
1181 *args
1182 Passed to `gbart`.
1183 check_trees
1184 If `True`, check all trees with `check_trace` after running the MCMC,
1185 and assert that they are all valid. Set to `False` to allow jax tracing.
1186 **kw
1187 Passed to `gbart`.
1188 """
1190 def __init__(self, *args, check_trees: bool = True, **kw): 1ab
1191 super().__init__(*args, **kw) 1ab
1192 if check_trees: 1ab
1193 bad = self.check_trees() 1ab
1194 bad_count = jnp.count_nonzero(bad) 1ab
1195 assert bad_count == 0 1ab
1197 def show_tree(self, i_sample: int, i_tree: int, print_all: bool = False): 1ab
1198 """Print a single tree in human-readable format.
1200 Parameters
1201 ----------
1202 i_sample
1203 The index of the posterior sample.
1204 i_tree
1205 The index of the tree in the sample.
1206 print_all
1207 If `True`, also print the content of unused node slots.
1208 """
1209 tree = TreesTrace.from_dataclass(self._main_trace)
1210 tree = tree_map(lambda x: x[i_sample, i_tree, :], tree)
1211 s = format_tree(tree, print_all=print_all)
1212 print(s) # noqa: T201, this method is intended for debug
1214 def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, '']: 1ab
1215 """Return the harmonic mean of the error variance.
1217 Parameters
1218 ----------
1219 prior
1220 If `True`, use the prior distribution, otherwise use the full
1221 conditional at the last MCMC iteration.
1223 Returns
1224 -------
1225 The harmonic mean 1/E[1/sigma^2] in the selected distribution.
1226 """
1227 bart = self._mcmc_state
1228 assert bart.sigma2_alpha is not None
1229 assert bart.z is None
1230 if prior:
1231 alpha = bart.sigma2_alpha
1232 beta = bart.sigma2_beta
1233 else:
1234 resid = bart.resid
1235 alpha = bart.sigma2_alpha + resid.size / 2
1236 norm2 = resid @ resid
1237 beta = bart.sigma2_beta + norm2 / 2
1238 sigma2 = beta / alpha
1239 return jnp.sqrt(sigma2)
1241 def compare_resid(self) -> tuple[Float32[Array, ' n'], Float32[Array, ' n']]: 1ab
1242 """Re-compute residuals to compare them with the updated ones.
1244 Returns
1245 -------
1246 resid1 : Float32[Array, 'n']
1247 The final state of the residuals updated during the MCMC.
1248 resid2 : Float32[Array, 'n']
1249 The residuals computed from the final state of the trees.
1250 """
1251 bart = self._mcmc_state 1ab
1252 resid1 = bart.resid 1ab
1254 trees = evaluate_forest(bart.X, bart.forest) 1ab
1256 if bart.z is not None: 1ab
1257 ref = bart.z 1ab
1258 else:
1259 ref = bart.y 1ab
1260 resid2 = ref - (trees + bart.offset) 1ab
1262 return resid1, resid2 1ab
1264 def avg_acc(self) -> tuple[Float32[Array, ''], Float32[Array, '']]: 1ab
1265 """Compute the average acceptance rates of tree moves.
1267 Returns
1268 -------
1269 acc_grow : Float32[Array, '']
1270 The average acceptance rate of grow moves.
1271 acc_prune : Float32[Array, '']
1272 The average acceptance rate of prune moves.
1273 """
1274 trace = self._main_trace
1276 def acc(prefix):
1277 acc = getattr(trace, f'{prefix}_acc_count')
1278 prop = getattr(trace, f'{prefix}_prop_count')
1279 return acc.sum() / prop.sum()
1281 return acc('grow'), acc('prune')
1283 def avg_prop(self) -> tuple[Float32[Array, ''], Float32[Array, '']]: 1ab
1284 """Compute the average proposal rate of grow and prune moves.
1286 Returns
1287 -------
1288 prop_grow : Float32[Array, '']
1289 The fraction of times grow was proposed instead of prune.
1290 prop_prune : Float32[Array, '']
1291 The fraction of times prune was proposed instead of grow.
1293 Notes
1294 -----
1295 This function does not take into account cases where no move was
1296 proposed.
1297 """
1298 trace = self._main_trace
1300 def prop(prefix):
1301 return getattr(trace, f'{prefix}_prop_count').sum()
1303 pgrow = prop('grow')
1304 pprune = prop('prune')
1305 total = pgrow + pprune
1306 return pgrow / total, pprune / total
1308 def avg_move(self) -> tuple[Float32[Array, ''], Float32[Array, '']]: 1ab
1309 """Compute the move rate.
1311 Returns
1312 -------
1313 rate_grow : Float32[Array, '']
1314 The fraction of times a grow move was proposed and accepted.
1315 rate_prune : Float32[Array, '']
1316 The fraction of times a prune move was proposed and accepted.
1317 """
1318 agrow, aprune = self.avg_acc()
1319 pgrow, pprune = self.avg_prop()
1320 return agrow * pgrow, aprune * pprune
1322 def depth_distr(self) -> Float32[Array, 'trace_length d']: 1ab
1323 """Histogram of tree depths for each state of the trees.
1325 Returns
1326 -------
1327 A matrix where each row contains a histogram of tree depths.
1328 """
1329 return trace_depth_distr(self._main_trace.split_tree) 1ab
1331 def points_per_decision_node_distr(self) -> Float32[Array, 'trace_length n+1']: 1ab
1332 """Histogram of number of points belonging to parent-of-leaf nodes.
1334 Returns
1335 -------
1336 A matrix where each row contains a histogram of number of points.
1337 """
1338 return trace_points_per_decision_node_distr( 1ab
1339 self._main_trace, self._mcmc_state.X
1340 )
1342 def points_per_leaf_distr(self) -> Float32[Array, 'trace_length n+1']: 1ab
1343 """Histogram of number of points belonging to leaves.
1345 Returns
1346 -------
1347 A matrix where each row contains a histogram of number of points.
1348 """
1349 return trace_points_per_leaf_distr(self._main_trace, self._mcmc_state.X) 1ab
1351 def check_trees(self) -> UInt[Array, 'trace_length ntree']: 1ab
1352 """Apply `check_trace` to all the tree draws."""
1353 return check_trace(self._main_trace, self._mcmc_state.forest.max_split) 1ab
1355 def tree_goes_bad(self) -> Bool[Array, 'trace_length ntree']: 1ab
1356 """Find iterations where a tree becomes invalid.
1358 Returns
1359 -------
1360 A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
1361 """
1362 bad = self.check_trees().astype(bool)
1363 bad_before = jnp.pad(bad[:-1], [(1, 0), (0, 0)])
1364 return bad & ~bad_before