Coverage for src/bartz/debug.py: 49%
111 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
1import functools 1ab
3import jax 1ab
4from jax import lax 1ab
5from jax import numpy as jnp 1ab
7from . import grove, jaxext 1ab
10def print_tree(leaf_tree, var_tree, split_tree, print_all=False): 1ab
11 tee = '├──'
12 corner = '└──'
13 join = '│ '
14 space = ' '
15 down = '┐'
16 bottom = '╢' # '┨' #
18 def traverse_tree(index, depth, indent, first_indent, next_indent, unused):
19 if index >= len(leaf_tree):
20 return
22 var = var_tree.at[index].get(mode='fill', fill_value=0)
23 split = split_tree.at[index].get(mode='fill', fill_value=0)
25 is_leaf = split == 0
26 left_child = 2 * index
27 right_child = 2 * index + 1
29 if print_all:
30 if unused:
31 category = 'unused'
32 elif is_leaf:
33 category = 'leaf'
34 else:
35 category = 'decision'
36 node_str = f'{category}({var}, {split}, {leaf_tree[index]})'
37 else:
38 assert not unused
39 if is_leaf:
40 node_str = f'{leaf_tree[index]:#.2g}'
41 else:
42 node_str = f'({var}: {split})'
44 if not is_leaf or (print_all and left_child < len(leaf_tree)):
45 link = down
46 elif not print_all and left_child >= len(leaf_tree):
47 link = bottom
48 else:
49 link = ' '
51 max_number = len(leaf_tree) - 1
52 ndigits = len(str(max_number))
53 number = str(index).rjust(ndigits)
55 print(f' {number} {indent}{first_indent}{link}{node_str}')
57 indent += next_indent
58 unused = unused or is_leaf
60 if unused and not print_all:
61 return
63 traverse_tree(left_child, depth + 1, indent, tee, join, unused)
64 traverse_tree(right_child, depth + 1, indent, corner, space, unused)
66 traverse_tree(1, 0, '', '', '', False)
69def tree_actual_depth(split_tree): 1ab
70 is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True)
71 depth = grove.tree_depths(is_leaf.size)
72 depth = jnp.where(is_leaf, depth, 0)
73 return jnp.max(depth)
76def forest_depth_distr(split_trees): 1ab
77 depth = grove.tree_depth(split_trees) + 1
78 depths = jax.vmap(tree_actual_depth)(split_trees)
79 return jnp.bincount(depths, length=depth)
82def trace_depth_distr(split_trees_trace): 1ab
83 return jax.vmap(forest_depth_distr)(split_trees_trace)
86def points_per_leaf_distr(var_tree, split_tree, X): 1ab
87 traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None)) 1ab
88 indices = traverse_tree(X, var_tree, split_tree) 1ab
89 count_tree = jnp.zeros( 1ab
90 2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)
91 )
92 count_tree = count_tree.at[indices].add(1) 1ab
93 is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8) 1ab
94 return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1) 1ab
97def forest_points_per_leaf_distr(bart, X): 1ab
98 distr = jnp.zeros(X.shape[1] + 1, int) 1ab
99 trees = bart['var_trees'], bart['split_trees'] 1ab
101 def loop(distr, tree): 1ab
102 return distr + points_per_leaf_distr(*tree, X), None 1ab
104 distr, _ = lax.scan(loop, distr, trees) 1ab
105 return distr 1ab
108def trace_points_per_leaf_distr(bart, X): 1ab
109 def loop(_, bart): 1ab
110 return None, forest_points_per_leaf_distr(bart, X) 1ab
112 _, distr = lax.scan(loop, None, bart) 1ab
113 return distr 1ab
116def check_types(leaf_tree, var_tree, split_tree, max_split): 1ab
117 expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1) 1ab
118 expected_split_dtype = max_split.dtype 1ab
119 return ( 1ab
120 var_tree.dtype == expected_var_dtype
121 and split_tree.dtype == expected_split_dtype
122 )
125def check_sizes(leaf_tree, var_tree, split_tree, max_split): 1ab
126 return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size 1ab
129def check_unused_node(leaf_tree, var_tree, split_tree, max_split): 1ab
130 return (var_tree[0] == 0) & (split_tree[0] == 0) 1ab
133def check_leaf_values(leaf_tree, var_tree, split_tree, max_split): 1ab
134 return jnp.all(jnp.isfinite(leaf_tree)) 1ab
137def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split): 1ab
138 index = jnp.arange( 1ab
139 2 * split_tree.size,
140 dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1),
141 )
142 parent_index = index >> 1 1ab
143 is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0 1ab
144 parent_is_leaf = split_tree[parent_index] == 0 1ab
145 stray = is_not_leaf & parent_is_leaf 1ab
146 stray = stray.at[1].set(False) 1ab
147 return ~jnp.any(stray) 1ab
150check_functions = [ 1ab
151 check_types,
152 check_sizes,
153 check_unused_node,
154 check_leaf_values,
155 check_stray_nodes,
156]
159def check_tree(leaf_tree, var_tree, split_tree, max_split): 1ab
160 error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1) 1ab
161 error = error_type(0) 1ab
162 for i, func in enumerate(check_functions): 1ab
163 ok = func(leaf_tree, var_tree, split_tree, max_split) 1ab
164 ok = jnp.bool_(ok) 1ab
165 bit = (~ok) << i 1ab
166 error |= bit 1ab
167 return error 1ab
170def describe_error(error): 1ab
171 return [func.__name__ for i, func in enumerate(check_functions) if error & (1 << i)]
174check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None)) 1ab
177@functools.partial(jax.vmap, in_axes=(0, None)) 1ab
178def check_trace(trace, state): 1ab
179 return check_forest( 1ab
180 trace['leaf_trees'],
181 trace['var_trees'],
182 trace['split_trees'],
183 state.max_split,
184 )