Coverage for src/bartz/debug.py: 50%
113 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
1import functools 1a
3import jax 1a
4from jax import numpy as jnp 1a
5from jax import lax 1a
7from . import grove 1a
8from . import mcmcstep 1a
9from . import jaxext 1a
11def print_tree(leaf_tree, var_tree, split_tree, print_all=False): 1a
13 tee = '├──'
14 corner = '└──'
15 join = '│ '
16 space = ' '
17 down = '┐'
18 bottom = '╢' # '┨' #
20 def traverse_tree(index, depth, indent, first_indent, next_indent, unused):
21 if index >= len(leaf_tree):
22 return
24 var = var_tree.at[index].get(mode='fill', fill_value=0)
25 split = split_tree.at[index].get(mode='fill', fill_value=0)
27 is_leaf = split == 0
28 left_child = 2 * index
29 right_child = 2 * index + 1
31 if print_all:
32 if unused:
33 category = 'unused'
34 elif is_leaf:
35 category = 'leaf'
36 else:
37 category = 'decision'
38 node_str = f'{category}({var}, {split}, {leaf_tree[index]})'
39 else:
40 assert not unused
41 if is_leaf:
42 node_str = f'{leaf_tree[index]:#.2g}'
43 else:
44 node_str = f'({var}: {split})'
46 if not is_leaf or (print_all and left_child < len(leaf_tree)):
47 link = down
48 elif not print_all and left_child >= len(leaf_tree):
49 link = bottom
50 else:
51 link = ' '
53 max_number = len(leaf_tree) - 1
54 ndigits = len(str(max_number))
55 number = str(index).rjust(ndigits)
57 print(f' {number} {indent}{first_indent}{link}{node_str}')
59 indent += next_indent
60 unused = unused or is_leaf
62 if unused and not print_all:
63 return
65 traverse_tree(left_child, depth + 1, indent, tee, join, unused)
66 traverse_tree(right_child, depth + 1, indent, corner, space, unused)
68 traverse_tree(1, 0, '', '', '', False)
70def tree_actual_depth(split_tree): 1a
71 is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True)
72 depth = grove.tree_depths(is_leaf.size)
73 depth = jnp.where(is_leaf, depth, 0)
74 return jnp.max(depth)
76def forest_depth_distr(split_trees): 1a
77 depth = grove.tree_depth(split_trees) + 1
78 depths = jax.vmap(tree_actual_depth)(split_trees)
79 return jnp.bincount(depths, length=depth)
81def trace_depth_distr(split_trees_trace): 1a
82 return jax.vmap(forest_depth_distr)(split_trees_trace)
84def points_per_leaf_distr(var_tree, split_tree, X): 1a
85 traverse_tree = jax.vmap(grove.traverse_tree, in_axes=(1, None, None)) 1a
86 indices = traverse_tree(X, var_tree, split_tree) 1a
87 count_tree = jnp.zeros(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(indices.size)) 1a
88 count_tree = count_tree.at[indices].add(1) 1a
89 is_leaf = grove.is_actual_leaf(split_tree, add_bottom_level=True).view(jnp.uint8) 1a
90 return jnp.bincount(count_tree, is_leaf, length=X.shape[1] + 1) 1a
92def forest_points_per_leaf_distr(bart, X): 1a
93 distr = jnp.zeros(X.shape[1] + 1, int) 1a
94 trees = bart['var_trees'], bart['split_trees'] 1a
95 def loop(distr, tree): 1a
96 return distr + points_per_leaf_distr(*tree, X), None 1a
97 distr, _ = lax.scan(loop, distr, trees) 1a
98 return distr 1a
100def trace_points_per_leaf_distr(bart, X): 1a
101 def loop(_, bart): 1a
102 return None, forest_points_per_leaf_distr(bart, X) 1a
103 _, distr = lax.scan(loop, None, bart) 1a
104 return distr 1a
106def check_types(leaf_tree, var_tree, split_tree, max_split): 1a
107 expected_var_dtype = jaxext.minimal_unsigned_dtype(max_split.size - 1) 1a
108 expected_split_dtype = max_split.dtype 1a
109 return var_tree.dtype == expected_var_dtype and split_tree.dtype == expected_split_dtype 1a
111def check_sizes(leaf_tree, var_tree, split_tree, max_split): 1a
112 return leaf_tree.size == 2 * var_tree.size == 2 * split_tree.size 1a
114def check_unused_node(leaf_tree, var_tree, split_tree, max_split): 1a
115 return (var_tree[0] == 0) & (split_tree[0] == 0) 1a
117def check_leaf_values(leaf_tree, var_tree, split_tree, max_split): 1a
118 return jnp.all(jnp.isfinite(leaf_tree)) 1a
120def check_stray_nodes(leaf_tree, var_tree, split_tree, max_split): 1a
121 index = jnp.arange(2 * split_tree.size, dtype=jaxext.minimal_unsigned_dtype(2 * split_tree.size - 1)) 1a
122 parent_index = index >> 1 1a
123 is_not_leaf = split_tree.at[index].get(mode='fill', fill_value=0) != 0 1a
124 parent_is_leaf = split_tree[parent_index] == 0 1a
125 stray = is_not_leaf & parent_is_leaf 1a
126 stray = stray.at[1].set(False) 1a
127 return ~jnp.any(stray) 1a
129check_functions = [ 1a
130 check_types,
131 check_sizes,
132 check_unused_node,
133 check_leaf_values,
134 check_stray_nodes,
135]
137def check_tree(leaf_tree, var_tree, split_tree, max_split): 1a
138 error_type = jaxext.minimal_unsigned_dtype(2 ** len(check_functions) - 1) 1a
139 error = error_type(0) 1a
140 for i, func in enumerate(check_functions): 1a
141 ok = func(leaf_tree, var_tree, split_tree, max_split) 1a
142 ok = jnp.bool_(ok) 1a
143 bit = (~ok) << i 1a
144 error |= bit 1a
145 return error 1a
147def describe_error(error): 1a
148 return [
149 func.__name__
150 for i, func in enumerate(check_functions)
151 if error & (1 << i)
152 ]
154check_forest = jax.vmap(check_tree, in_axes=(0, 0, 0, None)) 1a
156@functools.partial(jax.vmap, in_axes=(0, None)) 1a
157def check_trace(trace, state): 1a
158 return check_forest(trace['leaf_trees'], trace['var_trees'], trace['split_trees'], state['max_split']) 1a