Coverage for src/bartz/debug.py: 87%
442 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-07-31 16:09 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-07-31 16:09 +0000
1# bartz/src/bartz/debug.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Debugging utilities. The main functionality is the class `debug_mc_gbart`."""
27from collections.abc import Callable 1ab
28from dataclasses import replace 1ab
29from functools import partial 1ab
30from math import ceil, log2 1ab
31from re import fullmatch 1ab
33import numpy 1ab
34from equinox import Module, field 1ab
35from jax import jit, lax, random, vmap 1ab
36from jax import numpy as jnp 1ab
37from jax.tree_util import tree_map 1ab
38from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, UInt 1ab
40from bartz.BART import FloatLike, gbart, mc_gbart 1ab
41from bartz.grove import ( 1ab
42 TreeHeaps,
43 evaluate_forest,
44 is_actual_leaf,
45 is_leaves_parent,
46 traverse_tree,
47 tree_depth,
48 tree_depths,
49)
50from bartz.jaxext import minimal_unsigned_dtype, vmap_nodoc 1ab
51from bartz.jaxext import split as split_key 1ab
52from bartz.mcmcloop import TreesTrace 1ab
53from bartz.mcmcstep import randint_masked 1ab
56def format_tree(tree: TreeHeaps, *, print_all: bool = False) -> str: 1ab
57 """Convert a tree to a human-readable string.
59 Parameters
60 ----------
61 tree
62 A single tree to format.
63 print_all
64 If `True`, also print the contents of unused node slots in the arrays.
66 Returns
67 -------
68 A string representation of the tree.
69 """
70 tee = '├──' 1ab
71 corner = '└──' 1ab
72 join = '│ ' 1ab
73 space = ' ' 1ab
74 down = '┐' 1ab
75 bottom = '╢' # '┨' # 1ab
77 def traverse_tree( 1ab
78 lines: list[str],
79 index: int,
80 depth: int,
81 indent: str,
82 first_indent: str,
83 next_indent: str,
84 unused: bool,
85 ):
86 if index >= len(tree.leaf_tree): 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true1ab
87 return
89 var: int = tree.var_tree.at[index].get(mode='fill', fill_value=0).item() 1ab
90 split: int = tree.split_tree.at[index].get(mode='fill', fill_value=0).item() 1ab
92 is_leaf = split == 0 1ab
93 left_child = 2 * index 1ab
94 right_child = 2 * index + 1 1ab
96 if print_all: 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true1ab
97 if unused:
98 category = 'unused'
99 elif is_leaf:
100 category = 'leaf'
101 else:
102 category = 'decision'
103 node_str = f'{category}({var}, {split}, {tree.leaf_tree[index]})'
104 else:
105 assert not unused 1ab
106 if is_leaf: 1ab
107 node_str = f'{tree.leaf_tree[index]:#.2g}' 1ab
108 else:
109 node_str = f'x{var} < {split}' 1ab
111 if not is_leaf or (print_all and left_child < len(tree.leaf_tree)): 1ab
112 link = down 1ab
113 elif not print_all and left_child >= len(tree.leaf_tree): 1ab
114 link = bottom 1ab
115 else:
116 link = ' ' 1ab
118 max_number = len(tree.leaf_tree) - 1 1ab
119 ndigits = len(str(max_number)) 1ab
120 number = str(index).rjust(ndigits) 1ab
122 lines.append(f' {number} {indent}{first_indent}{link}{node_str}') 1ab
124 indent += next_indent 1ab
125 unused = unused or is_leaf 1ab
127 if unused and not print_all: 1ab
128 return 1ab
130 traverse_tree(lines, left_child, depth + 1, indent, tee, join, unused) 1ab
131 traverse_tree(lines, right_child, depth + 1, indent, corner, space, unused) 1ab
133 lines = [] 1ab
134 traverse_tree(lines, 1, 0, '', '', '', False) 1ab
135 return '\n'.join(lines) 1ab
138def tree_actual_depth(split_tree: UInt[Array, ' 2**(d-1)']) -> Int32[Array, '']: 1ab
139 """Measure the depth of the tree.
141 Parameters
142 ----------
143 split_tree
144 The cutpoints of the decision rules.
146 Returns
147 -------
148 The depth of the deepest leaf in the tree. The root is at depth 0.
149 """
150 # this could be done just with split_tree != 0
151 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab
152 depth = tree_depths(is_leaf.size) 1ab
153 depth = jnp.where(is_leaf, depth, 0) 1ab
154 return jnp.max(depth) 1ab
157def forest_depth_distr( 1ab
158 split_tree: UInt[Array, 'num_trees 2**(d-1)'],
159) -> Int32[Array, ' d']:
160 """Histogram the depths of a set of trees.
162 Parameters
163 ----------
164 split_tree
165 The cutpoints of the decision rules of the trees.
167 Returns
168 -------
169 An integer vector where the i-th element counts how many trees have depth i.
170 """
171 depth = tree_depth(split_tree) + 1 1ab
172 depths = vmap(tree_actual_depth)(split_tree) 1ab
173 return jnp.bincount(depths, length=depth) 1ab
176@jit 1ab
177def trace_depth_distr( 1ab
178 split_tree: UInt[Array, 'trace_length num_trees 2**(d-1)'],
179) -> Int32[Array, 'trace_length d']:
180 """Histogram the depths of a sequence of sets of trees.
182 Parameters
183 ----------
184 split_tree
185 The cutpoints of the decision rules of the trees.
187 Returns
188 -------
189 A matrix where element (t,i) counts how many trees have depth i in set t.
190 """
191 return vmap(forest_depth_distr)(split_tree) 1ab
194@vmap_nodoc 1ab
195def chains_depth_distr( 1ab
196 split_tree: UInt[Array, 'nchains trace_length num_trees 2**(d-1)'],
197) -> Int32[Array, 'nchains trace_length d']:
198 """Histogram the depths of chains of forests of trees.
200 Parameters
201 ----------
202 split_tree
203 The cutpoints of the decision rules of the trees.
205 Returns
206 -------
207 A tensor where element (c,t,i) counts how many trees have depth i in forest t in chain c.
208 """
209 return trace_depth_distr(split_tree) 1ab
212def points_per_decision_node_distr( 1ab
213 var_tree: UInt[Array, ' 2**(d-1)'],
214 split_tree: UInt[Array, ' 2**(d-1)'],
215 X: UInt[Array, 'p n'],
216) -> Int32[Array, ' n+1']:
217 """Histogram points-per-node counts.
219 Count how many parent-of-leaf nodes in a tree select each possible amount
220 of points.
222 Parameters
223 ----------
224 var_tree
225 The variables of the decision rules.
226 split_tree
227 The cutpoints of the decision rules.
228 X
229 The set of points to count.
231 Returns
232 -------
233 A vector where the i-th element counts how many next-to-leaf nodes have i points.
234 """
235 traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None)) 1ab
236 indices = traverse_tree_X(X, var_tree, split_tree) 1ab
237 indices >>= 1 1ab
238 count_tree = jnp.zeros(split_tree.size, int).at[indices].add(1).at[0].set(0) 1ab
239 is_parent = is_leaves_parent(split_tree) 1ab
240 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_parent) 1ab
243def forest_points_per_decision_node_distr( 1ab
244 trees: TreeHeaps, X: UInt[Array, 'p n']
245) -> Int32[Array, ' n+1']:
246 """Histogram points-per-node counts for a set of trees.
248 Count how many parent-of-leaf nodes in a set of trees select each possible
249 amount of points.
251 Parameters
252 ----------
253 trees
254 The set of trees. The variables must have broadcast shape (num_trees,).
255 X
256 The set of points to count.
258 Returns
259 -------
260 A vector where the i-th element counts how many next-to-leaf nodes have i points.
261 """
262 distr = jnp.zeros(X.shape[1] + 1, int) 1ab
264 def loop(distr, heaps: tuple[Array, Array]): 1ab
265 return distr + points_per_decision_node_distr(*heaps, X), None 1ab
267 distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree)) 1ab
268 return distr 1ab
271@jit 1ab
272@partial(vmap_nodoc, in_axes=(0, None)) 1ab
273def chains_points_per_decision_node_distr( 1ab
274 chains: TreeHeaps, X: UInt[Array, 'p n']
275) -> Int32[Array, 'nchains trace_length n+1']:
276 """Separately histogram points-per-node counts over chains of forests of trees.
278 For each set of trees, count how many parent-of-leaf nodes select each
279 possible amount of points.
281 Parameters
282 ----------
283 chains
284 The chains of forests of trees. The variables must have broadcast shape
285 (nchains, trace_length, num_trees).
286 X
287 The set of points to count.
289 Returns
290 -------
291 A tensor where element (c,t,i) counts how many next-to-leaf nodes have i points in forest t in chain c.
292 """
294 def loop(_, forests): 1ab
295 return None, forest_points_per_decision_node_distr(forests, X) 1ab
297 _, distr = lax.scan(loop, None, chains) 1ab
298 return distr 1ab
301def points_per_leaf_distr( 1ab
302 var_tree: UInt[Array, ' 2**(d-1)'],
303 split_tree: UInt[Array, ' 2**(d-1)'],
304 X: UInt[Array, 'p n'],
305) -> Int32[Array, ' n+1']:
306 """Histogram points-per-leaf counts in a tree.
308 Count how many leaves in a tree select each possible amount of points.
310 Parameters
311 ----------
312 var_tree
313 The variables of the decision rules.
314 split_tree
315 The cutpoints of the decision rules.
316 X
317 The set of points to count.
319 Returns
320 -------
321 A vector where the i-th element counts how many leaves have i points.
322 """
323 traverse_tree_X = vmap(traverse_tree, in_axes=(1, None, None)) 1ab
324 indices = traverse_tree_X(X, var_tree, split_tree) 1ab
325 count_tree = jnp.zeros(2 * split_tree.size, int).at[indices].add(1) 1ab
326 is_leaf = is_actual_leaf(split_tree, add_bottom_level=True) 1ab
327 return jnp.zeros(X.shape[1] + 1, int).at[count_tree].add(is_leaf) 1ab
330def forest_points_per_leaf_distr( 1ab
331 trees: TreeHeaps, X: UInt[Array, 'p n']
332) -> Int32[Array, ' n+1']:
333 """Histogram points-per-leaf counts over a set of trees.
335 Count how many leaves in a set of trees select each possible amount of points.
337 Parameters
338 ----------
339 trees
340 The set of trees. The variables must have broadcast shape (num_trees,).
341 X
342 The set of points to count.
344 Returns
345 -------
346 A vector where the i-th element counts how many leaves have i points.
347 """
348 distr = jnp.zeros(X.shape[1] + 1, int) 1ab
350 def loop(distr, heaps: tuple[Array, Array]): 1ab
351 return distr + points_per_leaf_distr(*heaps, X), None 1ab
353 distr, _ = lax.scan(loop, distr, (trees.var_tree, trees.split_tree)) 1ab
354 return distr 1ab
357@jit 1ab
358@partial(vmap_nodoc, in_axes=(0, None)) 1ab
359def chains_points_per_leaf_distr( 1ab
360 chains: TreeHeaps, X: UInt[Array, 'p n']
361) -> Int32[Array, 'nchains trace_length n+1']:
362 """Separately histogram points-per-leaf counts over chains of forests of trees.
364 For each set of trees, count how many leaves select each possible amount of
365 points.
367 Parameters
368 ----------
369 chains
370 The chains of forests of trees. The variables must have broadcast shape
371 (nchains, trace_length, num_trees).
372 X
373 The set of points to count.
375 Returns
376 -------
377 A matrix where element (t,i) counts how many leaves have i points in set t.
378 """
380 def loop(_, forests): 1ab
381 return None, forest_points_per_leaf_distr(forests, X) 1ab
383 _, distr = lax.scan(loop, None, chains) 1ab
384 return distr 1ab
387check_functions = [] 1ab
390CheckFunc = Callable[[TreeHeaps, UInt[Array, ' p']], bool | Bool[Array, '']] 1ab
393def check(func: CheckFunc) -> CheckFunc: 1ab
394 """Add a function to a list of functions used to check trees.
396 Use to decorate functions that check whether a tree is valid in some way.
397 These functions are invoked automatically by `check_tree`, `check_trace` and
398 `debug_gbart`.
400 Parameters
401 ----------
402 func
403 The function to add to the list. It must accept a `TreeHeaps` and a
404 `max_split` argument, and return a boolean scalar that indicates if the
405 tree is ok.
407 Returns
408 -------
409 The function unchanged.
410 """
411 check_functions.append(func) 1ab
412 return func 1ab
415@check 1ab
416def check_types(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: 1ab
417 """Check that integer types are as small as possible and coherent."""
418 expected_var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab
419 expected_split_dtype = max_split.dtype 1ab
420 return ( 1ab
421 tree.var_tree.dtype == expected_var_dtype
422 and tree.split_tree.dtype == expected_split_dtype
423 )
426@check 1ab
427def check_sizes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> bool: # noqa: ARG001 1ab
428 """Check that array sizes are coherent."""
429 return tree.leaf_tree.size == 2 * tree.var_tree.size == 2 * tree.split_tree.size 1ab
432@check 1ab
433def check_unused_node(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab
434 """Check that the unused node slot at index 0 is not dirty."""
435 return (tree.var_tree[0] == 0) & (tree.split_tree[0] == 0) 1ab
438@check 1ab
439def check_leaf_values(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab
440 """Check that all leaf values are not inf of nan."""
441 return jnp.all(jnp.isfinite(tree.leaf_tree)) 1ab
444@check 1ab
445def check_stray_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab
446 """Check if there is any marked-non-leaf node with a marked-leaf parent."""
447 index = jnp.arange( 1ab
448 2 * tree.split_tree.size,
449 dtype=minimal_unsigned_dtype(2 * tree.split_tree.size - 1),
450 )
451 parent_index = index >> 1 1ab
452 is_not_leaf = tree.split_tree.at[index].get(mode='fill', fill_value=0) != 0 1ab
453 parent_is_leaf = tree.split_tree[parent_index] == 0 1ab
454 stray = is_not_leaf & parent_is_leaf 1ab
455 stray = stray.at[1].set(False) 1ab
456 return ~jnp.any(stray) 1ab
459@check 1ab
460def check_rule_consistency( 1ab
461 tree: TreeHeaps, max_split: UInt[Array, ' p']
462) -> bool | Bool[Array, '']:
463 """Check that decision rules define proper subsets of ancestor rules."""
464 if tree.var_tree.size < 4: 464 ↛ 465line 464 didn't jump to line 465 because the condition on line 464 was never true1ab
465 return True
467 # initial boundaries of decision rules. use extreme integers instead of 0,
468 # max_split to avoid checking if there is something out of bounds.
469 small = jnp.iinfo(jnp.int32).min 1ab
470 large = jnp.iinfo(jnp.int32).max 1ab
471 lower = jnp.full(max_split.size, small, jnp.int32) 1ab
472 upper = jnp.full(max_split.size, large, jnp.int32) 1ab
473 # specify the type explicitly, otherwise they are weakly types and get
474 # implicitly converted to split.dtype (typically uint8) in the expressions
476 def _check_recursive(node, lower, upper): 1ab
477 # read decision rule
478 var = tree.var_tree[node] 1ab
479 split = tree.split_tree[node] 1ab
481 # get rule boundaries from ancestors. use fill value in case var is
482 # out of bounds, we don't want to check out of bounds in this function
483 lower_var = lower.at[var].get(mode='fill', fill_value=small) 1ab
484 upper_var = upper.at[var].get(mode='fill', fill_value=large) 1ab
486 # check rule is in bounds
487 bad = jnp.where(split, (split <= lower_var) | (split >= upper_var), False) 1ab
489 # recurse
490 if node < tree.var_tree.size // 2: 1ab
491 idx = jnp.where(split, var, max_split.size) 1ab
492 bad |= _check_recursive(2 * node, lower, upper.at[idx].set(split)) 1ab
493 bad |= _check_recursive(2 * node + 1, lower.at[idx].set(split), upper) 1ab
495 return bad 1ab
497 return ~_check_recursive(1, lower, upper) 1ab
500@check 1ab
501def check_num_nodes(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> Bool[Array, '']: # noqa: ARG001 1ab
502 """Check that #leaves = 1 + #(internal nodes)."""
503 is_leaf = is_actual_leaf(tree.split_tree, add_bottom_level=True) 1ab
504 num_leaves = jnp.count_nonzero(is_leaf) 1ab
505 num_internal = jnp.count_nonzero(tree.split_tree) 1ab
506 return num_leaves == num_internal + 1 1ab
509@check 1ab
510def check_var_in_bounds( 1ab
511 tree: TreeHeaps, max_split: UInt[Array, ' p']
512) -> Bool[Array, '']:
513 """Check that variables are in [0, max_split.size)."""
514 decision_node = tree.split_tree.astype(bool) 1ab
515 in_bounds = (tree.var_tree >= 0) & (tree.var_tree < max_split.size) 1ab
516 return jnp.all(in_bounds | ~decision_node) 1ab
519@check 1ab
520def check_split_in_bounds( 1ab
521 tree: TreeHeaps, max_split: UInt[Array, ' p']
522) -> Bool[Array, '']:
523 """Check that splits are in [0, max_split[var]]."""
524 max_split_var = ( 1ab
525 max_split.astype(jnp.int32)
526 .at[tree.var_tree]
527 .get(mode='fill', fill_value=jnp.iinfo(jnp.int32).max)
528 )
529 return jnp.all((tree.split_tree >= 0) & (tree.split_tree <= max_split_var)) 1ab
532def check_tree(tree: TreeHeaps, max_split: UInt[Array, ' p']) -> UInt[Array, '']: 1ab
533 """Check the validity of a tree.
535 Use `describe_error` to parse the error code returned by this function.
537 Parameters
538 ----------
539 tree
540 The tree to check.
541 max_split
542 The maximum split value for each variable.
544 Returns
545 -------
546 An integer where each bit indicates whether a check failed.
547 """
548 error_type = minimal_unsigned_dtype(2 ** len(check_functions) - 1) 1ab
549 error = error_type(0) 1ab
550 for i, func in enumerate(check_functions): 1ab
551 ok = func(tree, max_split) 1ab
552 ok = jnp.bool_(ok) 1ab
553 bit = (~ok) << i 1ab
554 error |= bit 1ab
555 return error 1ab
558def describe_error(error: int | Integer[Array, '']) -> list[str]: 1ab
559 """Describe the error code returned by `check_tree`.
561 Parameters
562 ----------
563 error
564 The error code returned by `check_tree`.
566 Returns
567 -------
568 A list of the function names that implement the failed checks.
569 """
570 return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
573@jit 1ab
574@partial(vmap_nodoc, in_axes=(0, None)) 1ab
575def check_trace( 1ab
576 trace: TreeHeaps, max_split: UInt[Array, ' p']
577) -> UInt[Array, 'trace_length num_trees']:
578 """Check the validity of a sequence of sets of trees.
580 Use `describe_error` to parse the error codes returned by this function.
582 Parameters
583 ----------
584 trace
585 The sequence of sets of trees to check. The tree arrays must have
586 broadcast shape (trace_length, num_trees). This object can have
587 additional attributes beyond the tree arrays, they are ignored.
588 max_split
589 The maximum split value for each variable.
591 Returns
592 -------
593 A matrix of error codes for each tree.
594 """
595 trees = TreesTrace.from_dataclass(trace) 1ab
596 return lax.map(partial(check_tree, max_split=max_split), trees) 1ab
599@partial(vmap_nodoc, in_axes=(0, None)) 1ab
600def check_chains( 1ab
601 chains: TreeHeaps, max_split: UInt[Array, ' p']
602) -> UInt[Array, 'nchains trace_length num_trees']:
603 """Check the validity of sequences of sets of trees.
605 Use `describe_error` to parse the error codes returned by this function.
607 Parameters
608 ----------
609 chains
610 The sequences of sets of trees to check. The tree arrays must have
611 broadcast shape (nchains, trace_length, num_trees). This object can have
612 additional attributes beyond the tree arrays, they are ignored.
613 max_split
614 The maximum split value for each variable.
616 Returns
617 -------
618 A tensor of error codes for each tree.
619 """
620 return check_trace(chains, max_split) 1ab
623def _get_next_line(s: str, i: int) -> tuple[str, int]: 1ab
624 """Get the next line from a string and the new index."""
625 i_new = s.find('\n', i) 1ab
626 if i_new == -1: 626 ↛ 627line 626 didn't jump to line 627 because the condition on line 626 was never true1ab
627 return s[i:], len(s)
628 return s[i:i_new], i_new + 1 1ab
631class BARTTraceMeta(Module): 1ab
632 """Metadata of R BART tree traces.
634 Parameters
635 ----------
636 ndpost
637 The number of posterior draws.
638 ntree
639 The number of trees in the model.
640 numcut
641 The maximum split value for each variable.
642 heap_size
643 The size of the heap required to store the trees.
644 """
646 ndpost: int = field(static=True) 1ab
647 ntree: int = field(static=True) 1ab
648 numcut: UInt[Array, ' p'] 1ab
649 heap_size: int = field(static=True) 1ab
652def scan_BART_trees(trees: str) -> BARTTraceMeta: 1ab
653 """Scan an R BART tree trace checking for errors and parsing metadata.
655 Parameters
656 ----------
657 trees
658 The string representation of a trace of trees of the R BART package.
659 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
661 Returns
662 -------
663 An object containing the metadata.
665 Raises
666 ------
667 ValueError
668 If the string is malformed or contains leftover characters.
669 """
670 # parse first line
671 line, i_char = _get_next_line(trees, 0) 1ab
672 i_line = 1 1ab
673 match = fullmatch(r'(\d+) (\d+) (\d+)', line) 1ab
674 if match is None: 674 ↛ 675line 674 didn't jump to line 675 because the condition on line 674 was never true1ab
675 msg = f'Malformed header at {i_line=}'
676 raise ValueError(msg)
677 ndpost, ntree, p = map(int, match.groups()) 1ab
679 # initial values for maxima
680 max_heap_index = 0 1ab
681 numcut = numpy.zeros(p, int) 1ab
683 # cycle over iterations and trees
684 for i_iter in range(ndpost): 1ab
685 for i_tree in range(ntree): 1ab
686 # parse first line of tree definition
687 line, i_char = _get_next_line(trees, i_char) 1ab
688 i_line += 1 1ab
689 match = fullmatch(r'(\d+)', line) 1ab
690 if match is None: 690 ↛ 691line 690 didn't jump to line 691 because the condition on line 690 was never true1ab
691 msg = f'Malformed tree header at {i_iter=} {i_tree=} {i_line=}'
692 raise ValueError(msg)
693 num_nodes = int(line) 1ab
695 # cycle over nodes
696 for i_node in range(num_nodes): 1ab
697 # parse node definition
698 line, i_char = _get_next_line(trees, i_char) 1ab
699 i_line += 1 1ab
700 match = fullmatch( 1ab
701 r'(\d+) (\d+) (\d+) (-?\d+(\.\d+)?(e(\+|-|)\d+)?)', line
702 )
703 if match is None: 703 ↛ 704line 703 didn't jump to line 704 because the condition on line 703 was never true1ab
704 msg = f'Malformed node definition at {i_iter=} {i_tree=} {i_node=} {i_line=}'
705 raise ValueError(msg)
706 i_heap = int(match.group(1)) 1ab
707 var = int(match.group(2)) 1ab
708 split = int(match.group(3)) 1ab
710 # update maxima
711 numcut[var] = max(numcut[var], split) 1ab
712 max_heap_index = max(max_heap_index, i_heap) 1ab
714 assert i_char <= len(trees) 1ab
715 if i_char < len(trees): 715 ↛ 716line 715 didn't jump to line 716 because the condition on line 715 was never true1ab
716 msg = f'Leftover {len(trees) - i_char} characters in string'
717 raise ValueError(msg)
719 # determine minimal integer type for numcut
720 numcut += 1 # because BART is 0-based 1ab
721 split_dtype = minimal_unsigned_dtype(numcut.max()) 1ab
722 numcut = jnp.array(numcut.astype(split_dtype)) 1ab
724 # determine minimum heap size to store the trees
725 heap_size = 2 ** ceil(log2(max_heap_index + 1)) 1ab
727 return BARTTraceMeta(ndpost=ndpost, ntree=ntree, numcut=numcut, heap_size=heap_size) 1ab
730class TraceWithOffset(Module): 1ab
731 """Implementation of `bartz.mcmcloop.Trace`."""
733 leaf_tree: Float32[Array, 'ndpost ntree 2**d'] 1ab
734 var_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 1ab
735 split_tree: UInt[Array, 'ndpost ntree 2**(d-1)'] 1ab
736 offset: Float32[Array, ' ndpost'] 1ab
738 @classmethod 1ab
739 def from_trees_trace( 1ab
740 cls, trees: TreeHeaps, offset: Float32[Array, '']
741 ) -> 'TraceWithOffset':
742 """Create a `TraceWithOffset` from a `TreeHeaps`."""
743 ndpost, _, _ = trees.leaf_tree.shape 1ab
744 return cls( 1ab
745 leaf_tree=trees.leaf_tree,
746 var_tree=trees.var_tree,
747 split_tree=trees.split_tree,
748 offset=jnp.full(ndpost, offset),
749 )
752def trees_BART_to_bartz( 1ab
753 trees: str, *, min_maxdepth: int = 0, offset: FloatLike | None = None
754) -> tuple[TraceWithOffset, BARTTraceMeta]:
755 """Convert trees from the R BART format to the bartz format.
757 Parameters
758 ----------
759 trees
760 The string representation of a trace of trees of the R BART package.
761 Can be accessed from ``mc_gbart(...).treedraws['trees']``.
762 min_maxdepth
763 The maximum tree depth of the output will be set to the maximum
764 observed depth in the input trees. Use this parameter to require at
765 least this maximum depth in the output format.
766 offset
767 The trace returned by `bartz.mcmcloop.run_mcmc` contains an offset to be
768 summed to the sum of trees. To match that behavior, this function
769 returns an offset as well, zero by default. Set with this parameter
770 otherwise.
772 Returns
773 -------
774 trace : TraceWithOffset
775 A representation of the trees compatible with the trace returned by
776 `bartz.mcmcloop.run_mcmc`.
777 meta : BARTTraceMeta
778 The metadata of the trace, containing the number of iterations, trees,
779 and the maximum split value.
780 """
781 # scan all the string checking for errors and determining sizes
782 meta = scan_BART_trees(trees) 1ab
784 # skip first line
785 _, i_char = _get_next_line(trees, 0) 1ab
787 heap_size = max(meta.heap_size, 2**min_maxdepth) 1ab
788 leaf_trees = numpy.zeros((meta.ndpost, meta.ntree, heap_size), dtype=numpy.float32) 1ab
789 var_trees = numpy.zeros( 1ab
790 (meta.ndpost, meta.ntree, heap_size // 2),
791 dtype=minimal_unsigned_dtype(meta.numcut.size - 1),
792 )
793 split_trees = numpy.zeros( 1ab
794 (meta.ndpost, meta.ntree, heap_size // 2), dtype=meta.numcut.dtype
795 )
797 # cycle over iterations and trees
798 for i_iter in range(meta.ndpost): 1ab
799 for i_tree in range(meta.ntree): 1ab
800 # parse first line of tree definition
801 line, i_char = _get_next_line(trees, i_char) 1ab
802 num_nodes = int(line) 1ab
804 is_internal = numpy.zeros(heap_size // 2, dtype=bool) 1ab
806 # cycle over nodes
807 for _ in range(num_nodes): 1ab
808 # parse node definition
809 line, i_char = _get_next_line(trees, i_char) 1ab
810 values = line.split() 1ab
811 i_heap = int(values[0]) 1ab
812 var = int(values[1]) 1ab
813 split = int(values[2]) 1ab
814 leaf = float(values[3]) 1ab
816 # update values
817 leaf_trees[i_iter, i_tree, i_heap] = leaf 1ab
818 is_internal[i_heap // 2] = True 1ab
819 if i_heap < heap_size // 2: 1ab
820 var_trees[i_iter, i_tree, i_heap] = var 1ab
821 split_trees[i_iter, i_tree, i_heap] = split + 1 1ab
823 is_internal[0] = False 1ab
824 split_trees[i_iter, i_tree, ~is_internal] = 0 1ab
826 return TraceWithOffset( 1ab
827 leaf_tree=jnp.array(leaf_trees),
828 var_tree=jnp.array(var_trees),
829 split_tree=jnp.array(split_trees),
830 offset=jnp.zeros(meta.ndpost)
831 if offset is None
832 else jnp.full(meta.ndpost, offset),
833 ), meta
836class SamplePriorStack(Module): 1ab
837 """Represent the manually managed stack used in `sample_prior`.
839 Each level of the stack represents a recursion into a child node in a
840 binary tree of maximum depth `d`.
842 Parameters
843 ----------
844 nonterminal
845 Whether the node is valid or the recursion is into unused node slots.
846 lower
847 upper
848 The available cutpoints along ``var`` are in the integer range
849 ``[1 + lower[var], 1 + upper[var])``.
850 var
851 split
852 The variable and cutpoint of a decision node.
853 """
855 nonterminal: Bool[Array, ' d-1'] 1ab
856 lower: UInt[Array, 'd-1 p'] 1ab
857 upper: UInt[Array, 'd-1 p'] 1ab
858 var: UInt[Array, ' d-1'] 1ab
859 split: UInt[Array, ' d-1'] 1ab
861 @classmethod 1ab
862 def initial( 1ab
863 cls, p_nonterminal: Float32[Array, ' d-1'], max_split: UInt[Array, ' p']
864 ) -> 'SamplePriorStack':
865 """Initialize the stack.
867 Parameters
868 ----------
869 p_nonterminal
870 The prior probability of a node being non-terminal conditional on
871 its ancestors and on having available decision rules, at each depth.
872 max_split
873 The number of cutpoints along each variable.
875 Returns
876 -------
877 A `SamplePriorStack` initialized to start the recursion.
878 """
879 var_dtype = minimal_unsigned_dtype(max_split.size - 1) 1ab
880 return cls( 1ab
881 nonterminal=jnp.ones(p_nonterminal.size, bool),
882 lower=jnp.zeros((p_nonterminal.size, max_split.size), max_split.dtype),
883 upper=jnp.broadcast_to(max_split, (p_nonterminal.size, max_split.size)),
884 var=jnp.zeros(p_nonterminal.size, var_dtype),
885 split=jnp.zeros(p_nonterminal.size, max_split.dtype),
886 )
889class SamplePriorTrees(Module): 1ab
890 """Object holding the trees generated by `sample_prior`.
892 Parameters
893 ----------
894 leaf_tree
895 var_tree
896 split_tree
897 The arrays representing the trees, see `bartz.grove`.
898 """
900 leaf_tree: Float32[Array, '* 2**d'] 1ab
901 var_tree: UInt[Array, '* 2**(d-1)'] 1ab
902 split_tree: UInt[Array, '* 2**(d-1)'] 1ab
904 @classmethod 1ab
905 def initial( 1ab
906 cls,
907 key: Key[Array, ''],
908 sigma_mu: Float32[Array, ''],
909 p_nonterminal: Float32[Array, ' d-1'],
910 max_split: UInt[Array, ' p'],
911 ) -> 'SamplePriorTrees':
912 """Initialize the trees.
914 The leaves are already correct and do not need to be changed.
916 Parameters
917 ----------
918 key
919 A jax random key.
920 sigma_mu
921 The prior standard deviation of each leaf.
922 p_nonterminal
923 The prior probability of a node being non-terminal conditional on
924 its ancestors and on having available decision rules, at each depth.
925 max_split
926 The number of cutpoints along each variable.
928 Returns
929 -------
930 Trees initialized with random leaves and stub tree structures.
931 """
932 heap_size = 2 ** (p_nonterminal.size + 1) 1ab
933 return cls( 1ab
934 leaf_tree=sigma_mu * random.normal(key, (heap_size,)),
935 var_tree=jnp.zeros(
936 heap_size // 2, dtype=minimal_unsigned_dtype(max_split.size - 1)
937 ),
938 split_tree=jnp.zeros(heap_size // 2, dtype=max_split.dtype),
939 )
942class SamplePriorCarry(Module): 1ab
943 """Object holding values carried along the recursion in `sample_prior`.
945 Parameters
946 ----------
947 key
948 A jax random key used to sample decision rules.
949 stack
950 The stack used to manage the recursion.
951 trees
952 The output arrays.
953 """
955 key: Key[Array, ''] 1ab
956 stack: SamplePriorStack 1ab
957 trees: SamplePriorTrees 1ab
959 @classmethod 1ab
960 def initial( 1ab
961 cls,
962 key: Key[Array, ''],
963 sigma_mu: Float32[Array, ''],
964 p_nonterminal: Float32[Array, ' d-1'],
965 max_split: UInt[Array, ' p'],
966 ) -> 'SamplePriorCarry':
967 """Initialize the carry object.
969 Parameters
970 ----------
971 key
972 A jax random key.
973 sigma_mu
974 The prior standard deviation of each leaf.
975 p_nonterminal
976 The prior probability of a node being non-terminal conditional on
977 its ancestors and on having available decision rules, at each depth.
978 max_split
979 The number of cutpoints along each variable.
981 Returns
982 -------
983 A `SamplePriorCarry` initialized to start the recursion.
984 """
985 keys = split_key(key) 1ab
986 return cls( 1ab
987 keys.pop(),
988 SamplePriorStack.initial(p_nonterminal, max_split),
989 SamplePriorTrees.initial(keys.pop(), sigma_mu, p_nonterminal, max_split),
990 )
993class SamplePriorX(Module): 1ab
994 """Object representing the recursion scan in `sample_prior`.
996 The sequence of nodes to visit is pre-computed recursively once, unrolling
997 the recursion schedule.
999 Parameters
1000 ----------
1001 node
1002 The heap index of the node to visit.
1003 depth
1004 The depth of the node.
1005 next_depth
1006 The depth of the next node to visit, either the left child or the right
1007 sibling of the node or of an ancestor.
1008 """
1010 node: Int32[Array, ' 2**(d-1)-1'] 1ab
1011 depth: Int32[Array, ' 2**(d-1)-1'] 1ab
1012 next_depth: Int32[Array, ' 2**(d-1)-1'] 1ab
1014 @classmethod 1ab
1015 def initial(cls, p_nonterminal: Float32[Array, ' d-1']) -> 'SamplePriorX': 1ab
1016 """Initialize the sequence of nodes to visit.
1018 Parameters
1019 ----------
1020 p_nonterminal
1021 The prior probability of a node being non-terminal conditional on
1022 its ancestors and on having available decision rules, at each depth.
1024 Returns
1025 -------
1026 A `SamplePriorX` initialized with the sequence of nodes to visit.
1027 """
1028 seq = cls._sequence(p_nonterminal.size) 1ab
1029 assert len(seq) == 2**p_nonterminal.size - 1 1ab
1030 node = [node for node, depth in seq] 1ab
1031 depth = [depth for node, depth in seq] 1ab
1032 next_depth = depth[1:] + [p_nonterminal.size] 1ab
1033 return cls( 1ab
1034 node=jnp.array(node),
1035 depth=jnp.array(depth),
1036 next_depth=jnp.array(next_depth),
1037 )
1039 @classmethod 1ab
1040 def _sequence( 1ab
1041 cls, max_depth: int, depth: int = 0, node: int = 1
1042 ) -> tuple[tuple[int, int], ...]:
1043 """Recursively generate a sequence [(node, depth), ...]."""
1044 if depth < max_depth: 1ab
1045 out = ((node, depth),) 1ab
1046 out += cls._sequence(max_depth, depth + 1, 2 * node) 1ab
1047 out += cls._sequence(max_depth, depth + 1, 2 * node + 1) 1ab
1048 return out 1ab
1049 return () 1ab
1052def sample_prior_onetree( 1ab
1053 key: Key[Array, ''],
1054 max_split: UInt[Array, ' p'],
1055 p_nonterminal: Float32[Array, ' d-1'],
1056 sigma_mu: Float32[Array, ''],
1057) -> SamplePriorTrees:
1058 """Sample a tree from the BART prior.
1060 Parameters
1061 ----------
1062 key
1063 A jax random key.
1064 max_split
1065 The maximum split value for each variable.
1066 p_nonterminal
1067 The prior probability of a node being non-terminal conditional on
1068 its ancestors and on having available decision rules, at each depth.
1069 sigma_mu
1070 The prior standard deviation of each leaf.
1072 Returns
1073 -------
1074 An object containing a generated tree.
1075 """
1076 carry = SamplePriorCarry.initial(key, sigma_mu, p_nonterminal, max_split) 1ab
1077 xs = SamplePriorX.initial(p_nonterminal) 1ab
1079 def loop(carry: SamplePriorCarry, x: SamplePriorX): 1ab
1080 keys = split_key(carry.key, 4) 1ab
1082 # get variables at current stack level
1083 stack = carry.stack 1ab
1084 nonterminal = stack.nonterminal[x.depth] 1ab
1085 lower = stack.lower[x.depth, :] 1ab
1086 upper = stack.upper[x.depth, :] 1ab
1088 # sample a random decision rule
1089 available: Bool[Array, ' p'] = lower < upper 1ab
1090 allowed = jnp.any(available) 1ab
1091 var = randint_masked(keys.pop(), available) 1ab
1092 split = 1 + random.randint(keys.pop(), (), lower[var], upper[var]) 1ab
1094 # cast to shorter integer types
1095 var = var.astype(carry.trees.var_tree.dtype) 1ab
1096 split = split.astype(carry.trees.split_tree.dtype) 1ab
1098 # decide whether to try to grow the node if it is growable
1099 pnt = p_nonterminal[x.depth] 1ab
1100 try_nonterminal: Bool[Array, ''] = random.bernoulli(keys.pop(), pnt) 1ab
1101 nonterminal &= try_nonterminal & allowed 1ab
1103 # update trees
1104 trees = carry.trees 1ab
1105 trees = replace( 1ab
1106 trees,
1107 var_tree=trees.var_tree.at[x.node].set(var),
1108 split_tree=trees.split_tree.at[x.node].set(
1109 jnp.where(nonterminal, split, 0)
1110 ),
1111 )
1113 def write_push_stack() -> SamplePriorStack: 1ab
1114 """Update the stack to go to the left child."""
1115 return replace( 1ab
1116 stack,
1117 nonterminal=stack.nonterminal.at[x.next_depth].set(nonterminal),
1118 lower=stack.lower.at[x.next_depth, :].set(lower),
1119 upper=stack.upper.at[x.next_depth, :].set(upper.at[var].set(split - 1)),
1120 var=stack.var.at[x.depth].set(var),
1121 split=stack.split.at[x.depth].set(split),
1122 )
1124 def pop_push_stack() -> SamplePriorStack: 1ab
1125 """Update the stack to go to the right sibling, possibly at lower depth."""
1126 var = stack.var[x.next_depth - 1] 1ab
1127 split = stack.split[x.next_depth - 1] 1ab
1128 lower = stack.lower[x.next_depth - 1, :] 1ab
1129 upper = stack.upper[x.next_depth - 1, :] 1ab
1130 return replace( 1ab
1131 stack,
1132 lower=stack.lower.at[x.next_depth, :].set(lower.at[var].set(split)),
1133 upper=stack.upper.at[x.next_depth, :].set(upper),
1134 )
1136 # update stack
1137 stack = lax.cond(x.next_depth > x.depth, write_push_stack, pop_push_stack) 1ab
1139 # update carry
1140 carry = replace(carry, key=keys.pop(), stack=stack, trees=trees) 1ab
1141 return carry, None 1ab
1143 carry, _ = lax.scan(loop, carry, xs) 1ab
1144 return carry.trees 1ab
1147@partial(vmap_nodoc, in_axes=(0, None, None, None)) 1ab
1148def sample_prior_forest( 1ab
1149 keys: Key[Array, ' num_trees'],
1150 max_split: UInt[Array, ' p'],
1151 p_nonterminal: Float32[Array, ' d-1'],
1152 sigma_mu: Float32[Array, ''],
1153) -> SamplePriorTrees:
1154 """Sample a set of independent trees from the BART prior.
1156 Parameters
1157 ----------
1158 keys
1159 A sequence of jax random keys, one for each tree. This determined the
1160 number of trees sampled.
1161 max_split
1162 The maximum split value for each variable.
1163 p_nonterminal
1164 The prior probability of a node being non-terminal conditional on
1165 its ancestors and on having available decision rules, at each depth.
1166 sigma_mu
1167 The prior standard deviation of each leaf.
1169 Returns
1170 -------
1171 An object containing the generated trees.
1172 """
1173 return sample_prior_onetree(keys, max_split, p_nonterminal, sigma_mu) 1ab
1176@partial(jit, static_argnums=(1, 2)) 1ab
1177def sample_prior( 1ab
1178 key: Key[Array, ''],
1179 trace_length: int,
1180 num_trees: int,
1181 max_split: UInt[Array, ' p'],
1182 p_nonterminal: Float32[Array, ' d-1'],
1183 sigma_mu: Float32[Array, ''],
1184) -> SamplePriorTrees:
1185 """Sample independent trees from the BART prior.
1187 Parameters
1188 ----------
1189 key
1190 A jax random key.
1191 trace_length
1192 The number of iterations.
1193 num_trees
1194 The number of trees for each iteration.
1195 max_split
1196 The number of cutpoints along each variable.
1197 p_nonterminal
1198 The prior probability of a node being non-terminal conditional on
1199 its ancestors and on having available decision rules, at each depth.
1200 This determines the maximum depth of the trees.
1201 sigma_mu
1202 The prior standard deviation of each leaf.
1204 Returns
1205 -------
1206 An object containing the generated trees, with batch shape (trace_length, num_trees).
1207 """
1208 keys = random.split(key, trace_length * num_trees) 1ab
1209 trees = sample_prior_forest(keys, max_split, p_nonterminal, sigma_mu) 1ab
1210 return tree_map(lambda x: x.reshape(trace_length, num_trees, -1), trees) 1ab
1213@partial(jit, static_argnames=('sum_trees',)) 1ab
1214def evaluate_forests( 1ab
1215 X: UInt[Array, 'p n'], trees: TreeHeaps, *, sum_trees: bool = True
1216) -> Float32[Array, 'nforests n'] | Float32[Array, 'nforests num_trees n']:
1217 """
1218 Evaluate ensembles of trees at an array of points.
1220 Parameters
1221 ----------
1222 X
1223 The coordinates to evaluate the trees at.
1224 trees
1225 The tree heaps, with batch shape (nforests, num_trees).
1226 sum_trees
1227 Whether to sum the values in each forest.
1229 Returns
1230 -------
1231 The (sum of) the values of the trees at the points in `X`.
1232 """
1234 @partial(vmap, in_axes=(None, 0)) 1ab
1235 def _evaluate_forests(X, trees): 1ab
1236 return evaluate_forest(X, trees, sum_trees=sum_trees) 1ab
1238 return _evaluate_forests(X, trees) 1ab
1241class debug_mc_gbart(mc_gbart): 1ab
1242 """A subclass of `mc_gbart` that adds debugging functionality.
1244 Parameters
1245 ----------
1246 *args
1247 Passed to `mc_gbart`.
1248 check_trees
1249 If `True`, check all trees with `check_trace` after running the MCMC,
1250 and assert that they are all valid. Set to `False` to allow jax tracing.
1251 **kw
1252 Passed to `mc_gbart`.
1253 """
1255 def __init__(self, *args, check_trees: bool = True, **kw): 1ab
1256 super().__init__(*args, **kw) 1ab
1257 if check_trees: 1ab
1258 bad = self.check_trees() 1ab
1259 bad_count = jnp.count_nonzero(bad) 1ab
1260 assert bad_count == 0 1ab
1262 def print_tree( 1ab
1263 self, i_chain: int, i_sample: int, i_tree: int, print_all: bool = False
1264 ):
1265 """Print a single tree in human-readable format.
1267 Parameters
1268 ----------
1269 i_chain
1270 The index of the MCMC chain.
1271 i_sample
1272 The index of the (post-burnin) sample in the chain.
1273 i_tree
1274 The index of the tree in the sample.
1275 print_all
1276 If `True`, also print the content of unused node slots.
1277 """
1278 tree = TreesTrace.from_dataclass(self._main_trace)
1279 tree = tree_map(lambda x: x[i_chain, i_sample, i_tree, :], tree)
1280 s = format_tree(tree, print_all=print_all)
1281 print(s) # noqa: T201, this method is intended for debug
1283 def sigma_harmonic_mean(self, prior: bool = False) -> Float32[Array, ' mc_cores']: 1ab
1284 """Return the harmonic mean of the error variance.
1286 Parameters
1287 ----------
1288 prior
1289 If `True`, use the prior distribution, otherwise use the full
1290 conditional at the last MCMC iteration.
1292 Returns
1293 -------
1294 The harmonic mean 1/E[1/sigma^2] in the selected distribution.
1295 """
1296 bart = self._mcmc_state
1297 assert bart.sigma2_alpha is not None
1298 assert bart.z is None
1299 if prior:
1300 alpha = bart.sigma2_alpha
1301 beta = bart.sigma2_beta
1302 else:
1303 alpha = bart.sigma2_alpha + bart.resid.size / 2
1304 norm2 = jnp.einsum('ij,ij->i', bart.resid, bart.resid)
1305 beta = bart.sigma2_beta + norm2 / 2
1306 sigma2 = beta / alpha
1307 return jnp.sqrt(sigma2)
1309 def compare_resid( 1ab
1310 self,
1311 ) -> tuple[Float32[Array, 'mc_cores n'], Float32[Array, 'mc_cores n']]:
1312 """Re-compute residuals to compare them with the updated ones.
1314 Returns
1315 -------
1316 resid1 : Float32[Array, 'mc_cores n']
1317 The final state of the residuals updated during the MCMC.
1318 resid2 : Float32[Array, 'mc_cores n']
1319 The residuals computed from the final state of the trees.
1320 """
1321 bart = self._mcmc_state 1ab
1322 resid1 = bart.resid 1ab
1324 forests = TreesTrace.from_dataclass(bart.forest) 1ab
1325 trees = evaluate_forests(bart.X, forests) 1ab
1327 if bart.z is not None: 1ab
1328 ref = bart.z 1ab
1329 else:
1330 ref = bart.y 1ab
1331 resid2 = ref - (trees + bart.offset) 1ab
1333 return resid1, resid2 1ab
1335 def avg_acc( 1ab
1336 self,
1337 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1338 """Compute the average acceptance rates of tree moves.
1340 Returns
1341 -------
1342 acc_grow : Float32[Array, 'mc_cores']
1343 The average acceptance rate of grow moves.
1344 acc_prune : Float32[Array, 'mc_cores']
1345 The average acceptance rate of prune moves.
1346 """
1347 trace = self._main_trace
1349 def acc(prefix):
1350 acc = getattr(trace, f'{prefix}_acc_count')
1351 prop = getattr(trace, f'{prefix}_prop_count')
1352 return acc.sum(axis=1) / prop.sum(axis=1)
1354 return acc('grow'), acc('prune')
1356 def avg_prop( 1ab
1357 self,
1358 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1359 """Compute the average proposal rate of grow and prune moves.
1361 Returns
1362 -------
1363 prop_grow : Float32[Array, 'mc_cores']
1364 The fraction of times grow was proposed instead of prune.
1365 prop_prune : Float32[Array, 'mc_cores']
1366 The fraction of times prune was proposed instead of grow.
1368 Notes
1369 -----
1370 This function does not take into account cases where no move was
1371 proposed.
1372 """
1373 trace = self._main_trace
1375 def prop(prefix):
1376 return getattr(trace, f'{prefix}_prop_count').sum(axis=1)
1378 pgrow = prop('grow')
1379 pprune = prop('prune')
1380 total = pgrow + pprune
1381 return pgrow / total, pprune / total
1383 def avg_move( 1ab
1384 self,
1385 ) -> tuple[Float32[Array, ' mc_cores'], Float32[Array, ' mc_cores']]:
1386 """Compute the move rate.
1388 Returns
1389 -------
1390 rate_grow : Float32[Array, 'mc_cores']
1391 The fraction of times a grow move was proposed and accepted.
1392 rate_prune : Float32[Array, 'mc_cores']
1393 The fraction of times a prune move was proposed and accepted.
1394 """
1395 agrow, aprune = self.avg_acc()
1396 pgrow, pprune = self.avg_prop()
1397 return agrow * pgrow, aprune * pprune
1399 def depth_distr(self) -> Float32[Array, 'mc_cores ndpost/mc_cores d']: 1ab
1400 """Histogram of tree depths for each state of the trees.
1402 Returns
1403 -------
1404 A matrix where each row contains a histogram of tree depths.
1405 """
1406 return chains_depth_distr(self._main_trace.split_tree) 1ab
1408 def points_per_decision_node_distr( 1ab
1409 self,
1410 ) -> Float32[Array, 'mc_cores ndpost/mc_cores n+1']:
1411 """Histogram of number of points belonging to parent-of-leaf nodes.
1413 Returns
1414 -------
1415 A matrix where each row contains a histogram of number of points.
1416 """
1417 return chains_points_per_decision_node_distr( 1ab
1418 self._main_trace, self._mcmc_state.X
1419 )
1421 def points_per_leaf_distr(self) -> Float32[Array, 'mc_cores ndpost/mc_cores n+1']: 1ab
1422 """Histogram of number of points belonging to leaves.
1424 Returns
1425 -------
1426 A matrix where each row contains a histogram of number of points.
1427 """
1428 return chains_points_per_leaf_distr(self._main_trace, self._mcmc_state.X) 1ab
1430 def check_trees(self) -> UInt[Array, 'mc_cores ndpost/mc_cores ntree']: 1ab
1431 """Apply `check_trace` to all the tree draws."""
1432 return check_chains(self._main_trace, self._mcmc_state.forest.max_split) 1ab
1434 def tree_goes_bad(self) -> Bool[Array, 'mc_cores ndpost/mc_cores ntree']: 1ab
1435 """Find iterations where a tree becomes invalid.
1437 Returns
1438 -------
1439 A where (i,j) is `True` if tree j is invalid at iteration i but not i-1.
1440 """
1441 bad = self.check_trees().astype(bool)
1442 bad_before = jnp.pad(bad[:, :-1, :], [(0, 0), (1, 0), (0, 0)])
1443 return bad & ~bad_before
1446class debug_gbart(debug_mc_gbart, gbart): 1ab
1447 """A subclass of `gbart` that adds debugging functionality.
1449 Parameters
1450 ----------
1451 *args
1452 Passed to `gbart`.
1453 check_trees
1454 If `True`, check all trees with `check_trace` after running the MCMC,
1455 and assert that they are all valid. Set to `False` to allow jax tracing.
1456 **kw
1457 Passed to `gbart`.
1458 """